From f58b21746e6ce7fb662f2e38099918da7287906a Mon Sep 17 00:00:00 2001 From: patriceac <48073125+patriceac@users.noreply.github.com> Date: Mon, 13 Feb 2023 17:42:36 -0800 Subject: [PATCH 01/27] Removing the 'None' option for face correction As per conversation : https://discord.com/channels/1014774730907209781/1014780368890630164/1074802779471757405 --- ui/media/js/main.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 633bb0d7..72e3901a 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -33,7 +33,7 @@ let promptStrengthField = document.querySelector('#prompt_strength') let samplerField = document.querySelector('#sampler_name') let samplerSelectionContainer = document.querySelector("#samplerSelection") let useFaceCorrectionField = document.querySelector("#use_face_correction") -let gfpganModelField = new ModelDropdown(document.querySelector("#gfpgan_model"), 'gfpgan', 'None') +let gfpganModelField = new ModelDropdown(document.querySelector("#gfpgan_model"), 'gfpgan') let useUpscalingField = document.querySelector("#use_upscale") let upscaleModelField = document.querySelector("#upscale_model") let upscaleAmountField = document.querySelector("#upscale_amount") From 9799309db9d9825a55e98a5742fd6a176e6ecba3 Mon Sep 17 00:00:00 2001 From: patriceac <48073125+patriceac@users.noreply.github.com> Date: Tue, 14 Feb 2023 02:31:13 -0800 Subject: [PATCH 02/27] Fix reloading of tasks with no file path In some conditions tasks may be reloaded with an empty file path (e.g. no face correction) --- ui/media/js/dnd.js | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index 69b24b96..4f4fc22a 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -375,6 +375,10 @@ function readUI() { } function getModelPath(filename, extensions) { + if (filename === null) { + return + } + let pathIdx if (filename.includes('/models/stable-diffusion/')) { pathIdx = filename.indexOf('/models/stable-diffusion/') + 25 // Linux, Mac paths From 2eb317c6b6c583c57a9653c0fae0604564906e12 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 14 Feb 2023 18:47:50 +0530 Subject: [PATCH 03/27] Format code, PEP8 using Black --- ui/easydiffusion/app.py | 138 +++++++------ ui/easydiffusion/device_manager.py | 134 +++++++----- ui/easydiffusion/model_manager.py | 159 ++++++++------ ui/easydiffusion/renderer.py | 80 +++++--- ui/easydiffusion/server.py | 235 ++++++++++++--------- ui/easydiffusion/task_manager.py | 296 +++++++++++++++++---------- ui/easydiffusion/types.py | 24 ++- ui/easydiffusion/utils/__init__.py | 4 +- ui/easydiffusion/utils/save_utils.py | 133 +++++++----- 9 files changed, 728 insertions(+), 475 deletions(-) diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index 4369a488..d556dd6f 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -7,7 +7,7 @@ import logging import shlex from rich.logging import RichHandler -from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config +from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config from easydiffusion import task_manager from easydiffusion.utils import log @@ -16,83 +16,86 @@ from easydiffusion.utils import log for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) -LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s' +LOG_FORMAT = "%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s" logging.basicConfig( - level=logging.INFO, - format=LOG_FORMAT, - datefmt="%X", - handlers=[RichHandler(markup=True, rich_tracebacks=False, show_time=False, show_level=False)], + level=logging.INFO, + format=LOG_FORMAT, + datefmt="%X", + handlers=[RichHandler(markup=True, rich_tracebacks=False, show_time=False, show_level=False)], ) SD_DIR = os.getcwd() -SD_UI_DIR = os.getenv('SD_UI_PATH', None) +SD_UI_DIR = os.getenv("SD_UI_PATH", None) sys.path.append(os.path.dirname(SD_UI_DIR)) -CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) -MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) +CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts")) +MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models")) -USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui')) -CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui')) -UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user')) +USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins", "ui")) +CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins", "ui")) +UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, "core"), (USER_UI_PLUGINS_DIR, "user")) -OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder -PRESERVE_CONFIG_VARS = ['FORCE_FULL_PRECISION'] -TASK_TTL = 15 * 60 # Discard last session's task timeout +OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder +PRESERVE_CONFIG_VARS = ["FORCE_FULL_PRECISION"] +TASK_TTL = 15 * 60 # Discard last session's task timeout APP_CONFIG_DEFAULTS = { # auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device. - 'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index) - 'update_branch': 'main', - 'ui': { - 'open_browser_on_start': True, - } + "render_devices": "auto", # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index) + "update_branch": "main", + "ui": { + "open_browser_on_start": True, + }, } + def init(): os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) update_render_threads() + def getConfig(default_val=APP_CONFIG_DEFAULTS): try: - config_json_path = os.path.join(CONFIG_DIR, 'config.json') + config_json_path = os.path.join(CONFIG_DIR, "config.json") if not os.path.exists(config_json_path): config = default_val else: - with open(config_json_path, 'r', encoding='utf-8') as f: + with open(config_json_path, "r", encoding="utf-8") as f: config = json.load(f) - if 'net' not in config: - config['net'] = {} - if os.getenv('SD_UI_BIND_PORT') is not None: - config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT')) + if "net" not in config: + config["net"] = {} + if os.getenv("SD_UI_BIND_PORT") is not None: + config["net"]["listen_port"] = int(os.getenv("SD_UI_BIND_PORT")) else: - config['net']['listen_port'] = 9000 - if os.getenv('SD_UI_BIND_IP') is not None: - config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0') + config["net"]["listen_port"] = 9000 + if os.getenv("SD_UI_BIND_IP") is not None: + config["net"]["listen_to_network"] = os.getenv("SD_UI_BIND_IP") == "0.0.0.0" else: - config['net']['listen_to_network'] = True + config["net"]["listen_to_network"] = True return config except Exception as e: log.warn(traceback.format_exc()) return default_val + def setConfig(config): - try: # config.json - config_json_path = os.path.join(CONFIG_DIR, 'config.json') - with open(config_json_path, 'w', encoding='utf-8') as f: + try: # config.json + config_json_path = os.path.join(CONFIG_DIR, "config.json") + with open(config_json_path, "w", encoding="utf-8") as f: json.dump(config, f) except: log.error(traceback.format_exc()) - try: # config.bat - config_bat_path = os.path.join(CONFIG_DIR, 'config.bat') + try: # config.bat + config_bat_path = os.path.join(CONFIG_DIR, "config.bat") config_bat = [] - if 'update_branch' in config: + if "update_branch" in config: config_bat.append(f"@set update_branch={config['update_branch']}") config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}") - bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' + bind_ip = "0.0.0.0" if config["net"]["listen_to_network"] else "127.0.0.1" config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") # Preserve these variables if they are set @@ -101,20 +104,20 @@ def setConfig(config): config_bat.append(f"@set {var}={os.getenv(var)}") if len(config_bat) > 0: - with open(config_bat_path, 'w', encoding='utf-8') as f: - f.write('\r\n'.join(config_bat)) + with open(config_bat_path, "w", encoding="utf-8") as f: + f.write("\r\n".join(config_bat)) except: log.error(traceback.format_exc()) - try: # config.sh - config_sh_path = os.path.join(CONFIG_DIR, 'config.sh') - config_sh = ['#!/bin/bash'] + try: # config.sh + config_sh_path = os.path.join(CONFIG_DIR, "config.sh") + config_sh = ["#!/bin/bash"] - if 'update_branch' in config: + if "update_branch" in config: config_sh.append(f"export update_branch={config['update_branch']}") config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}") - bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' + bind_ip = "0.0.0.0" if config["net"]["listen_to_network"] else "127.0.0.1" config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") # Preserve these variables if they are set @@ -123,47 +126,51 @@ def setConfig(config): config_bat.append(f'export {var}="{shlex.quote(os.getenv(var))}"') if len(config_sh) > 1: - with open(config_sh_path, 'w', encoding='utf-8') as f: - f.write('\n'.join(config_sh)) + with open(config_sh_path, "w", encoding="utf-8") as f: + f.write("\n".join(config_sh)) except: log.error(traceback.format_exc()) + def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level): config = getConfig() - if 'model' not in config: - config['model'] = {} + if "model" not in config: + config["model"] = {} - config['model']['stable-diffusion'] = ckpt_model_name - config['model']['vae'] = vae_model_name - config['model']['hypernetwork'] = hypernetwork_model_name + config["model"]["stable-diffusion"] = ckpt_model_name + config["model"]["vae"] = vae_model_name + config["model"]["hypernetwork"] = hypernetwork_model_name if vae_model_name is None or vae_model_name == "": - del config['model']['vae'] + del config["model"]["vae"] if hypernetwork_model_name is None or hypernetwork_model_name == "": - del config['model']['hypernetwork'] + del config["model"]["hypernetwork"] - config['vram_usage_level'] = vram_usage_level + config["vram_usage_level"] = vram_usage_level setConfig(config) + def update_render_threads(): config = getConfig() - render_devices = config.get('render_devices', 'auto') - active_devices = task_manager.get_devices()['active'].keys() + render_devices = config.get("render_devices", "auto") + active_devices = task_manager.get_devices()["active"].keys() - log.debug(f'requesting for render_devices: {render_devices}') + log.debug(f"requesting for render_devices: {render_devices}") task_manager.update_render_threads(render_devices, active_devices) + def getUIPlugins(): plugins = [] for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES: for file in os.listdir(plugins_dir): - if file.endswith('.plugin.js'): - plugins.append(f'/plugins/{dir_prefix}/{file}') + if file.endswith(".plugin.js"): + plugins.append(f"/plugins/{dir_prefix}/{file}") return plugins + def getIPConfig(): try: ips = socket.gethostbyname_ex(socket.gethostname()) @@ -173,10 +180,13 @@ def getIPConfig(): log.exception(e) return [] + def open_browser(): config = getConfig() - ui = config.get('ui', {}) - net = config.get('net', {'listen_port':9000}) - port = net.get('listen_port', 9000) - if ui.get('open_browser_on_start', True): - import webbrowser; webbrowser.open(f"http://localhost:{port}") + ui = config.get("ui", {}) + net = config.get("net", {"listen_port": 9000}) + port = net.get("listen_port", 9000) + if ui.get("open_browser_on_start", True): + import webbrowser + + webbrowser.open(f"http://localhost:{port}") diff --git a/ui/easydiffusion/device_manager.py b/ui/easydiffusion/device_manager.py index 786512ea..69086577 100644 --- a/ui/easydiffusion/device_manager.py +++ b/ui/easydiffusion/device_manager.py @@ -5,45 +5,54 @@ import re from easydiffusion.utils import log -''' +""" Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32). Otherwise the models will load at half-precision (i.e. float16). Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs). -''' +""" -COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked +COMPARABLE_GPU_PERCENTILE = ( + 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked +) mem_free_threshold = 0 + def get_device_delta(render_devices, active_devices): - ''' + """ render_devices: 'cpu', or 'auto' or ['cuda:N'...] active_devices: ['cpu', 'cuda:N'...] - ''' + """ - if render_devices in ('cpu', 'auto'): + if render_devices in ("cpu", "auto"): render_devices = [render_devices] elif render_devices is not None: if isinstance(render_devices, str): render_devices = [render_devices] if isinstance(render_devices, list) and len(render_devices) > 0: - render_devices = list(filter(lambda x: x.startswith('cuda:'), render_devices)) + render_devices = list(filter(lambda x: x.startswith("cuda:"), render_devices)) if len(render_devices) == 0: - raise Exception('Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}') + raise Exception( + 'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}' + ) render_devices = list(filter(lambda x: is_device_compatible(x), render_devices)) if len(render_devices) == 0: - raise Exception('Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion') + raise Exception( + "Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion" + ) else: - raise Exception('Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}') + raise Exception( + 'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}' + ) else: - render_devices = ['auto'] + render_devices = ["auto"] - if 'auto' in render_devices: + if "auto" in render_devices: render_devices = auto_pick_devices(active_devices) - if 'cpu' in render_devices: - log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!') + if "cpu" in render_devices: + log.warn("WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!") active_devices = set(active_devices) render_devices = set(render_devices) @@ -53,19 +62,21 @@ def get_device_delta(render_devices, active_devices): return devices_to_start, devices_to_stop + def auto_pick_devices(currently_active_devices): global mem_free_threshold - if not torch.cuda.is_available(): return ['cpu'] + if not torch.cuda.is_available(): + return ["cpu"] device_count = torch.cuda.device_count() if device_count == 1: - return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu'] + return ["cuda:0"] if is_device_compatible("cuda:0") else ["cpu"] - log.debug('Autoselecting GPU. Using most free memory.') + log.debug("Autoselecting GPU. Using most free memory.") devices = [] for device in range(device_count): - device = f'cuda:{device}' + device = f"cuda:{device}" if not is_device_compatible(device): continue @@ -73,11 +84,13 @@ def auto_pick_devices(currently_active_devices): mem_free /= float(10**9) mem_total /= float(10**9) device_name = torch.cuda.get_device_name(device) - log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb') - devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free}) + log.debug( + f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb" + ) + devices.append({"device": device, "device_name": device_name, "mem_free": mem_free}) - devices.sort(key=lambda x:x['mem_free'], reverse=True) - max_mem_free = devices[0]['mem_free'] + devices.sort(key=lambda x: x["mem_free"], reverse=True) + max_mem_free = devices[0]["mem_free"] curr_mem_free_threshold = COMPARABLE_GPU_PERCENTILE * max_mem_free mem_free_threshold = max(curr_mem_free_threshold, mem_free_threshold) @@ -87,23 +100,26 @@ def auto_pick_devices(currently_active_devices): # always be very low (since their VRAM contains the model). # These already-running devices probably aren't terrible, since they were picked in the past. # Worst case, the user can restart the program and that'll get rid of them. - devices = list(filter((lambda x: x['mem_free'] > mem_free_threshold or x['device'] in currently_active_devices), devices)) - devices = list(map(lambda x: x['device'], devices)) + devices = list( + filter((lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices), devices) + ) + devices = list(map(lambda x: x["device"], devices)) return devices + def device_init(context, device): - ''' + """ This function assumes the 'device' has already been verified to be compatible. `get_device_delta()` has already filtered out incompatible devices. - ''' + """ - validate_device_id(device, log_prefix='device_init') + validate_device_id(device, log_prefix="device_init") - if device == 'cpu': - context.device = 'cpu' + if device == "cpu": + context.device = "cpu" context.device_name = get_processor_name() context.half_precision = False - log.debug(f'Render device CPU available as {context.device_name}') + log.debug(f"Render device CPU available as {context.device_name}") return context.device_name = torch.cuda.get_device_name(device) @@ -111,7 +127,7 @@ def device_init(context, device): # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images if needs_to_force_full_precision(context): - log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') + log.warn(f"forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}") # Apply force_full_precision now before models are loaded. context.half_precision = False @@ -120,72 +136,90 @@ def device_init(context, device): return + def needs_to_force_full_precision(context): - if 'FORCE_FULL_PRECISION' in os.environ: + if "FORCE_FULL_PRECISION" in os.environ: return True device_name = context.device_name.lower() - return (('nvidia' in device_name or 'geforce' in device_name or 'quadro' in device_name) and (' 1660' in device_name or ' 1650' in device_name or ' t400' in device_name or ' t550' in device_name or ' t600' in device_name or ' t1000' in device_name or ' t1200' in device_name or ' t2000' in device_name)) + return ("nvidia" in device_name or "geforce" in device_name or "quadro" in device_name) and ( + " 1660" in device_name + or " 1650" in device_name + or " t400" in device_name + or " t550" in device_name + or " t600" in device_name + or " t1000" in device_name + or " t1200" in device_name + or " t2000" in device_name + ) + def get_max_vram_usage_level(device): - if device != 'cpu': + if device != "cpu": _, mem_total = torch.cuda.mem_get_info(device) mem_total /= float(10**9) if mem_total < 4.5: - return 'low' + return "low" elif mem_total < 6.5: - return 'balanced' + return "balanced" - return 'high' + return "high" -def validate_device_id(device, log_prefix=''): + +def validate_device_id(device, log_prefix=""): def is_valid(): if not isinstance(device, str): return False - if device == 'cpu': + if device == "cpu": return True - if not device.startswith('cuda:') or not device[5:].isnumeric(): + if not device.startswith("cuda:") or not device[5:].isnumeric(): return False return True if not is_valid(): - raise EnvironmentError(f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}") + raise EnvironmentError( + f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}" + ) + def is_device_compatible(device): - ''' + """ Returns True/False, and prints any compatibility errors - ''' - # static variable "history". - is_device_compatible.history = getattr(is_device_compatible, 'history', {}) + """ + # static variable "history". + is_device_compatible.history = getattr(is_device_compatible, "history", {}) try: - validate_device_id(device, log_prefix='is_device_compatible') + validate_device_id(device, log_prefix="is_device_compatible") except: log.error(str(e)) return False - if device == 'cpu': return True + if device == "cpu": + return True # Memory check try: _, mem_total = torch.cuda.mem_get_info(device) mem_total /= float(10**9) if mem_total < 3.0: if is_device_compatible.history.get(device) == None: - log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion') - is_device_compatible.history[device] = 1 + log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion") + is_device_compatible.history[device] = 1 return False except RuntimeError as e: log.error(str(e)) return False return True + def get_processor_name(): try: import platform, subprocess + if platform.system() == "Windows": return platform.processor() elif platform.system() == "Darwin": - os.environ['PATH'] = os.environ['PATH'] + os.pathsep + '/usr/sbin' + os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin" command = "sysctl -n machdep.cpu.brand_string" return subprocess.check_output(command).strip() elif platform.system() == "Linux": diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 9dd928e1..fb380695 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -8,29 +8,31 @@ from sdkit import Context from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model from sdkit.utils import hash_file_quick -KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] +KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan"] MODEL_EXTENSIONS = { - 'stable-diffusion': ['.ckpt', '.safetensors'], - 'vae': ['.vae.pt', '.ckpt', '.safetensors'], - 'hypernetwork': ['.pt', '.safetensors'], - 'gfpgan': ['.pth'], - 'realesrgan': ['.pth'], + "stable-diffusion": [".ckpt", ".safetensors"], + "vae": [".vae.pt", ".ckpt", ".safetensors"], + "hypernetwork": [".pt", ".safetensors"], + "gfpgan": [".pth"], + "realesrgan": [".pth"], } DEFAULT_MODELS = { - 'stable-diffusion': [ # needed to support the legacy installations - 'custom-model', # only one custom model file was supported initially, creatively named 'custom-model' - 'sd-v1-4', # Default fallback. + "stable-diffusion": [ # needed to support the legacy installations + "custom-model", # only one custom model file was supported initially, creatively named 'custom-model' + "sd-v1-4", # Default fallback. ], - 'gfpgan': ['GFPGANv1.3'], - 'realesrgan': ['RealESRGAN_x4plus'], + "gfpgan": ["GFPGANv1.3"], + "realesrgan": ["RealESRGAN_x4plus"], } -MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork'] +MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork"] known_models = {} + def init(): make_model_folders() - getModels() # run this once, to cache the picklescan results + getModels() # run this once, to cache the picklescan results + def load_default_models(context: Context): set_vram_optimizations(context) @@ -39,27 +41,28 @@ def load_default_models(context: Context): for model_type in MODELS_TO_LOAD_ON_START: context.model_paths[model_type] = resolve_model_to_use(model_type=model_type) try: - load_model(context, model_type) + load_model(context, model_type) except Exception as e: - log.error(f'[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]') - log.error(f'[red]Error: {e}[/red]') - log.error(f'[red]Consider removing the model from the model folder.[red]') + log.error(f"[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]") + log.error(f"[red]Error: {e}[/red]") + log.error(f"[red]Consider removing the model from the model folder.[red]") def unload_all(context: Context): for model_type in KNOWN_MODEL_TYPES: unload_model(context, model_type) -def resolve_model_to_use(model_name:str=None, model_type:str=None): + +def resolve_model_to_use(model_name: str = None, model_type: str = None): model_extensions = MODEL_EXTENSIONS.get(model_type, []) default_models = DEFAULT_MODELS.get(model_type, []) config = app.getConfig() model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR] - if not model_name: # When None try user configured model. + if not model_name: # When None try user configured model. # config = getConfig() - if 'model' in config and model_type in config['model']: - model_name = config['model'][model_type] + if "model" in config and model_type in config["model"]: + model_name = config["model"][model_type] if model_name: # Check models directory @@ -84,41 +87,54 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None): for model_extension in model_extensions: if os.path.exists(default_model_path + model_extension): if model_name is not None: - log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') + log.warn( + f"Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}" + ) return default_model_path + model_extension return None + def reload_models_if_necessary(context: Context, task_data: TaskData): model_paths_in_req = { - 'stable-diffusion': task_data.use_stable_diffusion_model, - 'vae': task_data.use_vae_model, - 'hypernetwork': task_data.use_hypernetwork_model, - 'gfpgan': task_data.use_face_correction, - 'realesrgan': task_data.use_upscale, + "stable-diffusion": task_data.use_stable_diffusion_model, + "vae": task_data.use_vae_model, + "hypernetwork": task_data.use_hypernetwork_model, + "gfpgan": task_data.use_face_correction, + "realesrgan": task_data.use_upscale, + } + models_to_reload = { + model_type: path + for model_type, path in model_paths_in_req.items() + if context.model_paths.get(model_type) != path } - models_to_reload = {model_type: path for model_type, path in model_paths_in_req.items() if context.model_paths.get(model_type) != path} - if set_vram_optimizations(context): # reload SD - models_to_reload['stable-diffusion'] = model_paths_in_req['stable-diffusion'] + if set_vram_optimizations(context): # reload SD + models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"] for model_type, model_path_in_req in models_to_reload.items(): context.model_paths[model_type] = model_path_in_req action_fn = unload_model if context.model_paths[model_type] is None else load_model - action_fn(context, model_type, scan_model=False) # we've scanned them already + action_fn(context, model_type, scan_model=False) # we've scanned them already + def resolve_model_paths(task_data: TaskData): - task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion') - task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae') - task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork') + task_data.use_stable_diffusion_model = resolve_model_to_use( + task_data.use_stable_diffusion_model, model_type="stable-diffusion" + ) + task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae") + task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork") + + if task_data.use_face_correction: + task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, "gfpgan") + if task_data.use_upscale: + task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, "realesrgan") - if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan') - if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan') def set_vram_optimizations(context: Context): config = app.getConfig() - vram_usage_level = config.get('vram_usage_level', 'balanced') + vram_usage_level = config.get("vram_usage_level", "balanced") if vram_usage_level != context.vram_usage_level: context.vram_usage_level = vram_usage_level @@ -126,42 +142,51 @@ def set_vram_optimizations(context: Context): return False + def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: model_dir_path = os.path.join(app.MODELS_DIR, model_type) os.makedirs(model_dir_path, exist_ok=True) - help_file_name = f'Place your {model_type} model files here.txt' + help_file_name = f"Place your {model_type} model files here.txt" help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}' - with open(os.path.join(model_dir_path, help_file_name), 'w', encoding='utf-8') as f: + with open(os.path.join(model_dir_path, help_file_name), "w", encoding="utf-8") as f: f.write(help_file_contents) + def is_malicious_model(file_path): try: scan_result = scan_model(file_path) if scan_result.issues_count > 0 or scan_result.infected_files > 0: - log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + log.warn( + ":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" + % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files) + ) return True else: - log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + log.debug( + "Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" + % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files) + ) return False except Exception as e: - log.error(f'error while scanning: {file_path}, error: {e}') + log.error(f"error while scanning: {file_path}, error: {e}") return False + def getModels(): models = { - 'active': { - 'stable-diffusion': 'sd-v1-4', - 'vae': '', - 'hypernetwork': '', + "active": { + "stable-diffusion": "sd-v1-4", + "vae": "", + "hypernetwork": "", }, - 'options': { - 'stable-diffusion': ['sd-v1-4'], - 'vae': [], - 'hypernetwork': [], + "options": { + "stable-diffusion": ["sd-v1-4"], + "vae": [], + "hypernetwork": [], }, } @@ -171,13 +196,16 @@ def getModels(): "Raised when picklescan reports a problem with a model" pass - def scan_directory(directory, suffixes, directoriesFirst:bool=True): + def scan_directory(directory, suffixes, directoriesFirst: bool = True): nonlocal models_scanned tree = [] - for entry in sorted(os.scandir(directory), key = lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower())): + for entry in sorted( + os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()) + ): if entry.is_file(): matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes)) - if len(matching_suffix) == 0: continue + if len(matching_suffix) == 0: + continue matching_suffix = matching_suffix[0] mtime = entry.stat().st_mtime @@ -187,12 +215,12 @@ def getModels(): if is_malicious_model(entry.path): raise MaliciousModelException(entry.path) known_models[entry.path] = mtime - tree.append(entry.name[:-len(matching_suffix)]) + tree.append(entry.name[: -len(matching_suffix)]) elif entry.is_dir(): - scan=scan_directory(entry.path, suffixes, directoriesFirst=False) + scan = scan_directory(entry.path, suffixes, directoriesFirst=False) if len(scan) != 0: - tree.append( (entry.name, scan ) ) + tree.append((entry.name, scan)) return tree def listModels(model_type): @@ -204,21 +232,22 @@ def getModels(): os.makedirs(models_dir) try: - models['options'][model_type] = scan_directory(models_dir, model_extensions) + models["options"][model_type] = scan_directory(models_dir, model_extensions) except MaliciousModelException as e: - models['scan-error'] = e + models["scan-error"] = e # custom models - listModels(model_type='stable-diffusion') - listModels(model_type='vae') - listModels(model_type='hypernetwork') - listModels(model_type='gfpgan') + listModels(model_type="stable-diffusion") + listModels(model_type="vae") + listModels(model_type="hypernetwork") + listModels(model_type="gfpgan") - if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. Nothing infected[/]') + if models_scanned > 0: + log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") # legacy - custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') + custom_weight_path = os.path.join(app.SD_DIR, "custom-model.ckpt") if os.path.exists(custom_weight_path): - models['options']['stable-diffusion'].append('custom-model') + models["options"]["stable-diffusion"].append("custom-model") return models diff --git a/ui/easydiffusion/renderer.py b/ui/easydiffusion/renderer.py index 672614ba..3b8adaa7 100644 --- a/ui/easydiffusion/renderer.py +++ b/ui/easydiffusion/renderer.py @@ -12,22 +12,26 @@ from sdkit.generate import generate_images from sdkit.filter import apply_filters from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc -context = Context() # thread-local -''' +context = Context() # thread-local +""" runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc -''' +""" + def init(device): - ''' + """ Initializes the fields that will be bound to this runtime's context, and sets the current torch device - ''' + """ context.stop_processing = False context.temp_images = {} context.partial_x_samples = None device_manager.device_init(context, device) -def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): + +def make_images( + req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback +): context.stop_processing = False print_task_info(req, task_data) @@ -36,18 +40,24 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu res = Response(req, task_data, images=construct_response(images, seeds, task_data, base_seed=req.seed)) res = res.json() data_queue.put(json.dumps(res)) - log.info('Task completed') + log.info("Task completed") return res -def print_task_info(req: GenerateImageRequest, task_data: TaskData): - req_str = pprint.pformat(get_printable_request(req)).replace("[","\[") - task_str = pprint.pformat(task_data.dict()).replace("[","\[") - log.info(f'request: {req_str}') - log.info(f'task data: {task_str}') -def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): - images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) +def print_task_info(req: GenerateImageRequest, task_data: TaskData): + req_str = pprint.pformat(get_printable_request(req)).replace("[", "\[") + task_str = pprint.pformat(task_data.dict()).replace("[", "\[") + log.info(f"request: {req_str}") + log.info(f"task data: {task_str}") + + +def make_images_internal( + req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback +): + images, user_stopped = generate_images_internal( + req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress + ) filtered_images = filter_images(task_data, images, user_stopped) if task_data.save_to_disk_path is not None: @@ -59,13 +69,22 @@ def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_qu else: return images + filtered_images, seeds + seeds -def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): + +def generate_images_internal( + req: GenerateImageRequest, + task_data: TaskData, + data_queue: queue.Queue, + task_temp_images: list, + step_callback, + stream_image_progress: bool, +): context.temp_images.clear() callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress) try: - if req.init_image is not None: req.sampler_name = 'ddim' + if req.init_image is not None: + req.sampler_name = "ddim" images = generate_images(context, callback=callback, **req.dict()) user_stopped = False @@ -75,31 +94,44 @@ def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, dat if context.partial_x_samples is not None: images = latent_samples_to_images(context, context.partial_x_samples) finally: - if hasattr(context, 'partial_x_samples') and context.partial_x_samples is not None: + if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None: del context.partial_x_samples context.partial_x_samples = None return images, user_stopped + def filter_images(task_data: TaskData, images: list, user_stopped): if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None): return images filters_to_apply = [] - if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan') - if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan') + if task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower(): + filters_to_apply.append("gfpgan") + if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower(): + filters_to_apply.append("realesrgan") return apply_filters(context, filters_to_apply, images, scale=task_data.upscale_amount) + def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int): return [ ResponseImage( data=img_to_base64_str(img, task_data.output_format, task_data.output_quality), seed=seed, - ) for img, seed in zip(images, seeds) + ) + for img, seed in zip(images, seeds) ] -def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): + +def make_step_callback( + req: GenerateImageRequest, + task_data: TaskData, + data_queue: queue.Queue, + task_temp_images: list, + step_callback, + stream_image_progress: bool, +): n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) last_callback_time = -1 @@ -107,11 +139,11 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu partial_images = [] images = latent_samples_to_images(context, x_samples) for i, img in enumerate(images): - buf = img_to_buffer(img, output_format='JPEG') + buf = img_to_buffer(img, output_format="JPEG") context.temp_images[f"{task_data.request_id}/{i}"] = buf task_temp_images[i] = buf - partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"}) + partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"}) del images return partial_images @@ -125,7 +157,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu progress = {"step": i, "step_time": step_time, "total_steps": n_steps} if stream_image_progress and i % 5 == 0: - progress['output'] = update_temp_img(x_samples, task_temp_images) + progress["output"] = update_temp_img(x_samples, task_temp_images) data_queue.put(json.dumps(progress)) diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index bbdca9c9..c04189c1 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -16,21 +16,25 @@ from easydiffusion import app, model_manager, task_manager from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest from easydiffusion.utils import log -log.info(f'started in {app.SD_DIR}') -log.info(f'started at {datetime.datetime.now():%x %X}') +log.info(f"started in {app.SD_DIR}") +log.info(f"started at {datetime.datetime.now():%x %X}") server_api = FastAPI() -NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} +NOCACHE_HEADERS = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} + class NoCacheStaticFiles(StaticFiles): def is_not_modified(self, response_headers, request_headers) -> bool: - if 'content-type' in response_headers and ('javascript' in response_headers['content-type'] or 'css' in response_headers['content-type']): + if "content-type" in response_headers and ( + "javascript" in response_headers["content-type"] or "css" in response_headers["content-type"] + ): response_headers.update(NOCACHE_HEADERS) return False return super().is_not_modified(response_headers, request_headers) + class SetAppConfigRequest(BaseModel): update_branch: str = None render_devices: Union[List[str], List[int], str, int] = None @@ -39,130 +43,142 @@ class SetAppConfigRequest(BaseModel): listen_to_network: bool = None listen_port: int = None + def init(): - server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media") + server_api.mount("/media", NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")), name="media") for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES: - server_api.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}") + server_api.mount( + f"/plugins/{dir_prefix}", NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}" + ) - @server_api.post('/app_config') - async def set_app_config(req : SetAppConfigRequest): + @server_api.post("/app_config") + async def set_app_config(req: SetAppConfigRequest): return set_app_config_internal(req) - @server_api.get('/get/{key:path}') - def read_web_data(key:str=None): + @server_api.get("/get/{key:path}") + def read_web_data(key: str = None): return read_web_data_internal(key) - @server_api.get('/ping') # Get server and optionally session status. - def ping(session_id:str=None): + @server_api.get("/ping") # Get server and optionally session status. + def ping(session_id: str = None): return ping_internal(session_id) - @server_api.post('/render') + @server_api.post("/render") def render(req: dict): return render_internal(req) - @server_api.post('/model/merge') + @server_api.post("/model/merge") def model_merge(req: dict): print(req) return model_merge_internal(req) - @server_api.get('/image/stream/{task_id:int}') - def stream(task_id:int): + @server_api.get("/image/stream/{task_id:int}") + def stream(task_id: int): return stream_internal(task_id) - @server_api.get('/image/stop') + @server_api.get("/image/stop") def stop(task: int): return stop_internal(task) - @server_api.get('/image/tmp/{task_id:int}/{img_id:int}') + @server_api.get("/image/tmp/{task_id:int}/{img_id:int}") def get_image(task_id: int, img_id: int): return get_image_internal(task_id, img_id) - @server_api.get('/') + @server_api.get("/") def read_root(): - return FileResponse(os.path.join(app.SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) + return FileResponse(os.path.join(app.SD_UI_DIR, "index.html"), headers=NOCACHE_HEADERS) @server_api.on_event("shutdown") - def shutdown_event(): # Signal render thread to close on shutdown - task_manager.current_state_error = SystemExit('Application shutting down.') + def shutdown_event(): # Signal render thread to close on shutdown + task_manager.current_state_error = SystemExit("Application shutting down.") + # API implementations -def set_app_config_internal(req : SetAppConfigRequest): +def set_app_config_internal(req: SetAppConfigRequest): config = app.getConfig() if req.update_branch is not None: - config['update_branch'] = req.update_branch + config["update_branch"] = req.update_branch if req.render_devices is not None: update_render_devices_in_config(config, req.render_devices) if req.ui_open_browser_on_start is not None: - if 'ui' not in config: - config['ui'] = {} - config['ui']['open_browser_on_start'] = req.ui_open_browser_on_start + if "ui" not in config: + config["ui"] = {} + config["ui"]["open_browser_on_start"] = req.ui_open_browser_on_start if req.listen_to_network is not None: - if 'net' not in config: - config['net'] = {} - config['net']['listen_to_network'] = bool(req.listen_to_network) + if "net" not in config: + config["net"] = {} + config["net"]["listen_to_network"] = bool(req.listen_to_network) if req.listen_port is not None: - if 'net' not in config: - config['net'] = {} - config['net']['listen_port'] = int(req.listen_port) + if "net" not in config: + config["net"] = {} + config["net"]["listen_port"] = int(req.listen_port) try: app.setConfig(config) if req.render_devices: app.update_render_threads() - return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS) + return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS) except Exception as e: log.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + def update_render_devices_in_config(config, render_devices): - if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'): - raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}') + if render_devices not in ("cpu", "auto") and not render_devices.startswith("cuda:"): + raise HTTPException(status_code=400, detail=f"Invalid render device requested: {render_devices}") - if render_devices.startswith('cuda:'): - render_devices = render_devices.split(',') + if render_devices.startswith("cuda:"): + render_devices = render_devices.split(",") - config['render_devices'] = render_devices + config["render_devices"] = render_devices -def read_web_data_internal(key:str=None): - if not key: # /get without parameters, stable-diffusion easter egg. - raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot - elif key == 'app_config': + +def read_web_data_internal(key: str = None): + if not key: # /get without parameters, stable-diffusion easter egg. + raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot + elif key == "app_config": return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS) - elif key == 'system_info': + elif key == "system_info": config = app.getConfig() - output_dir = config.get('force_save_path', os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME)) + output_dir = config.get("force_save_path", os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME)) system_info = { - 'devices': task_manager.get_devices(), - 'hosts': app.getIPConfig(), - 'default_output_dir': output_dir, - 'enforce_output_dir': ('force_save_path' in config), + "devices": task_manager.get_devices(), + "hosts": app.getIPConfig(), + "default_output_dir": output_dir, + "enforce_output_dir": ("force_save_path" in config), } - system_info['devices']['config'] = config.get('render_devices', "auto") + system_info["devices"]["config"] = config.get("render_devices", "auto") return JSONResponse(system_info, headers=NOCACHE_HEADERS) - elif key == 'models': + elif key == "models": return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS) - elif key == 'modifiers': return FileResponse(os.path.join(app.SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) - elif key == 'ui_plugins': return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS) + elif key == "modifiers": + return FileResponse(os.path.join(app.SD_UI_DIR, "modifiers.json"), headers=NOCACHE_HEADERS) + elif key == "ui_plugins": + return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS) else: - raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found + raise HTTPException(status_code=404, detail=f"Request for unknown {key}") # HTTP404 Not Found -def ping_internal(session_id:str=None): - if task_manager.is_alive() <= 0: # Check that render threads are alive. - if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) - raise HTTPException(status_code=500, detail='Render thread is dead.') - if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) + +def ping_internal(session_id: str = None): + if task_manager.is_alive() <= 0: # Check that render threads are alive. + if task_manager.current_state_error: + raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) + raise HTTPException(status_code=500, detail="Render thread is dead.") + if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): + raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) # Alive - response = {'status': str(task_manager.current_state)} + response = {"status": str(task_manager.current_state)} if session_id: session = task_manager.get_cached_session(session_id, update_ttl=True) - response['tasks'] = {id(t): t.status for t in session.tasks} - response['devices'] = task_manager.get_devices() + response["tasks"] = {id(t): t.status for t in session.tasks} + response["devices"] = task_manager.get_devices() return JSONResponse(response, headers=NOCACHE_HEADERS) + def render_internal(req: dict): try: # separate out the request data into rendering and task-specific data @@ -171,80 +187,99 @@ def render_internal(req: dict): # Overwrite user specified save path config = app.getConfig() - if 'force_save_path' in config: - task_data.save_to_disk_path = config['force_save_path'] + if "force_save_path" in config: + task_data.save_to_disk_path = config["force_save_path"] - render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision + render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision - app.save_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model, task_data.vram_usage_level) + app.save_to_config( + task_data.use_stable_diffusion_model, + task_data.use_vae_model, + task_data.use_hypernetwork_model, + task_data.vram_usage_level, + ) # enqueue the task new_task = task_manager.render(render_req, task_data) response = { - 'status': str(task_manager.current_state), - 'queue': len(task_manager.tasks_queue), - 'stream': f'/image/stream/{id(new_task)}', - 'task': id(new_task) + "status": str(task_manager.current_state), + "queue": len(task_manager.tasks_queue), + "stream": f"/image/stream/{id(new_task)}", + "task": id(new_task), } return JSONResponse(response, headers=NOCACHE_HEADERS) - except ChildProcessError as e: # Render thread is dead - raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error - except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. - raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable + except ChildProcessError as e: # Render thread is dead + raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error + except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. + raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable except Exception as e: log.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + def model_merge_internal(req: dict): try: from sdkit.train import merge_models from easydiffusion.utils.save_utils import filename_regex + mergeReq: MergeRequest = MergeRequest.parse_obj(req) - - merge_models(model_manager.resolve_model_to_use(mergeReq.model0,'stable-diffusion'), - model_manager.resolve_model_to_use(mergeReq.model1,'stable-diffusion'), - mergeReq.ratio, - os.path.join(app.MODELS_DIR, 'stable-diffusion', filename_regex.sub('_', mergeReq.out_path)), - mergeReq.use_fp16 + + merge_models( + model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"), + model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"), + mergeReq.ratio, + os.path.join(app.MODELS_DIR, "stable-diffusion", filename_regex.sub("_", mergeReq.out_path)), + mergeReq.use_fp16, ) - return JSONResponse({'status':'OK'}, headers=NOCACHE_HEADERS) + return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS) except Exception as e: log.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) -def stream_internal(task_id:int): - #TODO Move to WebSockets ?? + +def stream_internal(task_id: int): + # TODO Move to WebSockets ?? task = task_manager.get_cached_task(task_id, update_ttl=True) - if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound - #if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict + if not task: + raise HTTPException(status_code=404, detail=f"Request {task_id} not found.") # HTTP404 NotFound + # if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict if task.buffer_queue.empty() and not task.lock.locked(): if task.response: - #log.info(f'Session {session_id} sending cached response') + # log.info(f'Session {session_id} sending cached response') return JSONResponse(task.response, headers=NOCACHE_HEADERS) - raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early - #log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') - return StreamingResponse(task.read_buffer_generator(), media_type='application/json') + raise HTTPException(status_code=425, detail="Too Early, task not started yet.") # HTTP425 Too Early + # log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') + return StreamingResponse(task.read_buffer_generator(), media_type="application/json") + def stop_internal(task: int): if not task: - if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable: - raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict - task_manager.current_state_error = StopAsyncIteration('') - return {'OK'} + if ( + task_manager.current_state == task_manager.ServerStates.Online + or task_manager.current_state == task_manager.ServerStates.Unavailable + ): + raise HTTPException(status_code=409, detail="Not currently running any tasks.") # HTTP409 Conflict + task_manager.current_state_error = StopAsyncIteration("") + return {"OK"} task_id = task task = task_manager.get_cached_task(task_id, update_ttl=False) - if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found - if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict - task.error = StopAsyncIteration(f'Task {task_id} stop requested.') - return {'OK'} + if not task: + raise HTTPException(status_code=404, detail=f"Task {task_id} was not found.") # HTTP404 Not Found + if isinstance(task.error, StopAsyncIteration): + raise HTTPException(status_code=409, detail=f"Task {task_id} is already stopped.") # HTTP409 Conflict + task.error = StopAsyncIteration(f"Task {task_id} stop requested.") + return {"OK"} + def get_image_internal(task_id: int, img_id: int): task = task_manager.get_cached_task(task_id, update_ttl=True) - if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound - if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early + if not task: + raise HTTPException(status_code=410, detail=f"Task {task_id} could not be found.") # HTTP404 NotFound + if not task.temp_images[img_id]: + raise HTTPException(status_code=425, detail="Too Early, task data is not available yet.") # HTTP425 Too Early try: img_data = task.temp_images[img_id] img_data.seek(0) - return StreamingResponse(img_data, media_type='image/jpeg') + return StreamingResponse(img_data, media_type="image/jpeg") except KeyError as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/ui/easydiffusion/task_manager.py b/ui/easydiffusion/task_manager.py index 213f5a96..28f84963 100644 --- a/ui/easydiffusion/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -7,7 +7,7 @@ Notes: import json import traceback -TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout +TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout import torch import queue, threading, time, weakref @@ -19,71 +19,98 @@ from easydiffusion.utils import log from sdkit.utils import gc -THREAD_NAME_PREFIX = '' -ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' -LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. +THREAD_NAME_PREFIX = "" +ERR_LOCK_FAILED = " failed to acquire lock within timeout." +LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. -DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init. +DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init. + + +class SymbolClass(type): # Print nicely formatted Symbol names. + def __repr__(self): + return self.__qualname__ + + def __str__(self): + return self.__name__ + + +class Symbol(metaclass=SymbolClass): + pass -class SymbolClass(type): # Print nicely formatted Symbol names. - def __repr__(self): return self.__qualname__ - def __str__(self): return self.__name__ -class Symbol(metaclass=SymbolClass): pass class ServerStates: - class Init(Symbol): pass - class LoadingModel(Symbol): pass - class Online(Symbol): pass - class Rendering(Symbol): pass - class Unavailable(Symbol): pass + class Init(Symbol): + pass -class RenderTask(): # Task with output queue and completion lock. + class LoadingModel(Symbol): + pass + + class Online(Symbol): + pass + + class Rendering(Symbol): + pass + + class Unavailable(Symbol): + pass + + +class RenderTask: # Task with output queue and completion lock. def __init__(self, req: GenerateImageRequest, task_data: TaskData): task_data.request_id = id(self) self.render_request: GenerateImageRequest = req # Initial Request self.task_data: TaskData = task_data - self.response: Any = None # Copy of the last reponse - self.render_device = None # Select the task affinity. (Not used to change active devices). - self.temp_images:list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2) + self.response: Any = None # Copy of the last reponse + self.render_device = None # Select the task affinity. (Not used to change active devices). + self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2) self.error: Exception = None - self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed - self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments + self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed + self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments + async def read_buffer_generator(self): try: while not self.buffer_queue.empty(): res = self.buffer_queue.get(block=False) self.buffer_queue.task_done() yield res - except queue.Empty as e: yield + except queue.Empty as e: + yield + @property def status(self): if self.lock.locked(): - return 'running' + return "running" if isinstance(self.error, StopAsyncIteration): - return 'stopped' + return "stopped" if self.error: - return 'error' + return "error" if not self.buffer_queue.empty(): - return 'buffer' + return "buffer" if self.response: - return 'completed' - return 'pending' + return "completed" + return "pending" + @property def is_pending(self): return bool(not self.response and not self.error) + # Temporary cache to allow to query tasks results for a short time after they are completed. -class DataCache(): +class DataCache: def __init__(self): self._base = dict() self._lock: threading.Lock = threading.Lock() + def _get_ttl_time(self, ttl: int) -> int: return int(time.time()) + ttl + def _is_expired(self, timestamp: int) -> bool: return int(time.time()) >= timestamp + def clean(self) -> None: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clean' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("DataCache.clean" + ERR_LOCK_FAILED) try: # Create a list of expired keys to delete to_delete = [] @@ -95,20 +122,26 @@ class DataCache(): for key in to_delete: (_, val) = self._base[key] if isinstance(val, RenderTask): - log.debug(f'RenderTask {key} expired. Data removed.') + log.debug(f"RenderTask {key} expired. Data removed.") elif isinstance(val, SessionState): - log.debug(f'Session {key} expired. Data removed.') + log.debug(f"Session {key} expired. Data removed.") else: - log.debug(f'Key {key} expired. Data removed.') + log.debug(f"Key {key} expired. Data removed.") del self._base[key] finally: self._lock.release() + def clear(self) -> None: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clear' + ERR_LOCK_FAILED) - try: self._base.clear() - finally: self._lock.release() + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("DataCache.clear" + ERR_LOCK_FAILED) + try: + self._base.clear() + finally: + self._lock.release() + def delete(self, key: Hashable) -> bool: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.delete' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("DataCache.delete" + ERR_LOCK_FAILED) try: if key not in self._base: return False @@ -116,8 +149,10 @@ class DataCache(): return True finally: self._lock.release() + def keep(self, key: Hashable, ttl: int) -> bool: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.keep' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("DataCache.keep" + ERR_LOCK_FAILED) try: if key in self._base: _, value = self._base.get(key) @@ -126,12 +161,12 @@ class DataCache(): return False finally: self._lock.release() + def put(self, key: Hashable, value: Any, ttl: int) -> bool: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.put' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("DataCache.put" + ERR_LOCK_FAILED) try: - self._base[key] = ( - self._get_ttl_time(ttl), value - ) + self._base[key] = (self._get_ttl_time(ttl), value) except Exception as e: log.error(traceback.format_exc()) return False @@ -139,35 +174,41 @@ class DataCache(): return True finally: self._lock.release() + def tryGet(self, key: Hashable) -> Any: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.tryGet' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("DataCache.tryGet" + ERR_LOCK_FAILED) try: ttl, value = self._base.get(key, (None, None)) if ttl is not None and self._is_expired(ttl): - log.debug(f'Session {key} expired. Discarding data.') + log.debug(f"Session {key} expired. Discarding data.") del self._base[key] return None return value finally: self._lock.release() + manager_lock = threading.RLock() render_threads = [] current_state = ServerStates.Init -current_state_error:Exception = None +current_state_error: Exception = None tasks_queue = [] session_cache = DataCache() task_cache = DataCache() weak_thread_data = weakref.WeakKeyDictionary() idle_event: threading.Event = threading.Event() -class SessionState(): + +class SessionState: def __init__(self, id: str): self._id = id self._tasks_ids = [] + @property def id(self): return self._id + @property def tasks(self): tasks = [] @@ -176,6 +217,7 @@ class SessionState(): if task: tasks.append(task) return tasks + def put(self, task, ttl=TASK_TTL): task_id = id(task) self._tasks_ids.append(task_id) @@ -185,10 +227,12 @@ class SessionState(): self._tasks_ids.pop(0) return True + def thread_get_next_task(): from easydiffusion import renderer + if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): - log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.') + log.warn(f"Render thread on device: {renderer.context.device} failed to acquire manager lock.") return None if len(tasks_queue) <= 0: manager_lock.release() @@ -202,10 +246,10 @@ def thread_get_next_task(): continue # requested device alive, skip current one. else: # Requested device is not active, return error to UI. - queued_task.error = Exception(queued_task.render_device + ' is not currently active.') + queued_task.error = Exception(queued_task.render_device + " is not currently active.") task = queued_task break - if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1: + if not queued_task.render_device and renderer.context.device == "cpu" and is_alive() > 1: # not asking for any specific devices, cpu want to grab task but other render devices are alive. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. task = queued_task @@ -216,17 +260,19 @@ def thread_get_next_task(): finally: manager_lock.release() + def thread_render(device): global current_state, current_state_error from easydiffusion import renderer, model_manager + try: renderer.init(device) weak_thread_data[threading.current_thread()] = { - 'device': renderer.context.device, - 'device_name': renderer.context.device_name, - 'alive': True + "device": renderer.context.device, + "device_name": renderer.context.device_name, + "alive": True, } current_state = ServerStates.LoadingModel @@ -235,17 +281,14 @@ def thread_render(device): current_state = ServerStates.Online except Exception as e: log.error(traceback.format_exc()) - weak_thread_data[threading.current_thread()] = { - 'error': e, - 'alive': False - } + weak_thread_data[threading.current_thread()] = {"error": e, "alive": False} return while True: session_cache.clean() task_cache.clean() - if not weak_thread_data[threading.current_thread()]['alive']: - log.info(f'Shutting down thread for device {renderer.context.device}') + if not weak_thread_data[threading.current_thread()]["alive"]: + log.info(f"Shutting down thread for device {renderer.context.device}") model_manager.unload_all(renderer.context) return if isinstance(current_state_error, SystemExit): @@ -258,39 +301,47 @@ def thread_render(device): continue if task.error is not None: log.error(task.error) - task.response = {"status": 'failed', "detail": str(task.error)} + task.response = {"status": "failed", "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue if current_state_error: task.error = current_state_error - task.response = {"status": 'failed', "detail": str(task.error)} + task.response = {"status": "failed", "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue - log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}') - if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') + log.info(f"Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}") + if not task.lock.acquire(blocking=False): + raise Exception("Got locked task from queue.") try: + def step_callback(): global current_state_error - if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): + if ( + isinstance(current_state_error, SystemExit) + or isinstance(current_state_error, StopAsyncIteration) + or isinstance(task.error, StopAsyncIteration) + ): renderer.context.stop_processing = True if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None - log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}') + log.info(f"Session {task.task_data.session_id} sent cancel signal for task {id(task)}") current_state = ServerStates.LoadingModel model_manager.resolve_model_paths(task.task_data) model_manager.reload_models_if_necessary(renderer.context, task.task_data) current_state = ServerStates.Rendering - task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) + task.response = renderer.make_images( + task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback + ) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL) except Exception as e: task.error = str(e) - task.response = {"status": 'failed', "detail": str(task.error)} + task.response = {"status": "failed", "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) log.error(traceback.format_exc()) finally: @@ -299,21 +350,25 @@ def thread_render(device): task_cache.keep(id(task), TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL) if isinstance(task.error, StopAsyncIteration): - log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!') + log.info(f"Session {task.task_data.session_id} task {id(task)} cancelled!") elif task.error is not None: - log.info(f'Session {task.task_data.session_id} task {id(task)} failed!') + log.info(f"Session {task.task_data.session_id} task {id(task)} failed!") else: - log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.') + log.info( + f"Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}." + ) current_state = ServerStates.Online -def get_cached_task(task_id:str, update_ttl:bool=False): + +def get_cached_task(task_id: str, update_ttl: bool = False): # By calling keep before tryGet, wont discard if was expired. if update_ttl and not task_cache.keep(task_id, TASK_TTL): # Failed to keep task, already gone. return None return task_cache.tryGet(task_id) -def get_cached_session(session_id:str, update_ttl:bool=False): + +def get_cached_session(session_id: str, update_ttl: bool = False): if update_ttl: session_cache.keep(session_id, TASK_TTL) session = session_cache.tryGet(session_id) @@ -322,64 +377,68 @@ def get_cached_session(session_id:str, update_ttl:bool=False): session_cache.put(session_id, session, TASK_TTL) return session + def get_devices(): devices = { - 'all': {}, - 'active': {}, + "all": {}, + "active": {}, } def get_device_info(device): - if device == 'cpu': - return {'name': device_manager.get_processor_name()} - + if device == "cpu": + return {"name": device_manager.get_processor_name()} + mem_free, mem_total = torch.cuda.mem_get_info(device) mem_free /= float(10**9) mem_total /= float(10**9) return { - 'name': torch.cuda.get_device_name(device), - 'mem_free': mem_free, - 'mem_total': mem_total, - 'max_vram_usage_level': device_manager.get_max_vram_usage_level(device), + "name": torch.cuda.get_device_name(device), + "mem_free": mem_free, + "mem_total": mem_total, + "max_vram_usage_level": device_manager.get_max_vram_usage_level(device), } # list the compatible devices gpu_count = torch.cuda.device_count() for device in range(gpu_count): - device = f'cuda:{device}' + device = f"cuda:{device}" if not device_manager.is_device_compatible(device): continue - devices['all'].update({device: get_device_info(device)}) + devices["all"].update({device: get_device_info(device)}) - devices['all'].update({'cpu': get_device_info('cpu')}) + devices["all"].update({"cpu": get_device_info("cpu")}) # list the activated devices - if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('get_devices' + ERR_LOCK_FAILED) + if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("get_devices" + ERR_LOCK_FAILED) try: for rthread in render_threads: if not rthread.is_alive(): continue weak_data = weak_thread_data.get(rthread) - if not weak_data or not 'device' in weak_data or not 'device_name' in weak_data: + if not weak_data or not "device" in weak_data or not "device_name" in weak_data: continue - device = weak_data['device'] - devices['active'].update({device: get_device_info(device)}) + device = weak_data["device"] + devices["active"].update({device: get_device_info(device)}) finally: manager_lock.release() return devices + def is_alive(device=None): - if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED) + if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("is_alive" + ERR_LOCK_FAILED) nbr_alive = 0 try: for rthread in render_threads: if device is not None: weak_data = weak_thread_data.get(rthread) - if weak_data is None or not 'device' in weak_data or weak_data['device'] is None: + if weak_data is None or not "device" in weak_data or weak_data["device"] is None: continue - thread_device = weak_data['device'] + thread_device = weak_data["device"] if thread_device != device: continue if rthread.is_alive(): @@ -388,11 +447,13 @@ def is_alive(device=None): finally: manager_lock.release() + def start_render_thread(device): - if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED) - log.info(f'Start new Rendering Thread on device: {device}') + if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("start_render_thread" + ERR_LOCK_FAILED) + log.info(f"Start new Rendering Thread on device: {device}") try: - rthread = threading.Thread(target=thread_render, kwargs={'device': device}) + rthread = threading.Thread(target=thread_render, kwargs={"device": device}) rthread.daemon = True rthread.name = THREAD_NAME_PREFIX + device rthread.start() @@ -400,8 +461,8 @@ def start_render_thread(device): finally: manager_lock.release() timeout = DEVICE_START_TIMEOUT - while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]: - if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]: + while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]: + if rthread in weak_thread_data and "error" in weak_thread_data[rthread]: log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}") return False if timeout <= 0: @@ -410,25 +471,27 @@ def start_render_thread(device): time.sleep(1) return True + def stop_render_thread(device): try: - device_manager.validate_device_id(device, log_prefix='stop_render_thread') + device_manager.validate_device_id(device, log_prefix="stop_render_thread") except: log.error(traceback.format_exc()) return False - if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED) - log.info(f'Stopping Rendering Thread on device: {device}') + if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("stop_render_thread" + ERR_LOCK_FAILED) + log.info(f"Stopping Rendering Thread on device: {device}") try: thread_to_remove = None for rthread in render_threads: weak_data = weak_thread_data.get(rthread) - if weak_data is None or not 'device' in weak_data or weak_data['device'] is None: + if weak_data is None or not "device" in weak_data or weak_data["device"] is None: continue - thread_device = weak_data['device'] + thread_device = weak_data["device"] if thread_device == device: - weak_data['alive'] = False + weak_data["alive"] = False thread_to_remove = rthread break if thread_to_remove is not None: @@ -439,44 +502,51 @@ def stop_render_thread(device): return False + def update_render_threads(render_devices, active_devices): devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices) - log.debug(f'devices_to_start: {devices_to_start}') - log.debug(f'devices_to_stop: {devices_to_stop}') + log.debug(f"devices_to_start: {devices_to_start}") + log.debug(f"devices_to_stop: {devices_to_stop}") for device in devices_to_stop: if is_alive(device) <= 0: - log.debug(f'{device} is not alive') + log.debug(f"{device} is not alive") continue if not stop_render_thread(device): - log.warn(f'{device} could not stop render thread') + log.warn(f"{device} could not stop render thread") for device in devices_to_start: if is_alive(device) >= 1: - log.debug(f'{device} already registered.') + log.debug(f"{device} already registered.") continue if not start_render_thread(device): - log.warn(f'{device} failed to start.') + log.warn(f"{device} failed to start.") - if is_alive() <= 0: # No running devices, probably invalid user config. - raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json') + if is_alive() <= 0: # No running devices, probably invalid user config. + raise EnvironmentError( + 'ERROR: No active render devices! Please verify the "render_devices" value in config.json' + ) log.debug(f"active devices: {get_devices()['active']}") -def shutdown_event(): # Signal render thread to close on shutdown + +def shutdown_event(): # Signal render thread to close on shutdown global current_state_error - current_state_error = SystemExit('Application shutting down.') + current_state_error = SystemExit("Application shutting down.") + def render(render_req: GenerateImageRequest, task_data: TaskData): current_thread_count = is_alive() if current_thread_count <= 0: # Render thread is dead - raise ChildProcessError('Rendering thread has died.') + raise ChildProcessError("Rendering thread has died.") # Alive, check if task in cache session = get_cached_session(task_data.session_id, update_ttl=True) pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) if current_thread_count < len(pending_tasks): - raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') + raise ConnectionRefusedError( + f"Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}." + ) new_task = RenderTask(render_req, task_data) if session.put(new_task, TASK_TTL): @@ -489,4 +559,4 @@ def render(render_req: GenerateImageRequest, task_data: TaskData): return new_task finally: manager_lock.release() - raise RuntimeError('Failed to add task to cache.') + raise RuntimeError("Failed to add task to cache.") diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index c70e125b..b45eafd2 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -1,6 +1,7 @@ from pydantic import BaseModel from typing import Any + class GenerateImageRequest(BaseModel): prompt: str = "" negative_prompt: str = "" @@ -18,29 +19,31 @@ class GenerateImageRequest(BaseModel): prompt_strength: float = 0.8 preserve_init_image_color_profile = False - sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" + sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" hypernetwork_strength: float = 0 + class TaskData(BaseModel): request_id: str = None session_id: str = "session" save_to_disk_path: str = None - vram_usage_level: str = "balanced" # or "low" or "medium" + vram_usage_level: str = "balanced" # or "low" or "medium" - use_face_correction: str = None # or "GFPGANv1.3" - use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" - upscale_amount: int = 4 # or 2 + use_face_correction: str = None # or "GFPGANv1.3" + use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" + upscale_amount: int = 4 # or 2 use_stable_diffusion_model: str = "sd-v1-4" # use_stable_diffusion_config: str = "v1-inference" use_vae_model: str = None use_hypernetwork_model: str = None show_only_filtered_image: bool = False - output_format: str = "jpeg" # or "png" + output_format: str = "jpeg" # or "png" output_quality: int = 75 - metadata_output_format: str = "txt" # or "json" + metadata_output_format: str = "txt" # or "json" stream_image_progress: bool = False + class MergeRequest(BaseModel): model0: str = None model1: str = None @@ -48,8 +51,9 @@ class MergeRequest(BaseModel): out_path: str = "mix" use_fp16 = True + class Image: - data: str # base64 + data: str # base64 seed: int is_nsfw: bool path_abs: str = None @@ -65,6 +69,7 @@ class Image: "path_abs": self.path_abs, } + class Response: render_request: GenerateImageRequest task_data: TaskData @@ -80,7 +85,7 @@ class Response: del self.render_request.init_image_mask res = { - "status": 'succeeded', + "status": "succeeded", "render_request": self.render_request.dict(), "task_data": self.task_data.dict(), "output": [], @@ -91,5 +96,6 @@ class Response: return res + class UserInitiatedStop(Exception): pass diff --git a/ui/easydiffusion/utils/__init__.py b/ui/easydiffusion/utils/__init__.py index 8be070b4..b9c5e21a 100644 --- a/ui/easydiffusion/utils/__init__.py +++ b/ui/easydiffusion/utils/__init__.py @@ -1,8 +1,8 @@ import logging -log = logging.getLogger('easydiffusion') +log = logging.getLogger("easydiffusion") from .save_utils import ( save_images_to_disk, get_printable_request, -) \ No newline at end of file +) diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index 1835b7e8..b4a85538 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -7,89 +7,126 @@ from easydiffusion.types import TaskData, GenerateImageRequest from sdkit.utils import save_images, save_dicts -filename_regex = re.compile('[^a-zA-Z0-9._-]') +filename_regex = re.compile("[^a-zA-Z0-9._-]") # keep in sync with `ui/media/js/dnd.js` TASK_TEXT_MAPPING = { - 'prompt': 'Prompt', - 'width': 'Width', - 'height': 'Height', - 'seed': 'Seed', - 'num_inference_steps': 'Steps', - 'guidance_scale': 'Guidance Scale', - 'prompt_strength': 'Prompt Strength', - 'use_face_correction': 'Use Face Correction', - 'use_upscale': 'Use Upscaling', - 'upscale_amount': 'Upscale By', - 'sampler_name': 'Sampler', - 'negative_prompt': 'Negative Prompt', - 'use_stable_diffusion_model': 'Stable Diffusion model', - 'use_vae_model': 'VAE model', - 'use_hypernetwork_model': 'Hypernetwork model', - 'hypernetwork_strength': 'Hypernetwork Strength' + "prompt": "Prompt", + "width": "Width", + "height": "Height", + "seed": "Seed", + "num_inference_steps": "Steps", + "guidance_scale": "Guidance Scale", + "prompt_strength": "Prompt Strength", + "use_face_correction": "Use Face Correction", + "use_upscale": "Use Upscaling", + "upscale_amount": "Upscale By", + "sampler_name": "Sampler", + "negative_prompt": "Negative Prompt", + "use_stable_diffusion_model": "Stable Diffusion model", + "use_vae_model": "VAE model", + "use_hypernetwork_model": "Hypernetwork model", + "hypernetwork_strength": "Hypernetwork Strength", } + def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): now = time.time() - save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) + save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub("_", task_data.session_id)) metadata_entries = get_metadata_entries_for_request(req, task_data) make_filename = make_filename_callback(req, now=now) if task_data.show_only_filtered_image or filtered_images is images: - save_images(filtered_images, save_dir_path, file_name=make_filename, output_format=task_data.output_format, output_quality=task_data.output_quality) - if task_data.metadata_output_format.lower() in ['json', 'txt', 'embed']: - save_dicts(metadata_entries, save_dir_path, file_name=make_filename, output_format=task_data.metadata_output_format, file_format=task_data.output_format) + save_images( + filtered_images, + save_dir_path, + file_name=make_filename, + output_format=task_data.output_format, + output_quality=task_data.output_quality, + ) + if task_data.metadata_output_format.lower() in ["json", "txt", "embed"]: + save_dicts( + metadata_entries, + save_dir_path, + file_name=make_filename, + output_format=task_data.metadata_output_format, + file_format=task_data.output_format, + ) else: - make_filter_filename = make_filename_callback(req, now=now, suffix='filtered') + make_filter_filename = make_filename_callback(req, now=now, suffix="filtered") + + save_images( + images, + save_dir_path, + file_name=make_filename, + output_format=task_data.output_format, + output_quality=task_data.output_quality, + ) + save_images( + filtered_images, + save_dir_path, + file_name=make_filter_filename, + output_format=task_data.output_format, + output_quality=task_data.output_quality, + ) + if task_data.metadata_output_format.lower() in ["json", "txt", "embed"]: + save_dicts( + metadata_entries, + save_dir_path, + file_name=make_filter_filename, + output_format=task_data.metadata_output_format, + file_format=task_data.output_format, + ) - save_images(images, save_dir_path, file_name=make_filename, output_format=task_data.output_format, output_quality=task_data.output_quality) - save_images(filtered_images, save_dir_path, file_name=make_filter_filename, output_format=task_data.output_format, output_quality=task_data.output_quality) - if task_data.metadata_output_format.lower() in ['json', 'txt', 'embed']: - save_dicts(metadata_entries, save_dir_path, file_name=make_filter_filename, output_format=task_data.metadata_output_format, file_format=task_data.output_format) def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData): metadata = get_printable_request(req) - metadata.update({ - 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, - 'use_vae_model': task_data.use_vae_model, - 'use_hypernetwork_model': task_data.use_hypernetwork_model, - 'use_face_correction': task_data.use_face_correction, - 'use_upscale': task_data.use_upscale, - }) - if metadata['use_upscale'] is not None: - metadata['upscale_amount'] = task_data.upscale_amount - if (task_data.use_hypernetwork_model is None): - del metadata['hypernetwork_strength'] + metadata.update( + { + "use_stable_diffusion_model": task_data.use_stable_diffusion_model, + "use_vae_model": task_data.use_vae_model, + "use_hypernetwork_model": task_data.use_hypernetwork_model, + "use_face_correction": task_data.use_face_correction, + "use_upscale": task_data.use_upscale, + } + ) + if metadata["use_upscale"] is not None: + metadata["upscale_amount"] = task_data.upscale_amount + if task_data.use_hypernetwork_model is None: + del metadata["hypernetwork_strength"] # if text, format it in the text format expected by the UI - is_txt_format = (task_data.metadata_output_format.lower() == 'txt') + is_txt_format = task_data.metadata_output_format.lower() == "txt" if is_txt_format: metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING} entries = [metadata.copy() for _ in range(req.num_outputs)] for i, entry in enumerate(entries): - entry['Seed' if is_txt_format else 'seed'] = req.seed + i + entry["Seed" if is_txt_format else "seed"] = req.seed + i return entries + def get_printable_request(req: GenerateImageRequest): metadata = req.dict() - del metadata['init_image'] - del metadata['init_image_mask'] - if (req.init_image is None): - del metadata['prompt_strength'] + del metadata["init_image"] + del metadata["init_image_mask"] + if req.init_image is None: + del metadata["prompt_strength"] return metadata + def make_filename_callback(req: GenerateImageRequest, suffix=None, now=None): if now is None: now = time.time() - def make_filename(i): - img_id = base64.b64encode(int(now+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. - img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. - prompt_flattened = filename_regex.sub('_', req.prompt)[:50] + def make_filename(i): + img_id = base64.b64encode(int(now + i).to_bytes(8, "big")).decode() # Generate unique ID based on time. + img_id = img_id.translate({43: None, 47: None, 61: None})[-8:] # Remove + / = and keep last 8 chars. + + prompt_flattened = filename_regex.sub("_", req.prompt)[:50] name = f"{prompt_flattened}_{img_id}" - name = name if suffix is None else f'{name}_{suffix}' + name = name if suffix is None else f"{name}_{suffix}" return name return make_filename From 9d1dd09a07e195c31700ce39154c10f8d7249af4 Mon Sep 17 00:00:00 2001 From: JeLuF Date: Tue, 14 Feb 2023 15:03:25 +0100 Subject: [PATCH 04/27] 'Download all images' button (#765) * Use standard DOM function * Add 'download all images' button --------- Co-authored-by: cmdr2 --- ui/index.html | 2 ++ ui/media/css/main.css | 1 + ui/media/js/main.js | 16 ++++++++++++++++ ui/media/js/utils.js | 13 ------------- 4 files changed, 19 insertions(+), 13 deletions(-) diff --git a/ui/index.html b/ui/index.html index 1f7ae34e..05c4e68f 100644 --- a/ui/index.html +++ b/ui/index.html @@ -277,9 +277,11 @@ and selecting the desired modifiers.

Click "Image Settings" for additional settings like seed, image size, number of images to generate etc.

Enjoy! :) +
+
diff --git a/ui/media/css/main.css b/ui/media/css/main.css index f2c91f01..46a8a971 100644 --- a/ui/media/css/main.css +++ b/ui/media/css/main.css @@ -468,6 +468,7 @@ div.img-preview img { background: var(--accent-color); border: var(--primary-button-border); color: rgb(255, 221, 255); + padding: 3pt 6pt; } .secondaryButton { background: rgb(132, 8, 0); diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 72e3901a..db913a2a 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -61,6 +61,7 @@ let promptStrengthContainer = document.querySelector('#prompt_strength_container let initialText = document.querySelector("#initial-text") let previewTools = document.querySelector("#preview-tools") let clearAllPreviewsBtn = document.querySelector("#clear-all-previews") +let saveAllImagesBtn = document.querySelector("#save-all-images") let maskSetting = document.querySelector('#enable_mask') @@ -1160,6 +1161,21 @@ clearAllPreviewsBtn.addEventListener('click', (e) => { shiftOrConfirm(e, "Clear taskEntries.forEach(removeTask) })}) +saveAllImagesBtn.addEventListener('click', (e) => { + document.querySelectorAll(".imageTaskContainer").forEach(container => { + let req = htmlTaskMap.get(container) + container.querySelectorAll(".imgContainer img").forEach(img => { + if (img.closest('.imgItem').style.display === 'none') { + // console.log('skipping hidden image', img) + return + } + + onDownloadImageClick(req, img) + // console.log(req) + }) + }) +}) + stopImageBtn.addEventListener('click', (e) => { shiftOrConfirm(e, "Stop all the tasks?", async function(e) { await stopAllTasks() })}) diff --git a/ui/media/js/utils.js b/ui/media/js/utils.js index 69801571..50f5f162 100644 --- a/ui/media/js/utils.js +++ b/ui/media/js/utils.js @@ -20,19 +20,6 @@ function getNextSibling(elem, selector) { } } -function findClosestAncestor(element, selector) { - if (!element || !element.parentNode) { - // reached the top of the DOM tree, return null - return null; - } else if (element.parentNode.matches(selector)) { - // found an ancestor that matches the selector, return it - return element.parentNode; - } else { - // continue searching upwards - return findClosestAncestor(element.parentNode, selector); - } -} - /* Panel Stuff */ From c59745d346be2fe4e4ceae25cbc22b30e8c0afb7 Mon Sep 17 00:00:00 2001 From: patriceac <48073125+patriceac@users.noreply.github.com> Date: Wed, 15 Feb 2023 00:10:02 -0800 Subject: [PATCH 05/27] Cleaning up event listener that's no longer needed The event listener instantiates two objects every time the user clicks on the Merge tab. This is no longer needed after AssassinJN's CSS fixes from yesterday. --- ui/plugins/ui/merge.plugin.js | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/ui/plugins/ui/merge.plugin.js b/ui/plugins/ui/merge.plugin.js index d5e728dd..87c62d8b 100644 --- a/ui/plugins/ui/merge.plugin.js +++ b/ui/plugins/ui/merge.plugin.js @@ -334,15 +334,10 @@ linkTabContents(tabSettingsSingle) linkTabContents(tabSettingsBatch) - /////////////////////// Event Listener - document.addEventListener('tabClick', (e) => { - if (e.detail.name == 'merge') { - console.log('Activate') - let mergeModelAField = new ModelDropdown(document.querySelector('#mergeModelA'), 'stable-diffusion') - let mergeModelBField = new ModelDropdown(document.querySelector('#mergeModelB'), 'stable-diffusion') - updateChart() - } - }) + console.log('Activate') + let mergeModelAField = new ModelDropdown(document.querySelector('#mergeModelA'), 'stable-diffusion') + let mergeModelBField = new ModelDropdown(document.querySelector('#mergeModelB'), 'stable-diffusion') + updateChart() // slider const singleMergeRatioField = document.querySelector('#single-merge-ratio') From 744c6e4725990d29fb70e3fd3598be5751b9cff6 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 15 Feb 2023 21:40:02 +0530 Subject: [PATCH 06/27] sdkit 1.0.37 --- scripts/on_sd_start.bat | 2 +- scripts/on_sd_start.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 821e24aa..31b26074 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -92,7 +92,7 @@ if "%ERRORLEVEL%" EQU "0" ( set PYTHONNOUSERSITE=1 set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages - call python -m pip install --upgrade sdkit==1.0.36 -q || ( + call python -m pip install --upgrade sdkit==1.0.37 -q || ( echo "Error updating sdkit" ) ) diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 11d62531..0bfa091d 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -81,7 +81,7 @@ if python ../scripts/check_modules.py sdkit sdkit.models ldm transformers numpy export PYTHONNOUSERSITE=1 export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" - python -m pip install --upgrade sdkit==1.0.36 -q + python -m pip install --upgrade sdkit==1.0.37 -q fi else echo "Installing sdkit: https://pypi.org/project/sdkit/" From 5d3b59b94ea80ce4e7ba20abad5b203468928a63 Mon Sep 17 00:00:00 2001 From: JeLuF Date: Wed, 15 Feb 2023 21:15:55 +0100 Subject: [PATCH 07/27] No /proc/cpuinfo on MacOS Check whether /proc/cpuinfo exists before checking for AVX support --- scripts/bootstrap.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index 375ab3ac..d2a29dec 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -29,7 +29,9 @@ if ! which tar; then fail "'tar' not found. Please install tar."; fi if ! which bzip2; then fail "'bzip2' not found. Please install bzip2."; fi if pwd | grep ' '; then fail "The installation directory's path contains a space character. Conda will fail to install. Please change the directory."; fi -if ! cat /proc/cpuinfo | grep avx | uniq; then fail "Your CPU doesn't support AVX."; fi +if [ -f /proc/cpuinfo ]; then + if ! cat /proc/cpuinfo | grep avx | uniq; then fail "Your CPU doesn't support AVX."; fi +fi # https://mamba.readthedocs.io/en/latest/installation.html if [ "$OS_NAME" == "linux" ] && [ "$OS_ARCH" == "arm64" ]; then OS_ARCH="aarch64"; fi From 9f5f213cd3fe15b99e99564f905e111caaf61114 Mon Sep 17 00:00:00 2001 From: AssassinJN Date: Thu, 16 Feb 2023 00:05:46 -0500 Subject: [PATCH 08/27] Fix for dropdown widths (#883) * Fix dropdown location * change width --- ui/index.html | 2 +- ui/media/css/main.css | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ui/index.html b/ui/index.html index 05c4e68f..44e21493 100644 --- a/ui/index.html +++ b/ui/index.html @@ -228,7 +228,7 @@
  • Render Settings
  • -
  • +
  • Click to learn more about samplers From e73e82023749f45aff945432c5aefeb615c6e520 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 17 Feb 2023 15:22:42 +0530 Subject: [PATCH 20/27] Support server-side plugins. Currently supports overriding the get_cond_and_uncond function --- ui/easydiffusion/app.py | 50 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index d556dd6f..d245bf39 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -27,15 +27,21 @@ logging.basicConfig( SD_DIR = os.getcwd() SD_UI_DIR = os.getenv("SD_UI_PATH", None) -sys.path.append(os.path.dirname(SD_UI_DIR)) CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts")) MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models")) -USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins", "ui")) -CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins", "ui")) +USER_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins")) +CORE_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins")) + +USER_UI_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "ui") +CORE_UI_PLUGINS_DIR = os.path.join(CORE_PLUGINS_DIR, "ui") +USER_SERVER_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "server") UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, "core"), (USER_UI_PLUGINS_DIR, "user")) +sys.path.append(os.path.dirname(SD_UI_DIR)) +sys.path.append(USER_SERVER_PLUGINS_DIR) + OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder PRESERVE_CONFIG_VARS = ["FORCE_FULL_PRECISION"] TASK_TTL = 15 * 60 # Discard last session's task timeout @@ -51,6 +57,9 @@ APP_CONFIG_DEFAULTS = { def init(): os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) + os.makedirs(USER_SERVER_PLUGINS_DIR, exist_ok=True) + + load_server_plugins() update_render_threads() @@ -171,6 +180,41 @@ def getUIPlugins(): return plugins +def load_server_plugins(): + if not os.path.exists(USER_SERVER_PLUGINS_DIR): + return + + import importlib + + def load_plugin(file): + mod_path = file.replace(".py", "") + return importlib.import_module(mod_path) + + def apply_plugin(file, plugin): + if hasattr(plugin, "get_cond_and_uncond"): + import sdkit.generate.image_generator + + sdkit.generate.image_generator.get_cond_and_uncond = plugin.get_cond_and_uncond + log.info(f"Overridden get_cond_and_uncond with the one in the server plugin: {file}") + + for file in os.listdir(USER_SERVER_PLUGINS_DIR): + file_path = os.path.join(USER_SERVER_PLUGINS_DIR, file) + if (not os.path.isdir(file_path) and not file_path.endswith("_plugin.py")) or ( + os.path.isdir(file_path) and not file_path.endswith("_plugin") + ): + continue + + try: + log.info(f"Loading server plugin: {file}") + mod = load_plugin(file) + + log.info(f"Applying server plugin: {file}") + apply_plugin(file, mod) + except: + log.warn(f"Error while loading a server plugin") + log.warn(traceback.format_exc()) + + def getIPConfig(): try: ips = socket.gethostbyname_ex(socket.gethostname()) From 23f9bcb38ba9414e262f65db9a67b03e8977d42f Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 17 Feb 2023 15:22:59 +0530 Subject: [PATCH 21/27] Upgrade sdkit, moving the experimental parser into a plugin --- scripts/on_sd_start.bat | 2 +- scripts/on_sd_start.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 0a5209ce..c0f824bb 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -92,7 +92,7 @@ if "%ERRORLEVEL%" EQU "0" ( set PYTHONNOUSERSITE=1 set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages - call python -m pip install --upgrade sdkit==1.0.39 -q || ( + call python -m pip install --upgrade sdkit==1.0.40 -q || ( echo "Error updating sdkit" ) ) diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 3203d098..0e617e5e 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -81,7 +81,7 @@ if python ../scripts/check_modules.py sdkit sdkit.models ldm transformers numpy export PYTHONNOUSERSITE=1 export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" - python -m pip install --upgrade sdkit==1.0.39 -q + python -m pip install --upgrade sdkit==1.0.40 -q fi else echo "Installing sdkit: https://pypi.org/project/sdkit/" From a36fb55b0574986f7720e95da12644f41a38a9e1 Mon Sep 17 00:00:00 2001 From: JeLuF Date: Fri, 17 Feb 2023 10:53:51 +0100 Subject: [PATCH 22/27] Remove superfluous CarriageReturn \r\n creates CR CR LF in python, which confuses the Windows batch processor. With only \n, adding the config line for FP32 works as expected: 10:50:43.659 WARNING cuda:0 forcing full precision on this GPU, to avoid green images. GPU detected: NVIDIA GeForce GTX 1060 6GB --- ui/easydiffusion/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index d556dd6f..ebced964 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -105,7 +105,7 @@ def setConfig(config): if len(config_bat) > 0: with open(config_bat_path, "w", encoding="utf-8") as f: - f.write("\r\n".join(config_bat)) + f.write("\n".join(config_bat)) except: log.error(traceback.format_exc()) From 620f521e0c9bab432e3edbede1c6d1e6bc693268 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 17 Feb 2023 15:25:49 +0530 Subject: [PATCH 23/27] changelog --- CHANGES.md | 2 +- ui/index.html | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 1aa1613d..6b8be309 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -19,8 +19,8 @@ Our focus continues to remain on an easy installation experience, and an easy user-interface. While still remaining pretty powerful, in terms of features and speed. ### Detailed changelog +* 2.5.19 - 17 Feb 2023 - Initial support for server-side plugins. Currently supports overriding the `get_cond_and_uncond()` function. * 2.5.18 - 17 Feb 2023 - 5 new samplers! UniPC samplers, some of which produce images in less than 15 steps. Thanks @Schorny. -* 2.5.17 - 16 Feb 2023 - Experimental parser for prompts. Supports greater control over the weights assigned to prompt tokens. The experimental parser will be used only if the prompt starts with an exclamation mark, e.g. `!photo of an astronaut`. Thanks @madrang. * 2.5.16 - 13 Feb 2023 - Searchable dropdown for models. This is useful if you have a LOT of models. You can type part of the model name, to auto-search through your models. Thanks @patriceac for the feature, and @AssassinJN for help in UI tweaks! * 2.5.16 - 13 Feb 2023 - Lots of fixes and improvements to the installer. First round of changes to add Mac support. Thanks @JeLuf. * 2.5.16 - 13 Feb 2023 - UI bug fixes for the inpainter editor. Thanks @patriceac. diff --git a/ui/index.html b/ui/index.html index dbc31dbf..09b718f3 100644 --- a/ui/index.html +++ b/ui/index.html @@ -26,7 +26,7 @@
    From d8dec3e56a9049b69d21238a3b287ce8f538bdb3 Mon Sep 17 00:00:00 2001 From: patriceac <48073125+patriceac@users.noreply.github.com> Date: Fri, 17 Feb 2023 16:40:16 -0800 Subject: [PATCH 24/27] Fix the chevron enabled state upon refresh Fix for my previous PR. Apologies for this silly copy/paste mistake. https://discord.com/channels/1014774730907209781/1014780368890630164/1075782233970970704 --- ui/media/js/searchable-models.js | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ui/media/js/searchable-models.js b/ui/media/js/searchable-models.js index 86c5c068..c9cf3ea9 100644 --- a/ui/media/js/searchable-models.js +++ b/ui/media/js/searchable-models.js @@ -529,9 +529,9 @@ class ModelDropdown ) this.modelFilter.classList.add('model-selector') this.modelFilterArrow = document.querySelector(`#${this.modelFilter.id}-model-filter-arrow`) - // if (this.modelFilterArrow) { - // this.modelFilterArrow.style.color = state ? 'dimgray' : '' - // } + if (this.modelFilterArrow) { + this.modelFilterArrow.style.color = this.modelFilter.disabled ? 'dimgray' : '' + } this.modelList = document.querySelector(`#${this.modelFilter.id}-model-list`) this.modelResult = document.querySelector(`#${this.modelFilter.id}-model-result`) this.modelNoResult = document.querySelector(`#${this.modelFilter.id}-model-no-result`) From 5fffb82b16f58345fe254e89db0335e2978f5441 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sat, 18 Feb 2023 14:17:28 +0530 Subject: [PATCH 25/27] Pin the version of stable-diffusion-sdkit used, to avoid untested releases from getting used --- scripts/on_sd_start.bat | 2 +- scripts/on_sd_start.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index c0f824bb..ba1a1d97 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -113,7 +113,7 @@ if "%ERRORLEVEL%" EQU "0" ( call python -c "from importlib.metadata import version; print('sdkit version:', version('sdkit'))" @rem upgrade stable-diffusion-sdkit -call python -m pip install --upgrade stable-diffusion-sdkit -q || ( +call python -m pip install --upgrade stable-diffusion-sdkit==2.1.1 -q || ( echo "Error updating stable-diffusion-sdkit" ) call python -c "from importlib.metadata import version; print('stable-diffusion version:', version('stable-diffusion-sdkit'))" diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 0e617e5e..4483dad9 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -99,7 +99,7 @@ fi python -c "from importlib.metadata import version; print('sdkit version:', version('sdkit'))" # upgrade stable-diffusion-sdkit -python -m pip install --upgrade stable-diffusion-sdkit -q +python -m pip install --upgrade stable-diffusion-sdkit==2.1.1 -q python -c "from importlib.metadata import version; print('stable-diffusion version:', version('stable-diffusion-sdkit'))" # install rich From b43f9fc4ee42b14b5b051eecfa3030c06a9198a9 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sat, 18 Feb 2023 14:30:37 +0530 Subject: [PATCH 26/27] Upgrade stable-diffusion-sdkit to 2.1.3, to use transformers 4.26.1 --- scripts/on_sd_start.bat | 2 +- scripts/on_sd_start.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index ba1a1d97..2f4cad92 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -113,7 +113,7 @@ if "%ERRORLEVEL%" EQU "0" ( call python -c "from importlib.metadata import version; print('sdkit version:', version('sdkit'))" @rem upgrade stable-diffusion-sdkit -call python -m pip install --upgrade stable-diffusion-sdkit==2.1.1 -q || ( +call python -m pip install --upgrade stable-diffusion-sdkit==2.1.3 -q || ( echo "Error updating stable-diffusion-sdkit" ) call python -c "from importlib.metadata import version; print('stable-diffusion version:', version('stable-diffusion-sdkit'))" diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 4483dad9..3afb19ba 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -99,7 +99,7 @@ fi python -c "from importlib.metadata import version; print('sdkit version:', version('sdkit'))" # upgrade stable-diffusion-sdkit -python -m pip install --upgrade stable-diffusion-sdkit==2.1.1 -q +python -m pip install --upgrade stable-diffusion-sdkit==2.1.3 -q python -c "from importlib.metadata import version; print('stable-diffusion version:', version('stable-diffusion-sdkit'))" # install rich From e7a2dfa57fd79ed338db7e559ea5d7ace9cf9a2c Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sat, 18 Feb 2023 14:31:39 +0530 Subject: [PATCH 27/27] changelog --- CHANGES.md | 1 + ui/index.html | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index 6b8be309..b2951d97 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -19,6 +19,7 @@ Our focus continues to remain on an easy installation experience, and an easy user-interface. While still remaining pretty powerful, in terms of features and speed. ### Detailed changelog +* 2.5.20 - 18 Feb 2023 - Upgrade the version of transformers used. * 2.5.19 - 17 Feb 2023 - Initial support for server-side plugins. Currently supports overriding the `get_cond_and_uncond()` function. * 2.5.18 - 17 Feb 2023 - 5 new samplers! UniPC samplers, some of which produce images in less than 15 steps. Thanks @Schorny. * 2.5.16 - 13 Feb 2023 - Searchable dropdown for models. This is useful if you have a LOT of models. You can type part of the model name, to auto-search through your models. Thanks @patriceac for the feature, and @AssassinJN for help in UI tweaks! diff --git a/ui/index.html b/ui/index.html index 09b718f3..bff7b128 100644 --- a/ui/index.html +++ b/ui/index.html @@ -26,7 +26,7 @@