From fb6a7e04f56b1f1ac089f68d196d84600e400ea3 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 7 Dec 2022 22:15:35 +0530 Subject: [PATCH 01/74] Work-in-progress refactor of the backend, to move most of the logic to diffusion-kit and keeping this as a UI around that engine. Does not work yet. --- ui/sd_internal/__init__.py | 1 - ui/sd_internal/app.py | 156 ++++++++++++ ui/sd_internal/device_manager.py | 23 +- ui/sd_internal/model_manager.py | 141 +++++++++++ ui/sd_internal/runtime2.py | 95 ++++++++ ui/sd_internal/task_manager.py | 89 ++----- ui/server.py | 396 +++++-------------------------- 7 files changed, 484 insertions(+), 417 deletions(-) create mode 100644 ui/sd_internal/app.py create mode 100644 ui/sd_internal/model_manager.py create mode 100644 ui/sd_internal/runtime2.py diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index a2abe294..9f71f249 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -40,7 +40,6 @@ class Request: "num_outputs": self.num_outputs, "num_inference_steps": self.num_inference_steps, "guidance_scale": self.guidance_scale, - "hypernetwork_strengtgh": self.guidance_scale, "width": self.width, "height": self.height, "seed": self.seed, diff --git a/ui/sd_internal/app.py b/ui/sd_internal/app.py new file mode 100644 index 00000000..09242319 --- /dev/null +++ b/ui/sd_internal/app.py @@ -0,0 +1,156 @@ +import os +import socket +import sys +import json +import traceback + +from sd_internal import task_manager + +SD_DIR = os.getcwd() + +SD_UI_DIR = os.getenv('SD_UI_PATH', None) +sys.path.append(os.path.dirname(SD_UI_DIR)) + +CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) +MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) + +USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui')) +CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui')) +UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user')) + +STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors'] +VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] +HYPERNETWORK_MODEL_EXTENSIONS = ['.pt'] + +OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder +TASK_TTL = 15 * 60 # Discard last session's task timeout +APP_CONFIG_DEFAULTS = { + # auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device. + 'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index) + 'update_branch': 'main', + 'ui': { + 'open_browser_on_start': True, + }, +} +DEFAULT_MODELS = [ + # needed to support the legacy installations + 'custom-model', # Check if user has a custom model, use it first. + 'sd-v1-4', # Default fallback. +] + +def init(): + os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) + + update_render_threads() + +def getConfig(default_val=APP_CONFIG_DEFAULTS): + try: + config_json_path = os.path.join(CONFIG_DIR, 'config.json') + if not os.path.exists(config_json_path): + return default_val + with open(config_json_path, 'r', encoding='utf-8') as f: + config = json.load(f) + if 'net' not in config: + config['net'] = {} + if os.getenv('SD_UI_BIND_PORT') is not None: + config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT')) + if os.getenv('SD_UI_BIND_IP') is not None: + config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0') + return config + except Exception as e: + print(str(e)) + print(traceback.format_exc()) + return default_val + +def setConfig(config): + try: # config.json + config_json_path = os.path.join(CONFIG_DIR, 'config.json') + with open(config_json_path, 'w', encoding='utf-8') as f: + json.dump(config, f) + except: + print(traceback.format_exc()) + + try: # config.bat + config_bat_path = os.path.join(CONFIG_DIR, 'config.bat') + config_bat = [] + + if 'update_branch' in config: + config_bat.append(f"@set update_branch={config['update_branch']}") + + config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}") + bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' + config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") + + config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}") + + if len(config_bat) > 0: + with open(config_bat_path, 'w', encoding='utf-8') as f: + f.write('\r\n'.join(config_bat)) + except: + print(traceback.format_exc()) + + try: # config.sh + config_sh_path = os.path.join(CONFIG_DIR, 'config.sh') + config_sh = ['#!/bin/bash'] + + if 'update_branch' in config: + config_sh.append(f"export update_branch={config['update_branch']}") + + config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}") + bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' + config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") + + config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"") + + if len(config_sh) > 1: + with open(config_sh_path, 'w', encoding='utf-8') as f: + f.write('\n'.join(config_sh)) + except: + print(traceback.format_exc()) + +def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name): + config = getConfig() + if 'model' not in config: + config['model'] = {} + + config['model']['stable-diffusion'] = ckpt_model_name + config['model']['vae'] = vae_model_name + config['model']['hypernetwork'] = hypernetwork_model_name + + if vae_model_name is None or vae_model_name == "": + del config['model']['vae'] + if hypernetwork_model_name is None or hypernetwork_model_name == "": + del config['model']['hypernetwork'] + + setConfig(config) + +def update_render_threads(): + config = getConfig() + render_devices = config.get('render_devices', 'auto') + active_devices = task_manager.get_devices()['active'].keys() + + print('requesting for render_devices', render_devices) + task_manager.update_render_threads(render_devices, active_devices) + +def getUIPlugins(): + plugins = [] + + for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES: + for file in os.listdir(plugins_dir): + if file.endswith('.plugin.js'): + plugins.append(f'/plugins/{dir_prefix}/{file}') + + return plugins + +def getIPConfig(): + ips = socket.gethostbyname_ex(socket.gethostname()) + ips[2].append(ips[0]) + return ips[2] + +def open_browser(): + config = getConfig() + ui = config.get('ui', {}) + net = config.get('net', {'listen_port':9000}) + port = net.get('listen_port', 9000) + if ui.get('open_browser_on_start', True): + import webbrowser; webbrowser.open(f"http://localhost:{port}") diff --git a/ui/sd_internal/device_manager.py b/ui/sd_internal/device_manager.py index d2c6430b..3c1d8bf1 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/sd_internal/device_manager.py @@ -82,7 +82,7 @@ def auto_pick_devices(currently_active_devices): devices = list(map(lambda x: x['device'], devices)) return devices -def device_init(thread_data, device): +def device_init(context, device): ''' This function assumes the 'device' has already been verified to be compatible. `get_device_delta()` has already filtered out incompatible devices. @@ -91,21 +91,22 @@ def device_init(thread_data, device): validate_device_id(device, log_prefix='device_init') if device == 'cpu': - thread_data.device = 'cpu' - thread_data.device_name = get_processor_name() - print('Render device CPU available as', thread_data.device_name) + context.device = 'cpu' + context.device_name = get_processor_name() + context.precision = 'full' + print('Render device CPU available as', context.device_name) return - thread_data.device_name = torch.cuda.get_device_name(device) - thread_data.device = device + context.device_name = torch.cuda.get_device_name(device) + context.device = device # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images - device_name = thread_data.device_name.lower() - thread_data.force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) - if thread_data.force_full_precision: - print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', thread_data.device_name) + device_name = context.device_name.lower() + force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) + if force_full_precision: + print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', context.device_name) # Apply force_full_precision now before models are loaded. - thread_data.precision = 'full' + context.precision = 'full' print(f'Setting {device} as active') torch.cuda.device(device) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py new file mode 100644 index 00000000..3941b130 --- /dev/null +++ b/ui/sd_internal/model_manager.py @@ -0,0 +1,141 @@ +import os + +from sd_internal import app +import picklescan.scanner +import rich + +default_model_to_load = None +default_vae_to_load = None +default_hypernetwork_to_load = None + +known_models = {} + +def init(): + global default_model_to_load, default_vae_to_load, default_hypernetwork_to_load + + default_model_to_load = resolve_ckpt_to_use() + default_vae_to_load = resolve_vae_to_use() + default_hypernetwork_to_load = resolve_hypernetwork_to_use() + + getModels() # run this once, to cache the picklescan results + +def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]): + config = app.getConfig() + + model_dirs = [os.path.join(app.MODELS_DIR, model_dir), app.SD_DIR] + if not model_name: # When None try user configured model. + # config = getConfig() + if 'model' in config and model_type in config['model']: + model_name = config['model'][model_type] + + if model_name: + is_sd2 = config.get('test_sd2', False) + if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4 + print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!') + model_name = 'sd-v1-4' + + # Check models directory + models_dir_path = os.path.join(app.MODELS_DIR, model_dir, model_name) + for model_extension in model_extensions: + if os.path.exists(models_dir_path + model_extension): + return models_dir_path + model_extension + if os.path.exists(model_name + model_extension): + return os.path.abspath(model_name + model_extension) + + # Default locations + if model_name in default_models: + default_model_path = os.path.join(app.SD_DIR, model_name) + for model_extension in model_extensions: + if os.path.exists(default_model_path + model_extension): + return default_model_path + model_extension + + # Can't find requested model, check the default paths. + for default_model in default_models: + for model_dir in model_dirs: + default_model_path = os.path.join(model_dir, default_model) + for model_extension in model_extensions: + if os.path.exists(default_model_path + model_extension): + if model_name is not None: + print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') + return default_model_path + model_extension + + raise Exception('No valid models found.') + +def resolve_ckpt_to_use(model_name:str=None): + return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS) + +def resolve_vae_to_use(model_name:str=None): + try: + return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=app.VAE_MODEL_EXTENSIONS, default_models=[]) + except: + return None + +def resolve_hypernetwork_to_use(model_name:str=None): + try: + return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS, default_models=[]) + except: + return None + +def is_malicious_model(file_path): + try: + scan_result = picklescan.scanner.scan_file_path(file_path) + if scan_result.issues_count > 0 or scan_result.infected_files > 0: + rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + return True + else: + rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + return False + except Exception as e: + print('error while scanning', file_path, 'error:', e) + return False + +def getModels(): + models = { + 'active': { + 'stable-diffusion': 'sd-v1-4', + 'vae': '', + 'hypernetwork': '', + }, + 'options': { + 'stable-diffusion': ['sd-v1-4'], + 'vae': [], + 'hypernetwork': [], + }, + } + + def listModels(models_dirname, model_type, model_extensions): + models_dir = os.path.join(app.MODELS_DIR, models_dirname) + if not os.path.exists(models_dir): + os.makedirs(models_dir) + + for file in os.listdir(models_dir): + for model_extension in model_extensions: + if not file.endswith(model_extension): + continue + + model_path = os.path.join(models_dir, file) + mtime = os.path.getmtime(model_path) + mod_time = known_models[model_path] if model_path in known_models else -1 + if mod_time != mtime: + if is_malicious_model(model_path): + models['scan-error'] = file + return + known_models[model_path] = mtime + + model_name = file[:-len(model_extension)] + models['options'][model_type].append(model_name) + + models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates + models['options'][model_type].sort() + + # custom models + listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS) + listModels(models_dirname='vae', model_type='vae', model_extensions=app.VAE_MODEL_EXTENSIONS) + listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS) + + # legacy + custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') + if os.path.exists(custom_weight_path): + models['options']['stable-diffusion'].append('custom-model') + + return models diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py new file mode 100644 index 00000000..32befcde --- /dev/null +++ b/ui/sd_internal/runtime2.py @@ -0,0 +1,95 @@ +import threading +import queue + +from sd_internal import device_manager, Request, Response, Image as ResponseImage + +from modules import model_loader, image_generator, image_utils + +thread_data = threading.local() +''' +runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc +''' + +def init(device): + ''' + Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device + ''' + thread_data.stop_processing = False + thread_data.temp_images = {} + + thread_data.models = {} + thread_data.loaded_model_paths = {} + + thread_data.device = None + thread_data.device_name = None + thread_data.precision = 'autocast' + thread_data.vram_optimizations = ('TURBO', 'MOVE_MODELS') + + device_manager.device_init(thread_data, device) + + reload_models() + +def destroy(): + model_loader.unload_sd_model(thread_data) + model_loader.unload_gfpgan_model(thread_data) + model_loader.unload_realesrgan_model(thread_data) + +def reload_models(req: Request=None): + if is_hypernetwork_reload_necessary(task.request): + current_state = ServerStates.LoadingModel + runtime.reload_hypernetwork() + + if is_model_reload_necessary(task.request): + current_state = ServerStates.LoadingModel + runtime.reload_model() + +def load_models(): + if ckpt_file_path == None: + ckpt_file_path = default_model_to_load + if vae_file_path == None: + vae_file_path = default_vae_to_load + if hypernetwork_file_path == None: + hypernetwork_file_path = default_hypernetwork_to_load + if ckpt_file_path == current_model_path and vae_file_path == current_vae_path: + return + current_state = ServerStates.LoadingModel + try: + from sd_internal import runtime2 + runtime.thread_data.hypernetwork_file = hypernetwork_file_path + runtime.thread_data.ckpt_file = ckpt_file_path + runtime.thread_data.vae_file = vae_file_path + runtime.load_model_ckpt() + runtime.load_hypernetwork() + current_model_path = ckpt_file_path + current_vae_path = vae_file_path + current_hypernetwork_path = hypernetwork_file_path + current_state_error = None + current_state = ServerStates.Online + except Exception as e: + current_model_path = None + current_vae_path = None + current_state_error = e + current_state = ServerStates.Unavailable + print(traceback.format_exc()) + +def make_image(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): + try: + images = image_generator.make_image(context=thread_data, args=get_mk_img_args(req)) + except UserInitiatedStop: + pass + +def get_mk_img_args(req: Request): + args = req.json() + + if req.init_image is not None: + args['init_image'] = image_utils.base64_str_to_img(req.init_image) + + if req.mask is not None: + args['mask'] = image_utils.base64_str_to_img(req.mask) + + return args + +def on_image_step(x_samples, i): + pass + +image_generator.on_image_step = on_image_step diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 41fc00f6..4c2b5c6c 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -177,50 +177,14 @@ manager_lock = threading.RLock() render_threads = [] current_state = ServerStates.Init current_state_error:Exception = None -current_model_path = None -current_vae_path = None -current_hypernetwork_path = None tasks_queue = [] task_cache = TaskCache() -default_model_to_load = None -default_vae_to_load = None -default_hypernetwork_to_load = None weak_thread_data = weakref.WeakKeyDictionary() -def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None): - global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path - if ckpt_file_path == None: - ckpt_file_path = default_model_to_load - if vae_file_path == None: - vae_file_path = default_vae_to_load - if hypernetwork_file_path == None: - hypernetwork_file_path = default_hypernetwork_to_load - if ckpt_file_path == current_model_path and vae_file_path == current_vae_path: - return - current_state = ServerStates.LoadingModel - try: - from . import runtime - runtime.thread_data.hypernetwork_file = hypernetwork_file_path - runtime.thread_data.ckpt_file = ckpt_file_path - runtime.thread_data.vae_file = vae_file_path - runtime.load_model_ckpt() - runtime.load_hypernetwork() - current_model_path = ckpt_file_path - current_vae_path = vae_file_path - current_hypernetwork_path = hypernetwork_file_path - current_state_error = None - current_state = ServerStates.Online - except Exception as e: - current_model_path = None - current_vae_path = None - current_state_error = e - current_state = ServerStates.Unavailable - print(traceback.format_exc()) - def thread_get_next_task(): - from . import runtime + from sd_internal import runtime2 if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): - print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.') + print('Render thread on device', runtime2.thread_data.device, 'failed to acquire manager lock.') return None if len(tasks_queue) <= 0: manager_lock.release() @@ -228,7 +192,7 @@ def thread_get_next_task(): task = None try: # Select a render task. for queued_task in tasks_queue: - if queued_task.render_device and runtime.thread_data.device != queued_task.render_device: + if queued_task.render_device and runtime2.thread_data.device != queued_task.render_device: # Is asking for a specific render device. if is_alive(queued_task.render_device) > 0: continue # requested device alive, skip current one. @@ -237,7 +201,7 @@ def thread_get_next_task(): queued_task.error = Exception(queued_task.render_device + ' is not currently active.') task = queued_task break - if not queued_task.render_device and runtime.thread_data.device == 'cpu' and is_alive() > 1: + if not queued_task.render_device and runtime2.thread_data.device == 'cpu' and is_alive() > 1: # not asking for any specific devices, cpu want to grab task but other render devices are alive. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. task = queued_task @@ -249,30 +213,31 @@ def thread_get_next_task(): manager_lock.release() def thread_render(device): - global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path - from . import runtime + global current_state, current_state_error + + from sd_internal import runtime2 try: - runtime.thread_init(device) + runtime2.init(device) except Exception as e: print(traceback.format_exc()) weak_thread_data[threading.current_thread()] = { 'error': e } return + weak_thread_data[threading.current_thread()] = { - 'device': runtime.thread_data.device, - 'device_name': runtime.thread_data.device_name, + 'device': runtime2.thread_data.device, + 'device_name': runtime2.thread_data.device_name, 'alive': True } - if runtime.thread_data.device != 'cpu' or is_alive() == 1: - preload_model() - current_state = ServerStates.Online + + current_state = ServerStates.Online + while True: task_cache.clean() if not weak_thread_data[threading.current_thread()]['alive']: - print(f'Shutting down thread for device {runtime.thread_data.device}') - runtime.unload_models() - runtime.unload_filters() + print(f'Shutting down thread for device {runtime2.thread_data.device}') + runtime2.destroy() return if isinstance(current_state_error, SystemExit): current_state = ServerStates.Unavailable @@ -291,24 +256,17 @@ def thread_render(device): task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue - print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}') + print(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: - if runtime.is_hypernetwork_reload_necessary(task.request): - runtime.reload_hypernetwork() - current_hypernetwork_path = task.request.use_hypernetwork_model - - if runtime.is_model_reload_necessary(task.request): - current_state = ServerStates.LoadingModel - runtime.reload_model() - current_model_path = task.request.use_stable_diffusion_model - current_vae_path = task.request.use_vae_model + current_state = ServerStates.LoadingModel + runtime2.reload_models(task.request) def step_callback(): global current_state_error if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): - runtime.thread_data.stop_processing = True + runtime2.thread_data.stop_processing = True if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None @@ -317,7 +275,7 @@ def thread_render(device): task_cache.keep(task.request.session_id, TASK_TTL) current_state = ServerStates.Rendering - task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback) + task.response = runtime2.make_image(task.request, task.buffer_queue, task.temp_images, step_callback) except Exception as e: task.error = e print(traceback.format_exc()) @@ -331,7 +289,7 @@ def thread_render(device): elif task.error is not None: print(f'Session {task.request.session_id} task {id(task)} failed!') else: - print(f'Session {task.request.session_id} task {id(task)} completed by {runtime.thread_data.device_name}.') + print(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') current_state = ServerStates.Online def get_cached_task(session_id:str, update_ttl:bool=False): @@ -493,8 +451,7 @@ def render(req : ImageRequest): if task and not task.response and not task.error and not task.lock.locked(): # Unstarted task pending, deny queueing more than one. raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.') - # - from . import runtime + r = Request() r.session_id = req.session_id r.prompt = req.prompt diff --git a/ui/server.py b/ui/server.py index 804994a2..7922d8f3 100644 --- a/ui/server.py +++ b/ui/server.py @@ -2,64 +2,21 @@ Notes: async endpoints always run on the main thread. Without they run on the thread pool. """ -import json -import traceback - -import sys import os -import socket -import picklescan.scanner -import rich - -SD_DIR = os.getcwd() -print('started in ', SD_DIR) - -SD_UI_DIR = os.getenv('SD_UI_PATH', None) -sys.path.append(os.path.dirname(SD_UI_DIR)) - -CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) -MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) - -USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui')) -CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui')) -UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user')) - -STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors'] -VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] -HYPERNETWORK_MODEL_EXTENSIONS = ['.pt'] - -OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder -TASK_TTL = 15 * 60 # Discard last session's task timeout -APP_CONFIG_DEFAULTS = { - # auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device. - 'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index) - 'update_branch': 'main', - 'ui': { - 'open_browser_on_start': True, - }, -} -APP_CONFIG_DEFAULT_MODELS = [ - # needed to support the legacy installations - 'custom-model', # Check if user has a custom model, use it first. - 'sd-v1-4', # Default fallback. -] +import traceback +import logging +from typing import List, Union from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles from starlette.responses import FileResponse, JSONResponse, StreamingResponse from pydantic import BaseModel -import logging -#import queue, threading, time -from typing import Any, Generator, Hashable, List, Optional, Union -from sd_internal import Request, Response, task_manager +from sd_internal import app, model_manager, task_manager -app = FastAPI() +print('started in ', app.SD_DIR) -modifiers_cache = None -outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME) - -os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) +server_api = FastAPI() # don't show access log entries for URLs that start with the given prefix ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails'] @@ -74,132 +31,6 @@ class NoCacheStaticFiles(StaticFiles): return super().is_not_modified(response_headers, request_headers) -app.mount('/media', NoCacheStaticFiles(directory=os.path.join(SD_UI_DIR, 'media')), name="media") - -for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES: - app.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}") - -def getConfig(default_val=APP_CONFIG_DEFAULTS): - try: - config_json_path = os.path.join(CONFIG_DIR, 'config.json') - if not os.path.exists(config_json_path): - return default_val - with open(config_json_path, 'r', encoding='utf-8') as f: - config = json.load(f) - if 'net' not in config: - config['net'] = {} - if os.getenv('SD_UI_BIND_PORT') is not None: - config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT')) - if os.getenv('SD_UI_BIND_IP') is not None: - config['net']['listen_to_network'] = ( os.getenv('SD_UI_BIND_IP') == '0.0.0.0' ) - return config - except Exception as e: - print(str(e)) - print(traceback.format_exc()) - return default_val - -def setConfig(config): - print( json.dumps(config) ) - try: # config.json - config_json_path = os.path.join(CONFIG_DIR, 'config.json') - with open(config_json_path, 'w', encoding='utf-8') as f: - json.dump(config, f) - except: - print(traceback.format_exc()) - - try: # config.bat - config_bat_path = os.path.join(CONFIG_DIR, 'config.bat') - config_bat = [] - - if 'update_branch' in config: - config_bat.append(f"@set update_branch={config['update_branch']}") - - config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}") - bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' - config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") - - config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}") - - if len(config_bat) > 0: - with open(config_bat_path, 'w', encoding='utf-8') as f: - f.write('\r\n'.join(config_bat)) - except: - print(traceback.format_exc()) - - try: # config.sh - config_sh_path = os.path.join(CONFIG_DIR, 'config.sh') - config_sh = ['#!/bin/bash'] - - if 'update_branch' in config: - config_sh.append(f"export update_branch={config['update_branch']}") - - config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}") - bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' - config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") - - config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"") - - if len(config_sh) > 1: - with open(config_sh_path, 'w', encoding='utf-8') as f: - f.write('\n'.join(config_sh)) - except: - print(traceback.format_exc()) - -def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]): - config = getConfig() - - model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR] - if not model_name: # When None try user configured model. - # config = getConfig() - if 'model' in config and model_type in config['model']: - model_name = config['model'][model_type] - if model_name: - is_sd2 = config.get('test_sd2', False) - if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4 - print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!') - model_name = 'sd-v1-4' - - # Check models directory - models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name) - for model_extension in model_extensions: - if os.path.exists(models_dir_path + model_extension): - return models_dir_path - if os.path.exists(model_name + model_extension): - # Direct Path to file - model_name = os.path.abspath(model_name) - return model_name - # Default locations - if model_name in default_models: - default_model_path = os.path.join(SD_DIR, model_name) - for model_extension in model_extensions: - if os.path.exists(default_model_path + model_extension): - return default_model_path - # Can't find requested model, check the default paths. - for default_model in default_models: - for model_dir in model_dirs: - default_model_path = os.path.join(model_dir, default_model) - for model_extension in model_extensions: - if os.path.exists(default_model_path + model_extension): - if model_name is not None: - print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') - return default_model_path - raise Exception('No valid models found.') - -def resolve_ckpt_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=APP_CONFIG_DEFAULT_MODELS) - -def resolve_vae_to_use(model_name:str=None): - try: - return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[]) - except: - return None - -def resolve_hypernetwork_to_use(model_name:str=None): - try: - return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[]) - except: - return None - class SetAppConfigRequest(BaseModel): update_branch: str = None render_devices: Union[List[str], List[int], str, int] = None @@ -209,9 +40,25 @@ class SetAppConfigRequest(BaseModel): listen_port: int = None test_sd2: bool = None -@app.post('/app_config') +class LogSuppressFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + path = record.getMessage() + for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES: + if path.find(prefix) != -1: + return False + return True + +# don't log certain requests +logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) + +server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media") + +for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES: + app.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}") + +@server_api.post('/app_config') async def setAppConfig(req : SetAppConfigRequest): - config = getConfig() + config = app.getConfig() if req.update_branch is not None: config['update_branch'] = req.update_branch if req.render_devices is not None: @@ -231,121 +78,48 @@ async def setAppConfig(req : SetAppConfigRequest): if req.test_sd2 is not None: config['test_sd2'] = req.test_sd2 try: - setConfig(config) + app.setConfig(config) if req.render_devices: - update_render_threads() + app.update_render_threads() return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS) except Exception as e: print(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) -def is_malicious_model(file_path): - try: - scan_result = picklescan.scanner.scan_file_path(file_path) - if scan_result.issues_count > 0 or scan_result.infected_files > 0: - rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) - return True - else: - rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) - return False - except Exception as e: - print('error while scanning', file_path, 'error:', e) - return False +def update_render_devices_in_config(config, render_devices): + if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'): + raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}') -known_models = {} -def getModels(): - models = { - 'active': { - 'stable-diffusion': 'sd-v1-4', - 'vae': '', - 'hypernetwork': '', - }, - 'options': { - 'stable-diffusion': ['sd-v1-4'], - 'vae': [], - 'hypernetwork': [], - }, - } + if render_devices.startswith('cuda:'): + render_devices = render_devices.split(',') - def listModels(models_dirname, model_type, model_extensions): - models_dir = os.path.join(MODELS_DIR, models_dirname) - if not os.path.exists(models_dir): - os.makedirs(models_dir) + config['render_devices'] = render_devices - for file in os.listdir(models_dir): - for model_extension in model_extensions: - if not file.endswith(model_extension): - continue - - model_path = os.path.join(models_dir, file) - mtime = os.path.getmtime(model_path) - mod_time = known_models[model_path] if model_path in known_models else -1 - if mod_time != mtime: - if is_malicious_model(model_path): - models['scan-error'] = file - return - known_models[model_path] = mtime - - model_name = file[:-len(model_extension)] - models['options'][model_type].append(model_name) - - models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates - models['options'][model_type].sort() - - # custom models - listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS) - listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS) - listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS) - # legacy - custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt') - if os.path.exists(custom_weight_path): - models['options']['stable-diffusion'].append('custom-model') - - return models - -def getUIPlugins(): - plugins = [] - - for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES: - for file in os.listdir(plugins_dir): - if file.endswith('.plugin.js'): - plugins.append(f'/plugins/{dir_prefix}/{file}') - - return plugins - -def getIPConfig(): - ips = socket.gethostbyname_ex(socket.gethostname()) - ips[2].append(ips[0]) - return ips[2] - -@app.get('/get/{key:path}') +@server_api.get('/get/{key:path}') def read_web_data(key:str=None): if not key: # /get without parameters, stable-diffusion easter egg. raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot elif key == 'app_config': - config = getConfig(default_val=None) - if config is None: - config = APP_CONFIG_DEFAULTS - return JSONResponse(config, headers=NOCACHE_HEADERS) + return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS) elif key == 'system_info': - config = getConfig() + config = app.getConfig() system_info = { 'devices': task_manager.get_devices(), - 'hosts': getIPConfig(), + 'hosts': app.getIPConfig(), + 'default_output_dir': os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME), } system_info['devices']['config'] = config.get('render_devices', "auto") return JSONResponse(system_info, headers=NOCACHE_HEADERS) elif key == 'models': - return JSONResponse(getModels(), headers=NOCACHE_HEADERS) - elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) - elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS) - elif key == 'ui_plugins': return JSONResponse(getUIPlugins(), headers=NOCACHE_HEADERS) + return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS) + elif key == 'modifiers': return FileResponse(os.path.join(app.SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) + elif key == 'ui_plugins': return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS) else: raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found -@app.get('/ping') # Get server and optionally session status. +@server_api.get('/ping') # Get server and optionally session status. def ping(session_id:str=None): if task_manager.is_alive() <= 0: # Check that render threads are alive. if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) @@ -372,38 +146,14 @@ def ping(session_id:str=None): response['devices'] = task_manager.get_devices() return JSONResponse(response, headers=NOCACHE_HEADERS) -def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name): - config = getConfig() - if 'model' not in config: - config['model'] = {} - - config['model']['stable-diffusion'] = ckpt_model_name - config['model']['vae'] = vae_model_name - config['model']['hypernetwork'] = hypernetwork_model_name - - if vae_model_name is None or vae_model_name == "": - del config['model']['vae'] - if hypernetwork_model_name is None or hypernetwork_model_name == "": - del config['model']['hypernetwork'] - - setConfig(config) - -def update_render_devices_in_config(config, render_devices): - if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'): - raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}') - - if render_devices.startswith('cuda:'): - render_devices = render_devices.split(',') - - config['render_devices'] = render_devices - -@app.post('/render') +@server_api.post('/render') def render(req : task_manager.ImageRequest): try: - save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) - req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model) - req.use_vae_model = resolve_vae_to_use(req.use_vae_model) - req.use_hypernetwork_model = resolve_hypernetwork_to_use(req.use_hypernetwork_model) + app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) + req.use_stable_diffusion_model = model_manager.resolve_ckpt_to_use(req.use_stable_diffusion_model) + req.use_vae_model = model_manager.resolve_vae_to_use(req.use_vae_model) + req.use_hypernetwork_model = model_manager.resolve_hypernetwork_to_use(req.use_hypernetwork_model) + new_task = task_manager.render(req) response = { 'status': str(task_manager.current_state), @@ -419,7 +169,7 @@ def render(req : task_manager.ImageRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@app.get('/image/stream/{session_id:str}/{task_id:int}') +@server_api.get('/image/stream/{session_id:str}/{task_id:int}') def stream(session_id:str, task_id:int): #TODO Move to WebSockets ?? task = task_manager.get_cached_task(session_id, update_ttl=True) @@ -433,7 +183,7 @@ def stream(session_id:str, task_id:int): #print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') return StreamingResponse(task.read_buffer_generator(), media_type='application/json') -@app.get('/image/stop') +@server_api.get('/image/stop') def stop(session_id:str=None): if not session_id: if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable: @@ -446,7 +196,7 @@ def stop(session_id:str=None): task.error = StopAsyncIteration('') return {'OK'} -@app.get('/image/tmp/{session_id}/{img_id:int}') +@server_api.get('/image/tmp/{session_id}/{img_id:int}') def get_image(session_id, img_id): task = task_manager.get_cached_task(session_id, update_ttl=True) if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone @@ -458,49 +208,17 @@ def get_image(session_id, img_id): except KeyError as e: raise HTTPException(status_code=500, detail=str(e)) -@app.get('/') +@server_api.get('/') def read_root(): - return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) + return FileResponse(os.path.join(app.SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) -@app.on_event("shutdown") +@server_api.on_event("shutdown") def shutdown_event(): # Signal render thread to close on shutdown task_manager.current_state_error = SystemExit('Application shutting down.') -# don't log certain requests -class LogSuppressFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - path = record.getMessage() - for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES: - if path.find(prefix) != -1: - return False - return True -logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) - -# Check models and prepare cache for UI open -getModels() - -# Start the task_manager -task_manager.default_model_to_load = resolve_ckpt_to_use() -task_manager.default_vae_to_load = resolve_vae_to_use() -task_manager.default_hypernetwork_to_load = resolve_hypernetwork_to_use() - -def update_render_threads(): - config = getConfig() - render_devices = config.get('render_devices', 'auto') - active_devices = task_manager.get_devices()['active'].keys() - - print('requesting for render_devices', render_devices) - task_manager.update_render_threads(render_devices, active_devices) - -update_render_threads() +# Init the app +model_manager.init() +app.init() # start the browser ui -def open_browser(): - config = getConfig() - ui = config.get('ui', {}) - net = config.get('net', {'listen_port':9000}) - port = net.get('listen_port', 9000) - if ui.get('open_browser_on_start', True): - import webbrowser; webbrowser.open(f"http://localhost:{port}") - -open_browser() +app.open_browser() From bad89160cc224916fd4b367791d3a0442cc105ec Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 8 Dec 2022 13:50:46 +0530 Subject: [PATCH 02/74] Work-in-progress model loading --- ui/sd_internal/app.py | 4 ---- ui/sd_internal/model_manager.py | 25 ++++++++++++------------- ui/sd_internal/runtime2.py | 29 +++++++++++++++++++++++++---- ui/sd_internal/task_manager.py | 6 +++--- 4 files changed, 40 insertions(+), 24 deletions(-) diff --git a/ui/sd_internal/app.py b/ui/sd_internal/app.py index 09242319..00d2e718 100644 --- a/ui/sd_internal/app.py +++ b/ui/sd_internal/app.py @@ -18,10 +18,6 @@ USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui' CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui')) UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user')) -STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors'] -VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] -HYPERNETWORK_MODEL_EXTENSIONS = ['.pt'] - OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder TASK_TTL = 15 * 60 # Discard last session's task timeout APP_CONFIG_DEFAULTS = { diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 3941b130..906038e1 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -4,6 +4,10 @@ from sd_internal import app import picklescan.scanner import rich +STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors'] +VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] +HYPERNETWORK_MODEL_EXTENSIONS = ['.pt'] + default_model_to_load = None default_vae_to_load = None default_hypernetwork_to_load = None @@ -59,22 +63,17 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') return default_model_path + model_extension - raise Exception('No valid models found.') + print(f'No valid models found for model_name: {model_name}') + return None def resolve_ckpt_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS) + return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS) def resolve_vae_to_use(model_name:str=None): - try: - return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=app.VAE_MODEL_EXTENSIONS, default_models=[]) - except: - return None + return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[]) def resolve_hypernetwork_to_use(model_name:str=None): - try: - return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS, default_models=[]) - except: - return None + return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[]) def is_malicious_model(file_path): try: @@ -129,9 +128,9 @@ def getModels(): models['options'][model_type].sort() # custom models - listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS) - listModels(models_dirname='vae', model_type='vae', model_extensions=app.VAE_MODEL_EXTENSIONS) - listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS) + listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS) + listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS) + listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS) # legacy custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 32befcde..fc8d944d 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -1,7 +1,8 @@ import threading import queue -from sd_internal import device_manager, Request, Response, Image as ResponseImage +from sd_internal import device_manager, model_manager +from sd_internal import Request, Response, Image as ResponseImage from modules import model_loader, image_generator, image_utils @@ -18,7 +19,7 @@ def init(device): thread_data.temp_images = {} thread_data.models = {} - thread_data.loaded_model_paths = {} + thread_data.model_paths = {} thread_data.device = None thread_data.device_name = None @@ -27,14 +28,34 @@ def init(device): device_manager.device_init(thread_data, device) - reload_models() + load_default_models() def destroy(): model_loader.unload_sd_model(thread_data) model_loader.unload_gfpgan_model(thread_data) model_loader.unload_realesrgan_model(thread_data) -def reload_models(req: Request=None): +def load_default_models(): + thread_data.model_paths['stable-diffusion'] = model_manager.default_model_to_load + thread_data.model_paths['vae'] = model_manager.default_vae_to_load + + model_loader.load_sd_model(thread_data) + +def reload_models_if_necessary(req: Request=None): + needs_model_reload = False + if 'stable-diffusion' not in thread_data.models or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: + thread_data.ckpt_file = req.use_stable_diffusion_model + thread_data.vae_file = req.use_vae_model + needs_model_reload = True + + if thread_data.device != 'cpu': + if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ + (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): + thread_data.precision = 'full' if req.use_full_precision else 'autocast' + needs_model_reload = True + + return needs_model_reload + if is_hypernetwork_reload_necessary(task.request): current_state = ServerStates.LoadingModel runtime.reload_hypernetwork() diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 7a58ac14..04cc9a69 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -310,9 +310,6 @@ def thread_render(device): print(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: - current_state = ServerStates.LoadingModel - runtime2.reload_models(task.request) - def step_callback(): global current_state_error @@ -323,6 +320,9 @@ def thread_render(device): current_state_error = None print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') + current_state = ServerStates.LoadingModel + runtime2.reload_models_if_necessary(task.request) + current_state = ServerStates.Rendering task.response = runtime2.make_image(task.request, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive. From f4a6910ab45181d1f530bdc7a2c9d6e52c6cc90b Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 8 Dec 2022 21:39:09 +0530 Subject: [PATCH 03/74] Work-in-progress: refactored the end-to-end codebase. Missing: hypernetworks, turbo config, and SD 2. Not tested yet --- scripts/on_sd_start.bat | 5 - scripts/on_sd_start.sh | 5 - ui/index.html | 1 - ui/media/js/parameters.js | 19 +-- ui/sd_internal/__init__.py | 7 ++ ui/sd_internal/app.py | 5 - ui/sd_internal/device_manager.py | 10 +- ui/sd_internal/model_manager.py | 98 ++++++++++----- ui/sd_internal/runtime2.py | 209 ++++++++++++++++++++++--------- ui/sd_internal/task_manager.py | 2 +- ui/server.py | 6 +- 11 files changed, 239 insertions(+), 128 deletions(-) diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 2bcf56d2..d8981226 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -199,12 +199,7 @@ call WHERE uvicorn > .tmp -if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion" if not exist "..\models\vae" mkdir "..\models\vae" -if not exist "..\models\hypernetwork" mkdir "..\models\hypernetwork" -echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt" -echo. > "..\models\vae\Put your VAE files here.txt" -echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt" @if exist "sd-v1-4.ckpt" ( for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" ( diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 8682c5cc..3e1c9c58 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -159,12 +159,7 @@ fi -mkdir -p "../models/stable-diffusion" mkdir -p "../models/vae" -mkdir -p "../models/hypernetwork" -echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt" -echo "" > "../models/vae/Put your VAE files here.txt" -echo "" > "../models/hypernetwork/Put your hypernetwork files here.txt" if [ -f "sd-v1-4.ckpt" ]; then model_size=`find "sd-v1-4.ckpt" -printf "%s"` diff --git a/ui/index.html b/ui/index.html index d3fb6da3..0094201b 100644 --- a/ui/index.html +++ b/ui/index.html @@ -409,7 +409,6 @@ async function init() { await initSettings() await getModels() - await getDiskPath() await getAppConfig() await loadUIPlugins() await loadModifiers() diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index 5c522d7e..3643d1ec 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -327,20 +327,10 @@ autoPickGPUsField.addEventListener('click', function() { gpuSettingEntry.style.display = (this.checked ? 'none' : '') }) -async function getDiskPath() { - try { - var diskPath = getSetting("diskPath") - if (diskPath == '' || diskPath == undefined || diskPath == "undefined") { - let res = await fetch('/get/output_dir') - if (res.status === 200) { - res = await res.json() - res = res.output_dir - - setSetting("diskPath", res) - } - } - } catch (e) { - console.log('error fetching output dir path', e) +async function setDiskPath(defaultDiskPath) { + var diskPath = getSetting("diskPath") + if (diskPath == '' || diskPath == undefined || diskPath == "undefined") { + setSetting("diskPath", defaultDiskPath) } } @@ -415,6 +405,7 @@ async function getSystemInfo() { setDeviceInfo(devices) setHostInfo(res['hosts']) + setDiskPath(res['default_output_dir']) } catch (e) { console.log('error fetching devices', e) } diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index bf16bb5e..b001d3f9 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -105,6 +105,10 @@ class Response: request: Request images: list + def __init__(self, request: Request, images: list): + self.request = request + self.images = images + def json(self): res = { "status": 'succeeded', @@ -116,3 +120,6 @@ class Response: res["output"].append(image.json()) return res + +class UserInitiatedStop(Exception): + pass diff --git a/ui/sd_internal/app.py b/ui/sd_internal/app.py index 00d2e718..d1cec46f 100644 --- a/ui/sd_internal/app.py +++ b/ui/sd_internal/app.py @@ -28,11 +28,6 @@ APP_CONFIG_DEFAULTS = { 'open_browser_on_start': True, }, } -DEFAULT_MODELS = [ - # needed to support the legacy installations - 'custom-model', # Check if user has a custom model, use it first. - 'sd-v1-4', # Default fallback. -] def init(): os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) diff --git a/ui/sd_internal/device_manager.py b/ui/sd_internal/device_manager.py index 3c1d8bf1..4de6b265 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/sd_internal/device_manager.py @@ -101,10 +101,8 @@ def device_init(context, device): context.device = device # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images - device_name = context.device_name.lower() - force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) - if force_full_precision: - print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', context.device_name) + if needs_to_force_full_precision(context.device_name): + print(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') # Apply force_full_precision now before models are loaded. context.precision = 'full' @@ -113,6 +111,10 @@ def device_init(context, device): return +def needs_to_force_full_precision(context): + device_name = context.device_name.lower() + return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) + def validate_device_id(device, log_prefix=''): def is_valid(): if not isinstance(device, str): diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 906038e1..a8b249a2 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -1,32 +1,39 @@ import os -from sd_internal import app +from sd_internal import app, device_manager +from sd_internal import Request import picklescan.scanner import rich -STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors'] -VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] -HYPERNETWORK_MODEL_EXTENSIONS = ['.pt'] - -default_model_to_load = None -default_vae_to_load = None -default_hypernetwork_to_load = None +KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] +MODEL_EXTENSIONS = { + 'stable-diffusion': ['.ckpt', '.safetensors'], + 'vae': ['.vae.pt', '.ckpt'], + 'hypernetwork': ['.pt'], + 'gfpgan': ['.pth'], + 'realesrgan': ['.pth'], +} +DEFAULT_MODELS = { + 'stable-diffusion': [ # needed to support the legacy installations + 'custom-model', # only one custom model file was supported initially, creatively named 'custom-model' + 'sd-v1-4', # Default fallback. + ], + 'gfpgan': ['GFPGANv1.3'], + 'realesrgan': ['RealESRGAN_x4plus'], +} known_models = {} def init(): - global default_model_to_load, default_vae_to_load, default_hypernetwork_to_load - - default_model_to_load = resolve_ckpt_to_use() - default_vae_to_load = resolve_vae_to_use() - default_hypernetwork_to_load = resolve_hypernetwork_to_use() - + make_model_folders() getModels() # run this once, to cache the picklescan results -def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]): +def resolve_model_to_use(model_name:str, model_type:str): + model_extensions = MODEL_EXTENSIONS.get(model_type, []) + default_models = DEFAULT_MODELS.get(model_type, []) config = app.getConfig() - model_dirs = [os.path.join(app.MODELS_DIR, model_dir), app.SD_DIR] + model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR] if not model_name: # When None try user configured model. # config = getConfig() if 'model' in config and model_type in config['model']: @@ -39,7 +46,7 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex model_name = 'sd-v1-4' # Check models directory - models_dir_path = os.path.join(app.MODELS_DIR, model_dir, model_name) + models_dir_path = os.path.join(app.MODELS_DIR, model_type, model_name) for model_extension in model_extensions: if os.path.exists(models_dir_path + model_extension): return models_dir_path + model_extension @@ -66,14 +73,32 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex print(f'No valid models found for model_name: {model_name}') return None -def resolve_ckpt_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS) +def resolve_sd_model_to_use(model_name:str=None): + return resolve_model_to_use(model_name, model_type='stable-diffusion') -def resolve_vae_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[]) +def resolve_vae_model_to_use(model_name:str=None): + return resolve_model_to_use(model_name, model_type='vae') -def resolve_hypernetwork_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[]) +def resolve_hypernetwork_model_to_use(model_name:str=None): + return resolve_model_to_use(model_name, model_type='hypernetwork') + +def resolve_gfpgan_model_to_use(model_name:str=None): + return resolve_model_to_use(model_name, model_type='gfpgan') + +def resolve_realesrgan_model_to_use(model_name:str=None): + return resolve_model_to_use(model_name, model_type='realesrgan') + +def make_model_folders(): + for model_type in KNOWN_MODEL_TYPES: + model_dir_path = os.path.join(app.MODELS_DIR, model_type) + + os.makedirs(model_dir_path, exist_ok=True) + + help_file_name = f'Place your {model_type} model files here.txt' + help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}' + + with open(os.path.join(model_dir_path, help_file_name)) as f: + f.write(help_file_contents) def is_malicious_model(file_path): try: @@ -102,8 +127,9 @@ def getModels(): }, } - def listModels(models_dirname, model_type, model_extensions): - models_dir = os.path.join(app.MODELS_DIR, models_dirname) + def listModels(model_type): + model_extensions = MODEL_EXTENSIONS.get(model_type, []) + models_dir = os.path.join(app.MODELS_DIR, model_type) if not os.path.exists(models_dir): os.makedirs(models_dir) @@ -128,9 +154,9 @@ def getModels(): models['options'][model_type].sort() # custom models - listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS) - listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS) - listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS) + listModels(model_type='stable-diffusion') + listModels(model_type='vae') + listModels(model_type='hypernetwork') # legacy custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') @@ -138,3 +164,19 @@ def getModels(): models['options']['stable-diffusion'].append('custom-model') return models + +def is_sd_model_reload_necessary(thread_data, req: Request): + needs_model_reload = False + if 'stable-diffusion' not in thread_data.models or \ + thread_data.model_paths['stable-diffusion'] != req.use_stable_diffusion_model or \ + thread_data.model_paths['vae'] != req.use_vae_model: + + needs_model_reload = True + + if thread_data.device != 'cpu': + if (thread_data.precision == 'autocast' and req.use_full_precision) or \ + (thread_data.precision == 'full' and not req.use_full_precision and not device_manager.needs_to_force_full_precision(thread_data)): + thread_data.precision = 'full' if req.use_full_precision else 'autocast' + needs_model_reload = True + + return needs_model_reload diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index fc8d944d..840cec09 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -1,16 +1,23 @@ import threading import queue +import time +import json +import os +import base64 +import re from sd_internal import device_manager, model_manager -from sd_internal import Request, Response, Image as ResponseImage +from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop -from modules import model_loader, image_generator, image_utils +from modules import model_loader, image_generator, image_utils, image_filters thread_data = threading.local() ''' runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc ''' +filename_regex = re.compile('[^a-zA-Z0-9]') + def init(device): ''' Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device @@ -28,89 +35,167 @@ def init(device): device_manager.device_init(thread_data, device) - load_default_models() + init_and_load_default_models() def destroy(): model_loader.unload_sd_model(thread_data) model_loader.unload_gfpgan_model(thread_data) model_loader.unload_realesrgan_model(thread_data) -def load_default_models(): - thread_data.model_paths['stable-diffusion'] = model_manager.default_model_to_load - thread_data.model_paths['vae'] = model_manager.default_vae_to_load +def init_and_load_default_models(): + # init default model paths + thread_data.model_paths['stable-diffusion'] = model_manager.resolve_sd_model_to_use() + thread_data.model_paths['vae'] = model_manager.resolve_vae_model_to_use() + thread_data.model_paths['hypernetwork'] = model_manager.resolve_hypernetwork_model_to_use() + thread_data.model_paths['gfpgan'] = model_manager.resolve_gfpgan_model_to_use() + thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use() + # load mandatory models model_loader.load_sd_model(thread_data) -def reload_models_if_necessary(req: Request=None): - needs_model_reload = False - if 'stable-diffusion' not in thread_data.models or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: - thread_data.ckpt_file = req.use_stable_diffusion_model - thread_data.vae_file = req.use_vae_model - needs_model_reload = True +def reload_models_if_necessary(req: Request): + if model_manager.is_sd_model_reload_necessary(thread_data, req): + thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model + thread_data.model_paths['vae'] = req.use_vae_model - if thread_data.device != 'cpu': - if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ - (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): - thread_data.precision = 'full' if req.use_full_precision else 'autocast' - needs_model_reload = True + model_loader.load_sd_model(thread_data) - return needs_model_reload + # if is_hypernetwork_reload_necessary(task.request): + # current_state = ServerStates.LoadingModel + # runtime.reload_hypernetwork() - if is_hypernetwork_reload_necessary(task.request): - current_state = ServerStates.LoadingModel - runtime.reload_hypernetwork() +def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): + images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback) + images = apply_filters(req, images, user_stopped) - if is_model_reload_necessary(task.request): - current_state = ServerStates.LoadingModel - runtime.reload_model() + save_images(req, images) -def load_models(): - if ckpt_file_path == None: - ckpt_file_path = default_model_to_load - if vae_file_path == None: - vae_file_path = default_vae_to_load - if hypernetwork_file_path == None: - hypernetwork_file_path = default_hypernetwork_to_load - if ckpt_file_path == current_model_path and vae_file_path == current_vae_path: - return - current_state = ServerStates.LoadingModel - try: - from sd_internal import runtime2 - runtime.thread_data.hypernetwork_file = hypernetwork_file_path - runtime.thread_data.ckpt_file = ckpt_file_path - runtime.thread_data.vae_file = vae_file_path - runtime.load_model_ckpt() - runtime.load_hypernetwork() - current_model_path = ckpt_file_path - current_vae_path = vae_file_path - current_hypernetwork_path = hypernetwork_file_path - current_state_error = None - current_state = ServerStates.Online - except Exception as e: - current_model_path = None - current_vae_path = None - current_state_error = e - current_state = ServerStates.Unavailable - print(traceback.format_exc()) + return Response(req, images=construct_response(req, images)) + +def generate_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): + thread_data.temp_images.clear() + + image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback) -def make_image(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): try: images = image_generator.make_image(context=thread_data, args=get_mk_img_args(req)) + user_stopped = False except UserInitiatedStop: - pass + images = [] + user_stopped = True + if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None: + return images + for i in range(req.num_outputs): + images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0)) + + del thread_data.partial_x_samples + finally: + model_loader.gc(thread_data) + + images = [(image, req.seed + i, False) for i, image in enumerate(images)] + + return images, user_stopped + +def apply_filters(req: Request, images: list, user_stopped): + if user_stopped or (req.use_face_correction is None and req.use_upscale is None): + return images + + filters = [] + if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_gfpgan_model_to_use(req.use_face_correction))) + if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_realesrgan_model_to_use(req.use_upscale))) + + filtered_images = [] + for img, seed, _ in images: + for filter_fn, filter_model_path in filters: + img = filter_fn(thread_data, img, filter_model_path) + + filtered_images.append((img, seed, True)) + + if not req.show_only_filtered_image: + filtered_images = images + filtered_images + + return filtered_images + +def save_images(req: Request, images: list): + if req.save_to_disk_path is None: + return + + def get_image_id(i): + img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. + img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. + return img_id + + def get_image_basepath(i): + session_out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id)) + os.makedirs(session_out_path, exist_ok=True) + prompt_flattened = filename_regex.sub('_', req.prompt)[:50] + return os.path.join(session_out_path, f"{prompt_flattened}_{get_image_id(i)}") + + for i, img_data in enumerate(images): + img, seed, filtered = img_data + img_path = get_image_basepath(i) + + if not filtered or req.show_only_filtered_image: + img_metadata_path = img_path + '.txt' + metadata = req.json() + metadata['seed'] = seed + with open(img_metadata_path, 'w', encoding='utf-8') as f: + f.write(metadata) + + img_path += '_filtered' if filtered else '' + img_path += '.' + req.output_format + img.save(img_path, quality=req.output_quality) + +def construct_response(req: Request, images: list): + return [ + ResponseImage( + data=image_utils.img_to_base64_str(img, req.output_format, req.output_quality), + seed=seed + ) for img, seed, _ in images + ] def get_mk_img_args(req: Request): args = req.json() - if req.init_image is not None: - args['init_image'] = image_utils.base64_str_to_img(req.init_image) - - if req.mask is not None: - args['mask'] = image_utils.base64_str_to_img(req.mask) + args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None + args['mask'] = image_utils.base64_str_to_img(req.mask) if req.mask is not None else None return args -def on_image_step(x_samples, i): - pass +def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): + n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) + last_callback_time = -1 -image_generator.on_image_step = on_image_step + def update_temp_img(req, x_samples, task_temp_images: list): + partial_images = [] + for i in range(req.num_outputs): + img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0)) + buf = image_utils.img_to_buffer(img, output_format='JPEG') + + del img + + thread_data.temp_images[f'{req.request_id}/{i}'] = buf + task_temp_images[i] = buf + partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'}) + return partial_images + + def on_image_step(x_samples, i): + nonlocal last_callback_time + + thread_data.partial_x_samples = x_samples + step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 + last_callback_time = time.time() + + progress = {"step": i, "step_time": step_time, "total_steps": n_steps} + + if req.stream_image_progress and i % 5 == 0: + progress['output'] = update_temp_img(req, x_samples, task_temp_images) + + data_queue.put(json.dumps(progress)) + + step_callback() + + if thread_data.stop_processing: + raise UserInitiatedStop("User requested that we stop processing") + + return on_image_step diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 04cc9a69..c6ab9737 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -324,7 +324,7 @@ def thread_render(device): runtime2.reload_models_if_necessary(task.request) current_state = ServerStates.Rendering - task.response = runtime2.make_image(task.request, task.buffer_queue, task.temp_images, step_callback) + task.response = runtime2.make_images(task.request, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL) session_cache.keep(task.request.session_id, TASK_TTL) diff --git a/ui/server.py b/ui/server.py index 1cb4c910..fd2557e2 100644 --- a/ui/server.py +++ b/ui/server.py @@ -137,9 +137,9 @@ def ping(session_id:str=None): def render(req : task_manager.ImageRequest): try: app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) - req.use_stable_diffusion_model = model_manager.resolve_ckpt_to_use(req.use_stable_diffusion_model) - req.use_vae_model = model_manager.resolve_vae_to_use(req.use_vae_model) - req.use_hypernetwork_model = model_manager.resolve_hypernetwork_to_use(req.use_hypernetwork_model) + req.use_stable_diffusion_model = model_manager.resolve_sd_model_to_use(req.use_stable_diffusion_model) + req.use_vae_model = model_manager.resolve_vae_model_to_use(req.use_vae_model) + req.use_hypernetwork_model = model_manager.resolve_hypernetwork_model_to_use(req.use_hypernetwork_model) new_task = task_manager.render(req) response = { From 27c61132871c21d3c7c30197c5f6bc14429ec82e Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 13:29:06 +0530 Subject: [PATCH 04/74] Support hypernetworks; moves the hypernetwork module to diffusion-kit --- ui/sd_internal/hypernetwork.py | 198 -------------------------------- ui/sd_internal/model_manager.py | 1 - ui/sd_internal/runtime2.py | 11 +- 3 files changed, 8 insertions(+), 202 deletions(-) delete mode 100644 ui/sd_internal/hypernetwork.py diff --git a/ui/sd_internal/hypernetwork.py b/ui/sd_internal/hypernetwork.py deleted file mode 100644 index 979a74f3..00000000 --- a/ui/sd_internal/hypernetwork.py +++ /dev/null @@ -1,198 +0,0 @@ -# this is basically a cut down version of https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/c9a2cfdf2a53d37c2de1908423e4f548088667ef/modules/hypernetworks/hypernetwork.py, mostly for feature parity -# I, c0bra5, don't really understand how deep learning works. I just know how to port stuff. - -import inspect -import torch -import optimizedSD.splitAttention -from . import runtime -from einops import rearrange - -optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} - -loaded_hypernetwork = None - -class HypernetworkModule(torch.nn.Module): - multiplier = 0.5 - activation_dict = { - "linear": torch.nn.Identity, - "relu": torch.nn.ReLU, - "leakyrelu": torch.nn.LeakyReLU, - "elu": torch.nn.ELU, - "swish": torch.nn.Hardswish, - "tanh": torch.nn.Tanh, - "sigmoid": torch.nn.Sigmoid, - } - activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) - - def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', - add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False): - super().__init__() - - assert layer_structure is not None, "layer_structure must not be None" - assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" - assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - - linears = [] - for i in range(len(layer_structure) - 1): - - # Add a fully-connected layer - linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) - - # Add an activation func except last layer - if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output): - pass - elif activation_func in self.activation_dict: - linears.append(self.activation_dict[activation_func]()) - else: - raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') - - # Add layer normalization - if add_layer_norm: - linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) - - # Add dropout except last layer - if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2): - linears.append(torch.nn.Dropout(p=0.3)) - - self.linear = torch.nn.Sequential(*linears) - - self.fix_old_state_dict(state_dict) - self.load_state_dict(state_dict) - - self.to(runtime.thread_data.device) - - def fix_old_state_dict(self, state_dict): - changes = { - 'linear1.bias': 'linear.0.bias', - 'linear1.weight': 'linear.0.weight', - 'linear2.bias': 'linear.1.bias', - 'linear2.weight': 'linear.1.weight', - } - - for fr, to in changes.items(): - x = state_dict.get(fr, None) - if x is None: - continue - - del state_dict[fr] - state_dict[to] = x - - def forward(self, x: torch.Tensor): - return x + self.linear(x) * runtime.thread_data.hypernetwork_strength - -def apply_hypernetwork(hypernetwork, context, layer=None): - hypernetwork_layers = hypernetwork.get(context.shape[2], None) - - if hypernetwork_layers is None: - return context, context - - if layer is not None: - layer.hyper_k = hypernetwork_layers[0] - layer.hyper_v = hypernetwork_layers[1] - - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) - return context_k, context_v - -def get_kv(context, hypernetwork): - if hypernetwork is None: - return context, context - else: - return apply_hypernetwork(runtime.thread_data.hypernetwork, context) - -# This might need updating as the optimisedSD code changes -# I think yall have a system for this (patch files in sd_internal) but idk how it works and no amount of searching gave me any clue -# just in case for attribution https://github.com/easydiffusion/diffusion-kit/blob/e8ea0cadd543056059cd951e76d4744de76327d2/optimizedSD/splitAttention.py#L171 -def new_cross_attention_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - # default context - context = context if context is not None else x() if inspect.isfunction(x) else x - # hypernetwork! - context_k, context_v = get_kv(context, runtime.thread_data.hypernetwork) - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - - limit = k.shape[0] - att_step = self.att_step - q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0)) - k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0)) - v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0)) - - q_chunks.reverse() - k_chunks.reverse() - v_chunks.reverse() - sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - del k, q, v - for i in range (0, limit, att_step): - - q_buffer = q_chunks.pop() - k_buffer = k_chunks.pop() - v_buffer = v_chunks.pop() - sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale - - del k_buffer, q_buffer - # attention, what we cannot get enough of, by chunks - - sim_buffer = sim_buffer.softmax(dim=-1) - - sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) - del v_buffer - sim[i:i+att_step,:,:] = sim_buffer - - del sim_buffer - sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h) - return self.to_out(sim) - - -def load_hypernetwork(path: str): - - state_dict = torch.load(path, map_location='cpu') - - layer_structure = state_dict.get('layer_structure', [1, 2, 1]) - activation_func = state_dict.get('activation_func', None) - weight_init = state_dict.get('weight_initialization', 'Normal') - add_layer_norm = state_dict.get('is_layer_norm', False) - use_dropout = state_dict.get('use_dropout', False) - activate_output = state_dict.get('activate_output', True) - last_layer_dropout = state_dict.get('last_layer_dropout', False) - # this is a bit verbose so leaving it commented out for the poor soul who ever has to debug this - # print(f"layer_structure: {layer_structure}") - # print(f"activation_func: {activation_func}") - # print(f"weight_init: {weight_init}") - # print(f"add_layer_norm: {add_layer_norm}") - # print(f"use_dropout: {use_dropout}") - # print(f"activate_output: {activate_output}") - # print(f"last_layer_dropout: {last_layer_dropout}") - - layers = {} - for size, sd in state_dict.items(): - if type(size) == int: - layers[size] = ( - HypernetworkModule(size, sd[0], layer_structure, activation_func, weight_init, add_layer_norm, - use_dropout, activate_output, last_layer_dropout=last_layer_dropout), - HypernetworkModule(size, sd[1], layer_structure, activation_func, weight_init, add_layer_norm, - use_dropout, activate_output, last_layer_dropout=last_layer_dropout), - ) - print(f"hypernetwork loaded") - return layers - - - -# overriding of original function -old_cross_attention_forward = optimizedSD.splitAttention.CrossAttention.forward -# hijacks the cross attention forward function to add hyper network support -def hijack_cross_attention(): - print("hypernetwork functionality added to cross attention") - optimizedSD.splitAttention.CrossAttention.forward = new_cross_attention_forward -# there was a cop on board -def unhijack_cross_attention_forward(): - print("hypernetwork functionality removed from cross attention") - optimizedSD.splitAttention.CrossAttention.forward = old_cross_attention_forward - -hijack_cross_attention() \ No newline at end of file diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index a8b249a2..21acb540 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -70,7 +70,6 @@ def resolve_model_to_use(model_name:str, model_type:str): print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') return default_model_path + model_extension - print(f'No valid models found for model_name: {model_name}') return None def resolve_sd_model_to_use(model_name:str=None): diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 840cec09..90a868a5 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -41,6 +41,7 @@ def destroy(): model_loader.unload_sd_model(thread_data) model_loader.unload_gfpgan_model(thread_data) model_loader.unload_realesrgan_model(thread_data) + model_loader.unload_hypernetwork_model(thread_data) def init_and_load_default_models(): # init default model paths @@ -60,9 +61,13 @@ def reload_models_if_necessary(req: Request): model_loader.load_sd_model(thread_data) - # if is_hypernetwork_reload_necessary(task.request): - # current_state = ServerStates.LoadingModel - # runtime.reload_hypernetwork() + if thread_data.model_paths.get('hypernetwork') != req.use_hypernetwork_model: + thread_data.model_paths['hypernetwork'] = req.use_hypernetwork_model + + if thread_data.model_paths['hypernetwork'] is not None: + model_loader.load_hypernetwork_model(thread_data) + else: + model_loader.unload_hypernetwork_model(thread_data) def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback) From 16410d90b81d9ceec1cc71e347224faae491d40c Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 15:21:49 +0530 Subject: [PATCH 05/74] Use the simplified model loading API in diffusion-kit; Catch and report exceptions while generating images --- ui/sd_internal/runtime2.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 90a868a5..0af0ead4 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -5,6 +5,7 @@ import json import os import base64 import re +import traceback from sd_internal import device_manager, model_manager from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop @@ -38,10 +39,10 @@ def init(device): init_and_load_default_models() def destroy(): - model_loader.unload_sd_model(thread_data) - model_loader.unload_gfpgan_model(thread_data) - model_loader.unload_realesrgan_model(thread_data) - model_loader.unload_hypernetwork_model(thread_data) + model_loader.unload_model(thread_data, 'stable-diffusion') + model_loader.unload_model(thread_data, 'gfpgan') + model_loader.unload_model(thread_data, 'realesrgan') + model_loader.unload_model(thread_data, 'hypernetwork') def init_and_load_default_models(): # init default model paths @@ -52,24 +53,36 @@ def init_and_load_default_models(): thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use() # load mandatory models - model_loader.load_sd_model(thread_data) + model_loader.load_model(thread_data, 'stable-diffusion') def reload_models_if_necessary(req: Request): if model_manager.is_sd_model_reload_necessary(thread_data, req): thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model thread_data.model_paths['vae'] = req.use_vae_model - model_loader.load_sd_model(thread_data) + model_loader.load_model(thread_data, 'stable-diffusion') if thread_data.model_paths.get('hypernetwork') != req.use_hypernetwork_model: thread_data.model_paths['hypernetwork'] = req.use_hypernetwork_model if thread_data.model_paths['hypernetwork'] is not None: - model_loader.load_hypernetwork_model(thread_data) + model_loader.load_model(thread_data, 'hypernetwork') else: - model_loader.unload_hypernetwork_model(thread_data) + model_loader.unload_model(thread_data, 'hypernetwork') def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): + try: + return _make_images_internal(req, data_queue, task_temp_images, step_callback) + except Exception as e: + print(traceback.format_exc()) + + data_queue.put(json.dumps({ + "status": 'failed', + "detail": str(e) + })) + raise e + +def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback) images = apply_filters(req, images, user_stopped) From accfec9007a75214d781228b8eb91928c47b6f3d Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 15:22:56 +0530 Subject: [PATCH 06/74] Space --- ui/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ui/server.py b/ui/server.py index fd2557e2..b3c54342 100644 --- a/ui/server.py +++ b/ui/server.py @@ -20,7 +20,6 @@ server_api = FastAPI() # don't show access log entries for URLs that start with the given prefix ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails'] - NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} class NoCacheStaticFiles(StaticFiles): From aa59575df39b8fec43b2a09bddec573961c0d3dd Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 15:24:55 +0530 Subject: [PATCH 07/74] Remove unused patch files --- ui/sd_internal/ddim_callback.patch | 162 ------------------------- ui/sd_internal/ddim_callback_sd2.patch | 84 ------------- 2 files changed, 246 deletions(-) delete mode 100644 ui/sd_internal/ddim_callback.patch delete mode 100644 ui/sd_internal/ddim_callback_sd2.patch diff --git a/ui/sd_internal/ddim_callback.patch b/ui/sd_internal/ddim_callback.patch deleted file mode 100644 index e4dd69e0..00000000 --- a/ui/sd_internal/ddim_callback.patch +++ /dev/null @@ -1,162 +0,0 @@ -diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py -index 79058bc..a473411 100644 ---- a/optimizedSD/ddpm.py -+++ b/optimizedSD/ddpm.py -@@ -564,12 +564,12 @@ class UNet(DDPM): - unconditional_guidance_scale=unconditional_guidance_scale, - callback=callback, img_callback=img_callback) - -+ yield from samples -+ - if(self.turbo): - self.model1.to("cpu") - self.model2.to("cpu") - -- return samples -- - @torch.no_grad() - def plms_sampling(self, cond,b, img, - ddim_use_original_steps=False, -@@ -608,10 +608,10 @@ class UNet(DDPM): - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) -- if callback: callback(i) -- if img_callback: img_callback(pred_x0, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(pred_x0, i) - -- return img -+ yield from img_callback(img, len(iterator)-1) - - @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, -@@ -740,13 +740,13 @@ class UNet(DDPM): - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) - -- if callback: callback(i) -- if img_callback: img_callback(x_dec, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x_dec, i) - - if mask is not None: -- return x0 * mask + (1. - mask) * x_dec -+ x_dec = x0 * mask + (1. - mask) * x_dec - -- return x_dec -+ yield from img_callback(x_dec, len(iterator)-1) - - - @torch.no_grad() -@@ -820,12 +820,12 @@ class UNet(DDPM): - - - d = to_d(x, sigma_hat, denoised) -- if callback: callback(i) -- if img_callback: img_callback(x, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x, i) - dt = sigmas[i + 1] - sigma_hat - # Euler method - x = x + d * dt -- return x -+ yield from img_callback(x, len(sigmas)-1) - - @torch.no_grad() - def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, img_callback=None): -@@ -852,14 +852,14 @@ class UNet(DDPM): - denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) -- if callback: callback(i) -- if img_callback: img_callback(x, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x, i) - d = to_d(x, sigmas[i], denoised) - # Euler method - dt = sigma_down - sigmas[i] - x = x + d * dt - x = x + torch.randn_like(x) * sigma_up -- return x -+ yield from img_callback(x, len(sigmas)-1) - - - -@@ -892,8 +892,8 @@ class UNet(DDPM): - denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - d = to_d(x, sigma_hat, denoised) -- if callback: callback(i) -- if img_callback: img_callback(x, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x, i) - dt = sigmas[i + 1] - sigma_hat - if sigmas[i + 1] == 0: - # Euler method -@@ -913,7 +913,7 @@ class UNet(DDPM): - d_2 = to_d(x_2, sigmas[i + 1], denoised_2) - d_prime = (d + d_2) / 2 - x = x + d_prime * dt -- return x -+ yield from img_callback(x, len(sigmas)-1) - - - @torch.no_grad() -@@ -944,8 +944,8 @@ class UNet(DDPM): - e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) - denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - -- if callback: callback(i) -- if img_callback: img_callback(x, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x, i) - - d = to_d(x, sigma_hat, denoised) - # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule -@@ -966,7 +966,7 @@ class UNet(DDPM): - - d_2 = to_d(x_2, sigma_mid, denoised_2) - x = x + d_2 * dt_2 -- return x -+ yield from img_callback(x, len(sigmas)-1) - - - @torch.no_grad() -@@ -994,8 +994,8 @@ class UNet(DDPM): - - - sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) -- if callback: callback(i) -- if img_callback: img_callback(x, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x, i) - d = to_d(x, sigmas[i], denoised) - # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule - sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 -@@ -1016,7 +1016,7 @@ class UNet(DDPM): - d_2 = to_d(x_2, sigma_mid, denoised_2) - x = x + d_2 * dt_2 - x = x + torch.randn_like(x) * sigma_up -- return x -+ yield from img_callback(x, len(sigmas)-1) - - - @torch.no_grad() -@@ -1042,8 +1042,8 @@ class UNet(DDPM): - e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) - denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - -- if callback: callback(i) -- if img_callback: img_callback(x, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x, i) - - d = to_d(x, sigmas[i], denoised) - ds.append(d) -@@ -1054,4 +1054,4 @@ class UNet(DDPM): - cur_order = min(i + 1, order) - coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] - x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) -- return x -+ yield from img_callback(x, len(sigmas)-1) diff --git a/ui/sd_internal/ddim_callback_sd2.patch b/ui/sd_internal/ddim_callback_sd2.patch deleted file mode 100644 index cadf81ca..00000000 --- a/ui/sd_internal/ddim_callback_sd2.patch +++ /dev/null @@ -1,84 +0,0 @@ -diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py -index 27ead0e..6215939 100644 ---- a/ldm/models/diffusion/ddim.py -+++ b/ldm/models/diffusion/ddim.py -@@ -100,7 +100,7 @@ class DDIMSampler(object): - size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - -- samples, intermediates = self.ddim_sampling(conditioning, size, -+ samples = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, -@@ -117,7 +117,8 @@ class DDIMSampler(object): - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule - ) -- return samples, intermediates -+ # return samples, intermediates -+ yield from samples - - @torch.no_grad() - def ddim_sampling(self, cond, shape, -@@ -168,14 +169,15 @@ class DDIMSampler(object): - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold) - img, pred_x0 = outs -- if callback: callback(i) -- if img_callback: img_callback(pred_x0, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - -- return img, intermediates -+ # return img, intermediates -+ yield from img_callback(pred_x0, len(iterator)-1) - - @torch.no_grad() - def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, -diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py -index 7002a36..0951f39 100644 ---- a/ldm/models/diffusion/plms.py -+++ b/ldm/models/diffusion/plms.py -@@ -96,7 +96,7 @@ class PLMSSampler(object): - size = (batch_size, C, H, W) - print(f'Data shape for PLMS sampling is {size}') - -- samples, intermediates = self.plms_sampling(conditioning, size, -+ samples = self.plms_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, -@@ -112,7 +112,8 @@ class PLMSSampler(object): - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ) -- return samples, intermediates -+ #return samples, intermediates -+ yield from samples - - @torch.no_grad() - def plms_sampling(self, cond, shape, -@@ -165,14 +166,15 @@ class PLMSSampler(object): - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) -- if callback: callback(i) -- if img_callback: img_callback(pred_x0, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - -- return img, intermediates -+ # return img, intermediates -+ yield from img_callback(pred_x0, len(iterator)-1) - - @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, From b40fb3a422472fc0fb4e2d68edf9876e8bd881dc Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 15:27:40 +0530 Subject: [PATCH 08/74] Model readme file write flag --- ui/sd_internal/model_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 21acb540..c7d22e1d 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -96,7 +96,7 @@ def make_model_folders(): help_file_name = f'Place your {model_type} model files here.txt' help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}' - with open(os.path.join(model_dir_path, help_file_name)) as f: + with open(os.path.join(model_dir_path, help_file_name), 'w', encoding='utf-8') as f: f.write(help_file_contents) def is_malicious_model(file_path): From 882081400251184f70658af06d1b8dee75364535 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 15:45:36 +0530 Subject: [PATCH 09/74] Simplify the API for resolving model paths; Code cleanup --- ui/sd_internal/model_manager.py | 17 +---------------- ui/sd_internal/runtime2.py | 17 ++++++----------- ui/sd_internal/task_manager.py | 4 ++-- ui/server.py | 6 +++--- 4 files changed, 12 insertions(+), 32 deletions(-) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index c7d22e1d..6cf9428a 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -28,7 +28,7 @@ def init(): make_model_folders() getModels() # run this once, to cache the picklescan results -def resolve_model_to_use(model_name:str, model_type:str): +def resolve_model_to_use(model_name:str=None, model_type:str=None): model_extensions = MODEL_EXTENSIONS.get(model_type, []) default_models = DEFAULT_MODELS.get(model_type, []) config = app.getConfig() @@ -72,21 +72,6 @@ def resolve_model_to_use(model_name:str, model_type:str): return None -def resolve_sd_model_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='stable-diffusion') - -def resolve_vae_model_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='vae') - -def resolve_hypernetwork_model_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='hypernetwork') - -def resolve_gfpgan_model_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='gfpgan') - -def resolve_realesrgan_model_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='realesrgan') - def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: model_dir_path = os.path.join(app.MODELS_DIR, model_type) diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 0af0ead4..3d4ff8ff 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -39,18 +39,13 @@ def init(device): init_and_load_default_models() def destroy(): - model_loader.unload_model(thread_data, 'stable-diffusion') - model_loader.unload_model(thread_data, 'gfpgan') - model_loader.unload_model(thread_data, 'realesrgan') - model_loader.unload_model(thread_data, 'hypernetwork') + for model_type in ('stable-diffusion', 'hypernetwork', 'gfpgan', 'realesrgan'): + model_loader.unload_model(thread_data, model_type) def init_and_load_default_models(): # init default model paths - thread_data.model_paths['stable-diffusion'] = model_manager.resolve_sd_model_to_use() - thread_data.model_paths['vae'] = model_manager.resolve_vae_model_to_use() - thread_data.model_paths['hypernetwork'] = model_manager.resolve_hypernetwork_model_to_use() - thread_data.model_paths['gfpgan'] = model_manager.resolve_gfpgan_model_to_use() - thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use() + for model_type in ('stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'): + thread_data.model_paths[model_type] = model_manager.resolve_model_to_use(model_type=model_type) # load mandatory models model_loader.load_model(thread_data, 'stable-diffusion') @@ -119,8 +114,8 @@ def apply_filters(req: Request, images: list, user_stopped): return images filters = [] - if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_gfpgan_model_to_use(req.use_face_correction))) - if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_realesrgan_model_to_use(req.use_upscale))) + if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(req.use_face_correction, model_type='gfpgan'))) + if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(req.use_upscale, model_type='realesrgan'))) filtered_images = [] for img, seed, _ in images: diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index c6ab9737..aec79239 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -11,10 +11,10 @@ TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout import torch import queue, threading, time, weakref -from typing import Any, Generator, Hashable, Optional, Union +from typing import Any, Hashable from pydantic import BaseModel -from sd_internal import Request, Response, runtime, device_manager +from sd_internal import Request, device_manager THREAD_NAME_PREFIX = 'Runtime-Render/' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' diff --git a/ui/server.py b/ui/server.py index b3c54342..ca5cac79 100644 --- a/ui/server.py +++ b/ui/server.py @@ -136,9 +136,9 @@ def ping(session_id:str=None): def render(req : task_manager.ImageRequest): try: app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) - req.use_stable_diffusion_model = model_manager.resolve_sd_model_to_use(req.use_stable_diffusion_model) - req.use_vae_model = model_manager.resolve_vae_model_to_use(req.use_vae_model) - req.use_hypernetwork_model = model_manager.resolve_hypernetwork_model_to_use(req.use_hypernetwork_model) + req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion') + req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae') + req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork') new_task = task_manager.render(req) response = { From 3fbb3f677384f002ecbd1b8501ef081de59fcb42 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 16:09:10 +0530 Subject: [PATCH 10/74] Use const --- ui/sd_internal/runtime2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 3d4ff8ff..57dd7150 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -44,7 +44,7 @@ def destroy(): def init_and_load_default_models(): # init default model paths - for model_type in ('stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'): + for model_type in model_manager.KNOWN_MODEL_TYPES: thread_data.model_paths[model_type] = model_manager.resolve_model_to_use(model_type=model_type) # load mandatory models From 0f656dbf2fae7262ba9996303864675ebe2d1335 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 16:11:08 +0530 Subject: [PATCH 11/74] Typo --- ui/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/server.py b/ui/server.py index ca5cac79..458bc1aa 100644 --- a/ui/server.py +++ b/ui/server.py @@ -53,7 +53,7 @@ logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media") for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES: - app.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}") + server_api.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}") @server_api.post('/app_config') async def setAppConfig(req : SetAppConfigRequest): From dbac2655f53961bb81bf57eab345e4070c9e934d Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 16:14:04 +0530 Subject: [PATCH 12/74] Typo --- ui/sd_internal/runtime2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 57dd7150..080180a4 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -10,7 +10,7 @@ import traceback from sd_internal import device_manager, model_manager from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop -from modules import model_loader, image_generator, image_utils, image_filters +from modules import model_loader, image_generator, image_utils, filters as image_filters thread_data = threading.local() ''' From f1de0be679cb84e7d7e804b65652a26d73b6c8f6 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 17:50:33 +0530 Subject: [PATCH 13/74] Fix integration issues after the refactor --- scripts/on_sd_start.bat | 2 +- scripts/on_sd_start.sh | 2 +- ui/sd_internal/device_manager.py | 2 +- ui/sd_internal/runtime2.py | 9 +++++++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index d8981226..9bac3757 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -393,7 +393,7 @@ call python --version @if NOT DEFINED SD_UI_BIND_PORT set SD_UI_BIND_PORT=9000 @if NOT DEFINED SD_UI_BIND_IP set SD_UI_BIND_IP=0.0.0.0 -@uvicorn server:app --app-dir "%SD_UI_PATH%" --port %SD_UI_BIND_PORT% --host %SD_UI_BIND_IP% +@uvicorn server:server_api --app-dir "%SD_UI_PATH%" --port %SD_UI_BIND_PORT% --host %SD_UI_BIND_IP% @pause diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 3e1c9c58..afb95295 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -322,6 +322,6 @@ cd .. export SD_UI_PATH=`pwd`/ui cd stable-diffusion -uvicorn server:app --app-dir "$SD_UI_PATH" --port ${SD_UI_BIND_PORT:-9000} --host ${SD_UI_BIND_IP:-0.0.0.0} +uvicorn server:server_api --app-dir "$SD_UI_PATH" --port ${SD_UI_BIND_PORT:-9000} --host ${SD_UI_BIND_IP:-0.0.0.0} read -p "Press any key to continue" diff --git a/ui/sd_internal/device_manager.py b/ui/sd_internal/device_manager.py index 4de6b265..c490a0c6 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/sd_internal/device_manager.py @@ -101,7 +101,7 @@ def device_init(context, device): context.device = device # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images - if needs_to_force_full_precision(context.device_name): + if needs_to_force_full_precision(context): print(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') # Apply force_full_precision now before models are loaded. context.precision = 'full' diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 080180a4..0fcf8eb7 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -83,7 +83,12 @@ def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_image save_images(req, images) - return Response(req, images=construct_response(req, images)) + res = Response(req, images=construct_response(req, images)) + res = res.json() + data_queue.put(json.dumps(res)) + print('Task completed') + + return res def generate_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): thread_data.temp_images.clear() @@ -91,7 +96,7 @@ def generate_images(req: Request, data_queue: queue.Queue, task_temp_images: lis image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback) try: - images = image_generator.make_image(context=thread_data, args=get_mk_img_args(req)) + images = image_generator.make_images(context=thread_data, args=get_mk_img_args(req)) user_stopped = False except UserInitiatedStop: images = [] From 79cc84b6117e4885e525821bc3eb078c3013f30c Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 19:39:56 +0530 Subject: [PATCH 14/74] Option to apply color correction (balances the histogram) during inpainting; Refactor the runtime to use a general-purpose dict --- ui/index.html | 1 + ui/media/js/auto-save.js | 3 +- ui/media/js/main.js | 8 ++++ ui/sd_internal/__init__.py | 4 ++ ui/sd_internal/runtime2.py | 83 ++++++++++++++++++++-------------- ui/sd_internal/task_manager.py | 2 + 6 files changed, 66 insertions(+), 35 deletions(-) diff --git a/ui/index.html b/ui/index.html index 0094201b..30648424 100644 --- a/ui/index.html +++ b/ui/index.html @@ -213,6 +213,7 @@
  • Render Settings
  • +
  • diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js index f503779a..9025f988 100644 --- a/ui/media/js/auto-save.js +++ b/ui/media/js/auto-save.js @@ -39,7 +39,8 @@ const SETTINGS_IDS_LIST = [ "turbo", "use_full_precision", "confirm_dangerous_actions", - "auto_save_settings" + "auto_save_settings", + "apply_color_correction" ] const IGNORE_BY_DEFAULT = [ diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 67ff7a9a..6c08aef7 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -26,6 +26,8 @@ let initImagePreview = document.querySelector("#init_image_preview") let initImageSizeBox = document.querySelector("#init_image_size_box") let maskImageSelector = document.querySelector("#mask") let maskImagePreview = document.querySelector("#mask_preview") +let applyColorCorrectionField = document.querySelector('#apply_color_correction') +let colorCorrectionSetting = document.querySelector('#apply_color_correction_setting') let promptStrengthSlider = document.querySelector('#prompt_strength_slider') let promptStrengthField = document.querySelector('#prompt_strength') let samplerField = document.querySelector('#sampler') @@ -759,6 +761,9 @@ function createTask(task) { taskConfig += `, Hypernetwork: ${task.reqBody.use_hypernetwork_model}` taskConfig += `, Hypernetwork Strength: ${task.reqBody.hypernetwork_strength}` } + if (task.reqBody.apply_color_correction) { + taskConfig += `, Color Correction: true` + } let taskEntry = document.createElement('div') taskEntry.id = `imageTaskContainer-${Date.now()}` @@ -867,6 +872,7 @@ function getCurrentUserRequest() { if (maskSetting.checked) { newTask.reqBody.mask = imageInpainter.getImg() } + newTask.reqBody.apply_color_correction = applyColorCorrectionField.checked newTask.reqBody.sampler = 'ddim' } else { newTask.reqBody.sampler = samplerField.value @@ -1257,6 +1263,7 @@ function img2imgLoad() { promptStrengthContainer.style.display = 'table-row' samplerSelectionContainer.style.display = "none" initImagePreviewContainer.classList.add("has-image") + colorCorrectionSetting.style.display = '' initImageSizeBox.textContent = initImagePreview.naturalWidth + " x " + initImagePreview.naturalHeight imageEditor.setImage(this.src, initImagePreview.naturalWidth, initImagePreview.naturalHeight) @@ -1271,6 +1278,7 @@ function img2imgUnload() { promptStrengthContainer.style.display = "none" samplerSelectionContainer.style.display = "" initImagePreviewContainer.classList.remove("has-image") + colorCorrectionSetting.style.display = 'none' imageEditor.setImage(null, parseInt(widthField.value), parseInt(heightField.value)) } diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index b001d3f9..9dd4f066 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -7,6 +7,7 @@ class Request: negative_prompt: str = "" init_image: str = None # base64 mask: str = None # base64 + apply_color_correction = False num_outputs: int = 1 num_inference_steps: int = 50 guidance_scale: float = 7.5 @@ -35,6 +36,7 @@ class Request: def json(self): return { + "request_id": self.request_id, "session_id": self.session_id, "prompt": self.prompt, "negative_prompt": self.negative_prompt, @@ -46,6 +48,7 @@ class Request: "seed": self.seed, "prompt_strength": self.prompt_strength, "sampler": self.sampler, + "apply_color_correction": self.apply_color_correction, "use_face_correction": self.use_face_correction, "use_upscale": self.use_upscale, "use_stable_diffusion_model": self.use_stable_diffusion_model, @@ -71,6 +74,7 @@ class Request: save_to_disk_path: {self.save_to_disk_path} turbo: {self.turbo} use_full_precision: {self.use_full_precision} + apply_color_correction: {self.apply_color_correction} use_face_correction: {self.use_face_correction} use_upscale: {self.use_upscale} use_stable_diffusion_model: {self.use_stable_diffusion_model} diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 0fcf8eb7..f74acc12 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -78,10 +78,15 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s raise e def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): - images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback) - images = apply_filters(req, images, user_stopped) + args = req_to_args(req) - save_images(req, images) + images, user_stopped = generate_images(args, data_queue, task_temp_images, step_callback, req.stream_image_progress) + images = apply_color_correction(args, images, user_stopped) + images = apply_filters(args, images, user_stopped, req.show_only_filtered_image) + + if req.save_to_disk_path is not None: + out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id)) + save_images(images, out_path, metadata=req.json(), show_only_filtered_image=req.show_only_filtered_image) res = Response(req, images=construct_response(req, images)) res = res.json() @@ -90,37 +95,48 @@ def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_image return res -def generate_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): +def generate_images(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): thread_data.temp_images.clear() - image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback) + image_generator.on_image_step = make_step_callback(args, data_queue, task_temp_images, step_callback, stream_image_progress) try: - images = image_generator.make_images(context=thread_data, args=get_mk_img_args(req)) + images = image_generator.make_images(context=thread_data, args=args) user_stopped = False except UserInitiatedStop: images = [] user_stopped = True if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None: return images - for i in range(req.num_outputs): + for i in range(args['num_outputs']): images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0)) del thread_data.partial_x_samples finally: model_loader.gc(thread_data) - images = [(image, req.seed + i, False) for i, image in enumerate(images)] + images = [(image, args['seed'] + i, False) for i, image in enumerate(images)] return images, user_stopped -def apply_filters(req: Request, images: list, user_stopped): - if user_stopped or (req.use_face_correction is None and req.use_upscale is None): +def apply_color_correction(args: dict, images: list, user_stopped): + if user_stopped or args['init_image'] is None or not args['apply_color_correction']: + return images + + for i, img_info in enumerate(images): + img, seed, filtered = img_info + img = image_utils.apply_color_correction(orig_image=args['init_image'], image_to_correct=img) + images[i] = (img, seed, filtered) + + return images + +def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_image): + if user_stopped or (args['use_face_correction'] is None and args['use_upscale'] is None): return images filters = [] - if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(req.use_face_correction, model_type='gfpgan'))) - if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(req.use_upscale, model_type='realesrgan'))) + if args['use_face_correction'].startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(args['use_face_correction'], model_type='gfpgan'))) + if args['use_face_correction'].use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(args['use_upscale'], model_type='realesrgan'))) filtered_images = [] for img, seed, _ in images: @@ -129,13 +145,13 @@ def apply_filters(req: Request, images: list, user_stopped): filtered_images.append((img, seed, True)) - if not req.show_only_filtered_image: + if not show_only_filtered_image: filtered_images = images + filtered_images return filtered_images -def save_images(req: Request, images: list): - if req.save_to_disk_path is None: +def save_images(images: list, save_to_disk_path, metadata: dict, show_only_filtered_image): + if save_to_disk_path is None: return def get_image_id(i): @@ -144,25 +160,24 @@ def save_images(req: Request, images: list): return img_id def get_image_basepath(i): - session_out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id)) - os.makedirs(session_out_path, exist_ok=True) - prompt_flattened = filename_regex.sub('_', req.prompt)[:50] - return os.path.join(session_out_path, f"{prompt_flattened}_{get_image_id(i)}") + os.makedirs(save_to_disk_path, exist_ok=True) + prompt_flattened = filename_regex.sub('_', metadata['prompt'])[:50] + return os.path.join(save_to_disk_path, f"{prompt_flattened}_{get_image_id(i)}") for i, img_data in enumerate(images): img, seed, filtered = img_data img_path = get_image_basepath(i) - if not filtered or req.show_only_filtered_image: + if not filtered or show_only_filtered_image: img_metadata_path = img_path + '.txt' - metadata = req.json() - metadata['seed'] = seed + m = metadata.copy() + m['seed'] = seed with open(img_metadata_path, 'w', encoding='utf-8') as f: - f.write(metadata) + f.write(m) img_path += '_filtered' if filtered else '' - img_path += '.' + req.output_format - img.save(img_path, quality=req.output_quality) + img_path += '.' + metadata['output_format'] + img.save(img_path, quality=metadata['output_quality']) def construct_response(req: Request, images: list): return [ @@ -172,7 +187,7 @@ def construct_response(req: Request, images: list): ) for img, seed, _ in images ] -def get_mk_img_args(req: Request): +def req_to_args(req: Request): args = req.json() args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None @@ -180,21 +195,21 @@ def get_mk_img_args(req: Request): return args -def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): - n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) +def make_step_callback(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): + n_steps = args['num_inference_steps'] if args['init_image'] is None else int(args['num_inference_steps'] * args['prompt_strength']) last_callback_time = -1 - def update_temp_img(req, x_samples, task_temp_images: list): + def update_temp_img(x_samples, task_temp_images: list): partial_images = [] - for i in range(req.num_outputs): + for i in range(args['num_outputs']): img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0)) buf = image_utils.img_to_buffer(img, output_format='JPEG') del img - thread_data.temp_images[f'{req.request_id}/{i}'] = buf + thread_data.temp_images[f"{args['request_id']}/{i}"] = buf task_temp_images[i] = buf - partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'}) + partial_images.append({'path': f"/image/tmp/{args['request_id']}/{i}"}) return partial_images def on_image_step(x_samples, i): @@ -206,8 +221,8 @@ def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images: progress = {"step": i, "step_time": step_time, "total_steps": n_steps} - if req.stream_image_progress and i % 5 == 0: - progress['output'] = update_temp_img(req, x_samples, task_temp_images) + if stream_image_progress and i % 5 == 0: + progress['output'] = update_temp_img(x_samples, task_temp_images) data_queue.put(json.dumps(progress)) diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index aec79239..6b0a1d8c 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -76,6 +76,7 @@ class ImageRequest(BaseModel): negative_prompt: str = "" init_image: str = None # base64 mask: str = None # base64 + apply_color_correction: bool = False num_outputs: int = 1 num_inference_steps: int = 50 guidance_scale: float = 7.5 @@ -522,6 +523,7 @@ def render(req : ImageRequest): r.negative_prompt = req.negative_prompt r.init_image = req.init_image r.mask = req.mask + r.apply_color_correction = req.apply_color_correction r.num_outputs = req.num_outputs r.num_inference_steps = req.num_inference_steps r.guidance_scale = req.guidance_scale From cde8c2d3bd3f307d2c19a3fcdda689b9ab2e26b9 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 21:30:18 +0530 Subject: [PATCH 15/74] Use a logger --- ui/sd_internal/app.py | 23 +++++++++---- ui/sd_internal/device_manager.py | 23 +++++++------ ui/sd_internal/model_manager.py | 23 +++++++++---- ui/sd_internal/runtime2.py | 7 ++-- ui/sd_internal/task_manager.py | 56 +++++++++++++++++--------------- ui/server.py | 13 ++++---- 6 files changed, 87 insertions(+), 58 deletions(-) diff --git a/ui/sd_internal/app.py b/ui/sd_internal/app.py index d1cec46f..c34f2ea6 100644 --- a/ui/sd_internal/app.py +++ b/ui/sd_internal/app.py @@ -3,9 +3,21 @@ import socket import sys import json import traceback +import logging +from rich.logging import RichHandler from sd_internal import task_manager +LOG_FORMAT = '[%(threadName)s] %(message)s' +logging.basicConfig( + level=logging.INFO, + format=LOG_FORMAT, + datefmt="[%X.%f]", + handlers=[RichHandler(markup=True)] +) + +log = logging.getLogger() + SD_DIR = os.getcwd() SD_UI_DIR = os.getenv('SD_UI_PATH', None) @@ -49,8 +61,7 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS): config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0') return config except Exception as e: - print(str(e)) - print(traceback.format_exc()) + log.warn(traceback.format_exc()) return default_val def setConfig(config): @@ -59,7 +70,7 @@ def setConfig(config): with open(config_json_path, 'w', encoding='utf-8') as f: json.dump(config, f) except: - print(traceback.format_exc()) + log.error(traceback.format_exc()) try: # config.bat config_bat_path = os.path.join(CONFIG_DIR, 'config.bat') @@ -78,7 +89,7 @@ def setConfig(config): with open(config_bat_path, 'w', encoding='utf-8') as f: f.write('\r\n'.join(config_bat)) except: - print(traceback.format_exc()) + log.error(traceback.format_exc()) try: # config.sh config_sh_path = os.path.join(CONFIG_DIR, 'config.sh') @@ -97,7 +108,7 @@ def setConfig(config): with open(config_sh_path, 'w', encoding='utf-8') as f: f.write('\n'.join(config_sh)) except: - print(traceback.format_exc()) + log.error(traceback.format_exc()) def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name): config = getConfig() @@ -120,7 +131,7 @@ def update_render_threads(): render_devices = config.get('render_devices', 'auto') active_devices = task_manager.get_devices()['active'].keys() - print('requesting for render_devices', render_devices) + log.debug(f'requesting for render_devices: {render_devices}') task_manager.update_render_threads(render_devices, active_devices) def getUIPlugins(): diff --git a/ui/sd_internal/device_manager.py b/ui/sd_internal/device_manager.py index c490a0c6..a3f91cfb 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/sd_internal/device_manager.py @@ -2,6 +2,9 @@ import os import torch import traceback import re +import logging + +log = logging.getLogger() COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked @@ -34,7 +37,7 @@ def get_device_delta(render_devices, active_devices): if 'auto' in render_devices: render_devices = auto_pick_devices(active_devices) if 'cpu' in render_devices: - print('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!') + log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!') active_devices = set(active_devices) render_devices = set(render_devices) @@ -53,7 +56,7 @@ def auto_pick_devices(currently_active_devices): if device_count == 1: return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu'] - print('Autoselecting GPU. Using most free memory.') + log.debug('Autoselecting GPU. Using most free memory.') devices = [] for device in range(device_count): device = f'cuda:{device}' @@ -64,7 +67,7 @@ def auto_pick_devices(currently_active_devices): mem_free /= float(10**9) mem_total /= float(10**9) device_name = torch.cuda.get_device_name(device) - print(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb') + log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb') devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free}) devices.sort(key=lambda x:x['mem_free'], reverse=True) @@ -94,7 +97,7 @@ def device_init(context, device): context.device = 'cpu' context.device_name = get_processor_name() context.precision = 'full' - print('Render device CPU available as', context.device_name) + log.debug(f'Render device CPU available as {context.device_name}') return context.device_name = torch.cuda.get_device_name(device) @@ -102,11 +105,11 @@ def device_init(context, device): # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images if needs_to_force_full_precision(context): - print(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') + log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') # Apply force_full_precision now before models are loaded. context.precision = 'full' - print(f'Setting {device} as active') + log.info(f'Setting {device} as active') torch.cuda.device(device) return @@ -135,7 +138,7 @@ def is_device_compatible(device): try: validate_device_id(device, log_prefix='is_device_compatible') except: - print(str(e)) + log.error(str(e)) return False if device == 'cpu': return True @@ -144,10 +147,10 @@ def is_device_compatible(device): _, mem_total = torch.cuda.mem_get_info(device) mem_total /= float(10**9) if mem_total < 3.0: - print(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion') + log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion') return False except RuntimeError as e: - print(str(e)) + log.error(str(e)) return False return True @@ -167,5 +170,5 @@ def get_processor_name(): if "model name" in line: return re.sub(".*model name.*:", "", line, 1).strip() except: - print(traceback.format_exc()) + log.error(traceback.format_exc()) return "cpu" diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 6cf9428a..827434d7 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -1,9 +1,12 @@ import os +import logging +import picklescan.scanner +import rich from sd_internal import app, device_manager from sd_internal import Request -import picklescan.scanner -import rich + +log = logging.getLogger() KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] MODEL_EXTENSIONS = { @@ -42,7 +45,7 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None): if model_name: is_sd2 = config.get('test_sd2', False) if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4 - print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!') + log.error('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!') model_name = 'sd-v1-4' # Check models directory @@ -67,7 +70,7 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None): for model_extension in model_extensions: if os.path.exists(default_model_path + model_extension): if model_name is not None: - print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') + log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') return default_model_path + model_extension return None @@ -88,13 +91,13 @@ def is_malicious_model(file_path): try: scan_result = picklescan.scanner.scan_file_path(file_path) if scan_result.issues_count > 0 or scan_result.infected_files > 0: - rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) return True else: - rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) return False except Exception as e: - print('error while scanning', file_path, 'error:', e) + log.error(f'error while scanning: {file_path}, error: {e}') return False def getModels(): @@ -111,7 +114,10 @@ def getModels(): }, } + models_scanned = 0 def listModels(model_type): + nonlocal models_scanned + model_extensions = MODEL_EXTENSIONS.get(model_type, []) models_dir = os.path.join(app.MODELS_DIR, model_type) if not os.path.exists(models_dir): @@ -126,6 +132,7 @@ def getModels(): mtime = os.path.getmtime(model_path) mod_time = known_models[model_path] if model_path in known_models else -1 if mod_time != mtime: + models_scanned += 1 if is_malicious_model(model_path): models['scan-error'] = file return @@ -142,6 +149,8 @@ def getModels(): listModels(model_type='vae') listModels(model_type='hypernetwork') + if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. 0 infected[/]') + # legacy custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') if os.path.exists(custom_weight_path): diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index f74acc12..c7845079 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -6,12 +6,15 @@ import os import base64 import re import traceback +import logging from sd_internal import device_manager, model_manager from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop from modules import model_loader, image_generator, image_utils, filters as image_filters +log = logging.getLogger() + thread_data = threading.local() ''' runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc @@ -69,7 +72,7 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s try: return _make_images_internal(req, data_queue, task_temp_images, step_callback) except Exception as e: - print(traceback.format_exc()) + log.error(traceback.format_exc()) data_queue.put(json.dumps({ "status": 'failed', @@ -91,7 +94,7 @@ def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_image res = Response(req, images=construct_response(req, images)) res = res.json() data_queue.put(json.dumps(res)) - print('Task completed') + log.info('Task completed') return res diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 6b0a1d8c..852631c6 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -6,6 +6,7 @@ Notes: """ import json import traceback +import logging TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout @@ -16,6 +17,8 @@ from typing import Any, Hashable from pydantic import BaseModel from sd_internal import Request, device_manager +log = logging.getLogger() + THREAD_NAME_PREFIX = 'Runtime-Render/' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. @@ -140,11 +143,11 @@ class DataCache(): for key in to_delete: (_, val) = self._base[key] if isinstance(val, RenderTask): - print(f'RenderTask {key} expired. Data removed.') + log.debug(f'RenderTask {key} expired. Data removed.') elif isinstance(val, SessionState): - print(f'Session {key} expired. Data removed.') + log.debug(f'Session {key} expired. Data removed.') else: - print(f'Key {key} expired. Data removed.') + log.debug(f'Key {key} expired. Data removed.') del self._base[key] finally: self._lock.release() @@ -178,8 +181,7 @@ class DataCache(): self._get_ttl_time(ttl), value ) except Exception as e: - print(str(e)) - print(traceback.format_exc()) + log.error(traceback.format_exc()) return False else: return True @@ -190,7 +192,7 @@ class DataCache(): try: ttl, value = self._base.get(key, (None, None)) if ttl is not None and self._is_expired(ttl): - print(f'Session {key} expired. Discarding data.') + log.debug(f'Session {key} expired. Discarding data.') del self._base[key] return None return value @@ -234,7 +236,7 @@ class SessionState(): def thread_get_next_task(): from sd_internal import runtime2 if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): - print('Render thread on device', runtime2.thread_data.device, 'failed to acquire manager lock.') + log.warn(f'Render thread on device: {runtime2.thread_data.device} failed to acquire manager lock.') return None if len(tasks_queue) <= 0: manager_lock.release() @@ -269,7 +271,7 @@ def thread_render(device): try: runtime2.init(device) except Exception as e: - print(traceback.format_exc()) + log.error(traceback.format_exc()) weak_thread_data[threading.current_thread()] = { 'error': e } @@ -287,7 +289,7 @@ def thread_render(device): session_cache.clean() task_cache.clean() if not weak_thread_data[threading.current_thread()]['alive']: - print(f'Shutting down thread for device {runtime2.thread_data.device}') + log.info(f'Shutting down thread for device {runtime2.thread_data.device}') runtime2.destroy() return if isinstance(current_state_error, SystemExit): @@ -299,7 +301,7 @@ def thread_render(device): idle_event.wait(timeout=1) continue if task.error is not None: - print(task.error) + log.error(task.error) task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue @@ -308,7 +310,7 @@ def thread_render(device): task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue - print(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') + log.info(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: def step_callback(): @@ -319,7 +321,7 @@ def thread_render(device): if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None - print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') + log.info(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') current_state = ServerStates.LoadingModel runtime2.reload_models_if_necessary(task.request) @@ -331,7 +333,7 @@ def thread_render(device): session_cache.keep(task.request.session_id, TASK_TTL) except Exception as e: task.error = e - print(traceback.format_exc()) + log.error(traceback.format_exc()) continue finally: # Task completed @@ -339,11 +341,11 @@ def thread_render(device): task_cache.keep(id(task), TASK_TTL) session_cache.keep(task.request.session_id, TASK_TTL) if isinstance(task.error, StopAsyncIteration): - print(f'Session {task.request.session_id} task {id(task)} cancelled!') + log.info(f'Session {task.request.session_id} task {id(task)} cancelled!') elif task.error is not None: - print(f'Session {task.request.session_id} task {id(task)} failed!') + log.info(f'Session {task.request.session_id} task {id(task)} failed!') else: - print(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') + log.info(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') current_state = ServerStates.Online def get_cached_task(task_id:str, update_ttl:bool=False): @@ -429,7 +431,7 @@ def is_alive(device=None): def start_render_thread(device): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED) - print('Start new Rendering Thread on device', device) + log.info(f'Start new Rendering Thread on device: {device}') try: rthread = threading.Thread(target=thread_render, kwargs={'device': device}) rthread.daemon = True @@ -441,7 +443,7 @@ def start_render_thread(device): timeout = DEVICE_START_TIMEOUT while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]: if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]: - print(rthread, device, 'error:', weak_thread_data[rthread]['error']) + log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}") return False if timeout <= 0: return False @@ -453,11 +455,11 @@ def stop_render_thread(device): try: device_manager.validate_device_id(device, log_prefix='stop_render_thread') except: - print(traceback.format_exc()) + log.error(traceback.format_exc()) return False if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED) - print('Stopping Rendering Thread on device', device) + log.info(f'Stopping Rendering Thread on device: {device}') try: thread_to_remove = None @@ -480,27 +482,27 @@ def stop_render_thread(device): def update_render_threads(render_devices, active_devices): devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices) - print('devices_to_start', devices_to_start) - print('devices_to_stop', devices_to_stop) + log.debug(f'devices_to_start: {devices_to_start}') + log.debug(f'devices_to_stop: {devices_to_stop}') for device in devices_to_stop: if is_alive(device) <= 0: - print(device, 'is not alive') + log.debug(f'{device} is not alive') continue if not stop_render_thread(device): - print(device, 'could not stop render thread') + log.warn(f'{device} could not stop render thread') for device in devices_to_start: if is_alive(device) >= 1: - print(device, 'already registered.') + log.debug(f'{device} already registered.') continue if not start_render_thread(device): - print(device, 'failed to start.') + log.warn(f'{device} failed to start.') if is_alive() <= 0: # No running devices, probably invalid user config. raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json') - print('active devices', get_devices()['active']) + log.debug(f"active devices: {get_devices()['active']}") def shutdown_event(): # Signal render thread to close on shutdown global current_state_error diff --git a/ui/server.py b/ui/server.py index 458bc1aa..aaa16ce0 100644 --- a/ui/server.py +++ b/ui/server.py @@ -14,7 +14,9 @@ from pydantic import BaseModel from sd_internal import app, model_manager, task_manager -print('started in ', app.SD_DIR) +log = logging.getLogger() + +log.info(f'started in {app.SD_DIR}') server_api = FastAPI() @@ -84,7 +86,7 @@ async def setAppConfig(req : SetAppConfigRequest): return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS) except Exception as e: - print(traceback.format_exc()) + log.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) def update_render_devices_in_config(config, render_devices): @@ -153,8 +155,7 @@ def render(req : task_manager.ImageRequest): except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable except Exception as e: - print(e) - print(traceback.format_exc()) + log.error(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) @server_api.get('/image/stream/{task_id:int}') @@ -165,10 +166,10 @@ def stream(task_id:int): #if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict if task.buffer_queue.empty() and not task.lock.locked(): if task.response: - #print(f'Session {session_id} sending cached response') + #log.info(f'Session {session_id} sending cached response') return JSONResponse(task.response, headers=NOCACHE_HEADERS) raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early - #print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') + #log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') return StreamingResponse(task.read_buffer_generator(), media_type='application/json') @server_api.get('/image/stop') From a2af811ad29ccd573d5ddbd5497b51bab8e5783a Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 22:47:34 +0530 Subject: [PATCH 16/74] Disable uvicorn access logging in favor of cleaner server-side logging, we already get all that info; Print the request metadata --- scripts/on_sd_start.bat | 2 +- scripts/on_sd_start.sh | 2 +- ui/sd_internal/app.py | 6 +++--- ui/sd_internal/runtime2.py | 1 + ui/sd_internal/task_manager.py | 2 +- 5 files changed, 7 insertions(+), 6 deletions(-) diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 9bac3757..5f1bfcb1 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -393,7 +393,7 @@ call python --version @if NOT DEFINED SD_UI_BIND_PORT set SD_UI_BIND_PORT=9000 @if NOT DEFINED SD_UI_BIND_IP set SD_UI_BIND_IP=0.0.0.0 -@uvicorn server:server_api --app-dir "%SD_UI_PATH%" --port %SD_UI_BIND_PORT% --host %SD_UI_BIND_IP% +@uvicorn server:server_api --app-dir "%SD_UI_PATH%" --port %SD_UI_BIND_PORT% --host %SD_UI_BIND_IP% --log-level critical @pause diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index afb95295..f4f540e7 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -322,6 +322,6 @@ cd .. export SD_UI_PATH=`pwd`/ui cd stable-diffusion -uvicorn server:server_api --app-dir "$SD_UI_PATH" --port ${SD_UI_BIND_PORT:-9000} --host ${SD_UI_BIND_IP:-0.0.0.0} +uvicorn server:server_api --app-dir "$SD_UI_PATH" --port ${SD_UI_BIND_PORT:-9000} --host ${SD_UI_BIND_IP:-0.0.0.0} --log-level critical read -p "Press any key to continue" diff --git a/ui/sd_internal/app.py b/ui/sd_internal/app.py index c34f2ea6..bfb90399 100644 --- a/ui/sd_internal/app.py +++ b/ui/sd_internal/app.py @@ -8,12 +8,12 @@ from rich.logging import RichHandler from sd_internal import task_manager -LOG_FORMAT = '[%(threadName)s] %(message)s' +LOG_FORMAT = '%(levelname)s %(threadName)s %(message)s' logging.basicConfig( level=logging.INFO, format=LOG_FORMAT, - datefmt="[%X.%f]", - handlers=[RichHandler(markup=True)] + datefmt="%X.%f", + handlers=[RichHandler(markup=True, rich_tracebacks=True, show_level=False)] ) log = logging.getLogger() diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index c7845079..cdffa574 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -70,6 +70,7 @@ def reload_models_if_necessary(req: Request): def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): try: + log.info(req) return _make_images_internal(req, data_queue, task_temp_images, step_callback) except Exception as e: log.error(traceback.format_exc()) diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 852631c6..4eb45ce6 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -19,7 +19,7 @@ from sd_internal import Request, device_manager log = logging.getLogger() -THREAD_NAME_PREFIX = 'Runtime-Render/' +THREAD_NAME_PREFIX = '' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. From 543f13f9a345df332a4cc29911b8c9266ca55e94 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sun, 11 Dec 2022 13:19:22 +0530 Subject: [PATCH 17/74] Tweak logging to increase the space available by 3 characters --- ui/sd_internal/app.py | 6 +++--- ui/sd_internal/model_manager.py | 2 +- ui/server.py | 2 ++ 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ui/sd_internal/app.py b/ui/sd_internal/app.py index bfb90399..47a2d610 100644 --- a/ui/sd_internal/app.py +++ b/ui/sd_internal/app.py @@ -8,12 +8,12 @@ from rich.logging import RichHandler from sd_internal import task_manager -LOG_FORMAT = '%(levelname)s %(threadName)s %(message)s' +LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s' logging.basicConfig( level=logging.INFO, format=LOG_FORMAT, - datefmt="%X.%f", - handlers=[RichHandler(markup=True, rich_tracebacks=True, show_level=False)] + datefmt="%X", + handlers=[RichHandler(markup=True, rich_tracebacks=True, show_time=False, show_level=False)] ) log = logging.getLogger() diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 827434d7..ebeb439c 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -149,7 +149,7 @@ def getModels(): listModels(model_type='vae') listModels(model_type='hypernetwork') - if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. 0 infected[/]') + if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. Nothing infected[/]') # legacy custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') diff --git a/ui/server.py b/ui/server.py index aaa16ce0..1c537dcd 100644 --- a/ui/server.py +++ b/ui/server.py @@ -5,6 +5,7 @@ Notes: import os import traceback import logging +import datetime from typing import List, Union from fastapi import FastAPI, HTTPException @@ -17,6 +18,7 @@ from sd_internal import app, model_manager, task_manager log = logging.getLogger() log.info(f'started in {app.SD_DIR}') +log.info(f'started at {datetime.datetime.now():%x %X}') server_api = FastAPI() From afb88616d85817470ff369890485b33a48946996 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sun, 11 Dec 2022 13:30:16 +0530 Subject: [PATCH 18/74] Load the models after the device init, to let the UI load before the models finish loading --- ui/sd_internal/runtime2.py | 4 +--- ui/sd_internal/task_manager.py | 1 + 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index cdffa574..06c6fdb6 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -39,13 +39,11 @@ def init(device): device_manager.device_init(thread_data, device) - init_and_load_default_models() - def destroy(): for model_type in ('stable-diffusion', 'hypernetwork', 'gfpgan', 'realesrgan'): model_loader.unload_model(thread_data, model_type) -def init_and_load_default_models(): +def load_default_models(): # init default model paths for model_type in model_manager.KNOWN_MODEL_TYPES: thread_data.model_paths[model_type] = model_manager.resolve_model_to_use(model_type=model_type) diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 4eb45ce6..9d1d26f6 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -283,6 +283,7 @@ def thread_render(device): 'alive': True } + runtime2.load_default_models() current_state = ServerStates.Online while True: From d03eed385987ee7e3c80f42172142aa711b287e3 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sun, 11 Dec 2022 14:14:59 +0530 Subject: [PATCH 19/74] Simplify the logic for reloading gfpgan and realesrgan models (based on the request), using the code path used for the other model types --- ui/sd_internal/runtime2.py | 29 ++++++++++++++++++----------- ui/server.py | 6 ++++++ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 06c6fdb6..b7ce6fe5 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -40,7 +40,7 @@ def init(device): device_manager.device_init(thread_data, device) def destroy(): - for model_type in ('stable-diffusion', 'hypernetwork', 'gfpgan', 'realesrgan'): + for model_type in model_manager.KNOWN_MODEL_TYPES: model_loader.unload_model(thread_data, model_type) def load_default_models(): @@ -52,19 +52,26 @@ def load_default_models(): model_loader.load_model(thread_data, 'stable-diffusion') def reload_models_if_necessary(req: Request): + model_paths_in_req = ( + ('hypernetwork', req.use_hypernetwork_model), + ('gfpgan', req.use_face_correction), + ('realesrgan', req.use_upscale), + ) + if model_manager.is_sd_model_reload_necessary(thread_data, req): thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model thread_data.model_paths['vae'] = req.use_vae_model model_loader.load_model(thread_data, 'stable-diffusion') - if thread_data.model_paths.get('hypernetwork') != req.use_hypernetwork_model: - thread_data.model_paths['hypernetwork'] = req.use_hypernetwork_model + for model_type, model_path_in_req in model_paths_in_req: + if thread_data.model_paths.get(model_type) != model_path_in_req: + thread_data.model_paths[model_type] = model_path_in_req - if thread_data.model_paths['hypernetwork'] is not None: - model_loader.load_model(thread_data, 'hypernetwork') - else: - model_loader.unload_model(thread_data, 'hypernetwork') + if thread_data.model_paths[model_type] is not None: + model_loader.load_model(thread_data, model_type) + else: + model_loader.unload_model(thread_data, model_type) def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): try: @@ -137,13 +144,13 @@ def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_ima return images filters = [] - if args['use_face_correction'].startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(args['use_face_correction'], model_type='gfpgan'))) - if args['use_face_correction'].use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(args['use_upscale'], model_type='realesrgan'))) + if 'gfpgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_gfpgan) + if 'realesrgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_realesrgan) filtered_images = [] for img, seed, _ in images: - for filter_fn, filter_model_path in filters: - img = filter_fn(thread_data, img, filter_model_path) + for filter_fn in filters: + img = filter_fn(thread_data, img) filtered_images.append((img, seed, True)) diff --git a/ui/server.py b/ui/server.py index 1c537dcd..258c433d 100644 --- a/ui/server.py +++ b/ui/server.py @@ -140,10 +140,16 @@ def ping(session_id:str=None): def render(req : task_manager.ImageRequest): try: app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) + + # resolve the model paths to use req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion') req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae') req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork') + if req.use_face_correction: req.use_face_correction = model_manager.resolve_model_to_use(req.use_face_correction, 'gfpgan') + if req.use_upscale: req.use_upscale = model_manager.resolve_model_to_use(req.use_upscale, 'gfpgan') + + # enqueue the task new_task = task_manager.render(req) response = { 'status': str(task_manager.current_state), From 6ce6dc3ff630c5953b33c509a004a017a5edfd63 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sun, 11 Dec 2022 18:16:29 +0530 Subject: [PATCH 20/74] Get rid of the ugly copying around (and maintaining) of multiple request-related fields. Split into two objects: task-related fields, and render-related fields. Also remove the ability for request-defined full-precision. Full-precision can now be forced by using a USE_FULL_PRECISION environment variable --- ui/media/js/auto-save.js | 2 - ui/media/js/dnd.js | 8 -- ui/media/js/engine.js | 2 - ui/media/js/main.js | 1 - ui/media/js/parameters.js | 9 --- ui/sd_internal/__init__.py | 90 +++------------------- ui/sd_internal/device_manager.py | 14 +++- ui/sd_internal/model_manager.py | 19 +---- ui/sd_internal/runtime2.py | 125 ++++++++++++++++--------------- ui/sd_internal/task_manager.py | 119 +++++------------------------ ui/server.py | 31 +++----- 11 files changed, 115 insertions(+), 305 deletions(-) diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js index 9025f988..df044e2b 100644 --- a/ui/media/js/auto-save.js +++ b/ui/media/js/auto-save.js @@ -37,7 +37,6 @@ const SETTINGS_IDS_LIST = [ "diskPath", "sound_toggle", "turbo", - "use_full_precision", "confirm_dangerous_actions", "auto_save_settings", "apply_color_correction" @@ -278,7 +277,6 @@ function tryLoadOldSettings() { "soundEnabled": "sound_toggle", "saveToDisk": "save_to_disk", "useCPU": "use_cpu", - "useFullPrecision": "use_full_precision", "useTurboMode": "turbo", "diskPath": "diskPath", "useFaceCorrection": "use_face_correction", diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index 53b55ac7..8b9819c2 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -217,13 +217,6 @@ const TASK_MAPPING = { readUI: () => turboField.checked, parse: (val) => Boolean(val) }, - use_full_precision: { name: 'Use Full Precision', - setUI: (use_full_precision) => { - useFullPrecisionField.checked = use_full_precision - }, - readUI: () => useFullPrecisionField.checked, - parse: (val) => Boolean(val) - }, stream_image_progress: { name: 'Stream Image Progress', setUI: (stream_image_progress) => { @@ -453,7 +446,6 @@ document.addEventListener("dragover", dragOverHandler) const TASK_REQ_NO_EXPORT = [ "use_cpu", "turbo", - "use_full_precision", "save_to_disk_path" ] const resetSettings = document.getElementById('reset-image-settings') diff --git a/ui/media/js/engine.js b/ui/media/js/engine.js index dd34ddb1..61822c0f 100644 --- a/ui/media/js/engine.js +++ b/ui/media/js/engine.js @@ -728,7 +728,6 @@ "stream_image_progress": 'boolean', "show_only_filtered_image": 'boolean', "turbo": 'boolean', - "use_full_precision": 'boolean', "output_format": 'string', "output_quality": 'number', } @@ -744,7 +743,6 @@ "stream_image_progress": true, "show_only_filtered_image": true, "turbo": false, - "use_full_precision": false, "output_format": "png", "output_quality": 75, } diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 6c08aef7..e058f9e9 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -850,7 +850,6 @@ function getCurrentUserRequest() { // allow_nsfw: allowNSFWField.checked, turbo: turboField.checked, //render_device: undefined, // Set device affinity. Prefer this device, but wont activate. - use_full_precision: useFullPrecisionField.checked, use_stable_diffusion_model: stableDiffusionModelField.value, use_vae_model: vaeModelField.value, stream_progress_updates: true, diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index 3643d1ec..f326953b 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -105,14 +105,6 @@ var PARAMETERS = [ note: "to process in parallel", default: false, }, - { - id: "use_full_precision", - type: ParameterType.checkbox, - label: "Use Full Precision", - note: "for GPU-only. warning: this will consume more VRAM", - icon: "fa-crosshairs", - default: false, - }, { id: "auto_save_settings", type: ParameterType.checkbox, @@ -214,7 +206,6 @@ let turboField = document.querySelector('#turbo') let useCPUField = document.querySelector('#use_cpu') let autoPickGPUsField = document.querySelector('#auto_pick_gpus') let useGPUsField = document.querySelector('#use_gpus') -let useFullPrecisionField = document.querySelector('#use_full_precision') let saveToDiskField = document.querySelector('#save_to_disk') let diskPathField = document.querySelector('#diskPath') let listenToNetworkField = document.querySelector("#listen_to_network") diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index 9dd4f066..a06220ca 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -1,93 +1,22 @@ -import json +from pydantic import BaseModel -class Request: +from modules.types import GenerateImageRequest + +class TaskData(BaseModel): request_id: str = None session_id: str = "session" - prompt: str = "" - negative_prompt: str = "" - init_image: str = None # base64 - mask: str = None # base64 - apply_color_correction = False - num_outputs: int = 1 - num_inference_steps: int = 50 - guidance_scale: float = 7.5 - width: int = 512 - height: int = 512 - seed: int = 42 - prompt_strength: float = 0.8 - sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" - # allow_nsfw: bool = False - precision: str = "autocast" # or "full" save_to_disk_path: str = None turbo: bool = True - use_full_precision: bool = False use_face_correction: str = None # or "GFPGANv1.3" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_stable_diffusion_model: str = "sd-v1-4" use_vae_model: str = None use_hypernetwork_model: str = None - hypernetwork_strength: float = 1 show_only_filtered_image: bool = False output_format: str = "jpeg" # or "png" output_quality: int = 75 - - stream_progress_updates: bool = False stream_image_progress: bool = False - def json(self): - return { - "request_id": self.request_id, - "session_id": self.session_id, - "prompt": self.prompt, - "negative_prompt": self.negative_prompt, - "num_outputs": self.num_outputs, - "num_inference_steps": self.num_inference_steps, - "guidance_scale": self.guidance_scale, - "width": self.width, - "height": self.height, - "seed": self.seed, - "prompt_strength": self.prompt_strength, - "sampler": self.sampler, - "apply_color_correction": self.apply_color_correction, - "use_face_correction": self.use_face_correction, - "use_upscale": self.use_upscale, - "use_stable_diffusion_model": self.use_stable_diffusion_model, - "use_vae_model": self.use_vae_model, - "use_hypernetwork_model": self.use_hypernetwork_model, - "hypernetwork_strength": self.hypernetwork_strength, - "output_format": self.output_format, - "output_quality": self.output_quality, - } - - def __str__(self): - return f''' - session_id: {self.session_id} - prompt: {self.prompt} - negative_prompt: {self.negative_prompt} - seed: {self.seed} - num_inference_steps: {self.num_inference_steps} - sampler: {self.sampler} - guidance_scale: {self.guidance_scale} - w: {self.width} - h: {self.height} - precision: {self.precision} - save_to_disk_path: {self.save_to_disk_path} - turbo: {self.turbo} - use_full_precision: {self.use_full_precision} - apply_color_correction: {self.apply_color_correction} - use_face_correction: {self.use_face_correction} - use_upscale: {self.use_upscale} - use_stable_diffusion_model: {self.use_stable_diffusion_model} - use_vae_model: {self.use_vae_model} - use_hypernetwork_model: {self.use_hypernetwork_model} - hypernetwork_strength: {self.hypernetwork_strength} - show_only_filtered_image: {self.show_only_filtered_image} - output_format: {self.output_format} - output_quality: {self.output_quality} - - stream_progress_updates: {self.stream_progress_updates} - stream_image_progress: {self.stream_image_progress}''' - class Image: data: str # base64 seed: int @@ -106,17 +35,20 @@ class Image: } class Response: - request: Request + render_request: GenerateImageRequest + task_data: TaskData images: list - def __init__(self, request: Request, images: list): - self.request = request + def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list): + self.render_request = render_request + self.task_data = task_data self.images = images def json(self): res = { "status": 'succeeded', - "request": self.request.json(), + "render_request": self.render_request.dict(), + "task_data": self.task_data.dict(), "output": [], } diff --git a/ui/sd_internal/device_manager.py b/ui/sd_internal/device_manager.py index a3f91cfb..733bab50 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/sd_internal/device_manager.py @@ -6,6 +6,13 @@ import logging log = logging.getLogger() +''' +Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32). +Otherwise the models will load at half-precision (i.e. float16). + +Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs). +''' + COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked mem_free_threshold = 0 @@ -96,7 +103,7 @@ def device_init(context, device): if device == 'cpu': context.device = 'cpu' context.device_name = get_processor_name() - context.precision = 'full' + context.half_precision = False log.debug(f'Render device CPU available as {context.device_name}') return @@ -107,7 +114,7 @@ def device_init(context, device): if needs_to_force_full_precision(context): log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') # Apply force_full_precision now before models are loaded. - context.precision = 'full' + context.half_precision = False log.info(f'Setting {device} as active') torch.cuda.device(device) @@ -115,6 +122,9 @@ def device_init(context, device): return def needs_to_force_full_precision(context): + if 'FORCE_FULL_PRECISION' in os.environ: + return True + device_name = context.device_name.lower() return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index ebeb439c..2129af3d 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -3,8 +3,7 @@ import logging import picklescan.scanner import rich -from sd_internal import app, device_manager -from sd_internal import Request +from sd_internal import app log = logging.getLogger() @@ -157,19 +156,3 @@ def getModels(): models['options']['stable-diffusion'].append('custom-model') return models - -def is_sd_model_reload_necessary(thread_data, req: Request): - needs_model_reload = False - if 'stable-diffusion' not in thread_data.models or \ - thread_data.model_paths['stable-diffusion'] != req.use_stable_diffusion_model or \ - thread_data.model_paths['vae'] != req.use_vae_model: - - needs_model_reload = True - - if thread_data.device != 'cpu': - if (thread_data.precision == 'autocast' and req.use_full_precision) or \ - (thread_data.precision == 'full' and not req.use_full_precision and not device_manager.needs_to_force_full_precision(thread_data)): - thread_data.precision = 'full' if req.use_full_precision else 'autocast' - needs_model_reload = True - - return needs_model_reload diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index b7ce6fe5..434f489b 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -9,13 +9,14 @@ import traceback import logging from sd_internal import device_manager, model_manager -from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop +from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop from modules import model_loader, image_generator, image_utils, filters as image_filters +from modules.types import Context, GenerateImageRequest log = logging.getLogger() -thread_data = threading.local() +thread_data = Context() ''' runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc ''' @@ -28,14 +29,7 @@ def init(device): ''' thread_data.stop_processing = False thread_data.temp_images = {} - - thread_data.models = {} - thread_data.model_paths = {} - - thread_data.device = None - thread_data.device_name = None - thread_data.precision = 'autocast' - thread_data.vram_optimizations = ('TURBO', 'MOVE_MODELS') + thread_data.partial_x_samples = None device_manager.device_init(thread_data, device) @@ -51,16 +45,16 @@ def load_default_models(): # load mandatory models model_loader.load_model(thread_data, 'stable-diffusion') -def reload_models_if_necessary(req: Request): +def reload_models_if_necessary(task_data: TaskData): model_paths_in_req = ( - ('hypernetwork', req.use_hypernetwork_model), - ('gfpgan', req.use_face_correction), - ('realesrgan', req.use_upscale), + ('hypernetwork', task_data.use_hypernetwork_model), + ('gfpgan', task_data.use_face_correction), + ('realesrgan', task_data.use_upscale), ) - if model_manager.is_sd_model_reload_necessary(thread_data, req): - thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model - thread_data.model_paths['vae'] = req.use_vae_model + if thread_data.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or thread_data.model_paths.get('vae') != task_data.use_vae_model: + thread_data.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model + thread_data.model_paths['vae'] = task_data.use_vae_model model_loader.load_model(thread_data, 'stable-diffusion') @@ -73,10 +67,17 @@ def reload_models_if_necessary(req: Request): else: model_loader.unload_model(thread_data, model_type) -def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): +def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): try: - log.info(req) - return _make_images_internal(req, data_queue, task_temp_images, step_callback) + # resolve the model paths to use + resolve_model_paths(task_data) + + # convert init image to PIL.Image + req.init_image = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None + req.init_image_mask = image_utils.base64_str_to_img(req.init_image_mask) if req.init_image_mask is not None else None + + # generate + return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) except Exception as e: log.error(traceback.format_exc()) @@ -86,66 +87,76 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s })) raise e -def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): - args = req_to_args(req) +def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): + metadata = req.dict() + del metadata['init_image'] + del metadata['init_image_mask'] + print(metadata) - images, user_stopped = generate_images(args, data_queue, task_temp_images, step_callback, req.stream_image_progress) - images = apply_color_correction(args, images, user_stopped) - images = apply_filters(args, images, user_stopped, req.show_only_filtered_image) + images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) + images = apply_color_correction(req, images, user_stopped) + images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image) - if req.save_to_disk_path is not None: - out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id)) - save_images(images, out_path, metadata=req.json(), show_only_filtered_image=req.show_only_filtered_image) + if task_data.save_to_disk_path is not None: + out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) + save_images(images, out_path, metadata=metadata, show_only_filtered_image=task_data.show_only_filtered_image) - res = Response(req, images=construct_response(req, images)) + res = Response(req, task_data, images=construct_response(images)) res = res.json() data_queue.put(json.dumps(res)) log.info('Task completed') return res -def generate_images(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): +def resolve_model_paths(task_data: TaskData): + task_data.use_stable_diffusion_model = model_manager.resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion') + task_data.use_vae_model = model_manager.resolve_model_to_use(task_data.use_vae_model, model_type='vae') + task_data.use_hypernetwork_model = model_manager.resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork') + + if task_data.use_face_correction: task_data.use_face_correction = model_manager.resolve_model_to_use(task_data.use_face_correction, 'gfpgan') + if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan') + +def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): thread_data.temp_images.clear() - image_generator.on_image_step = make_step_callback(args, data_queue, task_temp_images, step_callback, stream_image_progress) + image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress) try: - images = image_generator.make_images(context=thread_data, args=args) + images = image_generator.make_images(context=thread_data, req=req) user_stopped = False except UserInitiatedStop: images = [] user_stopped = True - if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None: - return images - for i in range(args['num_outputs']): - images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0)) - - del thread_data.partial_x_samples + if thread_data.partial_x_samples is not None: + for i in range(req.num_outputs): + images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0)) + + thread_data.partial_x_samples = None finally: model_loader.gc(thread_data) - images = [(image, args['seed'] + i, False) for i, image in enumerate(images)] + images = [(image, req.seed + i, False) for i, image in enumerate(images)] return images, user_stopped -def apply_color_correction(args: dict, images: list, user_stopped): - if user_stopped or args['init_image'] is None or not args['apply_color_correction']: +def apply_color_correction(req: GenerateImageRequest, images: list, user_stopped): + if user_stopped or req.init_image is None or not req.apply_color_correction: return images for i, img_info in enumerate(images): img, seed, filtered = img_info - img = image_utils.apply_color_correction(orig_image=args['init_image'], image_to_correct=img) + img = image_utils.apply_color_correction(orig_image=req.init_image, image_to_correct=img) images[i] = (img, seed, filtered) return images -def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_image): - if user_stopped or (args['use_face_correction'] is None and args['use_upscale'] is None): +def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_filtered_image): + if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None): return images filters = [] - if 'gfpgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_gfpgan) - if 'realesrgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_realesrgan) + if 'gfpgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_gfpgan) + if 'realesrgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_realesrgan) filtered_images = [] for img, seed, _ in images: @@ -188,37 +199,29 @@ def save_images(images: list, save_to_disk_path, metadata: dict, show_only_filte img_path += '.' + metadata['output_format'] img.save(img_path, quality=metadata['output_quality']) -def construct_response(req: Request, images: list): +def construct_response(task_data: TaskData, images: list): return [ ResponseImage( - data=image_utils.img_to_base64_str(img, req.output_format, req.output_quality), + data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality), seed=seed ) for img, seed, _ in images ] -def req_to_args(req: Request): - args = req.json() - - args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None - args['mask'] = image_utils.base64_str_to_img(req.mask) if req.mask is not None else None - - return args - -def make_step_callback(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): - n_steps = args['num_inference_steps'] if args['init_image'] is None else int(args['num_inference_steps'] * args['prompt_strength']) +def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): + n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) last_callback_time = -1 def update_temp_img(x_samples, task_temp_images: list): partial_images = [] - for i in range(args['num_outputs']): + for i in range(req.num_outputs): img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0)) buf = image_utils.img_to_buffer(img, output_format='JPEG') del img - thread_data.temp_images[f"{args['request_id']}/{i}"] = buf + thread_data.temp_images[f"{task_data.request_id}/{i}"] = buf task_temp_images[i] = buf - partial_images.append({'path': f"/image/tmp/{args['request_id']}/{i}"}) + partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"}) return partial_images def on_image_step(x_samples, i): diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 9d1d26f6..3ec1b99d 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -14,8 +14,8 @@ import torch import queue, threading, time, weakref from typing import Any, Hashable -from pydantic import BaseModel -from sd_internal import Request, device_manager +from sd_internal import TaskData, device_manager +from modules.types import GenerateImageRequest log = logging.getLogger() @@ -39,9 +39,10 @@ class ServerStates: class Unavailable(Symbol): pass class RenderTask(): # Task with output queue and completion lock. - def __init__(self, req: Request): + def __init__(self, req: GenerateImageRequest, task_data: TaskData): req.request_id = id(self) - self.request: Request = req # Initial Request + self.render_request: GenerateImageRequest = req # Initial Request + self.task_data: TaskData = task_data self.response: Any = None # Copy of the last reponse self.render_device = None # Select the task affinity. (Not used to change active devices). self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) @@ -72,55 +73,6 @@ class RenderTask(): # Task with output queue and completion lock. def is_pending(self): return bool(not self.response and not self.error) -# defaults from https://huggingface.co/blog/stable_diffusion -class ImageRequest(BaseModel): - session_id: str = "session" - prompt: str = "" - negative_prompt: str = "" - init_image: str = None # base64 - mask: str = None # base64 - apply_color_correction: bool = False - num_outputs: int = 1 - num_inference_steps: int = 50 - guidance_scale: float = 7.5 - width: int = 512 - height: int = 512 - seed: int = 42 - prompt_strength: float = 0.8 - sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" - # allow_nsfw: bool = False - save_to_disk_path: str = None - turbo: bool = True - use_cpu: bool = False ##TODO Remove after UI and plugins transition. - render_device: str = None # Select the task affinity. (Not used to change active devices). - use_full_precision: bool = False - use_face_correction: str = None # or "GFPGANv1.3" - use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" - use_stable_diffusion_model: str = "sd-v1-4" - use_vae_model: str = None - use_hypernetwork_model: str = None - hypernetwork_strength: float = None - show_only_filtered_image: bool = False - output_format: str = "jpeg" # or "png" - output_quality: int = 75 - - stream_progress_updates: bool = False - stream_image_progress: bool = False - -class FilterRequest(BaseModel): - session_id: str = "session" - model: str = None - name: str = "" - init_image: str = None # base64 - width: int = 512 - height: int = 512 - save_to_disk_path: str = None - turbo: bool = True - render_device: str = None - use_full_precision: bool = False - output_format: str = "jpeg" # or "png" - output_quality: int = 75 - # Temporary cache to allow to query tasks results for a short time after they are completed. class DataCache(): def __init__(self): @@ -311,7 +263,7 @@ def thread_render(device): task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue - log.info(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') + log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: def step_callback(): @@ -322,16 +274,16 @@ def thread_render(device): if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None - log.info(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') + log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}') current_state = ServerStates.LoadingModel - runtime2.reload_models_if_necessary(task.request) + runtime2.reload_models_if_necessary(task.task_data) current_state = ServerStates.Rendering - task.response = runtime2.make_images(task.request, task.buffer_queue, task.temp_images, step_callback) + task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL) - session_cache.keep(task.request.session_id, TASK_TTL) + session_cache.keep(task.task_data.session_id, TASK_TTL) except Exception as e: task.error = e log.error(traceback.format_exc()) @@ -340,13 +292,13 @@ def thread_render(device): # Task completed task.lock.release() task_cache.keep(id(task), TASK_TTL) - session_cache.keep(task.request.session_id, TASK_TTL) + session_cache.keep(task.task_data.session_id, TASK_TTL) if isinstance(task.error, StopAsyncIteration): - log.info(f'Session {task.request.session_id} task {id(task)} cancelled!') + log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!') elif task.error is not None: - log.info(f'Session {task.request.session_id} task {id(task)} failed!') + log.info(f'Session {task.task_data.session_id} task {id(task)} failed!') else: - log.info(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') + log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') current_state = ServerStates.Online def get_cached_task(task_id:str, update_ttl:bool=False): @@ -509,53 +461,18 @@ def shutdown_event(): # Signal render thread to close on shutdown global current_state_error current_state_error = SystemExit('Application shutting down.') -def render(req : ImageRequest): +def render(render_req: GenerateImageRequest, task_data: TaskData): current_thread_count = is_alive() if current_thread_count <= 0: # Render thread is dead raise ChildProcessError('Rendering thread has died.') # Alive, check if task in cache - session = get_cached_session(req.session_id, update_ttl=True) + session = get_cached_session(task_data.session_id, update_ttl=True) pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) if current_thread_count < len(pending_tasks): - raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') + raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') - r = Request() - r.session_id = req.session_id - r.prompt = req.prompt - r.negative_prompt = req.negative_prompt - r.init_image = req.init_image - r.mask = req.mask - r.apply_color_correction = req.apply_color_correction - r.num_outputs = req.num_outputs - r.num_inference_steps = req.num_inference_steps - r.guidance_scale = req.guidance_scale - r.width = req.width - r.height = req.height - r.seed = req.seed - r.prompt_strength = req.prompt_strength - r.sampler = req.sampler - # r.allow_nsfw = req.allow_nsfw - r.turbo = req.turbo - r.use_full_precision = req.use_full_precision - r.save_to_disk_path = req.save_to_disk_path - r.use_upscale: str = req.use_upscale - r.use_face_correction = req.use_face_correction - r.use_stable_diffusion_model = req.use_stable_diffusion_model - r.use_vae_model = req.use_vae_model - r.use_hypernetwork_model = req.use_hypernetwork_model - r.hypernetwork_strength = req.hypernetwork_strength - r.show_only_filtered_image = req.show_only_filtered_image - r.output_format = req.output_format - r.output_quality = req.output_quality - - r.stream_progress_updates = True # the underlying implementation only supports streaming - r.stream_image_progress = req.stream_image_progress - - if not req.stream_progress_updates: - r.stream_image_progress = False - - new_task = RenderTask(r) + new_task = RenderTask(render_req, task_data) if session.put(new_task, TASK_TTL): # Use twice the normal timeout for adding user requests. # Tries to force session.put to fail before tasks_queue.put would. diff --git a/ui/server.py b/ui/server.py index 258c433d..42bdb0f4 100644 --- a/ui/server.py +++ b/ui/server.py @@ -14,6 +14,8 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse from pydantic import BaseModel from sd_internal import app, model_manager, task_manager +from sd_internal import TaskData +from modules.types import GenerateImageRequest log = logging.getLogger() @@ -22,8 +24,6 @@ log.info(f'started at {datetime.datetime.now():%x %X}') server_api = FastAPI() -# don't show access log entries for URLs that start with the given prefix -ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails'] NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} class NoCacheStaticFiles(StaticFiles): @@ -43,17 +43,6 @@ class SetAppConfigRequest(BaseModel): listen_port: int = None test_sd2: bool = None -class LogSuppressFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - path = record.getMessage() - for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES: - if path.find(prefix) != -1: - return False - return True - -# don't log certain requests -logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) - server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media") for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES: @@ -137,20 +126,18 @@ def ping(session_id:str=None): return JSONResponse(response, headers=NOCACHE_HEADERS) @server_api.post('/render') -def render(req : task_manager.ImageRequest): +def render(req: dict): try: - app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) + # separate out the request data into rendering and task-specific data + render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req) + task_data: TaskData = TaskData.parse_obj(req) - # resolve the model paths to use - req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion') - req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae') - req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork') + render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision - if req.use_face_correction: req.use_face_correction = model_manager.resolve_model_to_use(req.use_face_correction, 'gfpgan') - if req.use_upscale: req.use_upscale = model_manager.resolve_model_to_use(req.use_upscale, 'gfpgan') + app.save_model_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model) # enqueue the task - new_task = task_manager.render(req) + new_task = task_manager.render(render_req, task_data) response = { 'status': str(task_manager.current_state), 'queue': len(task_manager.tasks_queue), From 0aa79685035559e513ee54bb95febedf2a6cee39 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sun, 11 Dec 2022 19:34:07 +0530 Subject: [PATCH 21/74] Move color correction to diffusionkit; Rename color correction to 'Preserve color profile' --- ui/index.html | 2 +- ui/media/js/main.js | 6 +++--- ui/sd_internal/runtime2.py | 18 ++---------------- 3 files changed, 6 insertions(+), 20 deletions(-) diff --git a/ui/index.html b/ui/index.html index 30648424..0273b574 100644 --- a/ui/index.html +++ b/ui/index.html @@ -213,7 +213,7 @@
    • Render Settings
    • -
    • +
    • diff --git a/ui/media/js/main.js b/ui/media/js/main.js index e058f9e9..6e00f8c0 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -761,8 +761,8 @@ function createTask(task) { taskConfig += `, Hypernetwork: ${task.reqBody.use_hypernetwork_model}` taskConfig += `, Hypernetwork Strength: ${task.reqBody.hypernetwork_strength}` } - if (task.reqBody.apply_color_correction) { - taskConfig += `, Color Correction: true` + if (task.reqBody.preserve_init_image_color_profile) { + taskConfig += `, Preserve Color Profile: true` } let taskEntry = document.createElement('div') @@ -871,7 +871,7 @@ function getCurrentUserRequest() { if (maskSetting.checked) { newTask.reqBody.mask = imageInpainter.getImg() } - newTask.reqBody.apply_color_correction = applyColorCorrectionField.checked + newTask.reqBody.preserve_init_image_color_profile = applyColorCorrectionField.checked newTask.reqBody.sampler = 'ddim' } else { newTask.reqBody.sampler = samplerField.value diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 434f489b..0e1efc2c 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -62,10 +62,8 @@ def reload_models_if_necessary(task_data: TaskData): if thread_data.model_paths.get(model_type) != model_path_in_req: thread_data.model_paths[model_type] = model_path_in_req - if thread_data.model_paths[model_type] is not None: - model_loader.load_model(thread_data, model_type) - else: - model_loader.unload_model(thread_data, model_type) + action_fn = model_loader.unload_model if thread_data.model_paths[model_type] is None else model_loader.load_model + action_fn(thread_data, model_type) def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): try: @@ -94,7 +92,6 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q print(metadata) images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) - images = apply_color_correction(req, images, user_stopped) images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image) if task_data.save_to_disk_path is not None: @@ -139,17 +136,6 @@ def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_tem return images, user_stopped -def apply_color_correction(req: GenerateImageRequest, images: list, user_stopped): - if user_stopped or req.init_image is None or not req.apply_color_correction: - return images - - for i, img_info in enumerate(images): - img, seed, filtered = img_info - img = image_utils.apply_color_correction(orig_image=req.init_image, image_to_correct=img) - images[i] = (img, seed, filtered) - - return images - def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_filtered_image): if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None): return images From 97919c7e87399b5dc017efae1a83b6f009935ef3 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sun, 11 Dec 2022 19:58:12 +0530 Subject: [PATCH 22/74] Simplify the runtime code --- ui/sd_internal/model_manager.py | 14 +++++++++++++ ui/sd_internal/runtime2.py | 36 +++++---------------------------- ui/sd_internal/task_manager.py | 7 ++++--- 3 files changed, 23 insertions(+), 34 deletions(-) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 2129af3d..3d68494f 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -4,6 +4,8 @@ import picklescan.scanner import rich from sd_internal import app +from modules import model_loader +from modules.types import Context log = logging.getLogger() @@ -30,6 +32,18 @@ def init(): make_model_folders() getModels() # run this once, to cache the picklescan results +def load_default_models(context: Context): + # init default model paths + for model_type in KNOWN_MODEL_TYPES: + context.model_paths[model_type] = resolve_model_to_use(model_type=model_type) + + # load mandatory models + model_loader.load_model(context, 'stable-diffusion') + +def unload_all(context: Context): + for model_type in KNOWN_MODEL_TYPES: + model_loader.unload_model(context, model_type) + def resolve_model_to_use(model_name:str=None, model_type:str=None): model_extensions = MODEL_EXTENSIONS.get(model_type, []) default_models = DEFAULT_MODELS.get(model_type, []) diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 0e1efc2c..399ba960 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -33,18 +33,6 @@ def init(device): device_manager.device_init(thread_data, device) -def destroy(): - for model_type in model_manager.KNOWN_MODEL_TYPES: - model_loader.unload_model(thread_data, model_type) - -def load_default_models(): - # init default model paths - for model_type in model_manager.KNOWN_MODEL_TYPES: - thread_data.model_paths[model_type] = model_manager.resolve_model_to_use(model_type=model_type) - - # load mandatory models - model_loader.load_model(thread_data, 'stable-diffusion') - def reload_models_if_necessary(task_data: TaskData): model_paths_in_req = ( ('hypernetwork', task_data.use_hypernetwork_model), @@ -67,14 +55,6 @@ def reload_models_if_necessary(task_data: TaskData): def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): try: - # resolve the model paths to use - resolve_model_paths(task_data) - - # convert init image to PIL.Image - req.init_image = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None - req.init_image_mask = image_utils.base64_str_to_img(req.init_image_mask) if req.init_image_mask is not None else None - - # generate return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) except Exception as e: log.error(traceback.format_exc()) @@ -86,17 +66,12 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu raise e def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): - metadata = req.dict() - del metadata['init_image'] - del metadata['init_image_mask'] - print(metadata) - images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image) if task_data.save_to_disk_path is not None: out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) - save_images(images, out_path, metadata=metadata, show_only_filtered_image=task_data.show_only_filtered_image) + save_images(images, out_path, metadata=req.to_metadata(), show_only_filtered_image=task_data.show_only_filtered_image) res = Response(req, task_data, images=construct_response(images)) res = res.json() @@ -114,6 +89,7 @@ def resolve_model_paths(task_data: TaskData): if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan') def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): + log.info(req.to_metadata()) thread_data.temp_images.clear() image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress) @@ -125,13 +101,11 @@ def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_tem images = [] user_stopped = True if thread_data.partial_x_samples is not None: - for i in range(req.num_outputs): - images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0)) - - thread_data.partial_x_samples = None + images = image_utils.latent_samples_to_images(thread_data, thread_data.partial_x_samples) + thread_data.partial_x_samples = None finally: model_loader.gc(thread_data) - + images = [(image, req.seed + i, False) for i, image in enumerate(images)] return images, user_stopped diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 3ec1b99d..c167a5f9 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -219,7 +219,7 @@ def thread_get_next_task(): def thread_render(device): global current_state, current_state_error - from sd_internal import runtime2 + from sd_internal import runtime2, model_manager try: runtime2.init(device) except Exception as e: @@ -235,7 +235,7 @@ def thread_render(device): 'alive': True } - runtime2.load_default_models() + model_manager.load_default_models(runtime2.thread_data) current_state = ServerStates.Online while True: @@ -243,7 +243,7 @@ def thread_render(device): task_cache.clean() if not weak_thread_data[threading.current_thread()]['alive']: log.info(f'Shutting down thread for device {runtime2.thread_data.device}') - runtime2.destroy() + model_manager.unload_all(runtime2.thread_data) return if isinstance(current_state_error, SystemExit): current_state = ServerStates.Unavailable @@ -280,6 +280,7 @@ def thread_render(device): runtime2.reload_models_if_necessary(task.task_data) current_state = ServerStates.Rendering + runtime2.resolve_model_paths(task.task_data) task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL) From 096556d8c93615f5a346735cde5ba5ed8d3be733 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sun, 11 Dec 2022 20:13:44 +0530 Subject: [PATCH 23/74] Move away the remaining model-related code to the model_manager --- ui/sd_internal/model_manager.py | 30 +++++++++++++++++++++++++++++- ui/sd_internal/runtime2.py | 31 +------------------------------ ui/sd_internal/task_manager.py | 4 ++-- 3 files changed, 32 insertions(+), 33 deletions(-) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 3d68494f..ea624322 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -3,7 +3,7 @@ import logging import picklescan.scanner import rich -from sd_internal import app +from sd_internal import app, TaskData from modules import model_loader from modules.types import Context @@ -88,6 +88,34 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None): return None +def reload_models_if_necessary(context: Context, task_data: TaskData): + model_paths_in_req = ( + ('hypernetwork', task_data.use_hypernetwork_model), + ('gfpgan', task_data.use_face_correction), + ('realesrgan', task_data.use_upscale), + ) + + if context.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or context.model_paths.get('vae') != task_data.use_vae_model: + context.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model + context.model_paths['vae'] = task_data.use_vae_model + + model_loader.load_model(context, 'stable-diffusion') + + for model_type, model_path_in_req in model_paths_in_req: + if context.model_paths.get(model_type) != model_path_in_req: + context.model_paths[model_type] = model_path_in_req + + action_fn = model_loader.unload_model if context.model_paths[model_type] is None else model_loader.load_model + action_fn(context, model_type) + +def resolve_model_paths(task_data: TaskData): + task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion') + task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae') + task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork') + + if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan') + if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan') + def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: model_dir_path = os.path.join(app.MODELS_DIR, model_type) diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 399ba960..1f2e9d27 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -1,4 +1,3 @@ -import threading import queue import time import json @@ -8,7 +7,7 @@ import re import traceback import logging -from sd_internal import device_manager, model_manager +from sd_internal import device_manager from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop from modules import model_loader, image_generator, image_utils, filters as image_filters @@ -33,26 +32,6 @@ def init(device): device_manager.device_init(thread_data, device) -def reload_models_if_necessary(task_data: TaskData): - model_paths_in_req = ( - ('hypernetwork', task_data.use_hypernetwork_model), - ('gfpgan', task_data.use_face_correction), - ('realesrgan', task_data.use_upscale), - ) - - if thread_data.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or thread_data.model_paths.get('vae') != task_data.use_vae_model: - thread_data.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model - thread_data.model_paths['vae'] = task_data.use_vae_model - - model_loader.load_model(thread_data, 'stable-diffusion') - - for model_type, model_path_in_req in model_paths_in_req: - if thread_data.model_paths.get(model_type) != model_path_in_req: - thread_data.model_paths[model_type] = model_path_in_req - - action_fn = model_loader.unload_model if thread_data.model_paths[model_type] is None else model_loader.load_model - action_fn(thread_data, model_type) - def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): try: return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) @@ -80,14 +59,6 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q return res -def resolve_model_paths(task_data: TaskData): - task_data.use_stable_diffusion_model = model_manager.resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion') - task_data.use_vae_model = model_manager.resolve_model_to_use(task_data.use_vae_model, model_type='vae') - task_data.use_hypernetwork_model = model_manager.resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork') - - if task_data.use_face_correction: task_data.use_face_correction = model_manager.resolve_model_to_use(task_data.use_face_correction, 'gfpgan') - if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan') - def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): log.info(req.to_metadata()) thread_data.temp_images.clear() diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index c167a5f9..5e6abce2 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -277,10 +277,10 @@ def thread_render(device): log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}') current_state = ServerStates.LoadingModel - runtime2.reload_models_if_necessary(task.task_data) + model_manager.resolve_model_paths(task.task_data) + model_manager.reload_models_if_necessary(runtime2.thread_data, task.task_data) current_state = ServerStates.Rendering - runtime2.resolve_model_paths(task.task_data) task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL) From 1a5b6ef2601f8c097c3147c08c634f7515e3c438 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sun, 11 Dec 2022 20:21:25 +0530 Subject: [PATCH 24/74] Rename runtime2.py to renderer.py; Will remove the old runtime soon --- ui/sd_internal/{runtime2.py => renderer.py} | 34 ++++++++++----------- ui/sd_internal/task_manager.py | 32 +++++++++---------- 2 files changed, 33 insertions(+), 33 deletions(-) rename ui/sd_internal/{runtime2.py => renderer.py} (87%) diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/renderer.py similarity index 87% rename from ui/sd_internal/runtime2.py rename to ui/sd_internal/renderer.py index 1f2e9d27..645fdd00 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/renderer.py @@ -15,7 +15,7 @@ from modules.types import Context, GenerateImageRequest log = logging.getLogger() -thread_data = Context() +context = Context() # thread-local ''' runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc ''' @@ -24,13 +24,13 @@ filename_regex = re.compile('[^a-zA-Z0-9]') def init(device): ''' - Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device + Initializes the fields that will be bound to this runtime's context, and sets the current torch device ''' - thread_data.stop_processing = False - thread_data.temp_images = {} - thread_data.partial_x_samples = None + context.stop_processing = False + context.temp_images = {} + context.partial_x_samples = None - device_manager.device_init(thread_data, device) + device_manager.device_init(context, device) def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): try: @@ -61,21 +61,21 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): log.info(req.to_metadata()) - thread_data.temp_images.clear() + context.temp_images.clear() image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress) try: - images = image_generator.make_images(context=thread_data, req=req) + images = image_generator.make_images(context=context, req=req) user_stopped = False except UserInitiatedStop: images = [] user_stopped = True - if thread_data.partial_x_samples is not None: - images = image_utils.latent_samples_to_images(thread_data, thread_data.partial_x_samples) - thread_data.partial_x_samples = None + if context.partial_x_samples is not None: + images = image_utils.latent_samples_to_images(context, context.partial_x_samples) + context.partial_x_samples = None finally: - model_loader.gc(thread_data) + model_loader.gc(context) images = [(image, req.seed + i, False) for i, image in enumerate(images)] @@ -92,7 +92,7 @@ def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_fil filtered_images = [] for img, seed, _ in images: for filter_fn in filters: - img = filter_fn(thread_data, img) + img = filter_fn(context, img) filtered_images.append((img, seed, True)) @@ -145,12 +145,12 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu def update_temp_img(x_samples, task_temp_images: list): partial_images = [] for i in range(req.num_outputs): - img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0)) + img = image_utils.latent_to_img(context, x_samples[i].unsqueeze(0)) buf = image_utils.img_to_buffer(img, output_format='JPEG') del img - thread_data.temp_images[f"{task_data.request_id}/{i}"] = buf + context.temp_images[f"{task_data.request_id}/{i}"] = buf task_temp_images[i] = buf partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"}) return partial_images @@ -158,7 +158,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu def on_image_step(x_samples, i): nonlocal last_callback_time - thread_data.partial_x_samples = x_samples + context.partial_x_samples = x_samples step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 last_callback_time = time.time() @@ -171,7 +171,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu step_callback() - if thread_data.stop_processing: + if context.stop_processing: raise UserInitiatedStop("User requested that we stop processing") return on_image_step diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 5e6abce2..4810584b 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -186,9 +186,9 @@ class SessionState(): return True def thread_get_next_task(): - from sd_internal import runtime2 + from sd_internal import renderer if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): - log.warn(f'Render thread on device: {runtime2.thread_data.device} failed to acquire manager lock.') + log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.') return None if len(tasks_queue) <= 0: manager_lock.release() @@ -196,7 +196,7 @@ def thread_get_next_task(): task = None try: # Select a render task. for queued_task in tasks_queue: - if queued_task.render_device and runtime2.thread_data.device != queued_task.render_device: + if queued_task.render_device and renderer.context.device != queued_task.render_device: # Is asking for a specific render device. if is_alive(queued_task.render_device) > 0: continue # requested device alive, skip current one. @@ -205,7 +205,7 @@ def thread_get_next_task(): queued_task.error = Exception(queued_task.render_device + ' is not currently active.') task = queued_task break - if not queued_task.render_device and runtime2.thread_data.device == 'cpu' and is_alive() > 1: + if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1: # not asking for any specific devices, cpu want to grab task but other render devices are alive. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. task = queued_task @@ -219,9 +219,9 @@ def thread_get_next_task(): def thread_render(device): global current_state, current_state_error - from sd_internal import runtime2, model_manager + from sd_internal import renderer, model_manager try: - runtime2.init(device) + renderer.init(device) except Exception as e: log.error(traceback.format_exc()) weak_thread_data[threading.current_thread()] = { @@ -230,20 +230,20 @@ def thread_render(device): return weak_thread_data[threading.current_thread()] = { - 'device': runtime2.thread_data.device, - 'device_name': runtime2.thread_data.device_name, + 'device': renderer.context.device, + 'device_name': renderer.context.device_name, 'alive': True } - model_manager.load_default_models(runtime2.thread_data) + model_manager.load_default_models(renderer.context) current_state = ServerStates.Online while True: session_cache.clean() task_cache.clean() if not weak_thread_data[threading.current_thread()]['alive']: - log.info(f'Shutting down thread for device {runtime2.thread_data.device}') - model_manager.unload_all(runtime2.thread_data) + log.info(f'Shutting down thread for device {renderer.context.device}') + model_manager.unload_all(renderer.context) return if isinstance(current_state_error, SystemExit): current_state = ServerStates.Unavailable @@ -263,14 +263,14 @@ def thread_render(device): task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue - log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') + log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: def step_callback(): global current_state_error if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): - runtime2.thread_data.stop_processing = True + renderer.context.stop_processing = True if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None @@ -278,10 +278,10 @@ def thread_render(device): current_state = ServerStates.LoadingModel model_manager.resolve_model_paths(task.task_data) - model_manager.reload_models_if_necessary(runtime2.thread_data, task.task_data) + model_manager.reload_models_if_necessary(renderer.context, task.task_data) current_state = ServerStates.Rendering - task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) + task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL) @@ -299,7 +299,7 @@ def thread_render(device): elif task.error is not None: log.info(f'Session {task.task_data.session_id} task {id(task)} failed!') else: - log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') + log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.') current_state = ServerStates.Online def get_cached_task(task_id:str, update_ttl:bool=False): From e45cbbf1cab201bafa4c9985b29359325c520ae1 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sun, 11 Dec 2022 20:42:31 +0530 Subject: [PATCH 25/74] Use the turbo setting if requested --- ui/sd_internal/model_manager.py | 9 +++++++++ ui/sd_internal/task_manager.py | 1 + 2 files changed, 10 insertions(+) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index ea624322..0cc9414c 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -37,6 +37,9 @@ def load_default_models(context: Context): for model_type in KNOWN_MODEL_TYPES: context.model_paths[model_type] = resolve_model_to_use(model_type=model_type) + # disable TURBO initially (this should be read from the config eventually) + context.vram_optimizations -= {'TURBO'} + # load mandatory models model_loader.load_model(context, 'stable-diffusion') @@ -116,6 +119,12 @@ def resolve_model_paths(task_data: TaskData): if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan') if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan') +def set_vram_optimizations(context: Context, task_data: TaskData): + if task_data.turbo: + context.vram_optimizations += {'TURBO'} + else: + context.vram_optimizations -= {'TURBO'} + def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: model_dir_path = os.path.join(app.MODELS_DIR, model_type) diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 4810584b..a6d4dadd 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -278,6 +278,7 @@ def thread_render(device): current_state = ServerStates.LoadingModel model_manager.resolve_model_paths(task.task_data) + model_manager.set_vram_optimizations(renderer.context, task.task_data) model_manager.reload_models_if_necessary(renderer.context, task.task_data) current_state = ServerStates.Rendering From b57649828dddeeea14de2b4e25090d93b0673087 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 12 Dec 2022 14:01:47 +0530 Subject: [PATCH 26/74] Refactor the save-to-disk code, moving parts of it to diffusionkit --- ui/sd_internal/__init__.py | 3 + ui/sd_internal/renderer.py | 110 +++++++++++++++++++++---------------- 2 files changed, 67 insertions(+), 46 deletions(-) diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index a06220ca..f0710652 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -7,14 +7,17 @@ class TaskData(BaseModel): session_id: str = "session" save_to_disk_path: str = None turbo: bool = True + use_face_correction: str = None # or "GFPGANv1.3" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_stable_diffusion_model: str = "sd-v1-4" use_vae_model: str = None use_hypernetwork_model: str = None + show_only_filtered_image: bool = False output_format: str = "jpeg" # or "png" output_quality: int = 75 + metadata_output_format: str = "txt" # or "json" stream_image_progress: bool = False class Image: diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index 645fdd00..440f2bf2 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -10,7 +10,7 @@ import logging from sd_internal import device_manager from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop -from modules import model_loader, image_generator, image_utils, filters as image_filters +from modules import model_loader, image_generator, image_utils, filters as image_filters, data_utils from modules.types import Context, GenerateImageRequest log = logging.getLogger() @@ -33,8 +33,18 @@ def init(device): device_manager.device_init(context, device) def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): + log.info(f'request: {req.dict()}') + log.info(f'task data: {task_data.dict()}') + try: - return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) + images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) + + res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed)) + res = res.json() + data_queue.put(json.dumps(res)) + log.info('Task completed') + + return res except Exception as e: log.error(traceback.format_exc()) @@ -46,21 +56,15 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) - images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image) + filtered_images = apply_filters(task_data, images, user_stopped) if task_data.save_to_disk_path is not None: - out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) - save_images(images, out_path, metadata=req.to_metadata(), show_only_filtered_image=task_data.show_only_filtered_image) + save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) + save_to_disk(images, filtered_images, save_folder_path, req, task_data) - res = Response(req, task_data, images=construct_response(images)) - res = res.json() - data_queue.put(json.dumps(res)) - log.info('Task completed') - - return res + return filtered_images if task_data.show_only_filtered_image else images + filtered_images def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): - log.info(req.to_metadata()) context.temp_images.clear() image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress) @@ -77,11 +81,9 @@ def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_tem finally: model_loader.gc(context) - images = [(image, req.seed + i, False) for i, image in enumerate(images)] - return images, user_stopped -def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_filtered_image): +def apply_filters(task_data: TaskData, images: list, user_stopped): if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None): return images @@ -90,52 +92,68 @@ def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_fil if 'realesrgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_realesrgan) filtered_images = [] - for img, seed, _ in images: + for img in images: for filter_fn in filters: img = filter_fn(context, img) - filtered_images.append((img, seed, True)) - - if not show_only_filtered_image: - filtered_images = images + filtered_images + filtered_images.append(img) return filtered_images -def save_images(images: list, save_to_disk_path, metadata: dict, show_only_filtered_image): - if save_to_disk_path is None: - return +def save_to_disk(images: list, filtered_images: list, save_folder_path, req: GenerateImageRequest, task_data: TaskData): + metadata = req.dict() + del metadata['init_image'] + del metadata['init_image_mask'] + metadata.update({ + 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, + 'use_vae_model': task_data.use_vae_model, + 'use_hypernetwork_model': task_data.use_hypernetwork_model, + 'use_face_correction': task_data.use_face_correction, + 'use_upscale': task_data.use_upscale, + }) - def get_image_id(i): + metadata_entries = get_metadata_entries(req, task_data) + + if task_data.show_only_filtered_image or filtered_images == images: + data_utils.save_images(filtered_images, save_folder_path, file_name=get_output_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_metadata(metadata_entries, save_folder_path, file_name=get_output_filename_callback(req), output_format=task_data.metadata_output_format) + else: + data_utils.save_images(images, save_folder_path, file_name=get_output_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_images(filtered_images, save_folder_path, file_name=get_output_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_metadata(metadata_entries, save_folder_path, file_name=get_output_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) + +def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): + metadata = req.dict() + del metadata['init_image'] + del metadata['init_image_mask'] + metadata.update({ + 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, + 'use_vae_model': task_data.use_vae_model, + 'use_hypernetwork_model': task_data.use_hypernetwork_model, + 'use_face_correction': task_data.use_face_correction, + 'use_upscale': task_data.use_upscale, + }) + + return [metadata.copy().update({'seed': req.seed + i}) for i in range(req.num_outputs)] + +def get_output_filename_callback(req: GenerateImageRequest, suffix=None): + def make_filename(i): img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. - return img_id - def get_image_basepath(i): - os.makedirs(save_to_disk_path, exist_ok=True) - prompt_flattened = filename_regex.sub('_', metadata['prompt'])[:50] - return os.path.join(save_to_disk_path, f"{prompt_flattened}_{get_image_id(i)}") + prompt_flattened = filename_regex.sub('_', req.prompt)[:50] + name = f"{prompt_flattened}_{img_id}" + name = name if suffix is None else f'{name}_{suffix}' + return name - for i, img_data in enumerate(images): - img, seed, filtered = img_data - img_path = get_image_basepath(i) + return make_filename - if not filtered or show_only_filtered_image: - img_metadata_path = img_path + '.txt' - m = metadata.copy() - m['seed'] = seed - with open(img_metadata_path, 'w', encoding='utf-8') as f: - f.write(m) - - img_path += '_filtered' if filtered else '' - img_path += '.' + metadata['output_format'] - img.save(img_path, quality=metadata['output_quality']) - -def construct_response(task_data: TaskData, images: list): +def construct_response(images: list, task_data: TaskData, base_seed: int): return [ ResponseImage( data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality), - seed=seed - ) for img, seed, _ in images + seed=base_seed + i + ) for i, img in enumerate(images) ] def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): From d0e50584ea54f37961a3ad1851f3cdf9f435ef99 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 12 Dec 2022 14:06:20 +0530 Subject: [PATCH 27/74] Expose the metadata format option in the UI --- ui/media/js/auto-save.js | 1 + ui/media/js/parameters.js | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js index df044e2b..44903724 100644 --- a/ui/media/js/auto-save.js +++ b/ui/media/js/auto-save.js @@ -38,6 +38,7 @@ const SETTINGS_IDS_LIST = [ "sound_toggle", "turbo", "confirm_dangerous_actions", + "metadata_output_format", "auto_save_settings", "apply_color_correction" ] diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index f326953b..a4167fc3 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -53,6 +53,19 @@ var PARAMETERS = [ return `` } }, + { + id: "metadata_output_format", + type: ParameterType.select, + label: "Metadata format", + note: "the metadata will be saved to disk in this format", + default: "txt", + options: [ + { + value: "txt", + label: "json" + } + ], + }, { id: "sound_toggle", type: ParameterType.checkbox, From 4bbf683d155489a8d757db4555726ddad2be0a89 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 12 Dec 2022 14:41:36 +0530 Subject: [PATCH 28/74] Minor refactor --- ui/sd_internal/renderer.py | 39 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index 440f2bf2..2bf079f5 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -101,26 +101,23 @@ def apply_filters(task_data: TaskData, images: list, user_stopped): return filtered_images def save_to_disk(images: list, filtered_images: list, save_folder_path, req: GenerateImageRequest, task_data: TaskData): - metadata = req.dict() - del metadata['init_image'] - del metadata['init_image_mask'] - metadata.update({ - 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, - 'use_vae_model': task_data.use_vae_model, - 'use_hypernetwork_model': task_data.use_hypernetwork_model, - 'use_face_correction': task_data.use_face_correction, - 'use_upscale': task_data.use_upscale, - }) - metadata_entries = get_metadata_entries(req, task_data) if task_data.show_only_filtered_image or filtered_images == images: - data_utils.save_images(filtered_images, save_folder_path, file_name=get_output_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_metadata(metadata_entries, save_folder_path, file_name=get_output_filename_callback(req), output_format=task_data.metadata_output_format) + data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format) else: - data_utils.save_images(images, save_folder_path, file_name=get_output_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_images(filtered_images, save_folder_path, file_name=get_output_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_metadata(metadata_entries, save_folder_path, file_name=get_output_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) + data_utils.save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) + +def construct_response(images: list, task_data: TaskData, base_seed: int): + return [ + ResponseImage( + data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality), + seed=base_seed + i + ) for i, img in enumerate(images) + ] def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): metadata = req.dict() @@ -136,7 +133,7 @@ def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): return [metadata.copy().update({'seed': req.seed + i}) for i in range(req.num_outputs)] -def get_output_filename_callback(req: GenerateImageRequest, suffix=None): +def make_filename_callback(req: GenerateImageRequest, suffix=None): def make_filename(i): img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. @@ -148,14 +145,6 @@ def get_output_filename_callback(req: GenerateImageRequest, suffix=None): return make_filename -def construct_response(images: list, task_data: TaskData, base_seed: int): - return [ - ResponseImage( - data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality), - seed=base_seed + i - ) for i, img in enumerate(images) - ] - def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) last_callback_time = -1 From 6b943f88d19d1fa99abd71bb737379aedf523523 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 12 Dec 2022 15:18:30 +0530 Subject: [PATCH 29/74] Set uvicorn log level to 'error' --- scripts/on_sd_start.bat | 2 +- scripts/on_sd_start.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 5f1bfcb1..620d0909 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -393,7 +393,7 @@ call python --version @if NOT DEFINED SD_UI_BIND_PORT set SD_UI_BIND_PORT=9000 @if NOT DEFINED SD_UI_BIND_IP set SD_UI_BIND_IP=0.0.0.0 -@uvicorn server:server_api --app-dir "%SD_UI_PATH%" --port %SD_UI_BIND_PORT% --host %SD_UI_BIND_IP% --log-level critical +@uvicorn server:server_api --app-dir "%SD_UI_PATH%" --port %SD_UI_BIND_PORT% --host %SD_UI_BIND_IP% --log-level error @pause diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index f4f540e7..3e082b81 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -322,6 +322,6 @@ cd .. export SD_UI_PATH=`pwd`/ui cd stable-diffusion -uvicorn server:server_api --app-dir "$SD_UI_PATH" --port ${SD_UI_BIND_PORT:-9000} --host ${SD_UI_BIND_IP:-0.0.0.0} --log-level critical +uvicorn server:server_api --app-dir "$SD_UI_PATH" --port ${SD_UI_BIND_PORT:-9000} --host ${SD_UI_BIND_IP:-0.0.0.0} --log-level error read -p "Press any key to continue" From ac0961d7d4dca73a8a90f06e91f4b283214f0402 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 12 Dec 2022 15:18:56 +0530 Subject: [PATCH 30/74] Typos from the refactor --- ui/sd_internal/model_manager.py | 4 ++-- ui/sd_internal/renderer.py | 6 +++--- ui/sd_internal/task_manager.py | 6 ++++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 0cc9414c..f5e8a5c7 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -121,9 +121,9 @@ def resolve_model_paths(task_data: TaskData): def set_vram_optimizations(context: Context, task_data: TaskData): if task_data.turbo: - context.vram_optimizations += {'TURBO'} + context.vram_optimizations.add('TURBO') else: - context.vram_optimizations -= {'TURBO'} + context.vram_optimizations.remove('TURBO') def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index 2bf079f5..aa805ad9 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -55,7 +55,7 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu raise e def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): - images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) + images, user_stopped = generate_images(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) filtered_images = apply_filters(task_data, images, user_stopped) if task_data.save_to_disk_path is not None: @@ -64,10 +64,10 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q return filtered_images if task_data.show_only_filtered_image else images + filtered_images -def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): +def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): context.temp_images.clear() - image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress) + image_generator.on_image_step = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress) try: images = image_generator.make_images(context=context, req=req) diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index a6d4dadd..ef2a4a3b 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -40,12 +40,12 @@ class ServerStates: class RenderTask(): # Task with output queue and completion lock. def __init__(self, req: GenerateImageRequest, task_data: TaskData): - req.request_id = id(self) + task_data.request_id = id(self) self.render_request: GenerateImageRequest = req # Initial Request self.task_data: TaskData = task_data self.response: Any = None # Copy of the last reponse self.render_device = None # Select the task affinity. (Not used to change active devices). - self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) + self.temp_images:list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2) self.error: Exception = None self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments @@ -235,7 +235,9 @@ def thread_render(device): 'alive': True } + current_state = ServerStates.LoadingModel model_manager.load_default_models(renderer.context) + current_state = ServerStates.Online while True: From fb32a38d9638fc0bf8868b2849e4c19de36408eb Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 12 Dec 2022 15:21:02 +0530 Subject: [PATCH 31/74] Rename sampler to sampler_name in the API --- ui/index.html | 4 ++-- ui/media/js/auto-save.js | 2 +- ui/media/js/main.js | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ui/index.html b/ui/index.html index 0273b574..612759c0 100644 --- a/ui/index.html +++ b/ui/index.html @@ -137,8 +137,8 @@
      - - diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js index 44903724..c5826087 100644 --- a/ui/media/js/auto-save.js +++ b/ui/media/js/auto-save.js @@ -15,7 +15,7 @@ const SETTINGS_IDS_LIST = [ "stable_diffusion_model", "vae_model", "hypernetwork_model", - "sampler", + "sampler_name", "width", "height", "num_inference_steps", diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 6e00f8c0..8d7d9ce3 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -30,7 +30,7 @@ let applyColorCorrectionField = document.querySelector('#apply_color_correction' let colorCorrectionSetting = document.querySelector('#apply_color_correction_setting') let promptStrengthSlider = document.querySelector('#prompt_strength_slider') let promptStrengthField = document.querySelector('#prompt_strength') -let samplerField = document.querySelector('#sampler') +let samplerField = document.querySelector('#sampler_name') let samplerSelectionContainer = document.querySelector("#samplerSelection") let useFaceCorrectionField = document.querySelector("#use_face_correction") let useUpscalingField = document.querySelector("#use_upscale") @@ -741,7 +741,7 @@ function onTaskStart(task) { } function createTask(task) { - let taskConfig = `Seed: ${task.seed}, Sampler: ${task.reqBody.sampler}, Inference Steps: ${task.reqBody.num_inference_steps}, Guidance Scale: ${task.reqBody.guidance_scale}, Model: ${task.reqBody.use_stable_diffusion_model}` + let taskConfig = `Seed: ${task.seed}, Sampler: ${task.reqBody.sampler_name}, Inference Steps: ${task.reqBody.num_inference_steps}, Guidance Scale: ${task.reqBody.guidance_scale}, Model: ${task.reqBody.use_stable_diffusion_model}` if (task.reqBody.use_vae_model.trim() !== '') { taskConfig += `, VAE: ${task.reqBody.use_vae_model}` } @@ -872,9 +872,9 @@ function getCurrentUserRequest() { newTask.reqBody.mask = imageInpainter.getImg() } newTask.reqBody.preserve_init_image_color_profile = applyColorCorrectionField.checked - newTask.reqBody.sampler = 'ddim' + newTask.reqBody.sampler_name = 'ddim' } else { - newTask.reqBody.sampler = samplerField.value + newTask.reqBody.sampler_name = samplerField.value } if (saveToDiskField.checked && diskPathField.value.trim() !== '') { newTask.reqBody.save_to_disk_path = diskPathField.value.trim() From 07bd580050b9113e061a5c473f14613cc810dc52 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 12 Dec 2022 15:44:22 +0530 Subject: [PATCH 32/74] Typos --- ui/sd_internal/__init__.py | 3 +++ ui/sd_internal/renderer.py | 12 ++++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index f0710652..7093562a 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -48,6 +48,9 @@ class Response: self.images = images def json(self): + del self.render_request.init_image + del self.render_request.init_image_mask + res = { "status": 'succeeded', "render_request": self.render_request.dict(), diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index aa805ad9..162b7da0 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -33,7 +33,7 @@ def init(device): device_manager.device_init(context, device) def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): - log.info(f'request: {req.dict()}') + log.info(f'request: {get_printable_request(req)}') log.info(f'task data: {task_data.dict()}') try: @@ -120,9 +120,7 @@ def construct_response(images: list, task_data: TaskData, base_seed: int): ] def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): - metadata = req.dict() - del metadata['init_image'] - del metadata['init_image_mask'] + metadata = get_printable_request(req) metadata.update({ 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, 'use_vae_model': task_data.use_vae_model, @@ -133,6 +131,12 @@ def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): return [metadata.copy().update({'seed': req.seed + i}) for i in range(req.num_outputs)] +def get_printable_request(req: GenerateImageRequest): + metadata = req.dict() + del metadata['init_image'] + del metadata['init_image_mask'] + return metadata + def make_filename_callback(req: GenerateImageRequest, suffix=None): def make_filename(i): img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. From 27963decc9c609b56d5d6babdc646753cf5b434c Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 12 Dec 2022 18:12:55 +0530 Subject: [PATCH 33/74] Use the multi-filters API --- ui/sd_internal/renderer.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index 162b7da0..11a38da8 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -11,7 +11,7 @@ from sd_internal import device_manager from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop from modules import model_loader, image_generator, image_utils, filters as image_filters, data_utils -from modules.types import Context, GenerateImageRequest +from modules.types import Context, GenerateImageRequest, FilterImageRequest log = logging.getLogger() @@ -88,15 +88,16 @@ def apply_filters(task_data: TaskData, images: list, user_stopped): return images filters = [] - if 'gfpgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_gfpgan) - if 'realesrgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_realesrgan) + if 'gfpgan' in task_data.use_face_correction.lower(): filters.append('gfpgan') + if 'realesrgan' in task_data.use_face_correction.lower(): filters.append('realesrgan') filtered_images = [] for img in images: - for filter_fn in filters: - img = filter_fn(context, img) + filter_req = FilterImageRequest() + filter_req.init_image = img - filtered_images.append(img) + filtered_image = image_filters.apply(context, filters, filter_req) + filtered_images.append(filtered_image) return filtered_images From a244a6873a6026d2157cda2298f196edbe2c8565 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 12 Dec 2022 20:46:11 +0530 Subject: [PATCH 34/74] Use the new 'diffusionkit' package name --- ui/sd_internal/__init__.py | 2 +- ui/sd_internal/model_manager.py | 4 ++-- ui/sd_internal/renderer.py | 4 ++-- ui/sd_internal/task_manager.py | 2 +- ui/server.py | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index 7093562a..71073216 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from modules.types import GenerateImageRequest +from diffusionkit.types import GenerateImageRequest class TaskData(BaseModel): request_id: str = None diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index f5e8a5c7..18418624 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -4,8 +4,8 @@ import picklescan.scanner import rich from sd_internal import app, TaskData -from modules import model_loader -from modules.types import Context +from diffusionkit import model_loader +from diffusionkit.types import Context log = logging.getLogger() diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index 11a38da8..1fff48e1 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -10,8 +10,8 @@ import logging from sd_internal import device_manager from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop -from modules import model_loader, image_generator, image_utils, filters as image_filters, data_utils -from modules.types import Context, GenerateImageRequest, FilterImageRequest +from diffusionkit import model_loader, image_generator, image_utils, filters as image_filters, data_utils +from diffusionkit.types import Context, GenerateImageRequest, FilterImageRequest log = logging.getLogger() diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index ef2a4a3b..f4576372 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -15,7 +15,7 @@ import queue, threading, time, weakref from typing import Any, Hashable from sd_internal import TaskData, device_manager -from modules.types import GenerateImageRequest +from diffusionkit.types import GenerateImageRequest log = logging.getLogger() diff --git a/ui/server.py b/ui/server.py index 42bdb0f4..7308dfc8 100644 --- a/ui/server.py +++ b/ui/server.py @@ -15,7 +15,7 @@ from pydantic import BaseModel from sd_internal import app, model_manager, task_manager from sd_internal import TaskData -from modules.types import GenerateImageRequest +from diffusionkit.types import GenerateImageRequest log = logging.getLogger() From a483bd0800e7887892895c2c43f4a022e6b03691 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 13 Dec 2022 11:46:13 +0530 Subject: [PATCH 35/74] No need to catch and report exceptions separately in the renderer now --- ui/sd_internal/renderer.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index 1fff48e1..0697c908 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -36,23 +36,14 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu log.info(f'request: {get_printable_request(req)}') log.info(f'task data: {task_data.dict()}') - try: - images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) + images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) - res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed)) - res = res.json() - data_queue.put(json.dumps(res)) - log.info('Task completed') + res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed)) + res = res.json() + data_queue.put(json.dumps(res)) + log.info('Task completed') - return res - except Exception as e: - log.error(traceback.format_exc()) - - data_queue.put(json.dumps({ - "status": 'failed', - "detail": str(e) - })) - raise e + return res def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): images, user_stopped = generate_images(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) From 6cd0b530c5a791ff33968ae15b66b55fd1415dee Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 13 Dec 2022 15:45:44 +0530 Subject: [PATCH 36/74] Simplify the code for VAE loading, and make it faster to load VAEs (because we don't reload the entire SD model each time a VAE changes); Record the error and end the thread if the SD model fails to load during startup --- ui/sd_internal/model_manager.py | 10 ++++------ ui/sd_internal/task_manager.py | 25 +++++++++++++------------ 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 18418624..b6c4c92d 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -42,6 +42,8 @@ def load_default_models(context: Context): # load mandatory models model_loader.load_model(context, 'stable-diffusion') + model_loader.load_model(context, 'vae') + model_loader.load_model(context, 'hypernetwork') def unload_all(context: Context): for model_type in KNOWN_MODEL_TYPES: @@ -93,17 +95,13 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None): def reload_models_if_necessary(context: Context, task_data: TaskData): model_paths_in_req = ( + ('stable-diffusion', task_data.use_stable_diffusion_model), + ('vae', task_data.use_vae_model), ('hypernetwork', task_data.use_hypernetwork_model), ('gfpgan', task_data.use_face_correction), ('realesrgan', task_data.use_upscale), ) - if context.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or context.model_paths.get('vae') != task_data.use_vae_model: - context.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model - context.model_paths['vae'] = task_data.use_vae_model - - model_loader.load_model(context, 'stable-diffusion') - for model_type, model_path_in_req in model_paths_in_req: if context.model_paths.get(model_type) != model_path_in_req: context.model_paths[model_type] = model_path_in_req diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 0db48d9d..3b8f6082 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -222,24 +222,25 @@ def thread_render(device): from sd_internal import renderer, model_manager try: renderer.init(device) + + weak_thread_data[threading.current_thread()] = { + 'device': renderer.context.device, + 'device_name': renderer.context.device_name, + 'alive': True + } + + current_state = ServerStates.LoadingModel + model_manager.load_default_models(renderer.context) + + current_state = ServerStates.Online except Exception as e: log.error(traceback.format_exc()) weak_thread_data[threading.current_thread()] = { - 'error': e + 'error': e, + 'alive': False } return - weak_thread_data[threading.current_thread()] = { - 'device': renderer.context.device, - 'device_name': renderer.context.device_name, - 'alive': True - } - - current_state = ServerStates.LoadingModel - model_manager.load_default_models(renderer.context) - - current_state = ServerStates.Online - while True: session_cache.clean() task_cache.clean() From cb81e2aacd5e47e64b0915b76b6ba6bbcf8e9033 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 14 Dec 2022 10:18:01 +0530 Subject: [PATCH 37/74] Fix a bug where the metadata output format wouldn't get sent to the backend --- ui/media/js/main.js | 1 + ui/media/js/parameters.js | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ui/media/js/main.js b/ui/media/js/main.js index b26f363a..26303420 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -894,6 +894,7 @@ function getCurrentUserRequest() { show_only_filtered_image: showOnlyFilteredImageField.checked, output_format: outputFormatField.value, output_quality: parseInt(outputQualityField.value), + metadata_output_format: document.querySelector('#metadata_output_format').value, original_prompt: promptField.value, active_tags: (activeTags.map(x => x.name)) } diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index a4167fc3..52a4b67c 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -57,11 +57,15 @@ var PARAMETERS = [ id: "metadata_output_format", type: ParameterType.select, label: "Metadata format", - note: "the metadata will be saved to disk in this format", + note: "will be saved to disk in this format", default: "txt", options: [ { value: "txt", + label: "txt" + }, + { + value: "json", label: "json" } ], From 0dbce101acec4f621dde37d0a2dc84826fd56088 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 14 Dec 2022 10:21:44 +0530 Subject: [PATCH 38/74] sampler -> sampler_name --- ui/media/js/dnd.js | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index 1887aeaf..fa4a2d43 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -160,9 +160,9 @@ const TASK_MAPPING = { readUI: () => (useUpscalingField.checked ? upscaleModelField.value : undefined), parse: (val) => val }, - sampler: { name: 'Sampler', - setUI: (sampler) => { - samplerField.value = sampler + sampler_name: { name: 'Sampler', + setUI: (sampler_name) => { + samplerField.value = sampler_name }, readUI: () => samplerField.value, parse: (val) => val @@ -351,7 +351,7 @@ const TASK_TEXT_MAPPING = { prompt_strength: 'Prompt Strength', use_face_correction: 'Use Face Correction', use_upscale: 'Use Upscaling', - sampler: 'Sampler', + sampler_name: 'Sampler', negative_prompt: 'Negative Prompt', use_stable_diffusion_model: 'Stable Diffusion model', use_hypernetwork_model: 'Hypernetwork model', From d103693811388ad8edd33d70de53139f22dd0150 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 14 Dec 2022 10:22:24 +0530 Subject: [PATCH 39/74] Bug in the metadata generation - made an array of None --- ui/sd_internal/renderer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index 0697c908..0420d51d 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -121,7 +121,10 @@ def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): 'use_upscale': task_data.use_upscale, }) - return [metadata.copy().update({'seed': req.seed + i}) for i in range(req.num_outputs)] + entries = [metadata.copy() for _ in range(req.num_outputs)] + for i, entry in enumerate(entries): + entry['seed'] = req.seed + i + return entries def get_printable_request(req: GenerateImageRequest): metadata = req.dict() From 84d606408ae103bf1a6c45981b3288993e6451ff Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 14 Dec 2022 10:31:19 +0530 Subject: [PATCH 40/74] Prompt is now a keyword in the new metadata format generated from diffusionkit --- ui/media/js/dnd.js | 1 + 1 file changed, 1 insertion(+) diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index fa4a2d43..6e733286 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -343,6 +343,7 @@ function getModelPath(filename, extensions) } const TASK_TEXT_MAPPING = { + prompt: 'Prompt', width: 'Width', height: 'Height', seed: 'Seed', From 7dc7f70582c2bd085b1e54f62659998367c4e90f Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 14 Dec 2022 10:34:36 +0530 Subject: [PATCH 41/74] Allow parsing .safetensors stable diffusion model path in the metadata parser --- ui/media/js/dnd.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index 6e733286..d09fe32f 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -171,7 +171,7 @@ const TASK_MAPPING = { setUI: (use_stable_diffusion_model) => { const oldVal = stableDiffusionModelField.value - use_stable_diffusion_model = getModelPath(use_stable_diffusion_model, ['.ckpt']) + use_stable_diffusion_model = getModelPath(use_stable_diffusion_model, ['.ckpt', '.safetensors']) stableDiffusionModelField.value = use_stable_diffusion_model if (!stableDiffusionModelField.value) { From d1ac90e16d9e5d2d396d21148fc7e4849de6c32a Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 14 Dec 2022 15:43:24 +0530 Subject: [PATCH 42/74] [metadata parsing] Support loading the flat JSON format saved by the next backend; Set the dropdown to None if the value is undefined or null in the metadata --- ui/media/js/dnd.js | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index d09fe32f..0f204326 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -184,6 +184,7 @@ const TASK_MAPPING = { use_vae_model: { name: 'VAE model', setUI: (use_vae_model) => { const oldVal = vaeModelField.value + use_vae_model = (use_vae_model === undefined || use_vae_model === null ? '' : use_vae_model) if (use_vae_model !== '') { use_vae_model = getModelPath(use_vae_model, ['.vae.pt', '.ckpt']) @@ -197,6 +198,7 @@ const TASK_MAPPING = { use_hypernetwork_model: { name: 'Hypernetwork model', setUI: (use_hypernetwork_model) => { const oldVal = hypernetworkModelField.value + use_hypernetwork_model = (use_hypernetwork_model === undefined || use_hypernetwork_model === null ? '' : use_hypernetwork_model) if (use_hypernetwork_model !== '') { use_hypernetwork_model = getModelPath(use_hypernetwork_model, ['.pt']) @@ -404,6 +406,9 @@ async function parseContent(text) { if (text.startsWith('{') && text.endsWith('}')) { try { const task = JSON.parse(text) + if (!('reqBody' in task)) { // support the format saved to the disk, by the UI + task.reqBody = Object.assign({}, task) + } restoreTaskToUI(task) return true } catch (e) { From 12e0194c7f5f189c2f9d2429465452bdfc5ef4ec Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 14 Dec 2022 16:30:08 +0530 Subject: [PATCH 43/74] Allow None as the value type in dnd parsing --- ui/media/js/dnd.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index 0f204326..37d99d44 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -184,7 +184,7 @@ const TASK_MAPPING = { use_vae_model: { name: 'VAE model', setUI: (use_vae_model) => { const oldVal = vaeModelField.value - use_vae_model = (use_vae_model === undefined || use_vae_model === null ? '' : use_vae_model) + use_vae_model = (use_vae_model === undefined || use_vae_model === null || use_vae_model === 'None' ? '' : use_vae_model) if (use_vae_model !== '') { use_vae_model = getModelPath(use_vae_model, ['.vae.pt', '.ckpt']) @@ -198,7 +198,7 @@ const TASK_MAPPING = { use_hypernetwork_model: { name: 'Hypernetwork model', setUI: (use_hypernetwork_model) => { const oldVal = hypernetworkModelField.value - use_hypernetwork_model = (use_hypernetwork_model === undefined || use_hypernetwork_model === null ? '' : use_hypernetwork_model) + use_hypernetwork_model = (use_hypernetwork_model === undefined || use_hypernetwork_model === null || use_hypernetwork_model === 'None' ? '' : use_hypernetwork_model) if (use_hypernetwork_model !== '') { use_hypernetwork_model = getModelPath(use_hypernetwork_model, ['.pt']) From 35ff4f439ef4cb30ed4d0aa201f355914bd7a27d Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 14 Dec 2022 16:30:19 +0530 Subject: [PATCH 44/74] Refactor save_to_disk --- ui/sd_internal/renderer.py | 57 ++----------------------- ui/sd_internal/save_utils.py | 80 ++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 54 deletions(-) create mode 100644 ui/sd_internal/save_utils.py diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index 0420d51d..55b2c4a7 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -1,13 +1,9 @@ import queue import time import json -import os -import base64 -import re -import traceback import logging -from sd_internal import device_manager +from sd_internal import device_manager, save_utils from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop from diffusionkit import model_loader, image_generator, image_utils, filters as image_filters, data_utils @@ -20,8 +16,6 @@ context = Context() # thread-local runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc ''' -filename_regex = re.compile('[^a-zA-Z0-9]') - def init(device): ''' Initializes the fields that will be bound to this runtime's context, and sets the current torch device @@ -33,7 +27,7 @@ def init(device): device_manager.device_init(context, device) def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): - log.info(f'request: {get_printable_request(req)}') + log.info(f'request: {save_utils.get_printable_request(req)}') log.info(f'task data: {task_data.dict()}') images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) @@ -50,8 +44,7 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q filtered_images = apply_filters(task_data, images, user_stopped) if task_data.save_to_disk_path is not None: - save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) - save_to_disk(images, filtered_images, save_folder_path, req, task_data) + save_utils.save_to_disk(images, filtered_images, req, task_data) return filtered_images if task_data.show_only_filtered_image else images + filtered_images @@ -92,17 +85,6 @@ def apply_filters(task_data: TaskData, images: list, user_stopped): return filtered_images -def save_to_disk(images: list, filtered_images: list, save_folder_path, req: GenerateImageRequest, task_data: TaskData): - metadata_entries = get_metadata_entries(req, task_data) - - if task_data.show_only_filtered_image or filtered_images == images: - data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format) - else: - data_utils.save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) - def construct_response(images: list, task_data: TaskData, base_seed: int): return [ ResponseImage( @@ -111,39 +93,6 @@ def construct_response(images: list, task_data: TaskData, base_seed: int): ) for i, img in enumerate(images) ] -def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): - metadata = get_printable_request(req) - metadata.update({ - 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, - 'use_vae_model': task_data.use_vae_model, - 'use_hypernetwork_model': task_data.use_hypernetwork_model, - 'use_face_correction': task_data.use_face_correction, - 'use_upscale': task_data.use_upscale, - }) - - entries = [metadata.copy() for _ in range(req.num_outputs)] - for i, entry in enumerate(entries): - entry['seed'] = req.seed + i - return entries - -def get_printable_request(req: GenerateImageRequest): - metadata = req.dict() - del metadata['init_image'] - del metadata['init_image_mask'] - return metadata - -def make_filename_callback(req: GenerateImageRequest, suffix=None): - def make_filename(i): - img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. - img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. - - prompt_flattened = filename_regex.sub('_', req.prompt)[:50] - name = f"{prompt_flattened}_{img_id}" - name = name if suffix is None else f'{name}_{suffix}' - return name - - return make_filename - def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) last_callback_time = -1 diff --git a/ui/sd_internal/save_utils.py b/ui/sd_internal/save_utils.py new file mode 100644 index 00000000..29c4a29c --- /dev/null +++ b/ui/sd_internal/save_utils.py @@ -0,0 +1,80 @@ +import os +import time +import base64 +import re + +from diffusionkit import data_utils +from diffusionkit.types import GenerateImageRequest + +from sd_internal import TaskData + +filename_regex = re.compile('[^a-zA-Z0-9]') + +# keep in sync with `ui/media/js/dnd.js` +TASK_TEXT_MAPPING = { + 'prompt': 'Prompt', + 'width': 'Width', + 'height': 'Height', + 'seed': 'Seed', + 'num_inference_steps': 'Steps', + 'guidance_scale': 'Guidance Scale', + 'prompt_strength': 'Prompt Strength', + 'use_face_correction': 'Use Face Correction', + 'use_upscale': 'Use Upscaling', + 'sampler_name': 'Sampler', + 'negative_prompt': 'Negative Prompt', + 'use_stable_diffusion_model': 'Stable Diffusion model', + 'use_hypernetwork_model': 'Hypernetwork model', + 'hypernetwork_strength': 'Hypernetwork Strength' +} + +def save_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): + save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) + metadata_entries = get_metadata_entries(req, task_data) + + if task_data.show_only_filtered_image or filtered_images == images: + data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format) + else: + data_utils.save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) + +def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): + metadata = get_printable_request(req) + metadata.update({ + 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, + 'use_vae_model': task_data.use_vae_model, + 'use_hypernetwork_model': task_data.use_hypernetwork_model, + 'use_face_correction': task_data.use_face_correction, + 'use_upscale': task_data.use_upscale, + }) + + # if text, format it in the text format expected by the UI + is_txt_format = (task_data.metadata_output_format.lower() == 'txt') + if is_txt_format: + metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING} + + entries = [metadata.copy() for _ in range(req.num_outputs)] + for i, entry in enumerate(entries): + entry['Seed' if is_txt_format else 'seed'] = req.seed + i + + return entries + +def get_printable_request(req: GenerateImageRequest): + metadata = req.dict() + del metadata['init_image'] + del metadata['init_image_mask'] + return metadata + +def make_filename_callback(req: GenerateImageRequest, suffix=None): + def make_filename(i): + img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. + img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. + + prompt_flattened = filename_regex.sub('_', req.prompt)[:50] + name = f"{prompt_flattened}_{img_id}" + name = name if suffix is None else f'{name}_{suffix}' + return name + + return make_filename \ No newline at end of file From fb075a0013d476120ce79f4eb2d9ed19132f75d0 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 14 Dec 2022 16:53:50 +0530 Subject: [PATCH 45/74] Fix whitespace --- ui/media/js/main.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ui/media/js/main.js b/ui/media/js/main.js index d309ad0a..c2dc32f3 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -629,9 +629,9 @@ function onTaskCompleted(task, reqBody, instance, outputContainer, stepUpdate) { time /= 1000 if (task.batchesDone == task.batchCount) { - if (!task.outputMsg.innerText.toLowerCase().includes('error')) { + if (!task.outputMsg.innerText.toLowerCase().includes('error')) { task.outputMsg.innerText = `Processed ${task.numOutputsTotal} images in ${time} seconds` - } + } task.progressBar.style.height = "0px" task.progressBar.style.border = "0px solid var(--background-color3)" task.progressBar.classList.remove("active") From aa01fd058ea44bd38526869c9dc0e04a31d6900f Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 15 Dec 2022 23:30:06 +0530 Subject: [PATCH 46/74] Set performance level (low, medium, high) instead of a Turbo field. The previous Turbo field is equivalent to 'Medium' performance now --- ui/media/js/auto-save.js | 2 +- ui/media/js/main.js | 4 +-- ui/media/js/parameters.js | 20 ++++++++---- ui/sd_internal/__init__.py | 2 +- ui/sd_internal/app.py | 4 ++- ui/sd_internal/device_manager.py | 12 ++++++++ ui/sd_internal/model_manager.py | 52 ++++++++++++++++++++------------ ui/sd_internal/task_manager.py | 2 +- ui/server.py | 2 +- 9 files changed, 67 insertions(+), 33 deletions(-) diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js index 1677c6c1..91a2d267 100644 --- a/ui/media/js/auto-save.js +++ b/ui/media/js/auto-save.js @@ -36,7 +36,7 @@ const SETTINGS_IDS_LIST = [ "save_to_disk", "diskPath", "sound_toggle", - "turbo", + "performance_level", "confirm_dangerous_actions", "metadata_output_format", "auto_save_settings", diff --git a/ui/media/js/main.js b/ui/media/js/main.js index c2dc32f3..3a669b67 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -602,7 +602,7 @@ function onTaskCompleted(task, reqBody, instance, outputContainer, stepUpdate) { Suggestions:
      1. If you have set an initial image, please try reducing its dimension to ${MAX_INIT_IMAGE_DIMENSION}x${MAX_INIT_IMAGE_DIMENSION} or smaller.
      - 2. Try disabling the 'Turbo mode' under 'Advanced Settings'.
      + 2. Try picking a lower performance level in the 'Performance Level' setting (in the 'Settings' tab).
      3. Try generating a smaller image.
      ` } } else { @@ -887,7 +887,7 @@ function getCurrentUserRequest() { width: parseInt(widthField.value), height: parseInt(heightField.value), // allow_nsfw: allowNSFWField.checked, - turbo: turboField.checked, + performance_level: perfLevelField.value, //render_device: undefined, // Set device affinity. Prefer this device, but wont activate. use_stable_diffusion_model: stableDiffusionModelField.value, use_vae_model: vaeModelField.value, diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index 52a4b67c..865f7bb5 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -94,12 +94,20 @@ var PARAMETERS = [ default: true, }, { - id: "turbo", - type: ParameterType.checkbox, - label: "Turbo Mode", - note: "generates images faster, but uses an additional 1 GB of GPU memory", + id: "performance_level", + type: ParameterType.select, + label: "Performance Level", + note: "Faster performance requires more GPU memory

      " + + "High: fastest, maximum GPU memory usage
      " + + "Medium: decent speed, uses 1 GB more memory than Low
      " + + "Low: slowest, for GPUs with 4 GB (or less) memory", icon: "fa-forward", - default: true, + default: "high", + options: [ + {value: "high", label: "High"}, + {value: "medium", label: "Medium"}, + {value: "low", label: "Low"} + ], }, { id: "use_cpu", @@ -219,7 +227,7 @@ function initParameters() { initParameters() -let turboField = document.querySelector('#turbo') +let perfLevelField = document.querySelector('#performance_level') let useCPUField = document.querySelector('#use_cpu') let autoPickGPUsField = document.querySelector('#auto_pick_gpus') let useGPUsField = document.querySelector('#use_gpus') diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index 71073216..5475c3a6 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -6,7 +6,7 @@ class TaskData(BaseModel): request_id: str = None session_id: str = "session" save_to_disk_path: str = None - turbo: bool = True + performance_level: str = "high" # or "low" or "medium" use_face_correction: str = None # or "GFPGANv1.3" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" diff --git a/ui/sd_internal/app.py b/ui/sd_internal/app.py index 47a2d610..7171e6c1 100644 --- a/ui/sd_internal/app.py +++ b/ui/sd_internal/app.py @@ -110,7 +110,7 @@ def setConfig(config): except: log.error(traceback.format_exc()) -def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name): +def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, performance_level): config = getConfig() if 'model' not in config: config['model'] = {} @@ -124,6 +124,8 @@ def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_nam if hypernetwork_model_name is None or hypernetwork_model_name == "": del config['model']['hypernetwork'] + config['performance_level'] = performance_level + setConfig(config) def update_render_threads(): diff --git a/ui/sd_internal/device_manager.py b/ui/sd_internal/device_manager.py index 733bab50..8b5c49be 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/sd_internal/device_manager.py @@ -128,6 +128,18 @@ def needs_to_force_full_precision(context): device_name = context.device_name.lower() return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) +def get_max_perf_level(device): + if device != 'cpu': + _, mem_total = torch.cuda.mem_get_info(device) + mem_total /= float(10**9) + + if mem_total < 4.5: + return 'low' + elif mem_total < 6.5: + return 'medium' + + return 'high' + def validate_device_id(device, log_prefix=''): def is_valid(): if not isinstance(device, str): diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index b6c4c92d..9c3320ea 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -3,7 +3,7 @@ import logging import picklescan.scanner import rich -from sd_internal import app, TaskData +from sd_internal import app, TaskData, device_manager from diffusionkit import model_loader from diffusionkit.types import Context @@ -25,6 +25,11 @@ DEFAULT_MODELS = { 'gfpgan': ['GFPGANv1.3'], 'realesrgan': ['RealESRGAN_x4plus'], } +PERF_LEVEL_TO_VRAM_OPTIMIZATIONS = { + 'low': {'KEEP_ENTIRE_MODEL_IN_CPU'}, + 'medium': {'KEEP_FS_AND_CS_IN_CPU', 'SET_ATTENTION_STEP_TO_4'}, + 'high': {}, +} known_models = {} @@ -37,8 +42,7 @@ def load_default_models(context: Context): for model_type in KNOWN_MODEL_TYPES: context.model_paths[model_type] = resolve_model_to_use(model_type=model_type) - # disable TURBO initially (this should be read from the config eventually) - context.vram_optimizations -= {'TURBO'} + set_vram_optimizations(context) # load mandatory models model_loader.load_model(context, 'stable-diffusion') @@ -94,20 +98,23 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None): return None def reload_models_if_necessary(context: Context, task_data: TaskData): - model_paths_in_req = ( - ('stable-diffusion', task_data.use_stable_diffusion_model), - ('vae', task_data.use_vae_model), - ('hypernetwork', task_data.use_hypernetwork_model), - ('gfpgan', task_data.use_face_correction), - ('realesrgan', task_data.use_upscale), - ) + model_paths_in_req = { + 'stable-diffusion': task_data.use_stable_diffusion_model, + 'vae': task_data.use_vae_model, + 'hypernetwork': task_data.use_hypernetwork_model, + 'gfpgan': task_data.use_face_correction, + 'realesrgan': task_data.use_upscale, + } + models_to_reload = {model_type: path for model_type, path in model_paths_in_req.items() if context.model_paths.get(model_type) != path} - for model_type, model_path_in_req in model_paths_in_req: - if context.model_paths.get(model_type) != model_path_in_req: - context.model_paths[model_type] = model_path_in_req + if set_vram_optimizations(context): # reload SD + models_to_reload['stable-diffusion'] = model_paths_in_req['stable-diffusion'] - action_fn = model_loader.unload_model if context.model_paths[model_type] is None else model_loader.load_model - action_fn(context, model_type) + for model_type, model_path_in_req in models_to_reload.items(): + context.model_paths[model_type] = model_path_in_req + + action_fn = model_loader.unload_model if context.model_paths[model_type] is None else model_loader.load_model + action_fn(context, model_type) def resolve_model_paths(task_data: TaskData): task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion') @@ -117,11 +124,16 @@ def resolve_model_paths(task_data: TaskData): if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan') if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan') -def set_vram_optimizations(context: Context, task_data: TaskData): - if task_data.turbo: - context.vram_optimizations.add('TURBO') - else: - context.vram_optimizations.remove('TURBO') +def set_vram_optimizations(context: Context): + config = app.getConfig() + perf_level = config.get('performance_level', device_manager.get_max_perf_level(context.device)) + vram_optimizations = PERF_LEVEL_TO_VRAM_OPTIMIZATIONS[perf_level] + + if vram_optimizations != context.vram_optimizations: + context.vram_optimizations = vram_optimizations + return True + + return False def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 3b8f6082..0780283d 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -281,7 +281,6 @@ def thread_render(device): current_state = ServerStates.LoadingModel model_manager.resolve_model_paths(task.task_data) - model_manager.set_vram_optimizations(renderer.context, task.task_data) model_manager.reload_models_if_necessary(renderer.context, task.task_data) current_state = ServerStates.Rendering @@ -342,6 +341,7 @@ def get_devices(): 'name': torch.cuda.get_device_name(device), 'mem_free': mem_free, 'mem_total': mem_total, + 'max_perf_level': device_manager.get_max_perf_level(device), } # list the compatible devices diff --git a/ui/server.py b/ui/server.py index 7308dfc8..11d3731a 100644 --- a/ui/server.py +++ b/ui/server.py @@ -134,7 +134,7 @@ def render(req: dict): render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision - app.save_model_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model) + app.save_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model, task_data.performance_level) # enqueue the task new_task = task_manager.render(render_req, task_data) From 7982a9ae257366653a1b9a305fde00de65d9b511 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 16 Dec 2022 11:34:49 +0530 Subject: [PATCH 47/74] Change the performance field to GPU Memory Usage instead, and use the 'balanced' profile by default, since it's just 5% slower than 'high', and uses nearly 50% less VRAM --- ui/media/js/auto-save.js | 2 +- ui/media/js/main.js | 4 ++-- ui/media/js/parameters.js | 14 +++++++------- ui/sd_internal/__init__.py | 2 +- ui/sd_internal/app.py | 4 ++-- ui/sd_internal/device_manager.py | 4 ++-- ui/sd_internal/model_manager.py | 23 +++++++++++++++++++---- ui/sd_internal/task_manager.py | 2 +- ui/server.py | 2 +- 9 files changed, 36 insertions(+), 21 deletions(-) diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js index 91a2d267..934b3f32 100644 --- a/ui/media/js/auto-save.js +++ b/ui/media/js/auto-save.js @@ -36,7 +36,7 @@ const SETTINGS_IDS_LIST = [ "save_to_disk", "diskPath", "sound_toggle", - "performance_level", + "vram_usage_level", "confirm_dangerous_actions", "metadata_output_format", "auto_save_settings", diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 3a669b67..1e7269b4 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -602,7 +602,7 @@ function onTaskCompleted(task, reqBody, instance, outputContainer, stepUpdate) { Suggestions:
      1. If you have set an initial image, please try reducing its dimension to ${MAX_INIT_IMAGE_DIMENSION}x${MAX_INIT_IMAGE_DIMENSION} or smaller.
      - 2. Try picking a lower performance level in the 'Performance Level' setting (in the 'Settings' tab).
      + 2. Try picking a lower level in the 'GPU Memory Usage' setting (in the 'Settings' tab).
      3. Try generating a smaller image.
      ` } } else { @@ -887,7 +887,7 @@ function getCurrentUserRequest() { width: parseInt(widthField.value), height: parseInt(heightField.value), // allow_nsfw: allowNSFWField.checked, - performance_level: perfLevelField.value, + vram_usage_level: vramUsageLevelField.value, //render_device: undefined, // Set device affinity. Prefer this device, but wont activate. use_stable_diffusion_model: stableDiffusionModelField.value, use_vae_model: vaeModelField.value, diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index 865f7bb5..93f1a266 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -94,18 +94,18 @@ var PARAMETERS = [ default: true, }, { - id: "performance_level", + id: "vram_usage_level", type: ParameterType.select, - label: "Performance Level", + label: "GPU Memory Usage", note: "Faster performance requires more GPU memory

      " + + "Balanced: almost as fast as High, significantly lower GPU memory usage
      " + "High: fastest, maximum GPU memory usage
      " + - "Medium: decent speed, uses 1 GB more memory than Low
      " + - "Low: slowest, for GPUs with 4 GB (or less) memory", + "Low: slowest, force-used for GPUs with 4 GB (or less) memory", icon: "fa-forward", - default: "high", + default: "balanced", options: [ + {value: "balanced", label: "Balanced"}, {value: "high", label: "High"}, - {value: "medium", label: "Medium"}, {value: "low", label: "Low"} ], }, @@ -227,7 +227,7 @@ function initParameters() { initParameters() -let perfLevelField = document.querySelector('#performance_level') +let vramUsageLevelField = document.querySelector('#vram_usage_level') let useCPUField = document.querySelector('#use_cpu') let autoPickGPUsField = document.querySelector('#auto_pick_gpus') let useGPUsField = document.querySelector('#use_gpus') diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index 5475c3a6..6f702749 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -6,7 +6,7 @@ class TaskData(BaseModel): request_id: str = None session_id: str = "session" save_to_disk_path: str = None - performance_level: str = "high" # or "low" or "medium" + vram_usage_level: str = "balanced" # or "low" or "medium" use_face_correction: str = None # or "GFPGANv1.3" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" diff --git a/ui/sd_internal/app.py b/ui/sd_internal/app.py index 7171e6c1..063069d2 100644 --- a/ui/sd_internal/app.py +++ b/ui/sd_internal/app.py @@ -110,7 +110,7 @@ def setConfig(config): except: log.error(traceback.format_exc()) -def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, performance_level): +def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level): config = getConfig() if 'model' not in config: config['model'] = {} @@ -124,7 +124,7 @@ def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, per if hypernetwork_model_name is None or hypernetwork_model_name == "": del config['model']['hypernetwork'] - config['performance_level'] = performance_level + config['vram_usage_level'] = vram_usage_level setConfig(config) diff --git a/ui/sd_internal/device_manager.py b/ui/sd_internal/device_manager.py index 8b5c49be..56508fad 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/sd_internal/device_manager.py @@ -128,7 +128,7 @@ def needs_to_force_full_precision(context): device_name = context.device_name.lower() return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) -def get_max_perf_level(device): +def get_max_vram_usage_level(device): if device != 'cpu': _, mem_total = torch.cuda.mem_get_info(device) mem_total /= float(10**9) @@ -136,7 +136,7 @@ def get_max_perf_level(device): if mem_total < 4.5: return 'low' elif mem_total < 6.5: - return 'medium' + return 'balanced' return 'high' diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 9c3320ea..56f1d9c0 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -25,9 +25,9 @@ DEFAULT_MODELS = { 'gfpgan': ['GFPGANv1.3'], 'realesrgan': ['RealESRGAN_x4plus'], } -PERF_LEVEL_TO_VRAM_OPTIMIZATIONS = { +VRAM_USAGE_LEVEL_TO_OPTIMIZATIONS = { + 'balanced': {'KEEP_FS_AND_CS_IN_CPU', 'SET_ATTENTION_STEP_TO_4'}, 'low': {'KEEP_ENTIRE_MODEL_IN_CPU'}, - 'medium': {'KEEP_FS_AND_CS_IN_CPU', 'SET_ATTENTION_STEP_TO_4'}, 'high': {}, } @@ -125,9 +125,24 @@ def resolve_model_paths(task_data: TaskData): if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan') def set_vram_optimizations(context: Context): + def is_greater(a, b): # is a > b? + if a == "low": # b will be "low", "balanced" or "high" + return False + elif a == "balanced" and b != "low": # b will be "balanced" or "high" + return False + return True + config = app.getConfig() - perf_level = config.get('performance_level', device_manager.get_max_perf_level(context.device)) - vram_optimizations = PERF_LEVEL_TO_VRAM_OPTIMIZATIONS[perf_level] + + max_usage_level = device_manager.get_max_vram_usage_level(context.device) + vram_usage_level = config.get('vram_usage_level', 'balanced') + + if is_greater(vram_usage_level, max_usage_level): + log.error(f'Requested GPU Memory Usage level ({vram_usage_level}) is higher than what is ' + \ + f'possible ({max_usage_level}) on this device ({context.device}). Using "{max_usage_level}" instead') + vram_usage_level = max_usage_level + + vram_optimizations = VRAM_USAGE_LEVEL_TO_OPTIMIZATIONS[vram_usage_level] if vram_optimizations != context.vram_optimizations: context.vram_optimizations = vram_optimizations diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 0780283d..094d853c 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -341,7 +341,7 @@ def get_devices(): 'name': torch.cuda.get_device_name(device), 'mem_free': mem_free, 'mem_total': mem_total, - 'max_perf_level': device_manager.get_max_perf_level(device), + 'max_vram_usage_level': device_manager.get_max_vram_usage_level(device), } # list the compatible devices diff --git a/ui/server.py b/ui/server.py index 11d3731a..59881975 100644 --- a/ui/server.py +++ b/ui/server.py @@ -134,7 +134,7 @@ def render(req: dict): render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision - app.save_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model, task_data.performance_level) + app.save_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model, task_data.vram_usage_level) # enqueue the task new_task = task_manager.render(render_req, task_data) From 25639cc3f80f56486308117ab26adee4d15cf19f Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 16 Dec 2022 14:11:55 +0530 Subject: [PATCH 48/74] Tweak Memory Usage setting text; Fix a bug with the memory usage setting comparison --- ui/media/js/parameters.js | 4 ++-- ui/sd_internal/model_manager.py | 10 ++-------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index 93f1a266..a8bd3b86 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -97,8 +97,8 @@ var PARAMETERS = [ id: "vram_usage_level", type: ParameterType.select, label: "GPU Memory Usage", - note: "Faster performance requires more GPU memory

      " + - "Balanced: almost as fast as High, significantly lower GPU memory usage
      " + + note: "Faster performance requires more GPU memory (VRAM)

      " + + "Balanced: nearly as fast as High, much lower VRAM usage
      " + "High: fastest, maximum GPU memory usage
      " + "Low: slowest, force-used for GPUs with 4 GB (or less) memory", icon: "fa-forward", diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 56f1d9c0..e7729b81 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -125,19 +125,13 @@ def resolve_model_paths(task_data: TaskData): if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan') def set_vram_optimizations(context: Context): - def is_greater(a, b): # is a > b? - if a == "low": # b will be "low", "balanced" or "high" - return False - elif a == "balanced" and b != "low": # b will be "balanced" or "high" - return False - return True - config = app.getConfig() max_usage_level = device_manager.get_max_vram_usage_level(context.device) vram_usage_level = config.get('vram_usage_level', 'balanced') - if is_greater(vram_usage_level, max_usage_level): + v = {'low': 0, 'balanced': 1, 'high': 2} + if v[vram_usage_level] > v[max_usage_level]: log.error(f'Requested GPU Memory Usage level ({vram_usage_level}) is higher than what is ' + \ f'possible ({max_usage_level}) on this device ({context.device}). Using "{max_usage_level}" instead') vram_usage_level = max_usage_level From aa8b50280b66cc69ecb449a37068af9bab2e2a43 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 16 Dec 2022 15:31:55 +0530 Subject: [PATCH 49/74] Remove the test_sd2 flag, the code now works with SD 2.0 --- ui/media/js/parameters.js | 18 +----------------- ui/sd_internal/app.py | 4 ---- ui/sd_internal/model_manager.py | 5 ----- ui/server.py | 3 --- 4 files changed, 1 insertion(+), 29 deletions(-) diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index a8bd3b86..4e7db4f5 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -164,14 +164,6 @@ var PARAMETERS = [ return `` } }, - { - id: "test_sd2", - type: ParameterType.checkbox, - label: "Test SD 2.0", - note: "Experimental! High memory usage! GPU-only! Not the final version! Please restart the program after changing this.", - icon: "fa-fire", - default: false, - }, { id: "use_beta_channel", type: ParameterType.checkbox, @@ -235,7 +227,6 @@ let saveToDiskField = document.querySelector('#save_to_disk') let diskPathField = document.querySelector('#diskPath') let listenToNetworkField = document.querySelector("#listen_to_network") let listenPortField = document.querySelector("#listen_port") -let testSD2Field = document.querySelector("#test_sd2") let useBetaChannelField = document.querySelector("#use_beta_channel") let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start") let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions") @@ -272,12 +263,6 @@ async function getAppConfig() { if (config.ui && config.ui.open_browser_on_start === false) { uiOpenBrowserOnStartField.checked = false } - if ('test_sd2' in config) { - testSD2Field.checked = config['test_sd2'] - } - - let testSD2SettingEntry = getParameterSettingsEntry('test_sd2') - testSD2SettingEntry.style.display = (config.update_branch === 'beta' ? '' : 'none') if (config.net && config.net.listen_to_network === false) { listenToNetworkField.checked = false } @@ -442,8 +427,7 @@ saveSettingsBtn.addEventListener('click', function() { 'update_branch': updateBranch, 'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked, 'listen_to_network': listenToNetworkField.checked, - 'listen_port': listenPortField.value, - 'test_sd2': testSD2Field.checked + 'listen_port': listenPortField.value }) saveSettingsBtn.classList.add('active') asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active')) diff --git a/ui/sd_internal/app.py b/ui/sd_internal/app.py index 063069d2..471ee074 100644 --- a/ui/sd_internal/app.py +++ b/ui/sd_internal/app.py @@ -83,8 +83,6 @@ def setConfig(config): bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") - config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}") - if len(config_bat) > 0: with open(config_bat_path, 'w', encoding='utf-8') as f: f.write('\r\n'.join(config_bat)) @@ -102,8 +100,6 @@ def setConfig(config): bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") - config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"") - if len(config_sh) > 1: with open(config_sh_path, 'w', encoding='utf-8') as f: f.write('\n'.join(config_sh)) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index e7729b81..30e97053 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -65,11 +65,6 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None): model_name = config['model'][model_type] if model_name: - is_sd2 = config.get('test_sd2', False) - if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4 - log.error('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!') - model_name = 'sd-v1-4' - # Check models directory models_dir_path = os.path.join(app.MODELS_DIR, model_type, model_name) for model_extension in model_extensions: diff --git a/ui/server.py b/ui/server.py index 59881975..284fed1b 100644 --- a/ui/server.py +++ b/ui/server.py @@ -41,7 +41,6 @@ class SetAppConfigRequest(BaseModel): ui_open_browser_on_start: bool = None listen_to_network: bool = None listen_port: int = None - test_sd2: bool = None server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media") @@ -67,8 +66,6 @@ async def setAppConfig(req : SetAppConfigRequest): if 'net' not in config: config['net'] = {} config['net']['listen_port'] = int(req.listen_port) - if req.test_sd2 is not None: - config['test_sd2'] = req.test_sd2 try: app.setConfig(config) From 8189b38e6eec2f749c8a049991586bbb06f1185c Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sat, 17 Dec 2022 15:59:09 +0530 Subject: [PATCH 50/74] Typo in decoding live preview images --- ui/sd_internal/renderer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index 55b2c4a7..ce6032bc 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -99,15 +99,14 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu def update_temp_img(x_samples, task_temp_images: list): partial_images = [] - for i in range(req.num_outputs): - img = image_utils.latent_to_img(context, x_samples[i].unsqueeze(0)) + images = image_utils.latent_samples_to_images(context, x_samples) + for i, img in enumerate(images): buf = image_utils.img_to_buffer(img, output_format='JPEG') - del img - context.temp_images[f"{task_data.request_id}/{i}"] = buf task_temp_images[i] = buf partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"}) + del images return partial_images def on_image_step(x_samples, i): From 1595f1ed05d3e03c58a55da3d7f2d3f9f2e27650 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sat, 17 Dec 2022 16:45:43 +0530 Subject: [PATCH 51/74] Add 6 new samplers; Fix a bug where new tasks wouldn't started if a previous task was stopped --- ui/index.html | 22 ++++++++++++++-------- ui/sd_internal/renderer.py | 1 + 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/ui/index.html b/ui/index.html index c746f6ea..9f36033f 100644 --- a/ui/index.html +++ b/ui/index.html @@ -133,14 +133,20 @@ Click to learn more about samplers diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index ce6032bc..d9ab3f0d 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -27,6 +27,7 @@ def init(device): device_manager.device_init(context, device) def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): + context.stop_processing = False log.info(f'request: {save_utils.get_printable_request(req)}') log.info(f'task data: {task_data.dict()}') From e483071894a9fb6827b4eb7f07fdbfe031dfdffe Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 19 Dec 2022 19:27:28 +0530 Subject: [PATCH 52/74] Rename diffusionkit to sdkit; Delete runtime.py (historic moment) --- ui/sd_internal/__init__.py | 2 +- ui/sd_internal/model_manager.py | 5 +- ui/sd_internal/renderer.py | 4 +- ui/sd_internal/runtime.py | 1078 ------------------------------- ui/sd_internal/save_utils.py | 14 +- ui/sd_internal/task_manager.py | 2 +- ui/server.py | 2 +- 7 files changed, 14 insertions(+), 1093 deletions(-) delete mode 100644 ui/sd_internal/runtime.py diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index 6f702749..3a748431 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from diffusionkit.types import GenerateImageRequest +from sdkit.types import GenerateImageRequest class TaskData(BaseModel): request_id: str = None diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 30e97053..ec72e6fa 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -1,11 +1,10 @@ import os import logging import picklescan.scanner -import rich from sd_internal import app, TaskData, device_manager -from diffusionkit import model_loader -from diffusionkit.types import Context +from sdkit.models import model_loader +from sdkit.types import Context log = logging.getLogger() diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index d9ab3f0d..1d289c99 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -6,8 +6,8 @@ import logging from sd_internal import device_manager, save_utils from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop -from diffusionkit import model_loader, image_generator, image_utils, filters as image_filters, data_utils -from diffusionkit.types import Context, GenerateImageRequest, FilterImageRequest +from sdkit import model_loader, image_generator, image_utils, filters as image_filters +from sdkit.types import Context, GenerateImageRequest, FilterImageRequest log = logging.getLogger() diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py deleted file mode 100644 index 2bf53d0c..00000000 --- a/ui/sd_internal/runtime.py +++ /dev/null @@ -1,1078 +0,0 @@ -"""runtime.py: torch device owned by a thread. -Notes: - Avoid device switching, transfering all models will get too complex. - To use a diffrent device signal the current render device to exit - And then start a new clean thread for the new device. -""" -import json -import os, re -import traceback -import queue -import torch -import numpy as np -from gc import collect as gc_collect -from omegaconf import OmegaConf -from PIL import Image, ImageOps -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange -import time -from pytorch_lightning import seed_everything -from torch import autocast -from contextlib import nullcontext -from einops import rearrange, repeat -from ldm.util import instantiate_from_config -from transformers import logging - -from gfpgan import GFPGANer -from basicsr.archs.rrdbnet_arch import RRDBNet -from realesrgan import RealESRGANer - -from server import HYPERNETWORK_MODEL_EXTENSIONS# , STABLE_DIFFUSION_MODEL_EXTENSIONS, VAE_MODEL_EXTENSIONS - -from threading import Lock -from safetensors.torch import load_file - -import uuid - -logging.set_verbosity_error() - -# consts -config_yaml = "optimizedSD/v1-inference.yaml" -filename_regex = re.compile('[^a-zA-Z0-9]') -gfpgan_temp_device_lock = Lock() # workaround: gfpgan currently can only start on one device at a time. - -# api stuff -from sd_internal import device_manager -from . import Request, Response, Image as ResponseImage -import base64 -from io import BytesIO -#from colorama import Fore - -from threading import local as LocalThreadVars -thread_data = LocalThreadVars() - -def thread_init(device): - # Thread bound properties - thread_data.stop_processing = False - thread_data.temp_images = {} - - thread_data.ckpt_file = None - thread_data.vae_file = None - thread_data.hypernetwork_file = None - thread_data.gfpgan_file = None - thread_data.real_esrgan_file = None - - thread_data.model = None - thread_data.modelCS = None - thread_data.modelFS = None - thread_data.hypernetwork = None - thread_data.hypernetwork_strength = 1 - thread_data.model_gfpgan = None - thread_data.model_real_esrgan = None - - thread_data.model_is_half = False - thread_data.model_fs_is_half = False - thread_data.device = None - thread_data.device_name = None - thread_data.unet_bs = 1 - thread_data.precision = 'autocast' - thread_data.sampler_plms = None - thread_data.sampler_ddim = None - - thread_data.turbo = False - thread_data.force_full_precision = False - thread_data.reduced_memory = True - - thread_data.test_sd2 = isSD2() - - device_manager.device_init(thread_data, device) - -# temp hack, will remove soon -def isSD2(): - try: - SD_UI_DIR = os.getenv('SD_UI_PATH', None) - CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) - config_json_path = os.path.join(CONFIG_DIR, 'config.json') - if not os.path.exists(config_json_path): - return False - with open(config_json_path, 'r', encoding='utf-8') as f: - config = json.load(f) - return config.get('test_sd2', False) - except Exception as e: - return False - -def load_model_ckpt(): - if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.') - if os.path.exists(thread_data.ckpt_file + '.ckpt'): - thread_data.ckpt_file += '.ckpt' - elif os.path.exists(thread_data.ckpt_file + '.safetensors'): - thread_data.ckpt_file += '.safetensors' - elif not os.path.exists(thread_data.ckpt_file): - raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt or .safetensors') - - if not thread_data.precision: - thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast' - - if not thread_data.unet_bs: - thread_data.unet_bs = 1 - - if thread_data.device == 'cpu': - thread_data.precision = 'full' - - print('loading', thread_data.ckpt_file, 'to device', thread_data.device, 'using precision', thread_data.precision) - - if thread_data.test_sd2: - load_model_ckpt_sd2() - else: - load_model_ckpt_sd1() - -def load_model_ckpt_sd1(): - sd, model_ver = load_model_from_config(thread_data.ckpt_file) - li, lo = [], [] - for key, value in sd.items(): - sp = key.split(".") - if (sp[0]) == "model": - if "input_blocks" in sp: - li.append(key) - elif "middle_block" in sp: - li.append(key) - elif "time_embed" in sp: - li.append(key) - else: - lo.append(key) - for key in li: - sd["model1." + key[6:]] = sd.pop(key) - for key in lo: - sd["model2." + key[6:]] = sd.pop(key) - - config = OmegaConf.load(f"{config_yaml}") - - model = instantiate_from_config(config.modelUNet) - _, _ = model.load_state_dict(sd, strict=False) - model.eval() - model.cdevice = torch.device(thread_data.device) - model.unet_bs = thread_data.unet_bs - model.turbo = thread_data.turbo - # if thread_data.device != 'cpu': - # model.to(thread_data.device) - #if thread_data.reduced_memory: - #model.model1.to("cpu") - #model.model2.to("cpu") - thread_data.model = model - - modelCS = instantiate_from_config(config.modelCondStage) - _, _ = modelCS.load_state_dict(sd, strict=False) - modelCS.eval() - modelCS.cond_stage_model.device = torch.device(thread_data.device) - # if thread_data.device != 'cpu': - # if thread_data.reduced_memory: - # modelCS.to('cpu') - # else: - # modelCS.to(thread_data.device) # Preload on device if not already there. - thread_data.modelCS = modelCS - - modelFS = instantiate_from_config(config.modelFirstStage) - _, _ = modelFS.load_state_dict(sd, strict=False) - - if thread_data.vae_file is not None: - try: - loaded = False - for model_extension in ['.ckpt', '.vae.pt']: - if os.path.exists(thread_data.vae_file + model_extension): - print(f"Loading VAE weights from: {thread_data.vae_file}{model_extension}") - vae_ckpt = torch.load(thread_data.vae_file + model_extension, map_location="cpu") - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} - modelFS.first_stage_model.load_state_dict(vae_dict, strict=False) - loaded = True - break - - if not loaded: - print(f'Cannot find VAE: {thread_data.vae_file}') - thread_data.vae_file = None - except: - print(traceback.format_exc()) - print(f'Could not load VAE: {thread_data.vae_file}') - thread_data.vae_file = None - - modelFS.eval() - # if thread_data.device != 'cpu': - # if thread_data.reduced_memory: - # modelFS.to('cpu') - # else: - # modelFS.to(thread_data.device) # Preload on device if not already there. - thread_data.modelFS = modelFS - del sd - - if thread_data.device != "cpu" and thread_data.precision == "autocast": - thread_data.model.half() - thread_data.modelCS.half() - thread_data.modelFS.half() - thread_data.model_is_half = True - thread_data.model_fs_is_half = True - else: - thread_data.model_is_half = False - thread_data.model_fs_is_half = False - - print(f'''loaded model - model file: {thread_data.ckpt_file} - model.device: {model.device} - modelCS.device: {modelCS.cond_stage_model.device} - modelFS.device: {thread_data.modelFS.device} - using precision: {thread_data.precision}''') - -def load_model_ckpt_sd2(): - sd, model_ver = load_model_from_config(thread_data.ckpt_file) - - config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if model_ver == 'sd2' else "configs/stable-diffusion/v1-inference.yaml" - config = OmegaConf.load(config_file) - verbose = False - - thread_data.model = instantiate_from_config(config.model) - m, u = thread_data.model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - thread_data.model.to(thread_data.device) - thread_data.model.eval() - del sd - - thread_data.model.cond_stage_model.device = torch.device(thread_data.device) - - if thread_data.device != "cpu" and thread_data.precision == "autocast": - thread_data.model.half() - thread_data.model_is_half = True - thread_data.model_fs_is_half = True - else: - thread_data.model_is_half = False - thread_data.model_fs_is_half = False - - print(f'''loaded model - model file: {thread_data.ckpt_file} - using precision: {thread_data.precision}''') - -def unload_filters(): - if thread_data.model_gfpgan is not None: - if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu') - - del thread_data.model_gfpgan - thread_data.model_gfpgan = None - - if thread_data.model_real_esrgan is not None: - if thread_data.device != 'cpu': thread_data.model_real_esrgan.model.to('cpu') - - del thread_data.model_real_esrgan - thread_data.model_real_esrgan = None - - gc() - -def unload_models(): - if thread_data.model is not None: - print('Unloading models...') - if thread_data.device != 'cpu': - if not thread_data.test_sd2: - thread_data.modelFS.to('cpu') - thread_data.modelCS.to('cpu') - thread_data.model.model1.to("cpu") - thread_data.model.model2.to("cpu") - - del thread_data.model - del thread_data.modelCS - del thread_data.modelFS - - thread_data.model = None - thread_data.modelCS = None - thread_data.modelFS = None - - gc() - -# def wait_model_move_to(model, target_device): # Send to target_device and wait until complete. -# if thread_data.device == target_device: return -# start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 -# if start_mem <= 0: return -# model_name = model.__class__.__name__ -# print(f'Device {thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mb') -# start_time = time.time() -# model.to(target_device) -# time_step = start_time -# WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout. -# last_mem = start_mem -# is_transfering = True -# while is_transfering: -# time.sleep(0.5) # 500ms -# mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 -# is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time. -# last_mem = mem -# if not is_transfering: -# break; -# if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity. -# print(f'Device {thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mb, Transfered: {round(start_mem - mem)}Mb') -# time_step = time.time() -# print(f'Device {thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mb in {round(time.time() - start_time, 3)} seconds to {target_device}') - -def move_to_cpu(model): - if thread_data.device != "cpu": - d = torch.device(thread_data.device) - mem = torch.cuda.memory_allocated(d) / 1e6 - model.to("cpu") - while torch.cuda.memory_allocated(d) / 1e6 >= mem: - time.sleep(1) - -def load_model_gfpgan(): - if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.') - model_path = thread_data.gfpgan_file + ".pth" - thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) - print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision) - -def load_model_real_esrgan(): - if thread_data.real_esrgan_file is None: raise ValueError(f'Thread real_esrgan_file is undefined.') - model_path = thread_data.real_esrgan_file + ".pth" - - RealESRGAN_models = { - 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), - 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) - } - - model_to_use = RealESRGAN_models[thread_data.real_esrgan_file] - - if thread_data.device == 'cpu': - thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half - #thread_data.model_real_esrgan.device = torch.device(thread_data.device) - thread_data.model_real_esrgan.model.to('cpu') - else: - thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half) - - thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file - print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision) - - -def get_session_out_path(disk_path, session_id): - if disk_path is None: return None - if session_id is None: return None - - session_out_path = os.path.join(disk_path, filename_regex.sub('_',session_id)) - os.makedirs(session_out_path, exist_ok=True) - return session_out_path - -def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None): - if disk_path is None: return None - if session_id is None: return None - if ext is None: raise Exception('Missing ext') - - session_out_path = get_session_out_path(disk_path, session_id) - - prompt_flattened = filename_regex.sub('_', prompt)[:50] - - if suffix is not None: - return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}") - return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}") - -def apply_filters(filter_name, image_data, model_path=None): - print(f'Applying filter {filter_name}...') - gc() # Free space before loading new data. - - if isinstance(image_data, torch.Tensor): - image_data.to(thread_data.device) - - if filter_name == 'gfpgan': - # This lock is only ever used here. No need to use timeout for the request. Should never deadlock. - with gfpgan_temp_device_lock: # Wait for any other devices to complete before starting. - # hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files - from facexlib.detection import retinaface - retinaface.device = torch.device(thread_data.device) - print('forced retinaface.device to', thread_data.device) - - if model_path is not None and model_path != thread_data.gfpgan_file: - thread_data.gfpgan_file = model_path - load_model_gfpgan() - elif not thread_data.model_gfpgan: - load_model_gfpgan() - if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.') - - print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision) - _, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) - image_data = output[:,:,::-1] - - if filter_name == 'real_esrgan': - if model_path is not None and model_path != thread_data.real_esrgan_file: - thread_data.real_esrgan_file = model_path - load_model_real_esrgan() - elif not thread_data.model_real_esrgan: - load_model_real_esrgan() - if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.') - print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision) - output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1]) - image_data = output[:,:,::-1] - - return image_data - -def is_model_reload_necessary(req: Request): - # custom model support: - # the req.use_stable_diffusion_model needs to be a valid path - # to the ckpt file (without the extension). - if os.path.exists(req.use_stable_diffusion_model + '.ckpt'): - req.use_stable_diffusion_model += '.ckpt' - elif os.path.exists(req.use_stable_diffusion_model + '.safetensors'): - req.use_stable_diffusion_model += '.safetensors' - elif not os.path.exists(req.use_stable_diffusion_model): - raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt or .safetensors') - - needs_model_reload = False - if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: - thread_data.ckpt_file = req.use_stable_diffusion_model - thread_data.vae_file = req.use_vae_model - needs_model_reload = True - - if thread_data.device != 'cpu': - if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ - (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): - thread_data.precision = 'full' if req.use_full_precision else 'autocast' - needs_model_reload = True - - return needs_model_reload - -def reload_model(): - unload_models() - unload_filters() - load_model_ckpt() - -def is_hypernetwork_reload_necessary(req: Request): - needs_model_reload = False - if thread_data.hypernetwork_file != req.use_hypernetwork_model: - thread_data.hypernetwork_file = req.use_hypernetwork_model - needs_model_reload = True - - return needs_model_reload - -def load_hypernetwork(): - if thread_data.test_sd2: - # Not yet supported in SD2 - return - - from . import hypernetwork - if thread_data.hypernetwork_file is not None: - try: - loaded = False - for model_extension in HYPERNETWORK_MODEL_EXTENSIONS: - if os.path.exists(thread_data.hypernetwork_file + model_extension): - print(f"Loading hypernetwork weights from: {thread_data.hypernetwork_file}{model_extension}") - thread_data.hypernetwork = hypernetwork.load_hypernetwork(thread_data.hypernetwork_file + model_extension) - loaded = True - break - - if not loaded: - print(f'Cannot find hypernetwork: {thread_data.hypernetwork_file}') - thread_data.hypernetwork_file = None - except: - print(traceback.format_exc()) - print(f'Could not load hypernetwork: {thread_data.hypernetwork_file}') - thread_data.hypernetwork_file = None - -def unload_hypernetwork(): - if thread_data.hypernetwork is not None: - print('Unloading hypernetwork...') - if thread_data.device != 'cpu': - for i in thread_data.hypernetwork: - thread_data.hypernetwork[i][0].to('cpu') - thread_data.hypernetwork[i][1].to('cpu') - del thread_data.hypernetwork - thread_data.hypernetwork = None - - gc() - -def reload_hypernetwork(): - unload_hypernetwork() - load_hypernetwork() - -def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): - try: - return do_mk_img(req, data_queue, task_temp_images, step_callback) - except Exception as e: - print(traceback.format_exc()) - - if thread_data.device != 'cpu' and not thread_data.test_sd2: - thread_data.modelFS.to('cpu') - thread_data.modelCS.to('cpu') - thread_data.model.model1.to("cpu") - thread_data.model.model2.to("cpu") - - gc() # Release from memory. - data_queue.put(json.dumps({ - "status": 'failed', - "detail": str(e) - })) - raise e - -def update_temp_img(req, x_samples, task_temp_images: list): - partial_images = [] - for i in range(req.num_outputs): - if thread_data.test_sd2: - x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0)) - else: - x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) - x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") - x_sample = x_sample.astype(np.uint8) - img = Image.fromarray(x_sample) - buf = img_to_buffer(img, output_format='JPEG') - - del img, x_sample, x_sample_ddim - # don't delete x_samples, it is used in the code that called this callback - - thread_data.temp_images[f'{req.request_id}/{i}'] = buf - task_temp_images[i] = buf - partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'}) - return partial_images - -# Build and return the apropriate generator for do_mk_img -def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None): - if not req.stream_progress_updates: - def empty_callback(x_samples, i): - step_callback() - return empty_callback - - thread_data.partial_x_samples = None - last_callback_time = -1 - def img_callback(x_samples, i): - nonlocal last_callback_time - - thread_data.partial_x_samples = x_samples - step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 - last_callback_time = time.time() - - progress = {"step": i, "step_time": step_time} - if extra_props is not None: - progress.update(extra_props) - - if req.stream_image_progress and i % 5 == 0: - progress['output'] = update_temp_img(req, x_samples, task_temp_images) - - data_queue.put(json.dumps(progress)) - - step_callback() - - if thread_data.stop_processing: - raise UserInitiatedStop("User requested that we stop processing") - return img_callback - -def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): - thread_data.stop_processing = False - - res = Response() - res.request = req - res.images = [] - thread_data.hypernetwork_strength = req.hypernetwork_strength - - thread_data.temp_images.clear() - - if thread_data.turbo != req.turbo and not thread_data.test_sd2: - thread_data.turbo = req.turbo - thread_data.model.turbo = req.turbo - - # Start by cleaning memory, loading and unloading things can leave memory allocated. - gc() - - opt_prompt = req.prompt - opt_seed = req.seed - opt_n_iter = 1 - opt_C = 4 - opt_f = 8 - opt_ddim_eta = 0.0 - - print(req, '\n device', torch.device(thread_data.device), "as", thread_data.device_name) - print('\n\n Using precision:', thread_data.precision) - - seed_everything(opt_seed) - - batch_size = req.num_outputs - prompt = opt_prompt - assert prompt is not None - data = [batch_size * [prompt]] - - if thread_data.precision == "autocast" and thread_data.device != "cpu": - precision_scope = autocast - else: - precision_scope = nullcontext - - mask = None - - if req.init_image is None: - handler = _txt2img - - init_latent = None - t_enc = None - else: - handler = _img2img - - init_image = load_img(req.init_image, req.width, req.height) - init_image = init_image.to(thread_data.device) - - if thread_data.device != "cpu" and thread_data.precision == "autocast": - init_image = init_image.half() - - if not thread_data.test_sd2: - thread_data.modelFS.to(thread_data.device) - - init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) - if thread_data.test_sd2: - init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space - else: - init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space - - if req.mask is not None: - mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device) - mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0) - mask = repeat(mask, '1 ... -> b ...', b=batch_size) - - if thread_data.device != "cpu" and thread_data.precision == "autocast": - mask = mask.half() - - # Send to CPU and wait until complete. - # wait_model_move_to(thread_data.modelFS, 'cpu') - if not thread_data.test_sd2: - move_to_cpu(thread_data.modelFS) - - assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]' - t_enc = int(req.prompt_strength * req.num_inference_steps) - print(f"target t_enc is {t_enc} steps") - - with torch.no_grad(): - for n in trange(opt_n_iter, desc="Sampling"): - for prompts in tqdm(data, desc="data"): - - with precision_scope("cuda"): - if thread_data.reduced_memory and not thread_data.test_sd2: - thread_data.modelCS.to(thread_data.device) - uc = None - if req.guidance_scale != 1.0: - if thread_data.test_sd2: - uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt]) - else: - uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) - if isinstance(prompts, tuple): - prompts = list(prompts) - - subprompts, weights = split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - c = torch.zeros_like(uc) - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(len(subprompts)): - weight = weights[i] - # if not skip_normalize: - weight = weight / totalWeight - if thread_data.test_sd2: - c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight) - else: - c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) - else: - if thread_data.test_sd2: - c = thread_data.model.get_learned_conditioning(prompts) - else: - c = thread_data.modelCS.get_learned_conditioning(prompts) - - if thread_data.reduced_memory and not thread_data.test_sd2: - thread_data.modelFS.to(thread_data.device) - - n_steps = req.num_inference_steps if req.init_image is None else t_enc - img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps}) - - # run the handler - try: - print('Running handler...') - if handler == _txt2img: - x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler) - else: - x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f) - except UserInitiatedStop: - if not hasattr(thread_data, 'partial_x_samples'): - continue - if thread_data.partial_x_samples is None: - del thread_data.partial_x_samples - continue - x_samples = thread_data.partial_x_samples - del thread_data.partial_x_samples - - print("decoding images") - img_data = [None] * batch_size - for i in range(batch_size): - if thread_data.test_sd2: - x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0)) - else: - x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) - x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") - x_sample = x_sample.astype(np.uint8) - img_data[i] = x_sample - del x_samples, x_samples_ddim, x_sample - - print("saving images") - for i in range(batch_size): - img = Image.fromarray(img_data[i]) - img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. - img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. - - has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \ - (req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN')) - - return_orig_img = not has_filters or not req.show_only_filtered_image - - if thread_data.stop_processing: - return_orig_img = True - - if req.save_to_disk_path is not None: - if return_orig_img: - img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format) - save_image(img, img_out_path, req.output_format, req.output_quality) - meta_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, 'txt') - save_metadata(meta_out_path, req, prompts[0], opt_seed) - - if return_orig_img: - img_buffer = img_to_buffer(img, req.output_format, req.output_quality) - img_str = buffer_to_base64_str(img_buffer, req.output_format) - res_image_orig = ResponseImage(data=img_str, seed=opt_seed) - res.images.append(res_image_orig) - task_temp_images[i] = img_buffer - - if req.save_to_disk_path is not None: - res_image_orig.path_abs = img_out_path - del img - - if has_filters and not thread_data.stop_processing: - filters_applied = [] - if req.use_face_correction: - img_data[i] = apply_filters('gfpgan', img_data[i], req.use_face_correction) - filters_applied.append(req.use_face_correction) - if req.use_upscale: - img_data[i] = apply_filters('real_esrgan', img_data[i], req.use_upscale) - filters_applied.append(req.use_upscale) - if (len(filters_applied) > 0): - filtered_image = Image.fromarray(img_data[i]) - filtered_buffer = img_to_buffer(filtered_image, req.output_format, req.output_quality) - filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format) - response_image = ResponseImage(data=filtered_img_data, seed=opt_seed) - res.images.append(response_image) - task_temp_images[i] = filtered_buffer - if req.save_to_disk_path is not None: - filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied)) - save_image(filtered_image, filtered_img_out_path, req.output_format, req.output_quality) - response_image.path_abs = filtered_img_out_path - del filtered_image - # Filter Applied, move to next seed - opt_seed += 1 - - # if thread_data.reduced_memory: - # unload_filters() - if not thread_data.test_sd2: - move_to_cpu(thread_data.modelFS) - del img_data - gc() - if thread_data.device != 'cpu': - print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb') - - print('Task completed') - res = res.json() - data_queue.put(json.dumps(res)) - - return res - -def save_image(img, img_out_path, output_format="", output_quality=75): - try: - if output_format.upper() == "JPEG": - img.save(img_out_path, quality=output_quality) - else: - img.save(img_out_path) - except: - print('could not save the file', traceback.format_exc()) - -def save_metadata(meta_out_path, req, prompt, opt_seed): - metadata = f'''{prompt} -Width: {req.width} -Height: {req.height} -Seed: {opt_seed} -Steps: {req.num_inference_steps} -Guidance Scale: {req.guidance_scale} -Prompt Strength: {req.prompt_strength} -Use Face Correction: {req.use_face_correction} -Use Upscaling: {req.use_upscale} -Sampler: {req.sampler} -Negative Prompt: {req.negative_prompt} -Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'} -VAE model: {req.use_vae_model} -Hypernetwork Model: {req.use_hypernetwork_model} -Hypernetwork Strength: {req.hypernetwork_strength} -''' - try: - with open(meta_out_path, 'w', encoding='utf-8') as f: - f.write(metadata) - except: - print('could not save the file', traceback.format_exc()) - -def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name): - shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] - - # Send to CPU and wait until complete. - # wait_model_move_to(thread_data.modelCS, 'cpu') - - if not thread_data.test_sd2: - move_to_cpu(thread_data.modelCS) - - if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim', 'dpm2'): - raise Exception('Only plms, ddim and dpm2 samplers are supported right now, in SD 2.0') - - - # samples, _ = sampler.sample(S=opt.steps, - # conditioning=c, - # batch_size=opt.n_samples, - # shape=shape, - # verbose=False, - # unconditional_guidance_scale=opt.scale, - # unconditional_conditioning=uc, - # eta=opt.ddim_eta, - # x_T=start_code) - - if thread_data.test_sd2: - if sampler_name == 'plms': - from ldm.models.diffusion.plms import PLMSSampler - sampler = PLMSSampler(thread_data.model) - elif sampler_name == 'ddim': - from ldm.models.diffusion.ddim import DDIMSampler - sampler = DDIMSampler(thread_data.model) - sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) - elif sampler_name == 'dpm2': - from ldm.models.diffusion.dpm_solver import DPMSolverSampler - sampler = DPMSolverSampler(thread_data.model) - - shape = [opt_C, opt_H // opt_f, opt_W // opt_f] - - samples_ddim, intermediates = sampler.sample( - S=opt_ddim_steps, - conditioning=c, - batch_size=opt_n_samples, - seed=opt_seed, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt_scale, - unconditional_conditioning=uc, - eta=opt_ddim_eta, - x_T=start_code, - img_callback=img_callback, - mask=mask, - sampler = sampler_name, - ) - else: - if sampler_name == 'ddim': - thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) - - samples_ddim = thread_data.model.sample( - S=opt_ddim_steps, - conditioning=c, - seed=opt_seed, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt_scale, - unconditional_conditioning=uc, - eta=opt_ddim_eta, - x_T=start_code, - img_callback=img_callback, - mask=mask, - sampler = sampler_name, - ) - return samples_ddim - -def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1): - # encode (scaled latent) - x_T = None if mask is None else init_latent - - if thread_data.test_sd2: - from ldm.models.diffusion.ddim import DDIMSampler - - sampler = DDIMSampler(thread_data.model) - - sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) - - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(thread_data.device)) - - samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt_scale,unconditional_conditioning=uc, img_callback=img_callback) - - else: - z_enc = thread_data.model.stochastic_encode( - init_latent, - torch.tensor([t_enc] * batch_size).to(thread_data.device), - opt_seed, - opt_ddim_eta, - opt_ddim_steps, - ) - - # decode it - samples_ddim = thread_data.model.sample( - t_enc, - c, - z_enc, - unconditional_guidance_scale=opt_scale, - unconditional_conditioning=uc, - img_callback=img_callback, - mask=mask, - x_T=x_T, - sampler = 'ddim' - ) - return samples_ddim - -def gc(): - gc_collect() - if thread_data.device == 'cpu': - return - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - -# internal - -def chunk(it, size): - it = iter(it) - return iter(lambda: tuple(islice(it, size)), ()) - -def load_model_from_config(ckpt, verbose=False): - print(f"Loading model from {ckpt}") - model_ver = 'sd1' - - if ckpt.endswith(".safetensors"): - print("Loading from safetensors") - pl_sd = load_file(ckpt, device="cpu") - else: - pl_sd = torch.load(ckpt, map_location="cpu") - - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - - if "state_dict" in pl_sd: - # check for a key that only seems to be present in SD2 models - if 'cond_stage_model.model.ln_final.bias' in pl_sd['state_dict'].keys(): - model_ver = 'sd2' - - return pl_sd["state_dict"], model_ver - else: - return pl_sd, model_ver - -class UserInitiatedStop(Exception): - pass - -def load_img(img_str, w0, h0): - image = base64_str_to_img(img_str).convert("RGB") - w, h = image.size - print(f"loaded input image of size ({w}, {h}) from base64") - if h0 is not None and w0 is not None: - h, w = h0, w0 - - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 - image = image.resize((w, h), resample=Image.Resampling.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.*image - 1. - -def load_mask(mask_str, h0, w0, newH, newW, invert=False): - image = base64_str_to_img(mask_str).convert("RGB") - w, h = image.size - print(f"loaded input mask of size ({w}, {h})") - - if invert: - print("inverted") - image = ImageOps.invert(image) - # where_0, where_1 = np.where(image == 0), np.where(image == 255) - # image[where_0], image[where_1] = 255, 0 - - if h0 is not None and w0 is not None: - h, w = h0, w0 - - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 - - print(f"New mask size ({w}, {h})") - image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS) - image = np.array(image) - - image = image.astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return image - -# https://stackoverflow.com/a/61114178 -def img_to_base64_str(img, output_format="PNG", output_quality=75): - buffered = img_to_buffer(img, output_format, quality=output_quality) - return buffer_to_base64_str(buffered, output_format) - -def img_to_buffer(img, output_format="PNG", output_quality=75): - buffered = BytesIO() - if ( output_format.upper() == "JPEG" ): - img.save(buffered, format=output_format, quality=output_quality) - else: - img.save(buffered, format=output_format) - buffered.seek(0) - return buffered - -def buffer_to_base64_str(buffered, output_format="PNG"): - buffered.seek(0) - img_byte = buffered.getvalue() - mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg" - img_str = f"data:{mime_type};base64," + base64.b64encode(img_byte).decode() - return img_str - -def base64_str_to_buffer(img_str): - mime_type = "image/png" if img_str.startswith("data:image/png;") else "image/jpeg" - img_str = img_str[len(f"data:{mime_type};base64,"):] - data = base64.b64decode(img_str) - buffered = BytesIO(data) - return buffered - -def base64_str_to_img(img_str): - buffered = base64_str_to_buffer(img_str) - img = Image.open(buffered) - return img - -def split_weighted_subprompts(text): - """ - grabs all text up to the first occurrence of ':' - uses the grabbed text as a sub-prompt, and takes the value following ':' as weight - if ':' has no value defined, defaults to 1.0 - repeats until no text remaining - """ - remaining = len(text) - prompts = [] - weights = [] - while remaining > 0: - if ":" in text: - idx = text.index(":") # first occurrence from start - # grab up to index as sub-prompt - prompt = text[:idx] - remaining -= idx - # remove from main text - text = text[idx+1:] - # find value for weight - if " " in text: - idx = text.index(" ") # first occurence - else: # no space, read to end - idx = len(text) - if idx != 0: - try: - weight = float(text[:idx]) - except: # couldn't treat as float - print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?") - weight = 1.0 - else: # no value found - weight = 1.0 - # remove from main text - remaining -= idx - text = text[idx+1:] - # append the sub-prompt and its weight - prompts.append(prompt) - weights.append(weight) - else: # no : found - if len(text) > 0: # there is still text though - # take remainder as weight 1 - prompts.append(text) - weights.append(1.0) - remaining = 0 - return prompts, weights diff --git a/ui/sd_internal/save_utils.py b/ui/sd_internal/save_utils.py index 29c4a29c..db9d0671 100644 --- a/ui/sd_internal/save_utils.py +++ b/ui/sd_internal/save_utils.py @@ -3,8 +3,8 @@ import time import base64 import re -from diffusionkit import data_utils -from diffusionkit.types import GenerateImageRequest +from sdkit.utils import save_images, save_dicts +from sdkit.types import GenerateImageRequest from sd_internal import TaskData @@ -33,12 +33,12 @@ def save_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, metadata_entries = get_metadata_entries(req, task_data) if task_data.show_only_filtered_image or filtered_images == images: - data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format) + save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + save_dicts(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format) else: - data_utils.save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) + save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) + save_dicts(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): metadata = get_printable_request(req) diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 094d853c..f9080a74 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -15,7 +15,7 @@ import queue, threading, time, weakref from typing import Any, Hashable from sd_internal import TaskData, device_manager -from diffusionkit.types import GenerateImageRequest +from sdkit.types import GenerateImageRequest log = logging.getLogger() diff --git a/ui/server.py b/ui/server.py index 284fed1b..43a11347 100644 --- a/ui/server.py +++ b/ui/server.py @@ -15,7 +15,7 @@ from pydantic import BaseModel from sd_internal import app, model_manager, task_manager from sd_internal import TaskData -from diffusionkit.types import GenerateImageRequest +from sdkit.types import GenerateImageRequest log = logging.getLogger() From 47e3884994aeca791c0421c47ab5d02ddc5d450f Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Mon, 19 Dec 2022 19:39:15 +0530 Subject: [PATCH 53/74] Rename the python package name to easydiffusion (from sd_internal) --- ui/easydiffusion/__init__.py | 0 ui/{sd_internal => easydiffusion}/app.py | 5 ++-- .../device_manager.py | 3 +-- .../model_manager.py | 8 +++---- ui/{sd_internal => easydiffusion}/renderer.py | 23 +++++++++---------- .../task_manager.py | 12 +++++----- .../__init__.py => easydiffusion/types.py} | 0 ui/easydiffusion/utils/__init__.py | 8 +++++++ .../utils}/save_utils.py | 10 ++++---- ui/server.py | 8 +++---- 10 files changed, 40 insertions(+), 37 deletions(-) create mode 100644 ui/easydiffusion/__init__.py rename ui/{sd_internal => easydiffusion}/app.py (98%) rename ui/{sd_internal => easydiffusion}/device_manager.py (99%) rename ui/{sd_internal => easydiffusion}/model_manager.py (98%) rename ui/{sd_internal => easydiffusion}/renderer.py (84%) rename ui/{sd_internal => easydiffusion}/task_manager.py (98%) rename ui/{sd_internal/__init__.py => easydiffusion/types.py} (100%) create mode 100644 ui/easydiffusion/utils/__init__.py rename ui/{sd_internal => easydiffusion/utils}/save_utils.py (91%) diff --git a/ui/easydiffusion/__init__.py b/ui/easydiffusion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ui/sd_internal/app.py b/ui/easydiffusion/app.py similarity index 98% rename from ui/sd_internal/app.py rename to ui/easydiffusion/app.py index 471ee074..c0af06a3 100644 --- a/ui/sd_internal/app.py +++ b/ui/easydiffusion/app.py @@ -6,7 +6,8 @@ import traceback import logging from rich.logging import RichHandler -from sd_internal import task_manager +from easydiffusion import task_manager +from easydiffusion.utils import log LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s' logging.basicConfig( @@ -16,8 +17,6 @@ logging.basicConfig( handlers=[RichHandler(markup=True, rich_tracebacks=True, show_time=False, show_level=False)] ) -log = logging.getLogger() - SD_DIR = os.getcwd() SD_UI_DIR = os.getenv('SD_UI_PATH', None) diff --git a/ui/sd_internal/device_manager.py b/ui/easydiffusion/device_manager.py similarity index 99% rename from ui/sd_internal/device_manager.py rename to ui/easydiffusion/device_manager.py index 56508fad..0f2ab850 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/easydiffusion/device_manager.py @@ -2,9 +2,8 @@ import os import torch import traceback import re -import logging -log = logging.getLogger() +from easydiffusion.utils import log ''' Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32). diff --git a/ui/sd_internal/model_manager.py b/ui/easydiffusion/model_manager.py similarity index 98% rename from ui/sd_internal/model_manager.py rename to ui/easydiffusion/model_manager.py index ec72e6fa..eb8aa7fd 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -1,13 +1,13 @@ import os -import logging import picklescan.scanner -from sd_internal import app, TaskData, device_manager +from easydiffusion import app, device_manager +from easydiffusion.types import TaskData +from easydiffusion.utils import log + from sdkit.models import model_loader from sdkit.types import Context -log = logging.getLogger() - KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] MODEL_EXTENSIONS = { 'stable-diffusion': ['.ckpt', '.safetensors'], diff --git a/ui/sd_internal/renderer.py b/ui/easydiffusion/renderer.py similarity index 84% rename from ui/sd_internal/renderer.py rename to ui/easydiffusion/renderer.py index 1d289c99..610e14fa 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/easydiffusion/renderer.py @@ -1,16 +1,15 @@ import queue import time import json -import logging -from sd_internal import device_manager, save_utils -from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop +from easydiffusion import device_manager +from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop +from easydiffusion.utils import get_printable_request, save_images_to_disk, log -from sdkit import model_loader, image_generator, image_utils, filters as image_filters +from sdkit import model_loader, image_generator, filters as image_filters +from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images from sdkit.types import Context, GenerateImageRequest, FilterImageRequest -log = logging.getLogger() - context = Context() # thread-local ''' runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc @@ -28,7 +27,7 @@ def init(device): def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): context.stop_processing = False - log.info(f'request: {save_utils.get_printable_request(req)}') + log.info(f'request: {get_printable_request(req)}') log.info(f'task data: {task_data.dict()}') images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) @@ -45,7 +44,7 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q filtered_images = apply_filters(task_data, images, user_stopped) if task_data.save_to_disk_path is not None: - save_utils.save_to_disk(images, filtered_images, req, task_data) + save_images_to_disk(images, filtered_images, req, task_data) return filtered_images if task_data.show_only_filtered_image else images + filtered_images @@ -61,7 +60,7 @@ def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue: images = [] user_stopped = True if context.partial_x_samples is not None: - images = image_utils.latent_samples_to_images(context, context.partial_x_samples) + images = latent_samples_to_images(context, context.partial_x_samples) context.partial_x_samples = None finally: model_loader.gc(context) @@ -89,7 +88,7 @@ def apply_filters(task_data: TaskData, images: list, user_stopped): def construct_response(images: list, task_data: TaskData, base_seed: int): return [ ResponseImage( - data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality), + data=img_to_base64_str(img, task_data.output_format, task_data.output_quality), seed=base_seed + i ) for i, img in enumerate(images) ] @@ -100,9 +99,9 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu def update_temp_img(x_samples, task_temp_images: list): partial_images = [] - images = image_utils.latent_samples_to_images(context, x_samples) + images = latent_samples_to_images(context, x_samples) for i, img in enumerate(images): - buf = image_utils.img_to_buffer(img, output_format='JPEG') + buf = img_to_buffer(img, output_format='JPEG') context.temp_images[f"{task_data.request_id}/{i}"] = buf task_temp_images[i] = buf diff --git a/ui/sd_internal/task_manager.py b/ui/easydiffusion/task_manager.py similarity index 98% rename from ui/sd_internal/task_manager.py rename to ui/easydiffusion/task_manager.py index f9080a74..19638715 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -6,7 +6,6 @@ Notes: """ import json import traceback -import logging TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout @@ -14,10 +13,11 @@ import torch import queue, threading, time, weakref from typing import Any, Hashable -from sd_internal import TaskData, device_manager -from sdkit.types import GenerateImageRequest +from easydiffusion import device_manager +from easydiffusion.types import TaskData +from easydiffusion.utils import log -log = logging.getLogger() +from sdkit.types import GenerateImageRequest THREAD_NAME_PREFIX = '' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' @@ -186,7 +186,7 @@ class SessionState(): return True def thread_get_next_task(): - from sd_internal import renderer + from easydiffusion import renderer if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.') return None @@ -219,7 +219,7 @@ def thread_get_next_task(): def thread_render(device): global current_state, current_state_error - from sd_internal import renderer, model_manager + from easydiffusion import renderer, model_manager try: renderer.init(device) diff --git a/ui/sd_internal/__init__.py b/ui/easydiffusion/types.py similarity index 100% rename from ui/sd_internal/__init__.py rename to ui/easydiffusion/types.py diff --git a/ui/easydiffusion/utils/__init__.py b/ui/easydiffusion/utils/__init__.py new file mode 100644 index 00000000..8be070b4 --- /dev/null +++ b/ui/easydiffusion/utils/__init__.py @@ -0,0 +1,8 @@ +import logging + +log = logging.getLogger('easydiffusion') + +from .save_utils import ( + save_images_to_disk, + get_printable_request, +) \ No newline at end of file diff --git a/ui/sd_internal/save_utils.py b/ui/easydiffusion/utils/save_utils.py similarity index 91% rename from ui/sd_internal/save_utils.py rename to ui/easydiffusion/utils/save_utils.py index db9d0671..bb1d09c9 100644 --- a/ui/sd_internal/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -3,11 +3,11 @@ import time import base64 import re +from easydiffusion.types import TaskData + from sdkit.utils import save_images, save_dicts from sdkit.types import GenerateImageRequest -from sd_internal import TaskData - filename_regex = re.compile('[^a-zA-Z0-9]') # keep in sync with `ui/media/js/dnd.js` @@ -28,9 +28,9 @@ TASK_TEXT_MAPPING = { 'hypernetwork_strength': 'Hypernetwork Strength' } -def save_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): +def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) - metadata_entries = get_metadata_entries(req, task_data) + metadata_entries = get_metadata_entries_for_request(req, task_data) if task_data.show_only_filtered_image or filtered_images == images: save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) @@ -40,7 +40,7 @@ def save_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) save_dicts(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) -def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): +def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData): metadata = get_printable_request(req) metadata.update({ 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, diff --git a/ui/server.py b/ui/server.py index 43a11347..8ee34fec 100644 --- a/ui/server.py +++ b/ui/server.py @@ -4,7 +4,6 @@ Notes: """ import os import traceback -import logging import datetime from typing import List, Union @@ -13,12 +12,11 @@ from fastapi.staticfiles import StaticFiles from starlette.responses import FileResponse, JSONResponse, StreamingResponse from pydantic import BaseModel -from sd_internal import app, model_manager, task_manager -from sd_internal import TaskData +from easydiffusion import app, model_manager, task_manager +from easydiffusion.types import TaskData +from easydiffusion.utils import log from sdkit.types import GenerateImageRequest -log = logging.getLogger() - log.info(f'started in {app.SD_DIR}') log.info(f'started at {datetime.datetime.now():%x %X}') From 5eeef41d8cabcdc485e8a1095198c07f7ac92632 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Tue, 20 Dec 2022 15:16:47 +0530 Subject: [PATCH 54/74] Update to use the latest sdkit API --- ui/easydiffusion/model_manager.py | 18 ++++++------ ui/easydiffusion/renderer.py | 43 ++++++++++++---------------- ui/easydiffusion/task_manager.py | 4 +-- ui/easydiffusion/types.py | 21 +++++++++++++- ui/easydiffusion/utils/save_utils.py | 3 +- ui/server.py | 3 +- 6 files changed, 50 insertions(+), 42 deletions(-) diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index eb8aa7fd..4c9fb1d1 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -5,14 +5,14 @@ from easydiffusion import app, device_manager from easydiffusion.types import TaskData from easydiffusion.utils import log -from sdkit.models import model_loader -from sdkit.types import Context +from sdkit import Context +from sdkit.models import load_model, unload_model KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] MODEL_EXTENSIONS = { 'stable-diffusion': ['.ckpt', '.safetensors'], - 'vae': ['.vae.pt', '.ckpt'], - 'hypernetwork': ['.pt'], + 'vae': ['.vae.pt', '.ckpt', '.safetensors'], + 'hypernetwork': ['.pt', '.safetensors'], 'gfpgan': ['.pth'], 'realesrgan': ['.pth'], } @@ -44,13 +44,13 @@ def load_default_models(context: Context): set_vram_optimizations(context) # load mandatory models - model_loader.load_model(context, 'stable-diffusion') - model_loader.load_model(context, 'vae') - model_loader.load_model(context, 'hypernetwork') + load_model(context, 'stable-diffusion') + load_model(context, 'vae') + load_model(context, 'hypernetwork') def unload_all(context: Context): for model_type in KNOWN_MODEL_TYPES: - model_loader.unload_model(context, model_type) + unload_model(context, model_type) def resolve_model_to_use(model_name:str=None, model_type:str=None): model_extensions = MODEL_EXTENSIONS.get(model_type, []) @@ -107,7 +107,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData): for model_type, model_path_in_req in models_to_reload.items(): context.model_paths[model_type] = model_path_in_req - action_fn = model_loader.unload_model if context.model_paths[model_type] is None else model_loader.load_model + action_fn = unload_model if context.model_paths[model_type] is None else load_model action_fn(context, model_type) def resolve_model_paths(task_data: TaskData): diff --git a/ui/easydiffusion/renderer.py b/ui/easydiffusion/renderer.py index 610e14fa..4b9847c0 100644 --- a/ui/easydiffusion/renderer.py +++ b/ui/easydiffusion/renderer.py @@ -3,12 +3,13 @@ import time import json from easydiffusion import device_manager -from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop +from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop, GenerateImageRequest from easydiffusion.utils import get_printable_request, save_images_to_disk, log -from sdkit import model_loader, image_generator, filters as image_filters -from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images -from sdkit.types import Context, GenerateImageRequest, FilterImageRequest +from sdkit import Context +from sdkit.generate import generate_images +from sdkit.filter import apply_filters +from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc context = Context() # thread-local ''' @@ -30,7 +31,7 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu log.info(f'request: {get_printable_request(req)}') log.info(f'task data: {task_data.dict()}') - images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) + images = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed)) res = res.json() @@ -39,22 +40,22 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu return res -def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): - images, user_stopped = generate_images(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) - filtered_images = apply_filters(task_data, images, user_stopped) +def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): + images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) + filtered_images = filter_images(task_data, images, user_stopped) if task_data.save_to_disk_path is not None: save_images_to_disk(images, filtered_images, req, task_data) return filtered_images if task_data.show_only_filtered_image else images + filtered_images -def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): +def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): context.temp_images.clear() - image_generator.on_image_step = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress) + callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress) try: - images = image_generator.make_images(context=context, req=req) + images = generate_images(context, callback=callback, **req.dict()) user_stopped = False except UserInitiatedStop: images = [] @@ -63,27 +64,19 @@ def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue: images = latent_samples_to_images(context, context.partial_x_samples) context.partial_x_samples = None finally: - model_loader.gc(context) + gc(context) return images, user_stopped -def apply_filters(task_data: TaskData, images: list, user_stopped): +def filter_images(task_data: TaskData, images: list, user_stopped): if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None): return images - filters = [] - if 'gfpgan' in task_data.use_face_correction.lower(): filters.append('gfpgan') - if 'realesrgan' in task_data.use_face_correction.lower(): filters.append('realesrgan') + filters_to_apply = [] + if 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan') + if 'realesrgan' in task_data.use_face_correction.lower(): filters_to_apply.append('realesrgan') - filtered_images = [] - for img in images: - filter_req = FilterImageRequest() - filter_req.init_image = img - - filtered_image = image_filters.apply(context, filters, filter_req) - filtered_images.append(filtered_image) - - return filtered_images + return apply_filters(context, filters_to_apply, images) def construct_response(images: list, task_data: TaskData, base_seed: int): return [ diff --git a/ui/easydiffusion/task_manager.py b/ui/easydiffusion/task_manager.py index 19638715..3a764137 100644 --- a/ui/easydiffusion/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -14,11 +14,9 @@ import queue, threading, time, weakref from typing import Any, Hashable from easydiffusion import device_manager -from easydiffusion.types import TaskData +from easydiffusion.types import TaskData, GenerateImageRequest from easydiffusion.utils import log -from sdkit.types import GenerateImageRequest - THREAD_NAME_PREFIX = '' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index 3a748431..2a10d521 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -1,6 +1,25 @@ from pydantic import BaseModel +from typing import Any -from sdkit.types import GenerateImageRequest +class GenerateImageRequest(BaseModel): + prompt: str = "" + negative_prompt: str = "" + + seed: int = 42 + width: int = 512 + height: int = 512 + + num_outputs: int = 1 + num_inference_steps: int = 50 + guidance_scale: float = 7.5 + + init_image: Any = None + init_image_mask: Any = None + prompt_strength: float = 0.8 + preserve_init_image_color_profile = False + + sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" + hypernetwork_strength: float = 0 class TaskData(BaseModel): request_id: str = None diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index bb1d09c9..d7fa82a3 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -3,10 +3,9 @@ import time import base64 import re -from easydiffusion.types import TaskData +from easydiffusion.types import TaskData, GenerateImageRequest from sdkit.utils import save_images, save_dicts -from sdkit.types import GenerateImageRequest filename_regex = re.compile('[^a-zA-Z0-9]') diff --git a/ui/server.py b/ui/server.py index 8ee34fec..8d9a129e 100644 --- a/ui/server.py +++ b/ui/server.py @@ -13,9 +13,8 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse from pydantic import BaseModel from easydiffusion import app, model_manager, task_manager -from easydiffusion.types import TaskData +from easydiffusion.types import TaskData, GenerateImageRequest from easydiffusion.utils import log -from sdkit.types import GenerateImageRequest log.info(f'started in {app.SD_DIR}') log.info(f'started at {datetime.datetime.now():%x %X}') From c804a9971ea5d8d00747ccb3d84ad712ad642ce3 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 22 Dec 2022 11:54:00 +0530 Subject: [PATCH 55/74] Work-in-progress code for adding a model config dropdown in the UI. Doesn't work yet --- ui/easydiffusion/model_manager.py | 7 ++++++- ui/easydiffusion/types.py | 1 + ui/easydiffusion/utils/save_utils.py | 12 ++++++------ ui/index.html | 4 ++++ 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 4c9fb1d1..7bc2b3af 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -6,7 +6,8 @@ from easydiffusion.types import TaskData from easydiffusion.utils import log from sdkit import Context -from sdkit.models import load_model, unload_model +from sdkit.models import load_model, unload_model, get_known_model_info +from sdkit.utils import hash_file_quick KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] MODEL_EXTENSIONS = { @@ -104,6 +105,10 @@ def reload_models_if_necessary(context: Context, task_data: TaskData): if set_vram_optimizations(context): # reload SD models_to_reload['stable-diffusion'] = model_paths_in_req['stable-diffusion'] + if 'stable-diffusion' in models_to_reload: + quick_hash = hash_file_quick(models_to_reload['stable-diffusion']) + known_model_info = get_known_model_info(quick_hash=quick_hash) + for model_type, model_path_in_req in models_to_reload.items(): context.model_paths[model_type] = model_path_in_req diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index 2a10d521..805c8683 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -30,6 +30,7 @@ class TaskData(BaseModel): use_face_correction: str = None # or "GFPGANv1.3" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_stable_diffusion_model: str = "sd-v1-4" + use_stable_diffusion_config: str = "v1-inference" use_vae_model: str = None use_hypernetwork_model: str = None diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index d7fa82a3..b9ce8aba 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -28,16 +28,16 @@ TASK_TEXT_MAPPING = { } def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): - save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) + save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) metadata_entries = get_metadata_entries_for_request(req, task_data) if task_data.show_only_filtered_image or filtered_images == images: - save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) - save_dicts(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format) + save_images(filtered_images, save_dir_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + save_dicts(metadata_entries, save_dir_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format) else: - save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) - save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) - save_dicts(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) + save_images(images, save_dir_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + save_images(filtered_images, save_dir_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) + save_dicts(metadata_entries, save_dir_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData): metadata = get_printable_request(req) diff --git a/ui/index.html b/ui/index.html index 9f36033f..5d12d419 100644 --- a/ui/index.html +++ b/ui/index.html @@ -125,6 +125,10 @@ Click to learn more about custom models + + + Click to learn more about custom models - +