Integration bugs

This commit is contained in:
cmdr2 2022-12-24 12:37:20 +05:30
parent d8543d1358
commit 107112d1c4
2 changed files with 10 additions and 13 deletions

View File

@ -5,7 +5,7 @@ from easydiffusion.types import TaskData
from easydiffusion.utils import log
from sdkit import Context
from sdkit.models import load_model, unload_model, get_known_model_info, scan_model
from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model
from sdkit.utils import hash_file_quick
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
@ -29,6 +29,7 @@ VRAM_USAGE_LEVEL_TO_OPTIMIZATIONS = {
'low': {'KEEP_ENTIRE_MODEL_IN_CPU'},
'high': {},
}
MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork']
known_models = {}
@ -37,16 +38,12 @@ def init():
getModels() # run this once, to cache the picklescan results
def load_default_models(context: Context):
# init default model paths
for model_type in KNOWN_MODEL_TYPES:
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
set_vram_optimizations(context)
# load mandatory models
load_model(context, 'stable-diffusion', scan_model=False) # we've scanned them already
load_model(context, 'vae')
load_model(context, 'hypernetwork')
# init default model paths
for model_type in MODELS_TO_LOAD_ON_START:
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
load_model(context, model_type)
def unload_all(context: Context):
for model_type in KNOWN_MODEL_TYPES:
@ -106,7 +103,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
if 'stable-diffusion' in models_to_reload:
quick_hash = hash_file_quick(models_to_reload['stable-diffusion'])
known_model_info = get_known_model_info(quick_hash=quick_hash)
known_model_info = get_model_info_from_db(quick_hash=quick_hash)
for model_type, model_path_in_req in models_to_reload.items():
context.model_paths[model_type] = model_path_in_req
@ -120,7 +117,7 @@ def resolve_model_paths(task_data: TaskData):
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan')
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan')
def set_vram_optimizations(context: Context):
config = app.getConfig()

View File

@ -73,8 +73,8 @@ def filter_images(task_data: TaskData, images: list, user_stopped):
return images
filters_to_apply = []
if 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
if 'realesrgan' in task_data.use_face_correction.lower(): filters_to_apply.append('realesrgan')
if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan')
return apply_filters(context, filters_to_apply, images)