diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index 4369a488..d556dd6f 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -7,7 +7,7 @@ import logging import shlex from rich.logging import RichHandler -from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config +from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config from easydiffusion import task_manager from easydiffusion.utils import log @@ -16,83 +16,86 @@ from easydiffusion.utils import log for handler in logging.root.handlers[:]: logging.root.removeHandler(handler) -LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s' +LOG_FORMAT = "%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s" logging.basicConfig( - level=logging.INFO, - format=LOG_FORMAT, - datefmt="%X", - handlers=[RichHandler(markup=True, rich_tracebacks=False, show_time=False, show_level=False)], + level=logging.INFO, + format=LOG_FORMAT, + datefmt="%X", + handlers=[RichHandler(markup=True, rich_tracebacks=False, show_time=False, show_level=False)], ) SD_DIR = os.getcwd() -SD_UI_DIR = os.getenv('SD_UI_PATH', None) +SD_UI_DIR = os.getenv("SD_UI_PATH", None) sys.path.append(os.path.dirname(SD_UI_DIR)) -CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) -MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) +CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts")) +MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models")) -USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui')) -CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui')) -UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user')) +USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins", "ui")) +CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins", "ui")) +UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, "core"), (USER_UI_PLUGINS_DIR, "user")) -OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder -PRESERVE_CONFIG_VARS = ['FORCE_FULL_PRECISION'] -TASK_TTL = 15 * 60 # Discard last session's task timeout +OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder +PRESERVE_CONFIG_VARS = ["FORCE_FULL_PRECISION"] +TASK_TTL = 15 * 60 # Discard last session's task timeout APP_CONFIG_DEFAULTS = { # auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device. - 'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index) - 'update_branch': 'main', - 'ui': { - 'open_browser_on_start': True, - } + "render_devices": "auto", # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index) + "update_branch": "main", + "ui": { + "open_browser_on_start": True, + }, } + def init(): os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) update_render_threads() + def getConfig(default_val=APP_CONFIG_DEFAULTS): try: - config_json_path = os.path.join(CONFIG_DIR, 'config.json') + config_json_path = os.path.join(CONFIG_DIR, "config.json") if not os.path.exists(config_json_path): config = default_val else: - with open(config_json_path, 'r', encoding='utf-8') as f: + with open(config_json_path, "r", encoding="utf-8") as f: config = json.load(f) - if 'net' not in config: - config['net'] = {} - if os.getenv('SD_UI_BIND_PORT') is not None: - config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT')) + if "net" not in config: + config["net"] = {} + if os.getenv("SD_UI_BIND_PORT") is not None: + config["net"]["listen_port"] = int(os.getenv("SD_UI_BIND_PORT")) else: - config['net']['listen_port'] = 9000 - if os.getenv('SD_UI_BIND_IP') is not None: - config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0') + config["net"]["listen_port"] = 9000 + if os.getenv("SD_UI_BIND_IP") is not None: + config["net"]["listen_to_network"] = os.getenv("SD_UI_BIND_IP") == "0.0.0.0" else: - config['net']['listen_to_network'] = True + config["net"]["listen_to_network"] = True return config except Exception as e: log.warn(traceback.format_exc()) return default_val + def setConfig(config): - try: # config.json - config_json_path = os.path.join(CONFIG_DIR, 'config.json') - with open(config_json_path, 'w', encoding='utf-8') as f: + try: # config.json + config_json_path = os.path.join(CONFIG_DIR, "config.json") + with open(config_json_path, "w", encoding="utf-8") as f: json.dump(config, f) except: log.error(traceback.format_exc()) - try: # config.bat - config_bat_path = os.path.join(CONFIG_DIR, 'config.bat') + try: # config.bat + config_bat_path = os.path.join(CONFIG_DIR, "config.bat") config_bat = [] - if 'update_branch' in config: + if "update_branch" in config: config_bat.append(f"@set update_branch={config['update_branch']}") config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}") - bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' + bind_ip = "0.0.0.0" if config["net"]["listen_to_network"] else "127.0.0.1" config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") # Preserve these variables if they are set @@ -101,20 +104,20 @@ def setConfig(config): config_bat.append(f"@set {var}={os.getenv(var)}") if len(config_bat) > 0: - with open(config_bat_path, 'w', encoding='utf-8') as f: - f.write('\r\n'.join(config_bat)) + with open(config_bat_path, "w", encoding="utf-8") as f: + f.write("\r\n".join(config_bat)) except: log.error(traceback.format_exc()) - try: # config.sh - config_sh_path = os.path.join(CONFIG_DIR, 'config.sh') - config_sh = ['#!/bin/bash'] + try: # config.sh + config_sh_path = os.path.join(CONFIG_DIR, "config.sh") + config_sh = ["#!/bin/bash"] - if 'update_branch' in config: + if "update_branch" in config: config_sh.append(f"export update_branch={config['update_branch']}") config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}") - bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' + bind_ip = "0.0.0.0" if config["net"]["listen_to_network"] else "127.0.0.1" config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") # Preserve these variables if they are set @@ -123,47 +126,51 @@ def setConfig(config): config_bat.append(f'export {var}="{shlex.quote(os.getenv(var))}"') if len(config_sh) > 1: - with open(config_sh_path, 'w', encoding='utf-8') as f: - f.write('\n'.join(config_sh)) + with open(config_sh_path, "w", encoding="utf-8") as f: + f.write("\n".join(config_sh)) except: log.error(traceback.format_exc()) + def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level): config = getConfig() - if 'model' not in config: - config['model'] = {} + if "model" not in config: + config["model"] = {} - config['model']['stable-diffusion'] = ckpt_model_name - config['model']['vae'] = vae_model_name - config['model']['hypernetwork'] = hypernetwork_model_name + config["model"]["stable-diffusion"] = ckpt_model_name + config["model"]["vae"] = vae_model_name + config["model"]["hypernetwork"] = hypernetwork_model_name if vae_model_name is None or vae_model_name == "": - del config['model']['vae'] + del config["model"]["vae"] if hypernetwork_model_name is None or hypernetwork_model_name == "": - del config['model']['hypernetwork'] + del config["model"]["hypernetwork"] - config['vram_usage_level'] = vram_usage_level + config["vram_usage_level"] = vram_usage_level setConfig(config) + def update_render_threads(): config = getConfig() - render_devices = config.get('render_devices', 'auto') - active_devices = task_manager.get_devices()['active'].keys() + render_devices = config.get("render_devices", "auto") + active_devices = task_manager.get_devices()["active"].keys() - log.debug(f'requesting for render_devices: {render_devices}') + log.debug(f"requesting for render_devices: {render_devices}") task_manager.update_render_threads(render_devices, active_devices) + def getUIPlugins(): plugins = [] for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES: for file in os.listdir(plugins_dir): - if file.endswith('.plugin.js'): - plugins.append(f'/plugins/{dir_prefix}/{file}') + if file.endswith(".plugin.js"): + plugins.append(f"/plugins/{dir_prefix}/{file}") return plugins + def getIPConfig(): try: ips = socket.gethostbyname_ex(socket.gethostname()) @@ -173,10 +180,13 @@ def getIPConfig(): log.exception(e) return [] + def open_browser(): config = getConfig() - ui = config.get('ui', {}) - net = config.get('net', {'listen_port':9000}) - port = net.get('listen_port', 9000) - if ui.get('open_browser_on_start', True): - import webbrowser; webbrowser.open(f"http://localhost:{port}") + ui = config.get("ui", {}) + net = config.get("net", {"listen_port": 9000}) + port = net.get("listen_port", 9000) + if ui.get("open_browser_on_start", True): + import webbrowser + + webbrowser.open(f"http://localhost:{port}") diff --git a/ui/easydiffusion/device_manager.py b/ui/easydiffusion/device_manager.py index 786512ea..69086577 100644 --- a/ui/easydiffusion/device_manager.py +++ b/ui/easydiffusion/device_manager.py @@ -5,45 +5,54 @@ import re from easydiffusion.utils import log -''' +""" Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32). Otherwise the models will load at half-precision (i.e. float16). Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs). -''' +""" -COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked +COMPARABLE_GPU_PERCENTILE = ( + 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked +) mem_free_threshold = 0 + def get_device_delta(render_devices, active_devices): - ''' + """ render_devices: 'cpu', or 'auto' or ['cuda:N'...] active_devices: ['cpu', 'cuda:N'...] - ''' + """ - if render_devices in ('cpu', 'auto'): + if render_devices in ("cpu", "auto"): render_devices = [render_devices] elif render_devices is not None: if isinstance(render_devices, str): render_devices = [render_devices] if isinstance(render_devices, list) and len(render_devices) > 0: - render_devices = list(filter(lambda x: x.startswith('cuda:'), render_devices)) + render_devices = list(filter(lambda x: x.startswith("cuda:"), render_devices)) if len(render_devices) == 0: - raise Exception('Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}') + raise Exception( + 'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}' + ) render_devices = list(filter(lambda x: is_device_compatible(x), render_devices)) if len(render_devices) == 0: - raise Exception('Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion') + raise Exception( + "Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion" + ) else: - raise Exception('Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}') + raise Exception( + 'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}' + ) else: - render_devices = ['auto'] + render_devices = ["auto"] - if 'auto' in render_devices: + if "auto" in render_devices: render_devices = auto_pick_devices(active_devices) - if 'cpu' in render_devices: - log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!') + if "cpu" in render_devices: + log.warn("WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!") active_devices = set(active_devices) render_devices = set(render_devices) @@ -53,19 +62,21 @@ def get_device_delta(render_devices, active_devices): return devices_to_start, devices_to_stop + def auto_pick_devices(currently_active_devices): global mem_free_threshold - if not torch.cuda.is_available(): return ['cpu'] + if not torch.cuda.is_available(): + return ["cpu"] device_count = torch.cuda.device_count() if device_count == 1: - return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu'] + return ["cuda:0"] if is_device_compatible("cuda:0") else ["cpu"] - log.debug('Autoselecting GPU. Using most free memory.') + log.debug("Autoselecting GPU. Using most free memory.") devices = [] for device in range(device_count): - device = f'cuda:{device}' + device = f"cuda:{device}" if not is_device_compatible(device): continue @@ -73,11 +84,13 @@ def auto_pick_devices(currently_active_devices): mem_free /= float(10**9) mem_total /= float(10**9) device_name = torch.cuda.get_device_name(device) - log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb') - devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free}) + log.debug( + f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb" + ) + devices.append({"device": device, "device_name": device_name, "mem_free": mem_free}) - devices.sort(key=lambda x:x['mem_free'], reverse=True) - max_mem_free = devices[0]['mem_free'] + devices.sort(key=lambda x: x["mem_free"], reverse=True) + max_mem_free = devices[0]["mem_free"] curr_mem_free_threshold = COMPARABLE_GPU_PERCENTILE * max_mem_free mem_free_threshold = max(curr_mem_free_threshold, mem_free_threshold) @@ -87,23 +100,26 @@ def auto_pick_devices(currently_active_devices): # always be very low (since their VRAM contains the model). # These already-running devices probably aren't terrible, since they were picked in the past. # Worst case, the user can restart the program and that'll get rid of them. - devices = list(filter((lambda x: x['mem_free'] > mem_free_threshold or x['device'] in currently_active_devices), devices)) - devices = list(map(lambda x: x['device'], devices)) + devices = list( + filter((lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices), devices) + ) + devices = list(map(lambda x: x["device"], devices)) return devices + def device_init(context, device): - ''' + """ This function assumes the 'device' has already been verified to be compatible. `get_device_delta()` has already filtered out incompatible devices. - ''' + """ - validate_device_id(device, log_prefix='device_init') + validate_device_id(device, log_prefix="device_init") - if device == 'cpu': - context.device = 'cpu' + if device == "cpu": + context.device = "cpu" context.device_name = get_processor_name() context.half_precision = False - log.debug(f'Render device CPU available as {context.device_name}') + log.debug(f"Render device CPU available as {context.device_name}") return context.device_name = torch.cuda.get_device_name(device) @@ -111,7 +127,7 @@ def device_init(context, device): # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images if needs_to_force_full_precision(context): - log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') + log.warn(f"forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}") # Apply force_full_precision now before models are loaded. context.half_precision = False @@ -120,72 +136,90 @@ def device_init(context, device): return + def needs_to_force_full_precision(context): - if 'FORCE_FULL_PRECISION' in os.environ: + if "FORCE_FULL_PRECISION" in os.environ: return True device_name = context.device_name.lower() - return (('nvidia' in device_name or 'geforce' in device_name or 'quadro' in device_name) and (' 1660' in device_name or ' 1650' in device_name or ' t400' in device_name or ' t550' in device_name or ' t600' in device_name or ' t1000' in device_name or ' t1200' in device_name or ' t2000' in device_name)) + return ("nvidia" in device_name or "geforce" in device_name or "quadro" in device_name) and ( + " 1660" in device_name + or " 1650" in device_name + or " t400" in device_name + or " t550" in device_name + or " t600" in device_name + or " t1000" in device_name + or " t1200" in device_name + or " t2000" in device_name + ) + def get_max_vram_usage_level(device): - if device != 'cpu': + if device != "cpu": _, mem_total = torch.cuda.mem_get_info(device) mem_total /= float(10**9) if mem_total < 4.5: - return 'low' + return "low" elif mem_total < 6.5: - return 'balanced' + return "balanced" - return 'high' + return "high" -def validate_device_id(device, log_prefix=''): + +def validate_device_id(device, log_prefix=""): def is_valid(): if not isinstance(device, str): return False - if device == 'cpu': + if device == "cpu": return True - if not device.startswith('cuda:') or not device[5:].isnumeric(): + if not device.startswith("cuda:") or not device[5:].isnumeric(): return False return True if not is_valid(): - raise EnvironmentError(f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}") + raise EnvironmentError( + f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}" + ) + def is_device_compatible(device): - ''' + """ Returns True/False, and prints any compatibility errors - ''' - # static variable "history". - is_device_compatible.history = getattr(is_device_compatible, 'history', {}) + """ + # static variable "history". + is_device_compatible.history = getattr(is_device_compatible, "history", {}) try: - validate_device_id(device, log_prefix='is_device_compatible') + validate_device_id(device, log_prefix="is_device_compatible") except: log.error(str(e)) return False - if device == 'cpu': return True + if device == "cpu": + return True # Memory check try: _, mem_total = torch.cuda.mem_get_info(device) mem_total /= float(10**9) if mem_total < 3.0: if is_device_compatible.history.get(device) == None: - log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion') - is_device_compatible.history[device] = 1 + log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion") + is_device_compatible.history[device] = 1 return False except RuntimeError as e: log.error(str(e)) return False return True + def get_processor_name(): try: import platform, subprocess + if platform.system() == "Windows": return platform.processor() elif platform.system() == "Darwin": - os.environ['PATH'] = os.environ['PATH'] + os.pathsep + '/usr/sbin' + os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin" command = "sysctl -n machdep.cpu.brand_string" return subprocess.check_output(command).strip() elif platform.system() == "Linux": diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 9dd928e1..fb380695 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -8,29 +8,31 @@ from sdkit import Context from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model from sdkit.utils import hash_file_quick -KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] +KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan"] MODEL_EXTENSIONS = { - 'stable-diffusion': ['.ckpt', '.safetensors'], - 'vae': ['.vae.pt', '.ckpt', '.safetensors'], - 'hypernetwork': ['.pt', '.safetensors'], - 'gfpgan': ['.pth'], - 'realesrgan': ['.pth'], + "stable-diffusion": [".ckpt", ".safetensors"], + "vae": [".vae.pt", ".ckpt", ".safetensors"], + "hypernetwork": [".pt", ".safetensors"], + "gfpgan": [".pth"], + "realesrgan": [".pth"], } DEFAULT_MODELS = { - 'stable-diffusion': [ # needed to support the legacy installations - 'custom-model', # only one custom model file was supported initially, creatively named 'custom-model' - 'sd-v1-4', # Default fallback. + "stable-diffusion": [ # needed to support the legacy installations + "custom-model", # only one custom model file was supported initially, creatively named 'custom-model' + "sd-v1-4", # Default fallback. ], - 'gfpgan': ['GFPGANv1.3'], - 'realesrgan': ['RealESRGAN_x4plus'], + "gfpgan": ["GFPGANv1.3"], + "realesrgan": ["RealESRGAN_x4plus"], } -MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork'] +MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork"] known_models = {} + def init(): make_model_folders() - getModels() # run this once, to cache the picklescan results + getModels() # run this once, to cache the picklescan results + def load_default_models(context: Context): set_vram_optimizations(context) @@ -39,27 +41,28 @@ def load_default_models(context: Context): for model_type in MODELS_TO_LOAD_ON_START: context.model_paths[model_type] = resolve_model_to_use(model_type=model_type) try: - load_model(context, model_type) + load_model(context, model_type) except Exception as e: - log.error(f'[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]') - log.error(f'[red]Error: {e}[/red]') - log.error(f'[red]Consider removing the model from the model folder.[red]') + log.error(f"[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]") + log.error(f"[red]Error: {e}[/red]") + log.error(f"[red]Consider removing the model from the model folder.[red]") def unload_all(context: Context): for model_type in KNOWN_MODEL_TYPES: unload_model(context, model_type) -def resolve_model_to_use(model_name:str=None, model_type:str=None): + +def resolve_model_to_use(model_name: str = None, model_type: str = None): model_extensions = MODEL_EXTENSIONS.get(model_type, []) default_models = DEFAULT_MODELS.get(model_type, []) config = app.getConfig() model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR] - if not model_name: # When None try user configured model. + if not model_name: # When None try user configured model. # config = getConfig() - if 'model' in config and model_type in config['model']: - model_name = config['model'][model_type] + if "model" in config and model_type in config["model"]: + model_name = config["model"][model_type] if model_name: # Check models directory @@ -84,41 +87,54 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None): for model_extension in model_extensions: if os.path.exists(default_model_path + model_extension): if model_name is not None: - log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') + log.warn( + f"Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}" + ) return default_model_path + model_extension return None + def reload_models_if_necessary(context: Context, task_data: TaskData): model_paths_in_req = { - 'stable-diffusion': task_data.use_stable_diffusion_model, - 'vae': task_data.use_vae_model, - 'hypernetwork': task_data.use_hypernetwork_model, - 'gfpgan': task_data.use_face_correction, - 'realesrgan': task_data.use_upscale, + "stable-diffusion": task_data.use_stable_diffusion_model, + "vae": task_data.use_vae_model, + "hypernetwork": task_data.use_hypernetwork_model, + "gfpgan": task_data.use_face_correction, + "realesrgan": task_data.use_upscale, + } + models_to_reload = { + model_type: path + for model_type, path in model_paths_in_req.items() + if context.model_paths.get(model_type) != path } - models_to_reload = {model_type: path for model_type, path in model_paths_in_req.items() if context.model_paths.get(model_type) != path} - if set_vram_optimizations(context): # reload SD - models_to_reload['stable-diffusion'] = model_paths_in_req['stable-diffusion'] + if set_vram_optimizations(context): # reload SD + models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"] for model_type, model_path_in_req in models_to_reload.items(): context.model_paths[model_type] = model_path_in_req action_fn = unload_model if context.model_paths[model_type] is None else load_model - action_fn(context, model_type, scan_model=False) # we've scanned them already + action_fn(context, model_type, scan_model=False) # we've scanned them already + def resolve_model_paths(task_data: TaskData): - task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion') - task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae') - task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork') + task_data.use_stable_diffusion_model = resolve_model_to_use( + task_data.use_stable_diffusion_model, model_type="stable-diffusion" + ) + task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae") + task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork") + + if task_data.use_face_correction: + task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, "gfpgan") + if task_data.use_upscale: + task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, "realesrgan") - if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan') - if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan') def set_vram_optimizations(context: Context): config = app.getConfig() - vram_usage_level = config.get('vram_usage_level', 'balanced') + vram_usage_level = config.get("vram_usage_level", "balanced") if vram_usage_level != context.vram_usage_level: context.vram_usage_level = vram_usage_level @@ -126,42 +142,51 @@ def set_vram_optimizations(context: Context): return False + def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: model_dir_path = os.path.join(app.MODELS_DIR, model_type) os.makedirs(model_dir_path, exist_ok=True) - help_file_name = f'Place your {model_type} model files here.txt' + help_file_name = f"Place your {model_type} model files here.txt" help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}' - with open(os.path.join(model_dir_path, help_file_name), 'w', encoding='utf-8') as f: + with open(os.path.join(model_dir_path, help_file_name), "w", encoding="utf-8") as f: f.write(help_file_contents) + def is_malicious_model(file_path): try: scan_result = scan_model(file_path) if scan_result.issues_count > 0 or scan_result.infected_files > 0: - log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + log.warn( + ":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" + % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files) + ) return True else: - log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + log.debug( + "Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" + % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files) + ) return False except Exception as e: - log.error(f'error while scanning: {file_path}, error: {e}') + log.error(f"error while scanning: {file_path}, error: {e}") return False + def getModels(): models = { - 'active': { - 'stable-diffusion': 'sd-v1-4', - 'vae': '', - 'hypernetwork': '', + "active": { + "stable-diffusion": "sd-v1-4", + "vae": "", + "hypernetwork": "", }, - 'options': { - 'stable-diffusion': ['sd-v1-4'], - 'vae': [], - 'hypernetwork': [], + "options": { + "stable-diffusion": ["sd-v1-4"], + "vae": [], + "hypernetwork": [], }, } @@ -171,13 +196,16 @@ def getModels(): "Raised when picklescan reports a problem with a model" pass - def scan_directory(directory, suffixes, directoriesFirst:bool=True): + def scan_directory(directory, suffixes, directoriesFirst: bool = True): nonlocal models_scanned tree = [] - for entry in sorted(os.scandir(directory), key = lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower())): + for entry in sorted( + os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()) + ): if entry.is_file(): matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes)) - if len(matching_suffix) == 0: continue + if len(matching_suffix) == 0: + continue matching_suffix = matching_suffix[0] mtime = entry.stat().st_mtime @@ -187,12 +215,12 @@ def getModels(): if is_malicious_model(entry.path): raise MaliciousModelException(entry.path) known_models[entry.path] = mtime - tree.append(entry.name[:-len(matching_suffix)]) + tree.append(entry.name[: -len(matching_suffix)]) elif entry.is_dir(): - scan=scan_directory(entry.path, suffixes, directoriesFirst=False) + scan = scan_directory(entry.path, suffixes, directoriesFirst=False) if len(scan) != 0: - tree.append( (entry.name, scan ) ) + tree.append((entry.name, scan)) return tree def listModels(model_type): @@ -204,21 +232,22 @@ def getModels(): os.makedirs(models_dir) try: - models['options'][model_type] = scan_directory(models_dir, model_extensions) + models["options"][model_type] = scan_directory(models_dir, model_extensions) except MaliciousModelException as e: - models['scan-error'] = e + models["scan-error"] = e # custom models - listModels(model_type='stable-diffusion') - listModels(model_type='vae') - listModels(model_type='hypernetwork') - listModels(model_type='gfpgan') + listModels(model_type="stable-diffusion") + listModels(model_type="vae") + listModels(model_type="hypernetwork") + listModels(model_type="gfpgan") - if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. Nothing infected[/]') + if models_scanned > 0: + log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") # legacy - custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') + custom_weight_path = os.path.join(app.SD_DIR, "custom-model.ckpt") if os.path.exists(custom_weight_path): - models['options']['stable-diffusion'].append('custom-model') + models["options"]["stable-diffusion"].append("custom-model") return models diff --git a/ui/easydiffusion/renderer.py b/ui/easydiffusion/renderer.py index 672614ba..3b8adaa7 100644 --- a/ui/easydiffusion/renderer.py +++ b/ui/easydiffusion/renderer.py @@ -12,22 +12,26 @@ from sdkit.generate import generate_images from sdkit.filter import apply_filters from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc -context = Context() # thread-local -''' +context = Context() # thread-local +""" runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc -''' +""" + def init(device): - ''' + """ Initializes the fields that will be bound to this runtime's context, and sets the current torch device - ''' + """ context.stop_processing = False context.temp_images = {} context.partial_x_samples = None device_manager.device_init(context, device) -def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): + +def make_images( + req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback +): context.stop_processing = False print_task_info(req, task_data) @@ -36,18 +40,24 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu res = Response(req, task_data, images=construct_response(images, seeds, task_data, base_seed=req.seed)) res = res.json() data_queue.put(json.dumps(res)) - log.info('Task completed') + log.info("Task completed") return res -def print_task_info(req: GenerateImageRequest, task_data: TaskData): - req_str = pprint.pformat(get_printable_request(req)).replace("[","\[") - task_str = pprint.pformat(task_data.dict()).replace("[","\[") - log.info(f'request: {req_str}') - log.info(f'task data: {task_str}') -def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): - images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) +def print_task_info(req: GenerateImageRequest, task_data: TaskData): + req_str = pprint.pformat(get_printable_request(req)).replace("[", "\[") + task_str = pprint.pformat(task_data.dict()).replace("[", "\[") + log.info(f"request: {req_str}") + log.info(f"task data: {task_str}") + + +def make_images_internal( + req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback +): + images, user_stopped = generate_images_internal( + req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress + ) filtered_images = filter_images(task_data, images, user_stopped) if task_data.save_to_disk_path is not None: @@ -59,13 +69,22 @@ def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_qu else: return images + filtered_images, seeds + seeds -def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): + +def generate_images_internal( + req: GenerateImageRequest, + task_data: TaskData, + data_queue: queue.Queue, + task_temp_images: list, + step_callback, + stream_image_progress: bool, +): context.temp_images.clear() callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress) try: - if req.init_image is not None: req.sampler_name = 'ddim' + if req.init_image is not None: + req.sampler_name = "ddim" images = generate_images(context, callback=callback, **req.dict()) user_stopped = False @@ -75,31 +94,44 @@ def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, dat if context.partial_x_samples is not None: images = latent_samples_to_images(context, context.partial_x_samples) finally: - if hasattr(context, 'partial_x_samples') and context.partial_x_samples is not None: + if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None: del context.partial_x_samples context.partial_x_samples = None return images, user_stopped + def filter_images(task_data: TaskData, images: list, user_stopped): if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None): return images filters_to_apply = [] - if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan') - if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan') + if task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower(): + filters_to_apply.append("gfpgan") + if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower(): + filters_to_apply.append("realesrgan") return apply_filters(context, filters_to_apply, images, scale=task_data.upscale_amount) + def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int): return [ ResponseImage( data=img_to_base64_str(img, task_data.output_format, task_data.output_quality), seed=seed, - ) for img, seed in zip(images, seeds) + ) + for img, seed in zip(images, seeds) ] -def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): + +def make_step_callback( + req: GenerateImageRequest, + task_data: TaskData, + data_queue: queue.Queue, + task_temp_images: list, + step_callback, + stream_image_progress: bool, +): n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) last_callback_time = -1 @@ -107,11 +139,11 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu partial_images = [] images = latent_samples_to_images(context, x_samples) for i, img in enumerate(images): - buf = img_to_buffer(img, output_format='JPEG') + buf = img_to_buffer(img, output_format="JPEG") context.temp_images[f"{task_data.request_id}/{i}"] = buf task_temp_images[i] = buf - partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"}) + partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"}) del images return partial_images @@ -125,7 +157,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu progress = {"step": i, "step_time": step_time, "total_steps": n_steps} if stream_image_progress and i % 5 == 0: - progress['output'] = update_temp_img(x_samples, task_temp_images) + progress["output"] = update_temp_img(x_samples, task_temp_images) data_queue.put(json.dumps(progress)) diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index bbdca9c9..c04189c1 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -16,21 +16,25 @@ from easydiffusion import app, model_manager, task_manager from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest from easydiffusion.utils import log -log.info(f'started in {app.SD_DIR}') -log.info(f'started at {datetime.datetime.now():%x %X}') +log.info(f"started in {app.SD_DIR}") +log.info(f"started at {datetime.datetime.now():%x %X}") server_api = FastAPI() -NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} +NOCACHE_HEADERS = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} + class NoCacheStaticFiles(StaticFiles): def is_not_modified(self, response_headers, request_headers) -> bool: - if 'content-type' in response_headers and ('javascript' in response_headers['content-type'] or 'css' in response_headers['content-type']): + if "content-type" in response_headers and ( + "javascript" in response_headers["content-type"] or "css" in response_headers["content-type"] + ): response_headers.update(NOCACHE_HEADERS) return False return super().is_not_modified(response_headers, request_headers) + class SetAppConfigRequest(BaseModel): update_branch: str = None render_devices: Union[List[str], List[int], str, int] = None @@ -39,130 +43,142 @@ class SetAppConfigRequest(BaseModel): listen_to_network: bool = None listen_port: int = None + def init(): - server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media") + server_api.mount("/media", NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")), name="media") for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES: - server_api.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}") + server_api.mount( + f"/plugins/{dir_prefix}", NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}" + ) - @server_api.post('/app_config') - async def set_app_config(req : SetAppConfigRequest): + @server_api.post("/app_config") + async def set_app_config(req: SetAppConfigRequest): return set_app_config_internal(req) - @server_api.get('/get/{key:path}') - def read_web_data(key:str=None): + @server_api.get("/get/{key:path}") + def read_web_data(key: str = None): return read_web_data_internal(key) - @server_api.get('/ping') # Get server and optionally session status. - def ping(session_id:str=None): + @server_api.get("/ping") # Get server and optionally session status. + def ping(session_id: str = None): return ping_internal(session_id) - @server_api.post('/render') + @server_api.post("/render") def render(req: dict): return render_internal(req) - @server_api.post('/model/merge') + @server_api.post("/model/merge") def model_merge(req: dict): print(req) return model_merge_internal(req) - @server_api.get('/image/stream/{task_id:int}') - def stream(task_id:int): + @server_api.get("/image/stream/{task_id:int}") + def stream(task_id: int): return stream_internal(task_id) - @server_api.get('/image/stop') + @server_api.get("/image/stop") def stop(task: int): return stop_internal(task) - @server_api.get('/image/tmp/{task_id:int}/{img_id:int}') + @server_api.get("/image/tmp/{task_id:int}/{img_id:int}") def get_image(task_id: int, img_id: int): return get_image_internal(task_id, img_id) - @server_api.get('/') + @server_api.get("/") def read_root(): - return FileResponse(os.path.join(app.SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) + return FileResponse(os.path.join(app.SD_UI_DIR, "index.html"), headers=NOCACHE_HEADERS) @server_api.on_event("shutdown") - def shutdown_event(): # Signal render thread to close on shutdown - task_manager.current_state_error = SystemExit('Application shutting down.') + def shutdown_event(): # Signal render thread to close on shutdown + task_manager.current_state_error = SystemExit("Application shutting down.") + # API implementations -def set_app_config_internal(req : SetAppConfigRequest): +def set_app_config_internal(req: SetAppConfigRequest): config = app.getConfig() if req.update_branch is not None: - config['update_branch'] = req.update_branch + config["update_branch"] = req.update_branch if req.render_devices is not None: update_render_devices_in_config(config, req.render_devices) if req.ui_open_browser_on_start is not None: - if 'ui' not in config: - config['ui'] = {} - config['ui']['open_browser_on_start'] = req.ui_open_browser_on_start + if "ui" not in config: + config["ui"] = {} + config["ui"]["open_browser_on_start"] = req.ui_open_browser_on_start if req.listen_to_network is not None: - if 'net' not in config: - config['net'] = {} - config['net']['listen_to_network'] = bool(req.listen_to_network) + if "net" not in config: + config["net"] = {} + config["net"]["listen_to_network"] = bool(req.listen_to_network) if req.listen_port is not None: - if 'net' not in config: - config['net'] = {} - config['net']['listen_port'] = int(req.listen_port) + if "net" not in config: + config["net"] = {} + config["net"]["listen_port"] = int(req.listen_port) try: app.setConfig(config) if req.render_devices: app.update_render_threads() - return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS) + return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS) except Exception as e: log.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + def update_render_devices_in_config(config, render_devices): - if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'): - raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}') + if render_devices not in ("cpu", "auto") and not render_devices.startswith("cuda:"): + raise HTTPException(status_code=400, detail=f"Invalid render device requested: {render_devices}") - if render_devices.startswith('cuda:'): - render_devices = render_devices.split(',') + if render_devices.startswith("cuda:"): + render_devices = render_devices.split(",") - config['render_devices'] = render_devices + config["render_devices"] = render_devices -def read_web_data_internal(key:str=None): - if not key: # /get without parameters, stable-diffusion easter egg. - raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot - elif key == 'app_config': + +def read_web_data_internal(key: str = None): + if not key: # /get without parameters, stable-diffusion easter egg. + raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot + elif key == "app_config": return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS) - elif key == 'system_info': + elif key == "system_info": config = app.getConfig() - output_dir = config.get('force_save_path', os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME)) + output_dir = config.get("force_save_path", os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME)) system_info = { - 'devices': task_manager.get_devices(), - 'hosts': app.getIPConfig(), - 'default_output_dir': output_dir, - 'enforce_output_dir': ('force_save_path' in config), + "devices": task_manager.get_devices(), + "hosts": app.getIPConfig(), + "default_output_dir": output_dir, + "enforce_output_dir": ("force_save_path" in config), } - system_info['devices']['config'] = config.get('render_devices', "auto") + system_info["devices"]["config"] = config.get("render_devices", "auto") return JSONResponse(system_info, headers=NOCACHE_HEADERS) - elif key == 'models': + elif key == "models": return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS) - elif key == 'modifiers': return FileResponse(os.path.join(app.SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) - elif key == 'ui_plugins': return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS) + elif key == "modifiers": + return FileResponse(os.path.join(app.SD_UI_DIR, "modifiers.json"), headers=NOCACHE_HEADERS) + elif key == "ui_plugins": + return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS) else: - raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found + raise HTTPException(status_code=404, detail=f"Request for unknown {key}") # HTTP404 Not Found -def ping_internal(session_id:str=None): - if task_manager.is_alive() <= 0: # Check that render threads are alive. - if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) - raise HTTPException(status_code=500, detail='Render thread is dead.') - if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) + +def ping_internal(session_id: str = None): + if task_manager.is_alive() <= 0: # Check that render threads are alive. + if task_manager.current_state_error: + raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) + raise HTTPException(status_code=500, detail="Render thread is dead.") + if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): + raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) # Alive - response = {'status': str(task_manager.current_state)} + response = {"status": str(task_manager.current_state)} if session_id: session = task_manager.get_cached_session(session_id, update_ttl=True) - response['tasks'] = {id(t): t.status for t in session.tasks} - response['devices'] = task_manager.get_devices() + response["tasks"] = {id(t): t.status for t in session.tasks} + response["devices"] = task_manager.get_devices() return JSONResponse(response, headers=NOCACHE_HEADERS) + def render_internal(req: dict): try: # separate out the request data into rendering and task-specific data @@ -171,80 +187,99 @@ def render_internal(req: dict): # Overwrite user specified save path config = app.getConfig() - if 'force_save_path' in config: - task_data.save_to_disk_path = config['force_save_path'] + if "force_save_path" in config: + task_data.save_to_disk_path = config["force_save_path"] - render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision + render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision - app.save_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model, task_data.vram_usage_level) + app.save_to_config( + task_data.use_stable_diffusion_model, + task_data.use_vae_model, + task_data.use_hypernetwork_model, + task_data.vram_usage_level, + ) # enqueue the task new_task = task_manager.render(render_req, task_data) response = { - 'status': str(task_manager.current_state), - 'queue': len(task_manager.tasks_queue), - 'stream': f'/image/stream/{id(new_task)}', - 'task': id(new_task) + "status": str(task_manager.current_state), + "queue": len(task_manager.tasks_queue), + "stream": f"/image/stream/{id(new_task)}", + "task": id(new_task), } return JSONResponse(response, headers=NOCACHE_HEADERS) - except ChildProcessError as e: # Render thread is dead - raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error - except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. - raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable + except ChildProcessError as e: # Render thread is dead + raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error + except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. + raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable except Exception as e: log.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) + def model_merge_internal(req: dict): try: from sdkit.train import merge_models from easydiffusion.utils.save_utils import filename_regex + mergeReq: MergeRequest = MergeRequest.parse_obj(req) - - merge_models(model_manager.resolve_model_to_use(mergeReq.model0,'stable-diffusion'), - model_manager.resolve_model_to_use(mergeReq.model1,'stable-diffusion'), - mergeReq.ratio, - os.path.join(app.MODELS_DIR, 'stable-diffusion', filename_regex.sub('_', mergeReq.out_path)), - mergeReq.use_fp16 + + merge_models( + model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"), + model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"), + mergeReq.ratio, + os.path.join(app.MODELS_DIR, "stable-diffusion", filename_regex.sub("_", mergeReq.out_path)), + mergeReq.use_fp16, ) - return JSONResponse({'status':'OK'}, headers=NOCACHE_HEADERS) + return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS) except Exception as e: log.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) -def stream_internal(task_id:int): - #TODO Move to WebSockets ?? + +def stream_internal(task_id: int): + # TODO Move to WebSockets ?? task = task_manager.get_cached_task(task_id, update_ttl=True) - if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound - #if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict + if not task: + raise HTTPException(status_code=404, detail=f"Request {task_id} not found.") # HTTP404 NotFound + # if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict if task.buffer_queue.empty() and not task.lock.locked(): if task.response: - #log.info(f'Session {session_id} sending cached response') + # log.info(f'Session {session_id} sending cached response') return JSONResponse(task.response, headers=NOCACHE_HEADERS) - raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early - #log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') - return StreamingResponse(task.read_buffer_generator(), media_type='application/json') + raise HTTPException(status_code=425, detail="Too Early, task not started yet.") # HTTP425 Too Early + # log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') + return StreamingResponse(task.read_buffer_generator(), media_type="application/json") + def stop_internal(task: int): if not task: - if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable: - raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict - task_manager.current_state_error = StopAsyncIteration('') - return {'OK'} + if ( + task_manager.current_state == task_manager.ServerStates.Online + or task_manager.current_state == task_manager.ServerStates.Unavailable + ): + raise HTTPException(status_code=409, detail="Not currently running any tasks.") # HTTP409 Conflict + task_manager.current_state_error = StopAsyncIteration("") + return {"OK"} task_id = task task = task_manager.get_cached_task(task_id, update_ttl=False) - if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found - if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict - task.error = StopAsyncIteration(f'Task {task_id} stop requested.') - return {'OK'} + if not task: + raise HTTPException(status_code=404, detail=f"Task {task_id} was not found.") # HTTP404 Not Found + if isinstance(task.error, StopAsyncIteration): + raise HTTPException(status_code=409, detail=f"Task {task_id} is already stopped.") # HTTP409 Conflict + task.error = StopAsyncIteration(f"Task {task_id} stop requested.") + return {"OK"} + def get_image_internal(task_id: int, img_id: int): task = task_manager.get_cached_task(task_id, update_ttl=True) - if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound - if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early + if not task: + raise HTTPException(status_code=410, detail=f"Task {task_id} could not be found.") # HTTP404 NotFound + if not task.temp_images[img_id]: + raise HTTPException(status_code=425, detail="Too Early, task data is not available yet.") # HTTP425 Too Early try: img_data = task.temp_images[img_id] img_data.seek(0) - return StreamingResponse(img_data, media_type='image/jpeg') + return StreamingResponse(img_data, media_type="image/jpeg") except KeyError as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/ui/easydiffusion/task_manager.py b/ui/easydiffusion/task_manager.py index 213f5a96..28f84963 100644 --- a/ui/easydiffusion/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -7,7 +7,7 @@ Notes: import json import traceback -TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout +TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout import torch import queue, threading, time, weakref @@ -19,71 +19,98 @@ from easydiffusion.utils import log from sdkit.utils import gc -THREAD_NAME_PREFIX = '' -ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' -LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. +THREAD_NAME_PREFIX = "" +ERR_LOCK_FAILED = " failed to acquire lock within timeout." +LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. -DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init. +DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init. + + +class SymbolClass(type): # Print nicely formatted Symbol names. + def __repr__(self): + return self.__qualname__ + + def __str__(self): + return self.__name__ + + +class Symbol(metaclass=SymbolClass): + pass -class SymbolClass(type): # Print nicely formatted Symbol names. - def __repr__(self): return self.__qualname__ - def __str__(self): return self.__name__ -class Symbol(metaclass=SymbolClass): pass class ServerStates: - class Init(Symbol): pass - class LoadingModel(Symbol): pass - class Online(Symbol): pass - class Rendering(Symbol): pass - class Unavailable(Symbol): pass + class Init(Symbol): + pass -class RenderTask(): # Task with output queue and completion lock. + class LoadingModel(Symbol): + pass + + class Online(Symbol): + pass + + class Rendering(Symbol): + pass + + class Unavailable(Symbol): + pass + + +class RenderTask: # Task with output queue and completion lock. def __init__(self, req: GenerateImageRequest, task_data: TaskData): task_data.request_id = id(self) self.render_request: GenerateImageRequest = req # Initial Request self.task_data: TaskData = task_data - self.response: Any = None # Copy of the last reponse - self.render_device = None # Select the task affinity. (Not used to change active devices). - self.temp_images:list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2) + self.response: Any = None # Copy of the last reponse + self.render_device = None # Select the task affinity. (Not used to change active devices). + self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2) self.error: Exception = None - self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed - self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments + self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed + self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments + async def read_buffer_generator(self): try: while not self.buffer_queue.empty(): res = self.buffer_queue.get(block=False) self.buffer_queue.task_done() yield res - except queue.Empty as e: yield + except queue.Empty as e: + yield + @property def status(self): if self.lock.locked(): - return 'running' + return "running" if isinstance(self.error, StopAsyncIteration): - return 'stopped' + return "stopped" if self.error: - return 'error' + return "error" if not self.buffer_queue.empty(): - return 'buffer' + return "buffer" if self.response: - return 'completed' - return 'pending' + return "completed" + return "pending" + @property def is_pending(self): return bool(not self.response and not self.error) + # Temporary cache to allow to query tasks results for a short time after they are completed. -class DataCache(): +class DataCache: def __init__(self): self._base = dict() self._lock: threading.Lock = threading.Lock() + def _get_ttl_time(self, ttl: int) -> int: return int(time.time()) + ttl + def _is_expired(self, timestamp: int) -> bool: return int(time.time()) >= timestamp + def clean(self) -> None: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clean' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("DataCache.clean" + ERR_LOCK_FAILED) try: # Create a list of expired keys to delete to_delete = [] @@ -95,20 +122,26 @@ class DataCache(): for key in to_delete: (_, val) = self._base[key] if isinstance(val, RenderTask): - log.debug(f'RenderTask {key} expired. Data removed.') + log.debug(f"RenderTask {key} expired. Data removed.") elif isinstance(val, SessionState): - log.debug(f'Session {key} expired. Data removed.') + log.debug(f"Session {key} expired. Data removed.") else: - log.debug(f'Key {key} expired. Data removed.') + log.debug(f"Key {key} expired. Data removed.") del self._base[key] finally: self._lock.release() + def clear(self) -> None: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clear' + ERR_LOCK_FAILED) - try: self._base.clear() - finally: self._lock.release() + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("DataCache.clear" + ERR_LOCK_FAILED) + try: + self._base.clear() + finally: + self._lock.release() + def delete(self, key: Hashable) -> bool: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.delete' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("DataCache.delete" + ERR_LOCK_FAILED) try: if key not in self._base: return False @@ -116,8 +149,10 @@ class DataCache(): return True finally: self._lock.release() + def keep(self, key: Hashable, ttl: int) -> bool: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.keep' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("DataCache.keep" + ERR_LOCK_FAILED) try: if key in self._base: _, value = self._base.get(key) @@ -126,12 +161,12 @@ class DataCache(): return False finally: self._lock.release() + def put(self, key: Hashable, value: Any, ttl: int) -> bool: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.put' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("DataCache.put" + ERR_LOCK_FAILED) try: - self._base[key] = ( - self._get_ttl_time(ttl), value - ) + self._base[key] = (self._get_ttl_time(ttl), value) except Exception as e: log.error(traceback.format_exc()) return False @@ -139,35 +174,41 @@ class DataCache(): return True finally: self._lock.release() + def tryGet(self, key: Hashable) -> Any: - if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.tryGet' + ERR_LOCK_FAILED) + if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("DataCache.tryGet" + ERR_LOCK_FAILED) try: ttl, value = self._base.get(key, (None, None)) if ttl is not None and self._is_expired(ttl): - log.debug(f'Session {key} expired. Discarding data.') + log.debug(f"Session {key} expired. Discarding data.") del self._base[key] return None return value finally: self._lock.release() + manager_lock = threading.RLock() render_threads = [] current_state = ServerStates.Init -current_state_error:Exception = None +current_state_error: Exception = None tasks_queue = [] session_cache = DataCache() task_cache = DataCache() weak_thread_data = weakref.WeakKeyDictionary() idle_event: threading.Event = threading.Event() -class SessionState(): + +class SessionState: def __init__(self, id: str): self._id = id self._tasks_ids = [] + @property def id(self): return self._id + @property def tasks(self): tasks = [] @@ -176,6 +217,7 @@ class SessionState(): if task: tasks.append(task) return tasks + def put(self, task, ttl=TASK_TTL): task_id = id(task) self._tasks_ids.append(task_id) @@ -185,10 +227,12 @@ class SessionState(): self._tasks_ids.pop(0) return True + def thread_get_next_task(): from easydiffusion import renderer + if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): - log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.') + log.warn(f"Render thread on device: {renderer.context.device} failed to acquire manager lock.") return None if len(tasks_queue) <= 0: manager_lock.release() @@ -202,10 +246,10 @@ def thread_get_next_task(): continue # requested device alive, skip current one. else: # Requested device is not active, return error to UI. - queued_task.error = Exception(queued_task.render_device + ' is not currently active.') + queued_task.error = Exception(queued_task.render_device + " is not currently active.") task = queued_task break - if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1: + if not queued_task.render_device and renderer.context.device == "cpu" and is_alive() > 1: # not asking for any specific devices, cpu want to grab task but other render devices are alive. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. task = queued_task @@ -216,17 +260,19 @@ def thread_get_next_task(): finally: manager_lock.release() + def thread_render(device): global current_state, current_state_error from easydiffusion import renderer, model_manager + try: renderer.init(device) weak_thread_data[threading.current_thread()] = { - 'device': renderer.context.device, - 'device_name': renderer.context.device_name, - 'alive': True + "device": renderer.context.device, + "device_name": renderer.context.device_name, + "alive": True, } current_state = ServerStates.LoadingModel @@ -235,17 +281,14 @@ def thread_render(device): current_state = ServerStates.Online except Exception as e: log.error(traceback.format_exc()) - weak_thread_data[threading.current_thread()] = { - 'error': e, - 'alive': False - } + weak_thread_data[threading.current_thread()] = {"error": e, "alive": False} return while True: session_cache.clean() task_cache.clean() - if not weak_thread_data[threading.current_thread()]['alive']: - log.info(f'Shutting down thread for device {renderer.context.device}') + if not weak_thread_data[threading.current_thread()]["alive"]: + log.info(f"Shutting down thread for device {renderer.context.device}") model_manager.unload_all(renderer.context) return if isinstance(current_state_error, SystemExit): @@ -258,39 +301,47 @@ def thread_render(device): continue if task.error is not None: log.error(task.error) - task.response = {"status": 'failed', "detail": str(task.error)} + task.response = {"status": "failed", "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue if current_state_error: task.error = current_state_error - task.response = {"status": 'failed', "detail": str(task.error)} + task.response = {"status": "failed", "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue - log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}') - if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') + log.info(f"Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}") + if not task.lock.acquire(blocking=False): + raise Exception("Got locked task from queue.") try: + def step_callback(): global current_state_error - if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): + if ( + isinstance(current_state_error, SystemExit) + or isinstance(current_state_error, StopAsyncIteration) + or isinstance(task.error, StopAsyncIteration) + ): renderer.context.stop_processing = True if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None - log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}') + log.info(f"Session {task.task_data.session_id} sent cancel signal for task {id(task)}") current_state = ServerStates.LoadingModel model_manager.resolve_model_paths(task.task_data) model_manager.reload_models_if_necessary(renderer.context, task.task_data) current_state = ServerStates.Rendering - task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) + task.response = renderer.make_images( + task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback + ) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL) except Exception as e: task.error = str(e) - task.response = {"status": 'failed', "detail": str(task.error)} + task.response = {"status": "failed", "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) log.error(traceback.format_exc()) finally: @@ -299,21 +350,25 @@ def thread_render(device): task_cache.keep(id(task), TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL) if isinstance(task.error, StopAsyncIteration): - log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!') + log.info(f"Session {task.task_data.session_id} task {id(task)} cancelled!") elif task.error is not None: - log.info(f'Session {task.task_data.session_id} task {id(task)} failed!') + log.info(f"Session {task.task_data.session_id} task {id(task)} failed!") else: - log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.') + log.info( + f"Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}." + ) current_state = ServerStates.Online -def get_cached_task(task_id:str, update_ttl:bool=False): + +def get_cached_task(task_id: str, update_ttl: bool = False): # By calling keep before tryGet, wont discard if was expired. if update_ttl and not task_cache.keep(task_id, TASK_TTL): # Failed to keep task, already gone. return None return task_cache.tryGet(task_id) -def get_cached_session(session_id:str, update_ttl:bool=False): + +def get_cached_session(session_id: str, update_ttl: bool = False): if update_ttl: session_cache.keep(session_id, TASK_TTL) session = session_cache.tryGet(session_id) @@ -322,64 +377,68 @@ def get_cached_session(session_id:str, update_ttl:bool=False): session_cache.put(session_id, session, TASK_TTL) return session + def get_devices(): devices = { - 'all': {}, - 'active': {}, + "all": {}, + "active": {}, } def get_device_info(device): - if device == 'cpu': - return {'name': device_manager.get_processor_name()} - + if device == "cpu": + return {"name": device_manager.get_processor_name()} + mem_free, mem_total = torch.cuda.mem_get_info(device) mem_free /= float(10**9) mem_total /= float(10**9) return { - 'name': torch.cuda.get_device_name(device), - 'mem_free': mem_free, - 'mem_total': mem_total, - 'max_vram_usage_level': device_manager.get_max_vram_usage_level(device), + "name": torch.cuda.get_device_name(device), + "mem_free": mem_free, + "mem_total": mem_total, + "max_vram_usage_level": device_manager.get_max_vram_usage_level(device), } # list the compatible devices gpu_count = torch.cuda.device_count() for device in range(gpu_count): - device = f'cuda:{device}' + device = f"cuda:{device}" if not device_manager.is_device_compatible(device): continue - devices['all'].update({device: get_device_info(device)}) + devices["all"].update({device: get_device_info(device)}) - devices['all'].update({'cpu': get_device_info('cpu')}) + devices["all"].update({"cpu": get_device_info("cpu")}) # list the activated devices - if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('get_devices' + ERR_LOCK_FAILED) + if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("get_devices" + ERR_LOCK_FAILED) try: for rthread in render_threads: if not rthread.is_alive(): continue weak_data = weak_thread_data.get(rthread) - if not weak_data or not 'device' in weak_data or not 'device_name' in weak_data: + if not weak_data or not "device" in weak_data or not "device_name" in weak_data: continue - device = weak_data['device'] - devices['active'].update({device: get_device_info(device)}) + device = weak_data["device"] + devices["active"].update({device: get_device_info(device)}) finally: manager_lock.release() return devices + def is_alive(device=None): - if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED) + if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("is_alive" + ERR_LOCK_FAILED) nbr_alive = 0 try: for rthread in render_threads: if device is not None: weak_data = weak_thread_data.get(rthread) - if weak_data is None or not 'device' in weak_data or weak_data['device'] is None: + if weak_data is None or not "device" in weak_data or weak_data["device"] is None: continue - thread_device = weak_data['device'] + thread_device = weak_data["device"] if thread_device != device: continue if rthread.is_alive(): @@ -388,11 +447,13 @@ def is_alive(device=None): finally: manager_lock.release() + def start_render_thread(device): - if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED) - log.info(f'Start new Rendering Thread on device: {device}') + if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("start_render_thread" + ERR_LOCK_FAILED) + log.info(f"Start new Rendering Thread on device: {device}") try: - rthread = threading.Thread(target=thread_render, kwargs={'device': device}) + rthread = threading.Thread(target=thread_render, kwargs={"device": device}) rthread.daemon = True rthread.name = THREAD_NAME_PREFIX + device rthread.start() @@ -400,8 +461,8 @@ def start_render_thread(device): finally: manager_lock.release() timeout = DEVICE_START_TIMEOUT - while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]: - if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]: + while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]: + if rthread in weak_thread_data and "error" in weak_thread_data[rthread]: log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}") return False if timeout <= 0: @@ -410,25 +471,27 @@ def start_render_thread(device): time.sleep(1) return True + def stop_render_thread(device): try: - device_manager.validate_device_id(device, log_prefix='stop_render_thread') + device_manager.validate_device_id(device, log_prefix="stop_render_thread") except: log.error(traceback.format_exc()) return False - if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED) - log.info(f'Stopping Rendering Thread on device: {device}') + if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): + raise Exception("stop_render_thread" + ERR_LOCK_FAILED) + log.info(f"Stopping Rendering Thread on device: {device}") try: thread_to_remove = None for rthread in render_threads: weak_data = weak_thread_data.get(rthread) - if weak_data is None or not 'device' in weak_data or weak_data['device'] is None: + if weak_data is None or not "device" in weak_data or weak_data["device"] is None: continue - thread_device = weak_data['device'] + thread_device = weak_data["device"] if thread_device == device: - weak_data['alive'] = False + weak_data["alive"] = False thread_to_remove = rthread break if thread_to_remove is not None: @@ -439,44 +502,51 @@ def stop_render_thread(device): return False + def update_render_threads(render_devices, active_devices): devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices) - log.debug(f'devices_to_start: {devices_to_start}') - log.debug(f'devices_to_stop: {devices_to_stop}') + log.debug(f"devices_to_start: {devices_to_start}") + log.debug(f"devices_to_stop: {devices_to_stop}") for device in devices_to_stop: if is_alive(device) <= 0: - log.debug(f'{device} is not alive') + log.debug(f"{device} is not alive") continue if not stop_render_thread(device): - log.warn(f'{device} could not stop render thread') + log.warn(f"{device} could not stop render thread") for device in devices_to_start: if is_alive(device) >= 1: - log.debug(f'{device} already registered.') + log.debug(f"{device} already registered.") continue if not start_render_thread(device): - log.warn(f'{device} failed to start.') + log.warn(f"{device} failed to start.") - if is_alive() <= 0: # No running devices, probably invalid user config. - raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json') + if is_alive() <= 0: # No running devices, probably invalid user config. + raise EnvironmentError( + 'ERROR: No active render devices! Please verify the "render_devices" value in config.json' + ) log.debug(f"active devices: {get_devices()['active']}") -def shutdown_event(): # Signal render thread to close on shutdown + +def shutdown_event(): # Signal render thread to close on shutdown global current_state_error - current_state_error = SystemExit('Application shutting down.') + current_state_error = SystemExit("Application shutting down.") + def render(render_req: GenerateImageRequest, task_data: TaskData): current_thread_count = is_alive() if current_thread_count <= 0: # Render thread is dead - raise ChildProcessError('Rendering thread has died.') + raise ChildProcessError("Rendering thread has died.") # Alive, check if task in cache session = get_cached_session(task_data.session_id, update_ttl=True) pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) if current_thread_count < len(pending_tasks): - raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') + raise ConnectionRefusedError( + f"Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}." + ) new_task = RenderTask(render_req, task_data) if session.put(new_task, TASK_TTL): @@ -489,4 +559,4 @@ def render(render_req: GenerateImageRequest, task_data: TaskData): return new_task finally: manager_lock.release() - raise RuntimeError('Failed to add task to cache.') + raise RuntimeError("Failed to add task to cache.") diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index c70e125b..b45eafd2 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -1,6 +1,7 @@ from pydantic import BaseModel from typing import Any + class GenerateImageRequest(BaseModel): prompt: str = "" negative_prompt: str = "" @@ -18,29 +19,31 @@ class GenerateImageRequest(BaseModel): prompt_strength: float = 0.8 preserve_init_image_color_profile = False - sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" + sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" hypernetwork_strength: float = 0 + class TaskData(BaseModel): request_id: str = None session_id: str = "session" save_to_disk_path: str = None - vram_usage_level: str = "balanced" # or "low" or "medium" + vram_usage_level: str = "balanced" # or "low" or "medium" - use_face_correction: str = None # or "GFPGANv1.3" - use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" - upscale_amount: int = 4 # or 2 + use_face_correction: str = None # or "GFPGANv1.3" + use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" + upscale_amount: int = 4 # or 2 use_stable_diffusion_model: str = "sd-v1-4" # use_stable_diffusion_config: str = "v1-inference" use_vae_model: str = None use_hypernetwork_model: str = None show_only_filtered_image: bool = False - output_format: str = "jpeg" # or "png" + output_format: str = "jpeg" # or "png" output_quality: int = 75 - metadata_output_format: str = "txt" # or "json" + metadata_output_format: str = "txt" # or "json" stream_image_progress: bool = False + class MergeRequest(BaseModel): model0: str = None model1: str = None @@ -48,8 +51,9 @@ class MergeRequest(BaseModel): out_path: str = "mix" use_fp16 = True + class Image: - data: str # base64 + data: str # base64 seed: int is_nsfw: bool path_abs: str = None @@ -65,6 +69,7 @@ class Image: "path_abs": self.path_abs, } + class Response: render_request: GenerateImageRequest task_data: TaskData @@ -80,7 +85,7 @@ class Response: del self.render_request.init_image_mask res = { - "status": 'succeeded', + "status": "succeeded", "render_request": self.render_request.dict(), "task_data": self.task_data.dict(), "output": [], @@ -91,5 +96,6 @@ class Response: return res + class UserInitiatedStop(Exception): pass diff --git a/ui/easydiffusion/utils/__init__.py b/ui/easydiffusion/utils/__init__.py index 8be070b4..b9c5e21a 100644 --- a/ui/easydiffusion/utils/__init__.py +++ b/ui/easydiffusion/utils/__init__.py @@ -1,8 +1,8 @@ import logging -log = logging.getLogger('easydiffusion') +log = logging.getLogger("easydiffusion") from .save_utils import ( save_images_to_disk, get_printable_request, -) \ No newline at end of file +) diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index 1835b7e8..b4a85538 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -7,89 +7,126 @@ from easydiffusion.types import TaskData, GenerateImageRequest from sdkit.utils import save_images, save_dicts -filename_regex = re.compile('[^a-zA-Z0-9._-]') +filename_regex = re.compile("[^a-zA-Z0-9._-]") # keep in sync with `ui/media/js/dnd.js` TASK_TEXT_MAPPING = { - 'prompt': 'Prompt', - 'width': 'Width', - 'height': 'Height', - 'seed': 'Seed', - 'num_inference_steps': 'Steps', - 'guidance_scale': 'Guidance Scale', - 'prompt_strength': 'Prompt Strength', - 'use_face_correction': 'Use Face Correction', - 'use_upscale': 'Use Upscaling', - 'upscale_amount': 'Upscale By', - 'sampler_name': 'Sampler', - 'negative_prompt': 'Negative Prompt', - 'use_stable_diffusion_model': 'Stable Diffusion model', - 'use_vae_model': 'VAE model', - 'use_hypernetwork_model': 'Hypernetwork model', - 'hypernetwork_strength': 'Hypernetwork Strength' + "prompt": "Prompt", + "width": "Width", + "height": "Height", + "seed": "Seed", + "num_inference_steps": "Steps", + "guidance_scale": "Guidance Scale", + "prompt_strength": "Prompt Strength", + "use_face_correction": "Use Face Correction", + "use_upscale": "Use Upscaling", + "upscale_amount": "Upscale By", + "sampler_name": "Sampler", + "negative_prompt": "Negative Prompt", + "use_stable_diffusion_model": "Stable Diffusion model", + "use_vae_model": "VAE model", + "use_hypernetwork_model": "Hypernetwork model", + "hypernetwork_strength": "Hypernetwork Strength", } + def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): now = time.time() - save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) + save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub("_", task_data.session_id)) metadata_entries = get_metadata_entries_for_request(req, task_data) make_filename = make_filename_callback(req, now=now) if task_data.show_only_filtered_image or filtered_images is images: - save_images(filtered_images, save_dir_path, file_name=make_filename, output_format=task_data.output_format, output_quality=task_data.output_quality) - if task_data.metadata_output_format.lower() in ['json', 'txt', 'embed']: - save_dicts(metadata_entries, save_dir_path, file_name=make_filename, output_format=task_data.metadata_output_format, file_format=task_data.output_format) + save_images( + filtered_images, + save_dir_path, + file_name=make_filename, + output_format=task_data.output_format, + output_quality=task_data.output_quality, + ) + if task_data.metadata_output_format.lower() in ["json", "txt", "embed"]: + save_dicts( + metadata_entries, + save_dir_path, + file_name=make_filename, + output_format=task_data.metadata_output_format, + file_format=task_data.output_format, + ) else: - make_filter_filename = make_filename_callback(req, now=now, suffix='filtered') + make_filter_filename = make_filename_callback(req, now=now, suffix="filtered") + + save_images( + images, + save_dir_path, + file_name=make_filename, + output_format=task_data.output_format, + output_quality=task_data.output_quality, + ) + save_images( + filtered_images, + save_dir_path, + file_name=make_filter_filename, + output_format=task_data.output_format, + output_quality=task_data.output_quality, + ) + if task_data.metadata_output_format.lower() in ["json", "txt", "embed"]: + save_dicts( + metadata_entries, + save_dir_path, + file_name=make_filter_filename, + output_format=task_data.metadata_output_format, + file_format=task_data.output_format, + ) - save_images(images, save_dir_path, file_name=make_filename, output_format=task_data.output_format, output_quality=task_data.output_quality) - save_images(filtered_images, save_dir_path, file_name=make_filter_filename, output_format=task_data.output_format, output_quality=task_data.output_quality) - if task_data.metadata_output_format.lower() in ['json', 'txt', 'embed']: - save_dicts(metadata_entries, save_dir_path, file_name=make_filter_filename, output_format=task_data.metadata_output_format, file_format=task_data.output_format) def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData): metadata = get_printable_request(req) - metadata.update({ - 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, - 'use_vae_model': task_data.use_vae_model, - 'use_hypernetwork_model': task_data.use_hypernetwork_model, - 'use_face_correction': task_data.use_face_correction, - 'use_upscale': task_data.use_upscale, - }) - if metadata['use_upscale'] is not None: - metadata['upscale_amount'] = task_data.upscale_amount - if (task_data.use_hypernetwork_model is None): - del metadata['hypernetwork_strength'] + metadata.update( + { + "use_stable_diffusion_model": task_data.use_stable_diffusion_model, + "use_vae_model": task_data.use_vae_model, + "use_hypernetwork_model": task_data.use_hypernetwork_model, + "use_face_correction": task_data.use_face_correction, + "use_upscale": task_data.use_upscale, + } + ) + if metadata["use_upscale"] is not None: + metadata["upscale_amount"] = task_data.upscale_amount + if task_data.use_hypernetwork_model is None: + del metadata["hypernetwork_strength"] # if text, format it in the text format expected by the UI - is_txt_format = (task_data.metadata_output_format.lower() == 'txt') + is_txt_format = task_data.metadata_output_format.lower() == "txt" if is_txt_format: metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING} entries = [metadata.copy() for _ in range(req.num_outputs)] for i, entry in enumerate(entries): - entry['Seed' if is_txt_format else 'seed'] = req.seed + i + entry["Seed" if is_txt_format else "seed"] = req.seed + i return entries + def get_printable_request(req: GenerateImageRequest): metadata = req.dict() - del metadata['init_image'] - del metadata['init_image_mask'] - if (req.init_image is None): - del metadata['prompt_strength'] + del metadata["init_image"] + del metadata["init_image_mask"] + if req.init_image is None: + del metadata["prompt_strength"] return metadata + def make_filename_callback(req: GenerateImageRequest, suffix=None, now=None): if now is None: now = time.time() - def make_filename(i): - img_id = base64.b64encode(int(now+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. - img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. - prompt_flattened = filename_regex.sub('_', req.prompt)[:50] + def make_filename(i): + img_id = base64.b64encode(int(now + i).to_bytes(8, "big")).decode() # Generate unique ID based on time. + img_id = img_id.translate({43: None, 47: None, 61: None})[-8:] # Remove + / = and keep last 8 chars. + + prompt_flattened = filename_regex.sub("_", req.prompt)[:50] name = f"{prompt_flattened}_{img_id}" - name = name if suffix is None else f'{name}_{suffix}' + name = name if suffix is None else f"{name}_{suffix}" return name return make_filename