mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-01-14 10:19:46 +01:00
Integration bugs
This commit is contained in:
parent
d8543d1358
commit
107112d1c4
@ -5,7 +5,7 @@ from easydiffusion.types import TaskData
|
|||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
from sdkit import Context
|
from sdkit import Context
|
||||||
from sdkit.models import load_model, unload_model, get_known_model_info, scan_model
|
from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model
|
||||||
from sdkit.utils import hash_file_quick
|
from sdkit.utils import hash_file_quick
|
||||||
|
|
||||||
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
||||||
@ -29,6 +29,7 @@ VRAM_USAGE_LEVEL_TO_OPTIMIZATIONS = {
|
|||||||
'low': {'KEEP_ENTIRE_MODEL_IN_CPU'},
|
'low': {'KEEP_ENTIRE_MODEL_IN_CPU'},
|
||||||
'high': {},
|
'high': {},
|
||||||
}
|
}
|
||||||
|
MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork']
|
||||||
|
|
||||||
known_models = {}
|
known_models = {}
|
||||||
|
|
||||||
@ -37,16 +38,12 @@ def init():
|
|||||||
getModels() # run this once, to cache the picklescan results
|
getModels() # run this once, to cache the picklescan results
|
||||||
|
|
||||||
def load_default_models(context: Context):
|
def load_default_models(context: Context):
|
||||||
# init default model paths
|
|
||||||
for model_type in KNOWN_MODEL_TYPES:
|
|
||||||
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
|
|
||||||
|
|
||||||
set_vram_optimizations(context)
|
set_vram_optimizations(context)
|
||||||
|
|
||||||
# load mandatory models
|
# init default model paths
|
||||||
load_model(context, 'stable-diffusion', scan_model=False) # we've scanned them already
|
for model_type in MODELS_TO_LOAD_ON_START:
|
||||||
load_model(context, 'vae')
|
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
|
||||||
load_model(context, 'hypernetwork')
|
load_model(context, model_type)
|
||||||
|
|
||||||
def unload_all(context: Context):
|
def unload_all(context: Context):
|
||||||
for model_type in KNOWN_MODEL_TYPES:
|
for model_type in KNOWN_MODEL_TYPES:
|
||||||
@ -106,7 +103,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
|
|||||||
|
|
||||||
if 'stable-diffusion' in models_to_reload:
|
if 'stable-diffusion' in models_to_reload:
|
||||||
quick_hash = hash_file_quick(models_to_reload['stable-diffusion'])
|
quick_hash = hash_file_quick(models_to_reload['stable-diffusion'])
|
||||||
known_model_info = get_known_model_info(quick_hash=quick_hash)
|
known_model_info = get_model_info_from_db(quick_hash=quick_hash)
|
||||||
|
|
||||||
for model_type, model_path_in_req in models_to_reload.items():
|
for model_type, model_path_in_req in models_to_reload.items():
|
||||||
context.model_paths[model_type] = model_path_in_req
|
context.model_paths[model_type] = model_path_in_req
|
||||||
@ -120,7 +117,7 @@ def resolve_model_paths(task_data: TaskData):
|
|||||||
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
|
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
|
||||||
|
|
||||||
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
|
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
|
||||||
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan')
|
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan')
|
||||||
|
|
||||||
def set_vram_optimizations(context: Context):
|
def set_vram_optimizations(context: Context):
|
||||||
config = app.getConfig()
|
config = app.getConfig()
|
||||||
|
@ -73,8 +73,8 @@ def filter_images(task_data: TaskData, images: list, user_stopped):
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
filters_to_apply = []
|
filters_to_apply = []
|
||||||
if 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
|
if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
|
||||||
if 'realesrgan' in task_data.use_face_correction.lower(): filters_to_apply.append('realesrgan')
|
if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan')
|
||||||
|
|
||||||
return apply_filters(context, filters_to_apply, images)
|
return apply_filters(context, filters_to_apply, images)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user