mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-23 00:33:28 +01:00
Simplify the logic for reloading gfpgan and realesrgan models (based on the request), using the code path used for the other model types
This commit is contained in:
parent
afb88616d8
commit
d03eed3859
@ -40,7 +40,7 @@ def init(device):
|
||||
device_manager.device_init(thread_data, device)
|
||||
|
||||
def destroy():
|
||||
for model_type in ('stable-diffusion', 'hypernetwork', 'gfpgan', 'realesrgan'):
|
||||
for model_type in model_manager.KNOWN_MODEL_TYPES:
|
||||
model_loader.unload_model(thread_data, model_type)
|
||||
|
||||
def load_default_models():
|
||||
@ -52,19 +52,26 @@ def load_default_models():
|
||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||
|
||||
def reload_models_if_necessary(req: Request):
|
||||
model_paths_in_req = (
|
||||
('hypernetwork', req.use_hypernetwork_model),
|
||||
('gfpgan', req.use_face_correction),
|
||||
('realesrgan', req.use_upscale),
|
||||
)
|
||||
|
||||
if model_manager.is_sd_model_reload_necessary(thread_data, req):
|
||||
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
|
||||
thread_data.model_paths['vae'] = req.use_vae_model
|
||||
|
||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||
|
||||
if thread_data.model_paths.get('hypernetwork') != req.use_hypernetwork_model:
|
||||
thread_data.model_paths['hypernetwork'] = req.use_hypernetwork_model
|
||||
for model_type, model_path_in_req in model_paths_in_req:
|
||||
if thread_data.model_paths.get(model_type) != model_path_in_req:
|
||||
thread_data.model_paths[model_type] = model_path_in_req
|
||||
|
||||
if thread_data.model_paths['hypernetwork'] is not None:
|
||||
model_loader.load_model(thread_data, 'hypernetwork')
|
||||
if thread_data.model_paths[model_type] is not None:
|
||||
model_loader.load_model(thread_data, model_type)
|
||||
else:
|
||||
model_loader.unload_model(thread_data, 'hypernetwork')
|
||||
model_loader.unload_model(thread_data, model_type)
|
||||
|
||||
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
try:
|
||||
@ -137,13 +144,13 @@ def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_ima
|
||||
return images
|
||||
|
||||
filters = []
|
||||
if args['use_face_correction'].startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(args['use_face_correction'], model_type='gfpgan')))
|
||||
if args['use_face_correction'].use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(args['use_upscale'], model_type='realesrgan')))
|
||||
if 'gfpgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_gfpgan)
|
||||
if 'realesrgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_realesrgan)
|
||||
|
||||
filtered_images = []
|
||||
for img, seed, _ in images:
|
||||
for filter_fn, filter_model_path in filters:
|
||||
img = filter_fn(thread_data, img, filter_model_path)
|
||||
for filter_fn in filters:
|
||||
img = filter_fn(thread_data, img)
|
||||
|
||||
filtered_images.append((img, seed, True))
|
||||
|
||||
|
@ -140,10 +140,16 @@ def ping(session_id:str=None):
|
||||
def render(req : task_manager.ImageRequest):
|
||||
try:
|
||||
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
|
||||
|
||||
# resolve the model paths to use
|
||||
req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion')
|
||||
req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae')
|
||||
req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork')
|
||||
|
||||
if req.use_face_correction: req.use_face_correction = model_manager.resolve_model_to_use(req.use_face_correction, 'gfpgan')
|
||||
if req.use_upscale: req.use_upscale = model_manager.resolve_model_to_use(req.use_upscale, 'gfpgan')
|
||||
|
||||
# enqueue the task
|
||||
new_task = task_manager.render(req)
|
||||
response = {
|
||||
'status': str(task_manager.current_state),
|
||||
|
Loading…
Reference in New Issue
Block a user