From 97919c7e87399b5dc017efae1a83b6f009935ef3 Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sun, 11 Dec 2022 19:58:12 +0530 Subject: [PATCH] Simplify the runtime code --- ui/sd_internal/model_manager.py | 14 +++++++++++++ ui/sd_internal/runtime2.py | 36 +++++---------------------------- ui/sd_internal/task_manager.py | 7 ++++--- 3 files changed, 23 insertions(+), 34 deletions(-) diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 2129af3d..3d68494f 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -4,6 +4,8 @@ import picklescan.scanner import rich from sd_internal import app +from modules import model_loader +from modules.types import Context log = logging.getLogger() @@ -30,6 +32,18 @@ def init(): make_model_folders() getModels() # run this once, to cache the picklescan results +def load_default_models(context: Context): + # init default model paths + for model_type in KNOWN_MODEL_TYPES: + context.model_paths[model_type] = resolve_model_to_use(model_type=model_type) + + # load mandatory models + model_loader.load_model(context, 'stable-diffusion') + +def unload_all(context: Context): + for model_type in KNOWN_MODEL_TYPES: + model_loader.unload_model(context, model_type) + def resolve_model_to_use(model_name:str=None, model_type:str=None): model_extensions = MODEL_EXTENSIONS.get(model_type, []) default_models = DEFAULT_MODELS.get(model_type, []) diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 0e1efc2c..399ba960 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -33,18 +33,6 @@ def init(device): device_manager.device_init(thread_data, device) -def destroy(): - for model_type in model_manager.KNOWN_MODEL_TYPES: - model_loader.unload_model(thread_data, model_type) - -def load_default_models(): - # init default model paths - for model_type in model_manager.KNOWN_MODEL_TYPES: - thread_data.model_paths[model_type] = model_manager.resolve_model_to_use(model_type=model_type) - - # load mandatory models - model_loader.load_model(thread_data, 'stable-diffusion') - def reload_models_if_necessary(task_data: TaskData): model_paths_in_req = ( ('hypernetwork', task_data.use_hypernetwork_model), @@ -67,14 +55,6 @@ def reload_models_if_necessary(task_data: TaskData): def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): try: - # resolve the model paths to use - resolve_model_paths(task_data) - - # convert init image to PIL.Image - req.init_image = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None - req.init_image_mask = image_utils.base64_str_to_img(req.init_image_mask) if req.init_image_mask is not None else None - - # generate return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) except Exception as e: log.error(traceback.format_exc()) @@ -86,17 +66,12 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu raise e def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): - metadata = req.dict() - del metadata['init_image'] - del metadata['init_image_mask'] - print(metadata) - images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image) if task_data.save_to_disk_path is not None: out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) - save_images(images, out_path, metadata=metadata, show_only_filtered_image=task_data.show_only_filtered_image) + save_images(images, out_path, metadata=req.to_metadata(), show_only_filtered_image=task_data.show_only_filtered_image) res = Response(req, task_data, images=construct_response(images)) res = res.json() @@ -114,6 +89,7 @@ def resolve_model_paths(task_data: TaskData): if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan') def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): + log.info(req.to_metadata()) thread_data.temp_images.clear() image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress) @@ -125,13 +101,11 @@ def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_tem images = [] user_stopped = True if thread_data.partial_x_samples is not None: - for i in range(req.num_outputs): - images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0)) - - thread_data.partial_x_samples = None + images = image_utils.latent_samples_to_images(thread_data, thread_data.partial_x_samples) + thread_data.partial_x_samples = None finally: model_loader.gc(thread_data) - + images = [(image, req.seed + i, False) for i, image in enumerate(images)] return images, user_stopped diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 3ec1b99d..c167a5f9 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -219,7 +219,7 @@ def thread_get_next_task(): def thread_render(device): global current_state, current_state_error - from sd_internal import runtime2 + from sd_internal import runtime2, model_manager try: runtime2.init(device) except Exception as e: @@ -235,7 +235,7 @@ def thread_render(device): 'alive': True } - runtime2.load_default_models() + model_manager.load_default_models(runtime2.thread_data) current_state = ServerStates.Online while True: @@ -243,7 +243,7 @@ def thread_render(device): task_cache.clean() if not weak_thread_data[threading.current_thread()]['alive']: log.info(f'Shutting down thread for device {runtime2.thread_data.device}') - runtime2.destroy() + model_manager.unload_all(runtime2.thread_data) return if isinstance(current_state_error, SystemExit): current_state = ServerStates.Unavailable @@ -280,6 +280,7 @@ def thread_render(device): runtime2.reload_models_if_necessary(task.task_data) current_state = ServerStates.Rendering + runtime2.resolve_model_paths(task.task_data) task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL)