diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index d75292c9..bdecc109 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -2,6 +2,7 @@ import os import shutil from glob import glob import traceback +from typing import Union from easydiffusion import app from easydiffusion.types import TaskData @@ -93,7 +94,14 @@ def unload_all(context: Context): del context.model_load_errors[model_type] -def resolve_model_to_use(model_name: str = None, model_type: str = None, fail_if_not_found: bool = True): +def resolve_model_to_use(model_name: Union[str, list] = None, model_type: str = None, fail_if_not_found: bool = True): + model_names = model_name if isinstance(model_name, list) else [model_name] + model_paths = [resolve_model_to_use_single(m, model_type, fail_if_not_found) for m in model_names] + + return model_paths[0] if len(model_paths) == 1 else model_paths + + +def resolve_model_to_use_single(model_name: str = None, model_type: str = None, fail_if_not_found: bool = True): model_extensions = MODEL_EXTENSIONS.get(model_type, []) default_models = DEFAULT_MODELS.get(model_type, []) config = app.getConfig() diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index abf8db29..a9e49a24 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List, Union from pydantic import BaseModel @@ -22,7 +22,7 @@ class GenerateImageRequest(BaseModel): sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" hypernetwork_strength: float = 0 - lora_alpha: float = 0 + lora_alpha: Union[float, List[float]] = 0 tiling: str = "none" # "none", "x", "y", "xy" @@ -32,15 +32,14 @@ class TaskData(BaseModel): save_to_disk_path: str = None vram_usage_level: str = "balanced" # or "low" or "medium" - use_face_correction: str = None # or "GFPGANv1.3" - use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" or "latent_upscaler" + use_face_correction: Union[str, List[str]] = None # or "GFPGANv1.3" + use_upscale: Union[str, List[str]] = None upscale_amount: int = 4 # or 2 latent_upscaler_steps: int = 10 - use_stable_diffusion_model: str = "sd-v1-4" - # use_stable_diffusion_config: str = "v1-inference" - use_vae_model: str = None - use_hypernetwork_model: str = None - use_lora_model: str = None + use_stable_diffusion_model: Union[str, List[str]] = "sd-v1-4" + use_vae_model: Union[str, List[str]] = None + use_hypernetwork_model: Union[str, List[str]] = None + use_lora_model: Union[str, List[str]] = None show_only_filtered_image: bool = False block_nsfw: bool = False