diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 3d68494f..ea624322 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -3,7 +3,7 @@ import logging import picklescan.scanner import rich -from sd_internal import app +from sd_internal import app, TaskData from modules import model_loader from modules.types import Context @@ -88,6 +88,34 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None): return None +def reload_models_if_necessary(context: Context, task_data: TaskData): + model_paths_in_req = ( + ('hypernetwork', task_data.use_hypernetwork_model), + ('gfpgan', task_data.use_face_correction), + ('realesrgan', task_data.use_upscale), + ) + + if context.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or context.model_paths.get('vae') != task_data.use_vae_model: + context.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model + context.model_paths['vae'] = task_data.use_vae_model + + model_loader.load_model(context, 'stable-diffusion') + + for model_type, model_path_in_req in model_paths_in_req: + if context.model_paths.get(model_type) != model_path_in_req: + context.model_paths[model_type] = model_path_in_req + + action_fn = model_loader.unload_model if context.model_paths[model_type] is None else model_loader.load_model + action_fn(context, model_type) + +def resolve_model_paths(task_data: TaskData): + task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion') + task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae') + task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork') + + if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan') + if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan') + def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: model_dir_path = os.path.join(app.MODELS_DIR, model_type) diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 399ba960..1f2e9d27 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -1,4 +1,3 @@ -import threading import queue import time import json @@ -8,7 +7,7 @@ import re import traceback import logging -from sd_internal import device_manager, model_manager +from sd_internal import device_manager from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop from modules import model_loader, image_generator, image_utils, filters as image_filters @@ -33,26 +32,6 @@ def init(device): device_manager.device_init(thread_data, device) -def reload_models_if_necessary(task_data: TaskData): - model_paths_in_req = ( - ('hypernetwork', task_data.use_hypernetwork_model), - ('gfpgan', task_data.use_face_correction), - ('realesrgan', task_data.use_upscale), - ) - - if thread_data.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or thread_data.model_paths.get('vae') != task_data.use_vae_model: - thread_data.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model - thread_data.model_paths['vae'] = task_data.use_vae_model - - model_loader.load_model(thread_data, 'stable-diffusion') - - for model_type, model_path_in_req in model_paths_in_req: - if thread_data.model_paths.get(model_type) != model_path_in_req: - thread_data.model_paths[model_type] = model_path_in_req - - action_fn = model_loader.unload_model if thread_data.model_paths[model_type] is None else model_loader.load_model - action_fn(thread_data, model_type) - def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): try: return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) @@ -80,14 +59,6 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q return res -def resolve_model_paths(task_data: TaskData): - task_data.use_stable_diffusion_model = model_manager.resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion') - task_data.use_vae_model = model_manager.resolve_model_to_use(task_data.use_vae_model, model_type='vae') - task_data.use_hypernetwork_model = model_manager.resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork') - - if task_data.use_face_correction: task_data.use_face_correction = model_manager.resolve_model_to_use(task_data.use_face_correction, 'gfpgan') - if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan') - def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): log.info(req.to_metadata()) thread_data.temp_images.clear() diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index c167a5f9..5e6abce2 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -277,10 +277,10 @@ def thread_render(device): log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}') current_state = ServerStates.LoadingModel - runtime2.reload_models_if_necessary(task.task_data) + model_manager.resolve_model_paths(task.task_data) + model_manager.reload_models_if_necessary(runtime2.thread_data, task.task_data) current_state = ServerStates.Rendering - runtime2.resolve_model_paths(task.task_data) task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL)