Integration bugs

This commit is contained in:
cmdr2 2022-12-24 12:37:20 +05:30
parent d8543d1358
commit 107112d1c4
2 changed files with 10 additions and 13 deletions

View File

@ -5,7 +5,7 @@ from easydiffusion.types import TaskData
from easydiffusion.utils import log from easydiffusion.utils import log
from sdkit import Context from sdkit import Context
from sdkit.models import load_model, unload_model, get_known_model_info, scan_model from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model
from sdkit.utils import hash_file_quick from sdkit.utils import hash_file_quick
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
@ -29,6 +29,7 @@ VRAM_USAGE_LEVEL_TO_OPTIMIZATIONS = {
'low': {'KEEP_ENTIRE_MODEL_IN_CPU'}, 'low': {'KEEP_ENTIRE_MODEL_IN_CPU'},
'high': {}, 'high': {},
} }
MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork']
known_models = {} known_models = {}
@ -37,16 +38,12 @@ def init():
getModels() # run this once, to cache the picklescan results getModels() # run this once, to cache the picklescan results
def load_default_models(context: Context): def load_default_models(context: Context):
# init default model paths
for model_type in KNOWN_MODEL_TYPES:
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
set_vram_optimizations(context) set_vram_optimizations(context)
# load mandatory models # init default model paths
load_model(context, 'stable-diffusion', scan_model=False) # we've scanned them already for model_type in MODELS_TO_LOAD_ON_START:
load_model(context, 'vae') context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
load_model(context, 'hypernetwork') load_model(context, model_type)
def unload_all(context: Context): def unload_all(context: Context):
for model_type in KNOWN_MODEL_TYPES: for model_type in KNOWN_MODEL_TYPES:
@ -106,7 +103,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
if 'stable-diffusion' in models_to_reload: if 'stable-diffusion' in models_to_reload:
quick_hash = hash_file_quick(models_to_reload['stable-diffusion']) quick_hash = hash_file_quick(models_to_reload['stable-diffusion'])
known_model_info = get_known_model_info(quick_hash=quick_hash) known_model_info = get_model_info_from_db(quick_hash=quick_hash)
for model_type, model_path_in_req in models_to_reload.items(): for model_type, model_path_in_req in models_to_reload.items():
context.model_paths[model_type] = model_path_in_req context.model_paths[model_type] = model_path_in_req
@ -120,7 +117,7 @@ def resolve_model_paths(task_data: TaskData):
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork') task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan') if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan') if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan')
def set_vram_optimizations(context: Context): def set_vram_optimizations(context: Context):
config = app.getConfig() config = app.getConfig()

View File

@ -73,8 +73,8 @@ def filter_images(task_data: TaskData, images: list, user_stopped):
return images return images
filters_to_apply = [] filters_to_apply = []
if 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan') if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
if 'realesrgan' in task_data.use_face_correction.lower(): filters_to_apply.append('realesrgan') if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan')
return apply_filters(context, filters_to_apply, images) return apply_filters(context, filters_to_apply, images)