diff --git a/scripts/check_modules.py b/scripts/check_modules.py index 6bf3b1b3..bc0f46e7 100644 --- a/scripts/check_modules.py +++ b/scripts/check_modules.py @@ -18,7 +18,7 @@ os_name = platform.system() modules_to_check = { "torch": ("1.11.0", "1.13.1", "2.0.0"), "torchvision": ("0.12.0", "0.14.1", "0.15.1"), - "sdkit": "1.0.154", + "sdkit": "1.0.155", "stable-diffusion-sdkit": "2.1.4", "rich": "12.6.0", "uvicorn": "0.19.0", diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 7a4ddde3..63f79859 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -186,6 +186,12 @@ def resolve_model_paths(models_data: ModelsData): continue if model_type == "codeformer": download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0") + elif model_type == "controlnet": + model_id = model_paths[model_type] + model_info = get_model_info_from_db(model_type=model_type, model_id=model_id) + if model_info: + filename = model_info.get("url", "").split("/")[-1] + download_if_necessary("controlnet", filename, model_id, skip_if_others_exist=False) model_paths[model_type] = resolve_model_to_use(model_paths[model_type], model_type=model_type) @@ -209,17 +215,17 @@ def download_default_models_if_necessary(): print(model_type, "model(s) found.") -def download_if_necessary(model_type: str, file_name: str, model_id: str): +def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True): model_path = os.path.join(app.MODELS_DIR, model_type, file_name) expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"] - other_models_exist = any_model_exists(model_type) + other_models_exist = any_model_exists(model_type) and skip_if_others_exist known_model_exists = os.path.exists(model_path) known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash if known_model_is_corrupt or (not other_models_exist and not known_model_exists): print("> download", model_type, model_id) - download_model(model_type, model_id, download_base_dir=app.MODELS_DIR) + download_model(model_type, model_id, download_base_dir=app.MODELS_DIR, download_config_if_available=False) def migrate_legacy_model_location(): @@ -296,7 +302,20 @@ def getModels(scan_for_malicious: bool = True): "lora": [], "codeformer": [{"codeformer": "CodeFormer"}], "embeddings": [], - "controlnet": [], + "controlnet": [ + {"control_v11p_sd15_canny": "Canny (*)"}, + {"control_v11p_sd15_openpose": "OpenPose (*)"}, + {"control_v11p_sd15_normalbae": "Normal BAE (*)"}, + {"control_v11f1p_sd15_depth": "Depth (*)"}, + {"control_v11p_sd15_scribble": "Scribble"}, + {"control_v11p_sd15_softedge": "Soft Edge"}, + {"control_v11p_sd15_inpaint": "Inpaint"}, + {"control_v11p_sd15_lineart": "Line Art"}, + {"control_v11p_sd15s2_lineart_anime": "Line Art Anime"}, + {"control_v11p_sd15_mlsd": "Straight Lines"}, + {"control_v11p_sd15_seg": "Segment"}, + {"control_v11e_sd15_shuffle": "Shuffle"}, + ], }, } diff --git a/ui/index.html b/ui/index.html index 68616ed4..29a147ff 100644 --- a/ui/index.html +++ b/ui/index.html @@ -151,7 +151,7 @@