diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 1ee5ce9d..2e7fef67 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -9,6 +9,7 @@ from easydiffusion.types import ModelsData from easydiffusion.utils import log from sdkit import Context from sdkit.models import load_model, scan_model, unload_model, download_model, get_model_info_from_db +from sdkit.models.model_loader.controlnet_filters import filters as cn_filters from sdkit.utils import hash_file_quick KNOWN_MODEL_TYPES = [ @@ -19,6 +20,8 @@ KNOWN_MODEL_TYPES = [ "realesrgan", "lora", "codeformer", + "embeddings", + "controlnet", ] MODEL_EXTENSIONS = { "stable-diffusion": [".ckpt", ".safetensors"], @@ -29,6 +32,7 @@ MODEL_EXTENSIONS = { "lora": [".ckpt", ".safetensors"], "codeformer": [".pth"], "embeddings": [".pt", ".bin", ".safetensors"], + "controlnet": [".pth", ".safetensors"], } DEFAULT_MODELS = { "stable-diffusion": [ @@ -177,7 +181,8 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models def resolve_model_paths(models_data: ModelsData): model_paths = models_data.model_paths for model_type in model_paths: - if model_type in ("latent_upscaler", "nsfw_checker"): # doesn't use model paths + skip_models = cn_filters + ["latent_upscaler", "nsfw_checker"] + if model_type in skip_models: # doesn't use model paths continue if model_type == "codeformer": download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0") @@ -291,6 +296,7 @@ def getModels(scan_for_malicious: bool = True): "lora": [], "codeformer": ["codeformer"], "embeddings": [], + "controlnet": [], }, } @@ -350,6 +356,7 @@ def getModels(scan_for_malicious: bool = True): listModels(model_type="gfpgan") listModels(model_type="lora") listModels(model_type="embeddings") + listModels(model_type="controlnet") if scan_for_malicious and models_scanned > 0: log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") diff --git a/ui/easydiffusion/tasks/render_images.py b/ui/easydiffusion/tasks/render_images.py index f14d478d..8df208b6 100644 --- a/ui/easydiffusion/tasks/render_images.py +++ b/ui/easydiffusion/tasks/render_images.py @@ -210,6 +210,9 @@ def generate_images_internal( if req.init_image is not None and not context.test_diffusers: req.sampler_name = "ddim" + if req.control_image and task_data.control_filter_to_apply: + req.control_image = filter_images(context, req.control_image, task_data.control_filter_to_apply)[0] + if context.test_diffusers: pipe = context.models["stable-diffusion"]["default"] if hasattr(pipe.unet, "_allocate_trt_buffers"): diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index 894867b8..181a9505 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -75,6 +75,7 @@ class TaskData(BaseModel): use_controlnet_model: Union[str, List[str]] = None filters: List[str] = [] filter_params: Dict[str, Dict[str, Any]] = {} + control_filter_to_apply: Union[str, List[str]] = None show_only_filtered_image: bool = False block_nsfw: bool = False @@ -135,6 +136,7 @@ class GenerateImageResponse: def json(self): del self.render_request.init_image del self.render_request.init_image_mask + del self.render_request.control_image task_data = self.task_data.dict() task_data.update(self.output_format.dict()) @@ -212,6 +214,9 @@ def convert_legacy_render_req_to_new(old_req: dict): model_paths["latent_upscaler"] = ( model_paths["latent_upscaler"] if "latent_upscaler" in model_paths["latent_upscaler"].lower() else None ) + if "control_filter_to_apply" in old_req: + filter_model = old_req["control_filter_to_apply"] + model_paths[filter_model] = filter_model if old_req.get("block_nsfw"): model_paths["nsfw_checker"] = "nsfw_checker" diff --git a/ui/index.html b/ui/index.html index d2a6194c..68616ed4 100644 --- a/ui/index.html +++ b/ui/index.html @@ -83,8 +83,8 @@