diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index a2abe294..9f71f249 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -40,7 +40,6 @@ class Request: "num_outputs": self.num_outputs, "num_inference_steps": self.num_inference_steps, "guidance_scale": self.guidance_scale, - "hypernetwork_strengtgh": self.guidance_scale, "width": self.width, "height": self.height, "seed": self.seed, diff --git a/ui/sd_internal/app.py b/ui/sd_internal/app.py new file mode 100644 index 00000000..09242319 --- /dev/null +++ b/ui/sd_internal/app.py @@ -0,0 +1,156 @@ +import os +import socket +import sys +import json +import traceback + +from sd_internal import task_manager + +SD_DIR = os.getcwd() + +SD_UI_DIR = os.getenv('SD_UI_PATH', None) +sys.path.append(os.path.dirname(SD_UI_DIR)) + +CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) +MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) + +USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui')) +CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui')) +UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user')) + +STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors'] +VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] +HYPERNETWORK_MODEL_EXTENSIONS = ['.pt'] + +OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder +TASK_TTL = 15 * 60 # Discard last session's task timeout +APP_CONFIG_DEFAULTS = { + # auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device. + 'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index) + 'update_branch': 'main', + 'ui': { + 'open_browser_on_start': True, + }, +} +DEFAULT_MODELS = [ + # needed to support the legacy installations + 'custom-model', # Check if user has a custom model, use it first. + 'sd-v1-4', # Default fallback. +] + +def init(): + os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) + + update_render_threads() + +def getConfig(default_val=APP_CONFIG_DEFAULTS): + try: + config_json_path = os.path.join(CONFIG_DIR, 'config.json') + if not os.path.exists(config_json_path): + return default_val + with open(config_json_path, 'r', encoding='utf-8') as f: + config = json.load(f) + if 'net' not in config: + config['net'] = {} + if os.getenv('SD_UI_BIND_PORT') is not None: + config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT')) + if os.getenv('SD_UI_BIND_IP') is not None: + config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0') + return config + except Exception as e: + print(str(e)) + print(traceback.format_exc()) + return default_val + +def setConfig(config): + try: # config.json + config_json_path = os.path.join(CONFIG_DIR, 'config.json') + with open(config_json_path, 'w', encoding='utf-8') as f: + json.dump(config, f) + except: + print(traceback.format_exc()) + + try: # config.bat + config_bat_path = os.path.join(CONFIG_DIR, 'config.bat') + config_bat = [] + + if 'update_branch' in config: + config_bat.append(f"@set update_branch={config['update_branch']}") + + config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}") + bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' + config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") + + config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}") + + if len(config_bat) > 0: + with open(config_bat_path, 'w', encoding='utf-8') as f: + f.write('\r\n'.join(config_bat)) + except: + print(traceback.format_exc()) + + try: # config.sh + config_sh_path = os.path.join(CONFIG_DIR, 'config.sh') + config_sh = ['#!/bin/bash'] + + if 'update_branch' in config: + config_sh.append(f"export update_branch={config['update_branch']}") + + config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}") + bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' + config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") + + config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"") + + if len(config_sh) > 1: + with open(config_sh_path, 'w', encoding='utf-8') as f: + f.write('\n'.join(config_sh)) + except: + print(traceback.format_exc()) + +def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name): + config = getConfig() + if 'model' not in config: + config['model'] = {} + + config['model']['stable-diffusion'] = ckpt_model_name + config['model']['vae'] = vae_model_name + config['model']['hypernetwork'] = hypernetwork_model_name + + if vae_model_name is None or vae_model_name == "": + del config['model']['vae'] + if hypernetwork_model_name is None or hypernetwork_model_name == "": + del config['model']['hypernetwork'] + + setConfig(config) + +def update_render_threads(): + config = getConfig() + render_devices = config.get('render_devices', 'auto') + active_devices = task_manager.get_devices()['active'].keys() + + print('requesting for render_devices', render_devices) + task_manager.update_render_threads(render_devices, active_devices) + +def getUIPlugins(): + plugins = [] + + for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES: + for file in os.listdir(plugins_dir): + if file.endswith('.plugin.js'): + plugins.append(f'/plugins/{dir_prefix}/{file}') + + return plugins + +def getIPConfig(): + ips = socket.gethostbyname_ex(socket.gethostname()) + ips[2].append(ips[0]) + return ips[2] + +def open_browser(): + config = getConfig() + ui = config.get('ui', {}) + net = config.get('net', {'listen_port':9000}) + port = net.get('listen_port', 9000) + if ui.get('open_browser_on_start', True): + import webbrowser; webbrowser.open(f"http://localhost:{port}") diff --git a/ui/sd_internal/device_manager.py b/ui/sd_internal/device_manager.py index d2c6430b..3c1d8bf1 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/sd_internal/device_manager.py @@ -82,7 +82,7 @@ def auto_pick_devices(currently_active_devices): devices = list(map(lambda x: x['device'], devices)) return devices -def device_init(thread_data, device): +def device_init(context, device): ''' This function assumes the 'device' has already been verified to be compatible. `get_device_delta()` has already filtered out incompatible devices. @@ -91,21 +91,22 @@ def device_init(thread_data, device): validate_device_id(device, log_prefix='device_init') if device == 'cpu': - thread_data.device = 'cpu' - thread_data.device_name = get_processor_name() - print('Render device CPU available as', thread_data.device_name) + context.device = 'cpu' + context.device_name = get_processor_name() + context.precision = 'full' + print('Render device CPU available as', context.device_name) return - thread_data.device_name = torch.cuda.get_device_name(device) - thread_data.device = device + context.device_name = torch.cuda.get_device_name(device) + context.device = device # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images - device_name = thread_data.device_name.lower() - thread_data.force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) - if thread_data.force_full_precision: - print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', thread_data.device_name) + device_name = context.device_name.lower() + force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) + if force_full_precision: + print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', context.device_name) # Apply force_full_precision now before models are loaded. - thread_data.precision = 'full' + context.precision = 'full' print(f'Setting {device} as active') torch.cuda.device(device) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py new file mode 100644 index 00000000..3941b130 --- /dev/null +++ b/ui/sd_internal/model_manager.py @@ -0,0 +1,141 @@ +import os + +from sd_internal import app +import picklescan.scanner +import rich + +default_model_to_load = None +default_vae_to_load = None +default_hypernetwork_to_load = None + +known_models = {} + +def init(): + global default_model_to_load, default_vae_to_load, default_hypernetwork_to_load + + default_model_to_load = resolve_ckpt_to_use() + default_vae_to_load = resolve_vae_to_use() + default_hypernetwork_to_load = resolve_hypernetwork_to_use() + + getModels() # run this once, to cache the picklescan results + +def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]): + config = app.getConfig() + + model_dirs = [os.path.join(app.MODELS_DIR, model_dir), app.SD_DIR] + if not model_name: # When None try user configured model. + # config = getConfig() + if 'model' in config and model_type in config['model']: + model_name = config['model'][model_type] + + if model_name: + is_sd2 = config.get('test_sd2', False) + if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4 + print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!') + model_name = 'sd-v1-4' + + # Check models directory + models_dir_path = os.path.join(app.MODELS_DIR, model_dir, model_name) + for model_extension in model_extensions: + if os.path.exists(models_dir_path + model_extension): + return models_dir_path + model_extension + if os.path.exists(model_name + model_extension): + return os.path.abspath(model_name + model_extension) + + # Default locations + if model_name in default_models: + default_model_path = os.path.join(app.SD_DIR, model_name) + for model_extension in model_extensions: + if os.path.exists(default_model_path + model_extension): + return default_model_path + model_extension + + # Can't find requested model, check the default paths. + for default_model in default_models: + for model_dir in model_dirs: + default_model_path = os.path.join(model_dir, default_model) + for model_extension in model_extensions: + if os.path.exists(default_model_path + model_extension): + if model_name is not None: + print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') + return default_model_path + model_extension + + raise Exception('No valid models found.') + +def resolve_ckpt_to_use(model_name:str=None): + return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS) + +def resolve_vae_to_use(model_name:str=None): + try: + return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=app.VAE_MODEL_EXTENSIONS, default_models=[]) + except: + return None + +def resolve_hypernetwork_to_use(model_name:str=None): + try: + return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS, default_models=[]) + except: + return None + +def is_malicious_model(file_path): + try: + scan_result = picklescan.scanner.scan_file_path(file_path) + if scan_result.issues_count > 0 or scan_result.infected_files > 0: + rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + return True + else: + rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + return False + except Exception as e: + print('error while scanning', file_path, 'error:', e) + return False + +def getModels(): + models = { + 'active': { + 'stable-diffusion': 'sd-v1-4', + 'vae': '', + 'hypernetwork': '', + }, + 'options': { + 'stable-diffusion': ['sd-v1-4'], + 'vae': [], + 'hypernetwork': [], + }, + } + + def listModels(models_dirname, model_type, model_extensions): + models_dir = os.path.join(app.MODELS_DIR, models_dirname) + if not os.path.exists(models_dir): + os.makedirs(models_dir) + + for file in os.listdir(models_dir): + for model_extension in model_extensions: + if not file.endswith(model_extension): + continue + + model_path = os.path.join(models_dir, file) + mtime = os.path.getmtime(model_path) + mod_time = known_models[model_path] if model_path in known_models else -1 + if mod_time != mtime: + if is_malicious_model(model_path): + models['scan-error'] = file + return + known_models[model_path] = mtime + + model_name = file[:-len(model_extension)] + models['options'][model_type].append(model_name) + + models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates + models['options'][model_type].sort() + + # custom models + listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS) + listModels(models_dirname='vae', model_type='vae', model_extensions=app.VAE_MODEL_EXTENSIONS) + listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS) + + # legacy + custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') + if os.path.exists(custom_weight_path): + models['options']['stable-diffusion'].append('custom-model') + + return models diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py new file mode 100644 index 00000000..32befcde --- /dev/null +++ b/ui/sd_internal/runtime2.py @@ -0,0 +1,95 @@ +import threading +import queue + +from sd_internal import device_manager, Request, Response, Image as ResponseImage + +from modules import model_loader, image_generator, image_utils + +thread_data = threading.local() +''' +runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc +''' + +def init(device): + ''' + Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device + ''' + thread_data.stop_processing = False + thread_data.temp_images = {} + + thread_data.models = {} + thread_data.loaded_model_paths = {} + + thread_data.device = None + thread_data.device_name = None + thread_data.precision = 'autocast' + thread_data.vram_optimizations = ('TURBO', 'MOVE_MODELS') + + device_manager.device_init(thread_data, device) + + reload_models() + +def destroy(): + model_loader.unload_sd_model(thread_data) + model_loader.unload_gfpgan_model(thread_data) + model_loader.unload_realesrgan_model(thread_data) + +def reload_models(req: Request=None): + if is_hypernetwork_reload_necessary(task.request): + current_state = ServerStates.LoadingModel + runtime.reload_hypernetwork() + + if is_model_reload_necessary(task.request): + current_state = ServerStates.LoadingModel + runtime.reload_model() + +def load_models(): + if ckpt_file_path == None: + ckpt_file_path = default_model_to_load + if vae_file_path == None: + vae_file_path = default_vae_to_load + if hypernetwork_file_path == None: + hypernetwork_file_path = default_hypernetwork_to_load + if ckpt_file_path == current_model_path and vae_file_path == current_vae_path: + return + current_state = ServerStates.LoadingModel + try: + from sd_internal import runtime2 + runtime.thread_data.hypernetwork_file = hypernetwork_file_path + runtime.thread_data.ckpt_file = ckpt_file_path + runtime.thread_data.vae_file = vae_file_path + runtime.load_model_ckpt() + runtime.load_hypernetwork() + current_model_path = ckpt_file_path + current_vae_path = vae_file_path + current_hypernetwork_path = hypernetwork_file_path + current_state_error = None + current_state = ServerStates.Online + except Exception as e: + current_model_path = None + current_vae_path = None + current_state_error = e + current_state = ServerStates.Unavailable + print(traceback.format_exc()) + +def make_image(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): + try: + images = image_generator.make_image(context=thread_data, args=get_mk_img_args(req)) + except UserInitiatedStop: + pass + +def get_mk_img_args(req: Request): + args = req.json() + + if req.init_image is not None: + args['init_image'] = image_utils.base64_str_to_img(req.init_image) + + if req.mask is not None: + args['mask'] = image_utils.base64_str_to_img(req.mask) + + return args + +def on_image_step(x_samples, i): + pass + +image_generator.on_image_step = on_image_step diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 41fc00f6..4c2b5c6c 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -177,50 +177,14 @@ manager_lock = threading.RLock() render_threads = [] current_state = ServerStates.Init current_state_error:Exception = None -current_model_path = None -current_vae_path = None -current_hypernetwork_path = None tasks_queue = [] task_cache = TaskCache() -default_model_to_load = None -default_vae_to_load = None -default_hypernetwork_to_load = None weak_thread_data = weakref.WeakKeyDictionary() -def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None): - global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path - if ckpt_file_path == None: - ckpt_file_path = default_model_to_load - if vae_file_path == None: - vae_file_path = default_vae_to_load - if hypernetwork_file_path == None: - hypernetwork_file_path = default_hypernetwork_to_load - if ckpt_file_path == current_model_path and vae_file_path == current_vae_path: - return - current_state = ServerStates.LoadingModel - try: - from . import runtime - runtime.thread_data.hypernetwork_file = hypernetwork_file_path - runtime.thread_data.ckpt_file = ckpt_file_path - runtime.thread_data.vae_file = vae_file_path - runtime.load_model_ckpt() - runtime.load_hypernetwork() - current_model_path = ckpt_file_path - current_vae_path = vae_file_path - current_hypernetwork_path = hypernetwork_file_path - current_state_error = None - current_state = ServerStates.Online - except Exception as e: - current_model_path = None - current_vae_path = None - current_state_error = e - current_state = ServerStates.Unavailable - print(traceback.format_exc()) - def thread_get_next_task(): - from . import runtime + from sd_internal import runtime2 if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): - print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.') + print('Render thread on device', runtime2.thread_data.device, 'failed to acquire manager lock.') return None if len(tasks_queue) <= 0: manager_lock.release() @@ -228,7 +192,7 @@ def thread_get_next_task(): task = None try: # Select a render task. for queued_task in tasks_queue: - if queued_task.render_device and runtime.thread_data.device != queued_task.render_device: + if queued_task.render_device and runtime2.thread_data.device != queued_task.render_device: # Is asking for a specific render device. if is_alive(queued_task.render_device) > 0: continue # requested device alive, skip current one. @@ -237,7 +201,7 @@ def thread_get_next_task(): queued_task.error = Exception(queued_task.render_device + ' is not currently active.') task = queued_task break - if not queued_task.render_device and runtime.thread_data.device == 'cpu' and is_alive() > 1: + if not queued_task.render_device and runtime2.thread_data.device == 'cpu' and is_alive() > 1: # not asking for any specific devices, cpu want to grab task but other render devices are alive. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. task = queued_task @@ -249,30 +213,31 @@ def thread_get_next_task(): manager_lock.release() def thread_render(device): - global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path - from . import runtime + global current_state, current_state_error + + from sd_internal import runtime2 try: - runtime.thread_init(device) + runtime2.init(device) except Exception as e: print(traceback.format_exc()) weak_thread_data[threading.current_thread()] = { 'error': e } return + weak_thread_data[threading.current_thread()] = { - 'device': runtime.thread_data.device, - 'device_name': runtime.thread_data.device_name, + 'device': runtime2.thread_data.device, + 'device_name': runtime2.thread_data.device_name, 'alive': True } - if runtime.thread_data.device != 'cpu' or is_alive() == 1: - preload_model() - current_state = ServerStates.Online + + current_state = ServerStates.Online + while True: task_cache.clean() if not weak_thread_data[threading.current_thread()]['alive']: - print(f'Shutting down thread for device {runtime.thread_data.device}') - runtime.unload_models() - runtime.unload_filters() + print(f'Shutting down thread for device {runtime2.thread_data.device}') + runtime2.destroy() return if isinstance(current_state_error, SystemExit): current_state = ServerStates.Unavailable @@ -291,24 +256,17 @@ def thread_render(device): task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue - print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}') + print(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: - if runtime.is_hypernetwork_reload_necessary(task.request): - runtime.reload_hypernetwork() - current_hypernetwork_path = task.request.use_hypernetwork_model - - if runtime.is_model_reload_necessary(task.request): - current_state = ServerStates.LoadingModel - runtime.reload_model() - current_model_path = task.request.use_stable_diffusion_model - current_vae_path = task.request.use_vae_model + current_state = ServerStates.LoadingModel + runtime2.reload_models(task.request) def step_callback(): global current_state_error if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): - runtime.thread_data.stop_processing = True + runtime2.thread_data.stop_processing = True if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None @@ -317,7 +275,7 @@ def thread_render(device): task_cache.keep(task.request.session_id, TASK_TTL) current_state = ServerStates.Rendering - task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback) + task.response = runtime2.make_image(task.request, task.buffer_queue, task.temp_images, step_callback) except Exception as e: task.error = e print(traceback.format_exc()) @@ -331,7 +289,7 @@ def thread_render(device): elif task.error is not None: print(f'Session {task.request.session_id} task {id(task)} failed!') else: - print(f'Session {task.request.session_id} task {id(task)} completed by {runtime.thread_data.device_name}.') + print(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') current_state = ServerStates.Online def get_cached_task(session_id:str, update_ttl:bool=False): @@ -493,8 +451,7 @@ def render(req : ImageRequest): if task and not task.response and not task.error and not task.lock.locked(): # Unstarted task pending, deny queueing more than one. raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.') - # - from . import runtime + r = Request() r.session_id = req.session_id r.prompt = req.prompt diff --git a/ui/server.py b/ui/server.py index 804994a2..7922d8f3 100644 --- a/ui/server.py +++ b/ui/server.py @@ -2,64 +2,21 @@ Notes: async endpoints always run on the main thread. Without they run on the thread pool. """ -import json -import traceback - -import sys import os -import socket -import picklescan.scanner -import rich - -SD_DIR = os.getcwd() -print('started in ', SD_DIR) - -SD_UI_DIR = os.getenv('SD_UI_PATH', None) -sys.path.append(os.path.dirname(SD_UI_DIR)) - -CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) -MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) - -USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui')) -CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui')) -UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user')) - -STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors'] -VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] -HYPERNETWORK_MODEL_EXTENSIONS = ['.pt'] - -OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder -TASK_TTL = 15 * 60 # Discard last session's task timeout -APP_CONFIG_DEFAULTS = { - # auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device. - 'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index) - 'update_branch': 'main', - 'ui': { - 'open_browser_on_start': True, - }, -} -APP_CONFIG_DEFAULT_MODELS = [ - # needed to support the legacy installations - 'custom-model', # Check if user has a custom model, use it first. - 'sd-v1-4', # Default fallback. -] +import traceback +import logging +from typing import List, Union from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles from starlette.responses import FileResponse, JSONResponse, StreamingResponse from pydantic import BaseModel -import logging -#import queue, threading, time -from typing import Any, Generator, Hashable, List, Optional, Union -from sd_internal import Request, Response, task_manager +from sd_internal import app, model_manager, task_manager -app = FastAPI() +print('started in ', app.SD_DIR) -modifiers_cache = None -outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME) - -os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) +server_api = FastAPI() # don't show access log entries for URLs that start with the given prefix ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails'] @@ -74,132 +31,6 @@ class NoCacheStaticFiles(StaticFiles): return super().is_not_modified(response_headers, request_headers) -app.mount('/media', NoCacheStaticFiles(directory=os.path.join(SD_UI_DIR, 'media')), name="media") - -for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES: - app.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}") - -def getConfig(default_val=APP_CONFIG_DEFAULTS): - try: - config_json_path = os.path.join(CONFIG_DIR, 'config.json') - if not os.path.exists(config_json_path): - return default_val - with open(config_json_path, 'r', encoding='utf-8') as f: - config = json.load(f) - if 'net' not in config: - config['net'] = {} - if os.getenv('SD_UI_BIND_PORT') is not None: - config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT')) - if os.getenv('SD_UI_BIND_IP') is not None: - config['net']['listen_to_network'] = ( os.getenv('SD_UI_BIND_IP') == '0.0.0.0' ) - return config - except Exception as e: - print(str(e)) - print(traceback.format_exc()) - return default_val - -def setConfig(config): - print( json.dumps(config) ) - try: # config.json - config_json_path = os.path.join(CONFIG_DIR, 'config.json') - with open(config_json_path, 'w', encoding='utf-8') as f: - json.dump(config, f) - except: - print(traceback.format_exc()) - - try: # config.bat - config_bat_path = os.path.join(CONFIG_DIR, 'config.bat') - config_bat = [] - - if 'update_branch' in config: - config_bat.append(f"@set update_branch={config['update_branch']}") - - config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}") - bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' - config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") - - config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}") - - if len(config_bat) > 0: - with open(config_bat_path, 'w', encoding='utf-8') as f: - f.write('\r\n'.join(config_bat)) - except: - print(traceback.format_exc()) - - try: # config.sh - config_sh_path = os.path.join(CONFIG_DIR, 'config.sh') - config_sh = ['#!/bin/bash'] - - if 'update_branch' in config: - config_sh.append(f"export update_branch={config['update_branch']}") - - config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}") - bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' - config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") - - config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"") - - if len(config_sh) > 1: - with open(config_sh_path, 'w', encoding='utf-8') as f: - f.write('\n'.join(config_sh)) - except: - print(traceback.format_exc()) - -def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]): - config = getConfig() - - model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR] - if not model_name: # When None try user configured model. - # config = getConfig() - if 'model' in config and model_type in config['model']: - model_name = config['model'][model_type] - if model_name: - is_sd2 = config.get('test_sd2', False) - if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4 - print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!') - model_name = 'sd-v1-4' - - # Check models directory - models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name) - for model_extension in model_extensions: - if os.path.exists(models_dir_path + model_extension): - return models_dir_path - if os.path.exists(model_name + model_extension): - # Direct Path to file - model_name = os.path.abspath(model_name) - return model_name - # Default locations - if model_name in default_models: - default_model_path = os.path.join(SD_DIR, model_name) - for model_extension in model_extensions: - if os.path.exists(default_model_path + model_extension): - return default_model_path - # Can't find requested model, check the default paths. - for default_model in default_models: - for model_dir in model_dirs: - default_model_path = os.path.join(model_dir, default_model) - for model_extension in model_extensions: - if os.path.exists(default_model_path + model_extension): - if model_name is not None: - print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') - return default_model_path - raise Exception('No valid models found.') - -def resolve_ckpt_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=APP_CONFIG_DEFAULT_MODELS) - -def resolve_vae_to_use(model_name:str=None): - try: - return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[]) - except: - return None - -def resolve_hypernetwork_to_use(model_name:str=None): - try: - return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[]) - except: - return None - class SetAppConfigRequest(BaseModel): update_branch: str = None render_devices: Union[List[str], List[int], str, int] = None @@ -209,9 +40,25 @@ class SetAppConfigRequest(BaseModel): listen_port: int = None test_sd2: bool = None -@app.post('/app_config') +class LogSuppressFilter(logging.Filter): + def filter(self, record: logging.LogRecord) -> bool: + path = record.getMessage() + for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES: + if path.find(prefix) != -1: + return False + return True + +# don't log certain requests +logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) + +server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media") + +for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES: + app.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}") + +@server_api.post('/app_config') async def setAppConfig(req : SetAppConfigRequest): - config = getConfig() + config = app.getConfig() if req.update_branch is not None: config['update_branch'] = req.update_branch if req.render_devices is not None: @@ -231,121 +78,48 @@ async def setAppConfig(req : SetAppConfigRequest): if req.test_sd2 is not None: config['test_sd2'] = req.test_sd2 try: - setConfig(config) + app.setConfig(config) if req.render_devices: - update_render_threads() + app.update_render_threads() return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS) except Exception as e: print(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) -def is_malicious_model(file_path): - try: - scan_result = picklescan.scanner.scan_file_path(file_path) - if scan_result.issues_count > 0 or scan_result.infected_files > 0: - rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) - return True - else: - rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) - return False - except Exception as e: - print('error while scanning', file_path, 'error:', e) - return False +def update_render_devices_in_config(config, render_devices): + if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'): + raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}') -known_models = {} -def getModels(): - models = { - 'active': { - 'stable-diffusion': 'sd-v1-4', - 'vae': '', - 'hypernetwork': '', - }, - 'options': { - 'stable-diffusion': ['sd-v1-4'], - 'vae': [], - 'hypernetwork': [], - }, - } + if render_devices.startswith('cuda:'): + render_devices = render_devices.split(',') - def listModels(models_dirname, model_type, model_extensions): - models_dir = os.path.join(MODELS_DIR, models_dirname) - if not os.path.exists(models_dir): - os.makedirs(models_dir) + config['render_devices'] = render_devices - for file in os.listdir(models_dir): - for model_extension in model_extensions: - if not file.endswith(model_extension): - continue - - model_path = os.path.join(models_dir, file) - mtime = os.path.getmtime(model_path) - mod_time = known_models[model_path] if model_path in known_models else -1 - if mod_time != mtime: - if is_malicious_model(model_path): - models['scan-error'] = file - return - known_models[model_path] = mtime - - model_name = file[:-len(model_extension)] - models['options'][model_type].append(model_name) - - models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates - models['options'][model_type].sort() - - # custom models - listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS) - listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS) - listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS) - # legacy - custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt') - if os.path.exists(custom_weight_path): - models['options']['stable-diffusion'].append('custom-model') - - return models - -def getUIPlugins(): - plugins = [] - - for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES: - for file in os.listdir(plugins_dir): - if file.endswith('.plugin.js'): - plugins.append(f'/plugins/{dir_prefix}/{file}') - - return plugins - -def getIPConfig(): - ips = socket.gethostbyname_ex(socket.gethostname()) - ips[2].append(ips[0]) - return ips[2] - -@app.get('/get/{key:path}') +@server_api.get('/get/{key:path}') def read_web_data(key:str=None): if not key: # /get without parameters, stable-diffusion easter egg. raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot elif key == 'app_config': - config = getConfig(default_val=None) - if config is None: - config = APP_CONFIG_DEFAULTS - return JSONResponse(config, headers=NOCACHE_HEADERS) + return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS) elif key == 'system_info': - config = getConfig() + config = app.getConfig() system_info = { 'devices': task_manager.get_devices(), - 'hosts': getIPConfig(), + 'hosts': app.getIPConfig(), + 'default_output_dir': os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME), } system_info['devices']['config'] = config.get('render_devices', "auto") return JSONResponse(system_info, headers=NOCACHE_HEADERS) elif key == 'models': - return JSONResponse(getModels(), headers=NOCACHE_HEADERS) - elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) - elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS) - elif key == 'ui_plugins': return JSONResponse(getUIPlugins(), headers=NOCACHE_HEADERS) + return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS) + elif key == 'modifiers': return FileResponse(os.path.join(app.SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) + elif key == 'ui_plugins': return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS) else: raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found -@app.get('/ping') # Get server and optionally session status. +@server_api.get('/ping') # Get server and optionally session status. def ping(session_id:str=None): if task_manager.is_alive() <= 0: # Check that render threads are alive. if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) @@ -372,38 +146,14 @@ def ping(session_id:str=None): response['devices'] = task_manager.get_devices() return JSONResponse(response, headers=NOCACHE_HEADERS) -def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name): - config = getConfig() - if 'model' not in config: - config['model'] = {} - - config['model']['stable-diffusion'] = ckpt_model_name - config['model']['vae'] = vae_model_name - config['model']['hypernetwork'] = hypernetwork_model_name - - if vae_model_name is None or vae_model_name == "": - del config['model']['vae'] - if hypernetwork_model_name is None or hypernetwork_model_name == "": - del config['model']['hypernetwork'] - - setConfig(config) - -def update_render_devices_in_config(config, render_devices): - if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'): - raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}') - - if render_devices.startswith('cuda:'): - render_devices = render_devices.split(',') - - config['render_devices'] = render_devices - -@app.post('/render') +@server_api.post('/render') def render(req : task_manager.ImageRequest): try: - save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) - req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model) - req.use_vae_model = resolve_vae_to_use(req.use_vae_model) - req.use_hypernetwork_model = resolve_hypernetwork_to_use(req.use_hypernetwork_model) + app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) + req.use_stable_diffusion_model = model_manager.resolve_ckpt_to_use(req.use_stable_diffusion_model) + req.use_vae_model = model_manager.resolve_vae_to_use(req.use_vae_model) + req.use_hypernetwork_model = model_manager.resolve_hypernetwork_to_use(req.use_hypernetwork_model) + new_task = task_manager.render(req) response = { 'status': str(task_manager.current_state), @@ -419,7 +169,7 @@ def render(req : task_manager.ImageRequest): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@app.get('/image/stream/{session_id:str}/{task_id:int}') +@server_api.get('/image/stream/{session_id:str}/{task_id:int}') def stream(session_id:str, task_id:int): #TODO Move to WebSockets ?? task = task_manager.get_cached_task(session_id, update_ttl=True) @@ -433,7 +183,7 @@ def stream(session_id:str, task_id:int): #print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') return StreamingResponse(task.read_buffer_generator(), media_type='application/json') -@app.get('/image/stop') +@server_api.get('/image/stop') def stop(session_id:str=None): if not session_id: if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable: @@ -446,7 +196,7 @@ def stop(session_id:str=None): task.error = StopAsyncIteration('') return {'OK'} -@app.get('/image/tmp/{session_id}/{img_id:int}') +@server_api.get('/image/tmp/{session_id}/{img_id:int}') def get_image(session_id, img_id): task = task_manager.get_cached_task(session_id, update_ttl=True) if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone @@ -458,49 +208,17 @@ def get_image(session_id, img_id): except KeyError as e: raise HTTPException(status_code=500, detail=str(e)) -@app.get('/') +@server_api.get('/') def read_root(): - return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) + return FileResponse(os.path.join(app.SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) -@app.on_event("shutdown") +@server_api.on_event("shutdown") def shutdown_event(): # Signal render thread to close on shutdown task_manager.current_state_error = SystemExit('Application shutting down.') -# don't log certain requests -class LogSuppressFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - path = record.getMessage() - for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES: - if path.find(prefix) != -1: - return False - return True -logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) - -# Check models and prepare cache for UI open -getModels() - -# Start the task_manager -task_manager.default_model_to_load = resolve_ckpt_to_use() -task_manager.default_vae_to_load = resolve_vae_to_use() -task_manager.default_hypernetwork_to_load = resolve_hypernetwork_to_use() - -def update_render_threads(): - config = getConfig() - render_devices = config.get('render_devices', 'auto') - active_devices = task_manager.get_devices()['active'].keys() - - print('requesting for render_devices', render_devices) - task_manager.update_render_threads(render_devices, active_devices) - -update_render_threads() +# Init the app +model_manager.init() +app.init() # start the browser ui -def open_browser(): - config = getConfig() - ui = config.get('ui', {}) - net = config.get('net', {'listen_port':9000}) - port = net.get('listen_port', 9000) - if ui.get('open_browser_on_start', True): - import webbrowser; webbrowser.open(f"http://localhost:{port}") - -open_browser() +app.open_browser()