Use only realesrgan_x4 (not anime) for upscaling in codeformer

This commit is contained in:
cmdr2 2023-06-07 16:37:44 +05:30
parent e23f66a697
commit 267c7b85ea

View File

@ -7,10 +7,12 @@ from easydiffusion import device_manager
from easydiffusion.types import GenerateImageRequest from easydiffusion.types import GenerateImageRequest
from easydiffusion.types import Image as ResponseImage from easydiffusion.types import Image as ResponseImage
from easydiffusion.types import Response, TaskData, UserInitiatedStop from easydiffusion.types import Response, TaskData, UserInitiatedStop
from easydiffusion.model_manager import DEFAULT_MODELS, resolve_model_to_use
from easydiffusion.utils import get_printable_request, log, save_images_to_disk from easydiffusion.utils import get_printable_request, log, save_images_to_disk
from sdkit import Context from sdkit import Context
from sdkit.filter import apply_filters from sdkit.filter import apply_filters
from sdkit.generate import generate_images from sdkit.generate import generate_images
from sdkit.models import load_model
from sdkit.utils import ( from sdkit.utils import (
diffusers_latent_samples_to_images, diffusers_latent_samples_to_images,
gc, gc,
@ -157,37 +159,51 @@ def filter_images(req: GenerateImageRequest, task_data: TaskData, images: list,
if user_stopped: if user_stopped:
return images return images
filters_to_apply = []
filter_params = {}
if task_data.block_nsfw: if task_data.block_nsfw:
filters_to_apply.append("nsfw_checker") images = apply_filters(context, "nsfw_checker", images)
if task_data.use_face_correction and "codeformer" in task_data.use_face_correction.lower():
filters_to_apply.append("codeformer")
filter_params["upscale_faces"] = task_data.codeformer_upscale_faces if task_data.use_face_correction and "codeformer" in task_data.use_face_correction.lower():
filter_params["codeformer_fidelity"] = task_data.codeformer_fidelity default_realesrgan = DEFAULT_MODELS["realesrgan"][0]["file_name"]
prev_realesrgan_path = None
if task_data.codeformer_upscale_faces and default_realesrgan not in context.model_paths["realesrgan"]:
prev_realesrgan_path = context.model_paths["realesrgan"]
context.model_paths["realesrgan"] = resolve_model_to_use(default_realesrgan, "realesrgan")
load_model(context, "realesrgan")
try:
images = apply_filters(
context,
"codeformer",
images,
upscale_faces=task_data.codeformer_upscale_faces,
codeformer_fidelity=task_data.codeformer_fidelity,
)
finally:
if prev_realesrgan_path:
context.model_paths["realesrgan"] = prev_realesrgan_path
load_model(context, "realesrgan")
elif task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower(): elif task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower():
filters_to_apply.append("gfpgan") images = apply_filters(context, "gfpgan", images)
if task_data.use_upscale: if task_data.use_upscale:
if "realesrgan" in task_data.use_upscale.lower(): if "realesrgan" in task_data.use_upscale.lower():
filters_to_apply.append("realesrgan") images = apply_filters(context, "realesrgan", images, scale=task_data.upscale_amount)
elif task_data.use_upscale == "latent_upscaler": elif task_data.use_upscale == "latent_upscaler":
filters_to_apply.append("latent_upscaler") images = apply_filters(
context,
"latent_upscaler",
images,
scale=task_data.upscale_amount,
latent_upscaler_options={
"prompt": req.prompt,
"negative_prompt": req.negative_prompt,
"seed": req.seed,
"num_inference_steps": task_data.latent_upscaler_steps,
"guidance_scale": 0,
},
)
filter_params["latent_upscaler_options"] = { return images
"prompt": req.prompt,
"negative_prompt": req.negative_prompt,
"seed": req.seed,
"num_inference_steps": task_data.latent_upscaler_steps,
"guidance_scale": 0,
}
filter_params["scale"] = task_data.upscale_amount
if len(filters_to_apply) == 0:
return images
return apply_filters(context, filters_to_apply, images, **filter_params)
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int): def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):