Basic embeddings support

This commit is contained in:
JeLuF 2023-06-22 23:48:55 +02:00
parent 4bf78521ce
commit 3dc62a8857

View File

@ -27,6 +27,7 @@ MODEL_EXTENSIONS = {
"realesrgan": [".pth"],
"lora": [".ckpt", ".safetensors"],
"codeformer": [".pth"],
"embeddings": [".pt", ".bin", ".safetensors"],
}
DEFAULT_MODELS = {
"stable-diffusion": [
@ -58,6 +59,9 @@ def init():
def load_default_models(context: Context):
set_vram_optimizations(context)
config = app.getConfig()
context.embeddings_path = os.path.join(app.MODELS_DIR, "embeddings")
# init default model paths
for model_type in MODELS_TO_LOAD_ON_START:
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type, fail_if_not_found=False)
@ -318,6 +322,7 @@ def getModels():
"hypernetwork": [],
"lora": [],
"codeformer": ["codeformer"],
"embeddings": [],
},
}
@ -374,6 +379,7 @@ def getModels():
listModels(model_type="hypernetwork")
listModels(model_type="gfpgan")
listModels(model_type="lora")
listModels(model_type="embeddings")
if models_scanned > 0:
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")