mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-01-13 17:58:26 +01:00
Integration bugs
This commit is contained in:
parent
d8543d1358
commit
107112d1c4
@ -5,7 +5,7 @@ from easydiffusion.types import TaskData
|
||||
from easydiffusion.utils import log
|
||||
|
||||
from sdkit import Context
|
||||
from sdkit.models import load_model, unload_model, get_known_model_info, scan_model
|
||||
from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model
|
||||
from sdkit.utils import hash_file_quick
|
||||
|
||||
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
||||
@ -29,6 +29,7 @@ VRAM_USAGE_LEVEL_TO_OPTIMIZATIONS = {
|
||||
'low': {'KEEP_ENTIRE_MODEL_IN_CPU'},
|
||||
'high': {},
|
||||
}
|
||||
MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork']
|
||||
|
||||
known_models = {}
|
||||
|
||||
@ -37,16 +38,12 @@ def init():
|
||||
getModels() # run this once, to cache the picklescan results
|
||||
|
||||
def load_default_models(context: Context):
|
||||
# init default model paths
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
|
||||
|
||||
set_vram_optimizations(context)
|
||||
|
||||
# load mandatory models
|
||||
load_model(context, 'stable-diffusion', scan_model=False) # we've scanned them already
|
||||
load_model(context, 'vae')
|
||||
load_model(context, 'hypernetwork')
|
||||
# init default model paths
|
||||
for model_type in MODELS_TO_LOAD_ON_START:
|
||||
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
|
||||
load_model(context, model_type)
|
||||
|
||||
def unload_all(context: Context):
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
@ -106,7 +103,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||
|
||||
if 'stable-diffusion' in models_to_reload:
|
||||
quick_hash = hash_file_quick(models_to_reload['stable-diffusion'])
|
||||
known_model_info = get_known_model_info(quick_hash=quick_hash)
|
||||
known_model_info = get_model_info_from_db(quick_hash=quick_hash)
|
||||
|
||||
for model_type, model_path_in_req in models_to_reload.items():
|
||||
context.model_paths[model_type] = model_path_in_req
|
||||
@ -120,7 +117,7 @@ def resolve_model_paths(task_data: TaskData):
|
||||
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
|
||||
|
||||
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
|
||||
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan')
|
||||
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan')
|
||||
|
||||
def set_vram_optimizations(context: Context):
|
||||
config = app.getConfig()
|
||||
|
@ -73,8 +73,8 @@ def filter_images(task_data: TaskData, images: list, user_stopped):
|
||||
return images
|
||||
|
||||
filters_to_apply = []
|
||||
if 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
|
||||
if 'realesrgan' in task_data.use_face_correction.lower(): filters_to_apply.append('realesrgan')
|
||||
if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
|
||||
if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan')
|
||||
|
||||
return apply_filters(context, filters_to_apply, images)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user