Support for CodeFormer

Depends on https://github.com/easydiffusion/sdkit/pull/34.
This commit is contained in:
patriceac 2023-05-17 02:04:20 -07:00
parent 0adaf6c0a0
commit a25364732b
2 changed files with 5 additions and 2 deletions

View File

@ -107,11 +107,12 @@ def resolve_model_to_use(model_name: str = None, model_type: str = None):
def reload_models_if_necessary(context: Context, task_data: TaskData): def reload_models_if_necessary(context: Context, task_data: TaskData):
face_correction_model = "codeformer" if "codeformer" in task_data.use_face_correction.lower() else "gfpgan"
model_paths_in_req = { model_paths_in_req = {
"stable-diffusion": task_data.use_stable_diffusion_model, "stable-diffusion": task_data.use_stable_diffusion_model,
"vae": task_data.use_vae_model, "vae": task_data.use_vae_model,
"hypernetwork": task_data.use_hypernetwork_model, "hypernetwork": task_data.use_hypernetwork_model,
"gfpgan": task_data.use_face_correction, face_correction_model: task_data.use_face_correction,
"realesrgan": task_data.use_upscale, "realesrgan": task_data.use_upscale,
"nsfw_checker": True if task_data.block_nsfw else None, "nsfw_checker": True if task_data.block_nsfw else None,
"lora": task_data.use_lora_model, "lora": task_data.use_lora_model,

View File

@ -158,7 +158,9 @@ def filter_images(task_data: TaskData, images: list, user_stopped):
filters_to_apply = [] filters_to_apply = []
if task_data.block_nsfw: if task_data.block_nsfw:
filters_to_apply.append("nsfw_checker") filters_to_apply.append("nsfw_checker")
if task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower(): if task_data.use_face_correction and "codeformer" in task_data.use_face_correction.lower():
filters_to_apply.append("codeformer")
elif task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower():
filters_to_apply.append("gfpgan") filters_to_apply.append("gfpgan")
if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower(): if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower():
filters_to_apply.append("realesrgan") filters_to_apply.append("realesrgan")