diff --git a/ui/sd_internal/app.py b/ui/sd_internal/app.py index 09242319..00d2e718 100644 --- a/ui/sd_internal/app.py +++ b/ui/sd_internal/app.py @@ -18,10 +18,6 @@ USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui' CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui')) UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user')) -STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors'] -VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] -HYPERNETWORK_MODEL_EXTENSIONS = ['.pt'] - OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder TASK_TTL = 15 * 60 # Discard last session's task timeout APP_CONFIG_DEFAULTS = { diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 3941b130..906038e1 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -4,6 +4,10 @@ from sd_internal import app import picklescan.scanner import rich +STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors'] +VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] +HYPERNETWORK_MODEL_EXTENSIONS = ['.pt'] + default_model_to_load = None default_vae_to_load = None default_hypernetwork_to_load = None @@ -59,22 +63,17 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') return default_model_path + model_extension - raise Exception('No valid models found.') + print(f'No valid models found for model_name: {model_name}') + return None def resolve_ckpt_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS) + return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS) def resolve_vae_to_use(model_name:str=None): - try: - return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=app.VAE_MODEL_EXTENSIONS, default_models=[]) - except: - return None + return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[]) def resolve_hypernetwork_to_use(model_name:str=None): - try: - return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS, default_models=[]) - except: - return None + return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[]) def is_malicious_model(file_path): try: @@ -129,9 +128,9 @@ def getModels(): models['options'][model_type].sort() # custom models - listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS) - listModels(models_dirname='vae', model_type='vae', model_extensions=app.VAE_MODEL_EXTENSIONS) - listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS) + listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS) + listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS) + listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS) # legacy custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 32befcde..fc8d944d 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -1,7 +1,8 @@ import threading import queue -from sd_internal import device_manager, Request, Response, Image as ResponseImage +from sd_internal import device_manager, model_manager +from sd_internal import Request, Response, Image as ResponseImage from modules import model_loader, image_generator, image_utils @@ -18,7 +19,7 @@ def init(device): thread_data.temp_images = {} thread_data.models = {} - thread_data.loaded_model_paths = {} + thread_data.model_paths = {} thread_data.device = None thread_data.device_name = None @@ -27,14 +28,34 @@ def init(device): device_manager.device_init(thread_data, device) - reload_models() + load_default_models() def destroy(): model_loader.unload_sd_model(thread_data) model_loader.unload_gfpgan_model(thread_data) model_loader.unload_realesrgan_model(thread_data) -def reload_models(req: Request=None): +def load_default_models(): + thread_data.model_paths['stable-diffusion'] = model_manager.default_model_to_load + thread_data.model_paths['vae'] = model_manager.default_vae_to_load + + model_loader.load_sd_model(thread_data) + +def reload_models_if_necessary(req: Request=None): + needs_model_reload = False + if 'stable-diffusion' not in thread_data.models or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: + thread_data.ckpt_file = req.use_stable_diffusion_model + thread_data.vae_file = req.use_vae_model + needs_model_reload = True + + if thread_data.device != 'cpu': + if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ + (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): + thread_data.precision = 'full' if req.use_full_precision else 'autocast' + needs_model_reload = True + + return needs_model_reload + if is_hypernetwork_reload_necessary(task.request): current_state = ServerStates.LoadingModel runtime.reload_hypernetwork() diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 7a58ac14..04cc9a69 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -310,9 +310,6 @@ def thread_render(device): print(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: - current_state = ServerStates.LoadingModel - runtime2.reload_models(task.request) - def step_callback(): global current_state_error @@ -323,6 +320,9 @@ def thread_render(device): current_state_error = None print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') + current_state = ServerStates.LoadingModel + runtime2.reload_models_if_necessary(task.request) + current_state = ServerStates.Rendering task.response = runtime2.make_image(task.request, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive.