From 16410d90b81d9ceec1cc71e347224faae491d40c Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Fri, 9 Dec 2022 15:21:49 +0530 Subject: [PATCH] Use the simplified model loading API in diffusion-kit; Catch and report exceptions while generating images --- ui/sd_internal/runtime2.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 90a868a5..0af0ead4 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -5,6 +5,7 @@ import json import os import base64 import re +import traceback from sd_internal import device_manager, model_manager from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop @@ -38,10 +39,10 @@ def init(device): init_and_load_default_models() def destroy(): - model_loader.unload_sd_model(thread_data) - model_loader.unload_gfpgan_model(thread_data) - model_loader.unload_realesrgan_model(thread_data) - model_loader.unload_hypernetwork_model(thread_data) + model_loader.unload_model(thread_data, 'stable-diffusion') + model_loader.unload_model(thread_data, 'gfpgan') + model_loader.unload_model(thread_data, 'realesrgan') + model_loader.unload_model(thread_data, 'hypernetwork') def init_and_load_default_models(): # init default model paths @@ -52,24 +53,36 @@ def init_and_load_default_models(): thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use() # load mandatory models - model_loader.load_sd_model(thread_data) + model_loader.load_model(thread_data, 'stable-diffusion') def reload_models_if_necessary(req: Request): if model_manager.is_sd_model_reload_necessary(thread_data, req): thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model thread_data.model_paths['vae'] = req.use_vae_model - model_loader.load_sd_model(thread_data) + model_loader.load_model(thread_data, 'stable-diffusion') if thread_data.model_paths.get('hypernetwork') != req.use_hypernetwork_model: thread_data.model_paths['hypernetwork'] = req.use_hypernetwork_model if thread_data.model_paths['hypernetwork'] is not None: - model_loader.load_hypernetwork_model(thread_data) + model_loader.load_model(thread_data, 'hypernetwork') else: - model_loader.unload_hypernetwork_model(thread_data) + model_loader.unload_model(thread_data, 'hypernetwork') def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): + try: + return _make_images_internal(req, data_queue, task_temp_images, step_callback) + except Exception as e: + print(traceback.format_exc()) + + data_queue.put(json.dumps({ + "status": 'failed', + "detail": str(e) + })) + raise e + +def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback) images = apply_filters(req, images, user_stopped)