Simplify the logic for reloading gfpgan and realesrgan models (based on the request), using the code path used for the other model types

This commit is contained in:
cmdr2 2022-12-11 14:14:59 +05:30
parent afb88616d8
commit d03eed3859
2 changed files with 24 additions and 11 deletions

View File

@ -40,7 +40,7 @@ def init(device):
device_manager.device_init(thread_data, device)
def destroy():
for model_type in ('stable-diffusion', 'hypernetwork', 'gfpgan', 'realesrgan'):
for model_type in model_manager.KNOWN_MODEL_TYPES:
model_loader.unload_model(thread_data, model_type)
def load_default_models():
@ -52,19 +52,26 @@ def load_default_models():
model_loader.load_model(thread_data, 'stable-diffusion')
def reload_models_if_necessary(req: Request):
model_paths_in_req = (
('hypernetwork', req.use_hypernetwork_model),
('gfpgan', req.use_face_correction),
('realesrgan', req.use_upscale),
)
if model_manager.is_sd_model_reload_necessary(thread_data, req):
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
thread_data.model_paths['vae'] = req.use_vae_model
model_loader.load_model(thread_data, 'stable-diffusion')
if thread_data.model_paths.get('hypernetwork') != req.use_hypernetwork_model:
thread_data.model_paths['hypernetwork'] = req.use_hypernetwork_model
for model_type, model_path_in_req in model_paths_in_req:
if thread_data.model_paths.get(model_type) != model_path_in_req:
thread_data.model_paths[model_type] = model_path_in_req
if thread_data.model_paths['hypernetwork'] is not None:
model_loader.load_model(thread_data, 'hypernetwork')
else:
model_loader.unload_model(thread_data, 'hypernetwork')
if thread_data.model_paths[model_type] is not None:
model_loader.load_model(thread_data, model_type)
else:
model_loader.unload_model(thread_data, model_type)
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
try:
@ -137,13 +144,13 @@ def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_ima
return images
filters = []
if args['use_face_correction'].startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(args['use_face_correction'], model_type='gfpgan')))
if args['use_face_correction'].use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(args['use_upscale'], model_type='realesrgan')))
if 'gfpgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_gfpgan)
if 'realesrgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_realesrgan)
filtered_images = []
for img, seed, _ in images:
for filter_fn, filter_model_path in filters:
img = filter_fn(thread_data, img, filter_model_path)
for filter_fn in filters:
img = filter_fn(thread_data, img)
filtered_images.append((img, seed, True))

View File

@ -140,10 +140,16 @@ def ping(session_id:str=None):
def render(req : task_manager.ImageRequest):
try:
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
# resolve the model paths to use
req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion')
req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae')
req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork')
if req.use_face_correction: req.use_face_correction = model_manager.resolve_model_to_use(req.use_face_correction, 'gfpgan')
if req.use_upscale: req.use_upscale = model_manager.resolve_model_to_use(req.use_upscale, 'gfpgan')
# enqueue the task
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),