Simplify the logic for reloading gfpgan and realesrgan models (based on the request), using the code path used for the other model types

This commit is contained in:
cmdr2 2022-12-11 14:14:59 +05:30
parent afb88616d8
commit d03eed3859
2 changed files with 24 additions and 11 deletions

View File

@ -40,7 +40,7 @@ def init(device):
device_manager.device_init(thread_data, device) device_manager.device_init(thread_data, device)
def destroy(): def destroy():
for model_type in ('stable-diffusion', 'hypernetwork', 'gfpgan', 'realesrgan'): for model_type in model_manager.KNOWN_MODEL_TYPES:
model_loader.unload_model(thread_data, model_type) model_loader.unload_model(thread_data, model_type)
def load_default_models(): def load_default_models():
@ -52,19 +52,26 @@ def load_default_models():
model_loader.load_model(thread_data, 'stable-diffusion') model_loader.load_model(thread_data, 'stable-diffusion')
def reload_models_if_necessary(req: Request): def reload_models_if_necessary(req: Request):
model_paths_in_req = (
('hypernetwork', req.use_hypernetwork_model),
('gfpgan', req.use_face_correction),
('realesrgan', req.use_upscale),
)
if model_manager.is_sd_model_reload_necessary(thread_data, req): if model_manager.is_sd_model_reload_necessary(thread_data, req):
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
thread_data.model_paths['vae'] = req.use_vae_model thread_data.model_paths['vae'] = req.use_vae_model
model_loader.load_model(thread_data, 'stable-diffusion') model_loader.load_model(thread_data, 'stable-diffusion')
if thread_data.model_paths.get('hypernetwork') != req.use_hypernetwork_model: for model_type, model_path_in_req in model_paths_in_req:
thread_data.model_paths['hypernetwork'] = req.use_hypernetwork_model if thread_data.model_paths.get(model_type) != model_path_in_req:
thread_data.model_paths[model_type] = model_path_in_req
if thread_data.model_paths['hypernetwork'] is not None: if thread_data.model_paths[model_type] is not None:
model_loader.load_model(thread_data, 'hypernetwork') model_loader.load_model(thread_data, model_type)
else: else:
model_loader.unload_model(thread_data, 'hypernetwork') model_loader.unload_model(thread_data, model_type)
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
try: try:
@ -137,13 +144,13 @@ def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_ima
return images return images
filters = [] filters = []
if args['use_face_correction'].startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(args['use_face_correction'], model_type='gfpgan'))) if 'gfpgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_gfpgan)
if args['use_face_correction'].use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(args['use_upscale'], model_type='realesrgan'))) if 'realesrgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_realesrgan)
filtered_images = [] filtered_images = []
for img, seed, _ in images: for img, seed, _ in images:
for filter_fn, filter_model_path in filters: for filter_fn in filters:
img = filter_fn(thread_data, img, filter_model_path) img = filter_fn(thread_data, img)
filtered_images.append((img, seed, True)) filtered_images.append((img, seed, True))

View File

@ -140,10 +140,16 @@ def ping(session_id:str=None):
def render(req : task_manager.ImageRequest): def render(req : task_manager.ImageRequest):
try: try:
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
# resolve the model paths to use
req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion') req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion')
req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae') req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae')
req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork') req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork')
if req.use_face_correction: req.use_face_correction = model_manager.resolve_model_to_use(req.use_face_correction, 'gfpgan')
if req.use_upscale: req.use_upscale = model_manager.resolve_model_to_use(req.use_upscale, 'gfpgan')
# enqueue the task
new_task = task_manager.render(req) new_task = task_manager.render(req)
response = { response = {
'status': str(task_manager.current_state), 'status': str(task_manager.current_state),