diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index 2edf544f..abd003d1 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -37,7 +37,6 @@ ROOT_DIR = os.path.abspath(os.path.join(SD_DIR, "..")) SD_UI_DIR = os.getenv("SD_UI_PATH", None) CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts")) -MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models")) BUCKET_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "bucket")) USER_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins")) @@ -92,14 +91,21 @@ CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS = [ "-landscape", ] +MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models")) + def init(): + global MODELS_DIR + os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) os.makedirs(USER_SERVER_PLUGINS_DIR, exist_ok=True) # https://pytorch.org/docs/stable/storage.html warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") + config = getConfig() + MODELS_DIR = config.get("models_dir", MODELS_DIR) + def init_render_threads(): load_server_plugins() @@ -170,6 +176,8 @@ getConfig.__use_v3_engine_on_startup = None def setConfig(config): + global MODELS_DIR + try: # config.yaml config_yaml_path = os.path.join(CONFIG_DIR, "..", "config.yaml") config_yaml_path = os.path.abspath(config_yaml_path) @@ -206,6 +214,9 @@ def setConfig(config): except: log.error(traceback.format_exc()) + if config.get("models_dir"): + MODELS_DIR = config["models_dir"] + def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level): config = getConfig() diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index 058f3b62..311b3a03 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -67,6 +67,7 @@ class SetAppConfigRequest(BaseModel, extra=Extra.allow): listen_to_network: bool = None listen_port: int = None use_v3_engine: bool = True + models_dir: str = None def init(): @@ -176,6 +177,7 @@ def set_app_config_internal(req: SetAppConfigRequest): config["net"]["listen_port"] = int(req.listen_port) config["use_v3_engine"] = req.use_v3_engine + config["models_dir"] = req.models_dir for property, property_value in req.dict().items(): if property_value is not None and property not in req.__fields__ and property not in PROTECTED_CONFIG_KEYS: @@ -207,7 +209,12 @@ def read_web_data_internal(key: str = None, **kwargs): if not key: # /get without parameters, stable-diffusion easter egg. raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot elif key == "app_config": - return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS) + config = app.getConfig() + + if "models_dir" not in config: + config["models_dir"] = app.MODELS_DIR + + return JSONResponse(config, headers=NOCACHE_HEADERS) elif key == "system_info": config = app.getConfig() diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index a656f09a..96d119ef 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -97,6 +97,17 @@ var PARAMETERS = [ }, ], }, + { + id: "models_dir", + type: ParameterType.custom, + icon: "fa-folder-tree", + label: "Models Folder", + note: "Path to the 'models' folder. Please save and refresh the page after changing this.", + saveInAppConfig: true, + render: (parameter) => { + return `` + }, + }, { id: "block_nsfw", type: ParameterType.checkbox, @@ -422,6 +433,7 @@ let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_star let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions") let testDiffusers = document.querySelector("#use_v3_engine") let profileNameField = document.querySelector("#profileName") +let modelsDirField = document.querySelector("#models_dir") let saveSettingsBtn = document.querySelector("#save-system-settings-btn") @@ -463,6 +475,7 @@ async function getAppConfig() { if (config.net && config.net.listen_port !== undefined) { listenPortField.value = config.net.listen_port } + modelsDirField.value = config.models_dir let testDiffusersEnabled = true if (config.use_v3_engine === false) {