diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py index 06c6fdb6..b7ce6fe5 100644 --- a/ui/sd_internal/runtime2.py +++ b/ui/sd_internal/runtime2.py @@ -40,7 +40,7 @@ def init(device): device_manager.device_init(thread_data, device) def destroy(): - for model_type in ('stable-diffusion', 'hypernetwork', 'gfpgan', 'realesrgan'): + for model_type in model_manager.KNOWN_MODEL_TYPES: model_loader.unload_model(thread_data, model_type) def load_default_models(): @@ -52,19 +52,26 @@ def load_default_models(): model_loader.load_model(thread_data, 'stable-diffusion') def reload_models_if_necessary(req: Request): + model_paths_in_req = ( + ('hypernetwork', req.use_hypernetwork_model), + ('gfpgan', req.use_face_correction), + ('realesrgan', req.use_upscale), + ) + if model_manager.is_sd_model_reload_necessary(thread_data, req): thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model thread_data.model_paths['vae'] = req.use_vae_model model_loader.load_model(thread_data, 'stable-diffusion') - if thread_data.model_paths.get('hypernetwork') != req.use_hypernetwork_model: - thread_data.model_paths['hypernetwork'] = req.use_hypernetwork_model + for model_type, model_path_in_req in model_paths_in_req: + if thread_data.model_paths.get(model_type) != model_path_in_req: + thread_data.model_paths[model_type] = model_path_in_req - if thread_data.model_paths['hypernetwork'] is not None: - model_loader.load_model(thread_data, 'hypernetwork') - else: - model_loader.unload_model(thread_data, 'hypernetwork') + if thread_data.model_paths[model_type] is not None: + model_loader.load_model(thread_data, model_type) + else: + model_loader.unload_model(thread_data, model_type) def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): try: @@ -137,13 +144,13 @@ def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_ima return images filters = [] - if args['use_face_correction'].startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(args['use_face_correction'], model_type='gfpgan'))) - if args['use_face_correction'].use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(args['use_upscale'], model_type='realesrgan'))) + if 'gfpgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_gfpgan) + if 'realesrgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_realesrgan) filtered_images = [] for img, seed, _ in images: - for filter_fn, filter_model_path in filters: - img = filter_fn(thread_data, img, filter_model_path) + for filter_fn in filters: + img = filter_fn(thread_data, img) filtered_images.append((img, seed, True)) diff --git a/ui/server.py b/ui/server.py index 1c537dcd..258c433d 100644 --- a/ui/server.py +++ b/ui/server.py @@ -140,10 +140,16 @@ def ping(session_id:str=None): def render(req : task_manager.ImageRequest): try: app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) + + # resolve the model paths to use req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion') req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae') req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork') + if req.use_face_correction: req.use_face_correction = model_manager.resolve_model_to_use(req.use_face_correction, 'gfpgan') + if req.use_upscale: req.use_upscale = model_manager.resolve_model_to_use(req.use_upscale, 'gfpgan') + + # enqueue the task new_task = task_manager.render(req) response = { 'status': str(task_manager.current_state),