diff --git a/ui/server.py b/ui/server.py index c8c7324f..67c84949 100644 --- a/ui/server.py +++ b/ui/server.py @@ -84,7 +84,10 @@ async def ping(): model_is_loading = True from sd_internal import runtime - runtime.load_model_ckpt(ckpt_to_use="sd-v1-4") + + custom_weight_path = os.path.join(SCRIPT_DIR, 'custom-model.ckpt') + ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model" + runtime.load_model_ckpt(ckpt_to_use=ckpt_to_use) model_loaded = True model_is_loading = False