diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 0a1f1b5c..2a8b57fd 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -113,15 +113,17 @@ def resolve_model_to_use(model_name: str = None, model_type: str = None): def reload_models_if_necessary(context: Context, task_data: TaskData): - use_upscale_lower = task_data.use_upscale.lower() if task_data.use_upscale else "" + face_fix_lower = task_data.use_face_correction.lower() if task_data.use_face_correction else "" + upscale_lower = task_data.use_upscale.lower() if task_data.use_upscale else "" model_paths_in_req = { "stable-diffusion": task_data.use_stable_diffusion_model, "vae": task_data.use_vae_model, "hypernetwork": task_data.use_hypernetwork_model, - "gfpgan": task_data.use_face_correction, - "realesrgan": task_data.use_upscale if "realesrgan" in use_upscale_lower else None, - "latent_upscaler": True if task_data.use_upscale == "latent_upscaler" else None, + "codeformer": task_data.use_face_correction if "codeformer" in face_fix_lower else None, + "gfpgan": task_data.use_face_correction if "gfpgan" in face_fix_lower else None, + "realesrgan": task_data.use_upscale if "realesrgan" in upscale_lower else None, + "latent_upscaler": True if "latent_upscaler" in upscale_lower else None, "nsfw_checker": True if task_data.block_nsfw else None, "lora": task_data.use_lora_model, } diff --git a/ui/easydiffusion/renderer.py b/ui/easydiffusion/renderer.py index e2dae34f..1ebd05ec 100644 --- a/ui/easydiffusion/renderer.py +++ b/ui/easydiffusion/renderer.py @@ -160,7 +160,9 @@ def filter_images(req: GenerateImageRequest, task_data: TaskData, images: list, filter_params = {} if task_data.block_nsfw: filters_to_apply.append("nsfw_checker") - if task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower(): + if task_data.use_face_correction and "codeformer" in task_data.use_face_correction.lower(): + filters_to_apply.append("codeformer") + elif task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower(): filters_to_apply.append("gfpgan") if task_data.use_upscale: if "realesrgan" in task_data.use_upscale.lower():