diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index c7d22e1d..6cf9428a 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -28,7 +28,7 @@ def init(): make_model_folders() getModels() # run this once, to cache the picklescan results -def resolve_model_to_use(model_name:str, model_type:str): +def resolve_model_to_use(model_name:str=None, model_type:str=None): model_extensions = MODEL_EXTENSIONS.get(model_type, []) default_models = DEFAULT_MODELS.get(model_type, []) config = app.getConfig() @@ -72,21 +72,6 @@ def resolve_model_to_use(model_name:str, model_type:str): return None -def resolve_sd_model_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='stable-diffusion') - -def resolve_vae_model_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='vae') - -def resolve_hypernetwork_model_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='hypernetwork') - -def resolve_gfpgan_model_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='gfpgan') - -def resolve_realesrgan_model_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='realesrgan') - def make_model_folders(): for model_type in KNOWN_MODEL_TYPES: model_dir_path = os.path.join(app.MODELS_DIR, model_type) diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 0af0ead4..3d4ff8ff 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -39,18 +39,13 @@ def init(device): init_and_load_default_models() def destroy(): - model_loader.unload_model(thread_data, 'stable-diffusion') - model_loader.unload_model(thread_data, 'gfpgan') - model_loader.unload_model(thread_data, 'realesrgan') - model_loader.unload_model(thread_data, 'hypernetwork') + for model_type in ('stable-diffusion', 'hypernetwork', 'gfpgan', 'realesrgan'): + model_loader.unload_model(thread_data, model_type) def init_and_load_default_models(): # init default model paths - thread_data.model_paths['stable-diffusion'] = model_manager.resolve_sd_model_to_use() - thread_data.model_paths['vae'] = model_manager.resolve_vae_model_to_use() - thread_data.model_paths['hypernetwork'] = model_manager.resolve_hypernetwork_model_to_use() - thread_data.model_paths['gfpgan'] = model_manager.resolve_gfpgan_model_to_use() - thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use() + for model_type in ('stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'): + thread_data.model_paths[model_type] = model_manager.resolve_model_to_use(model_type=model_type) # load mandatory models model_loader.load_model(thread_data, 'stable-diffusion') @@ -119,8 +114,8 @@ def apply_filters(req: Request, images: list, user_stopped): return images filters = [] - if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_gfpgan_model_to_use(req.use_face_correction))) - if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_realesrgan_model_to_use(req.use_upscale))) + if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(req.use_face_correction, model_type='gfpgan'))) + if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(req.use_upscale, model_type='realesrgan'))) filtered_images = [] for img, seed, _ in images: diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index c6ab9737..aec79239 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -11,10 +11,10 @@ TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout import torch import queue, threading, time, weakref -from typing import Any, Generator, Hashable, Optional, Union +from typing import Any, Hashable from pydantic import BaseModel -from sd_internal import Request, Response, runtime, device_manager +from sd_internal import Request, device_manager THREAD_NAME_PREFIX = 'Runtime-Render/' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' diff --git a/ui/server.py b/ui/server.py index b3c54342..ca5cac79 100644 --- a/ui/server.py +++ b/ui/server.py @@ -136,9 +136,9 @@ def ping(session_id:str=None): def render(req : task_manager.ImageRequest): try: app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) - req.use_stable_diffusion_model = model_manager.resolve_sd_model_to_use(req.use_stable_diffusion_model) - req.use_vae_model = model_manager.resolve_vae_model_to_use(req.use_vae_model) - req.use_hypernetwork_model = model_manager.resolve_hypernetwork_model_to_use(req.use_hypernetwork_model) + req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion') + req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae') + req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork') new_task = task_manager.render(req) response = {