From 107112d1c485fdaaed048c55859d5001c574ca0e Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Sat, 24 Dec 2022 12:37:20 +0530 Subject: [PATCH] Integration bugs --- ui/easydiffusion/model_manager.py | 19 ++++++++----------- ui/easydiffusion/renderer.py | 4 ++-- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 367312ca..e08328fe 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -5,7 +5,7 @@ from easydiffusion.types import TaskData from easydiffusion.utils import log from sdkit import Context -from sdkit.models import load_model, unload_model, get_known_model_info, scan_model +from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model from sdkit.utils import hash_file_quick KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] @@ -29,6 +29,7 @@ VRAM_USAGE_LEVEL_TO_OPTIMIZATIONS = { 'low': {'KEEP_ENTIRE_MODEL_IN_CPU'}, 'high': {}, } +MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork'] known_models = {} @@ -37,16 +38,12 @@ def init(): getModels() # run this once, to cache the picklescan results def load_default_models(context: Context): - # init default model paths - for model_type in KNOWN_MODEL_TYPES: - context.model_paths[model_type] = resolve_model_to_use(model_type=model_type) - set_vram_optimizations(context) - # load mandatory models - load_model(context, 'stable-diffusion', scan_model=False) # we've scanned them already - load_model(context, 'vae') - load_model(context, 'hypernetwork') + # init default model paths + for model_type in MODELS_TO_LOAD_ON_START: + context.model_paths[model_type] = resolve_model_to_use(model_type=model_type) + load_model(context, model_type) def unload_all(context: Context): for model_type in KNOWN_MODEL_TYPES: @@ -106,7 +103,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData): if 'stable-diffusion' in models_to_reload: quick_hash = hash_file_quick(models_to_reload['stable-diffusion']) - known_model_info = get_known_model_info(quick_hash=quick_hash) + known_model_info = get_model_info_from_db(quick_hash=quick_hash) for model_type, model_path_in_req in models_to_reload.items(): context.model_paths[model_type] = model_path_in_req @@ -120,7 +117,7 @@ def resolve_model_paths(task_data: TaskData): task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork') if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan') - if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan') + if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan') def set_vram_optimizations(context: Context): config = app.getConfig() diff --git a/ui/easydiffusion/renderer.py b/ui/easydiffusion/renderer.py index 4b9847c0..b55f3667 100644 --- a/ui/easydiffusion/renderer.py +++ b/ui/easydiffusion/renderer.py @@ -73,8 +73,8 @@ def filter_images(task_data: TaskData, images: list, user_stopped): return images filters_to_apply = [] - if 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan') - if 'realesrgan' in task_data.use_face_correction.lower(): filters_to_apply.append('realesrgan') + if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan') + if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan') return apply_filters(context, filters_to_apply, images)