mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-08-18 19:39:05 +02:00
Support custom text encoders and Flux VAEs in the UI
This commit is contained in:
@@ -124,10 +124,10 @@ def update_modules():
|
||||
# if sdkit is 2.0.15.x (or lower), then diffusers should be restricted to 0.21.4 (see below for the reason)
|
||||
# otherwise use the current sdkit version (with the corresponding diffusers version)
|
||||
|
||||
expected_sdkit_version_str = "2.0.22.8"
|
||||
expected_sdkit_version_str = "2.0.22.9"
|
||||
expected_diffusers_version_str = "0.28.2"
|
||||
|
||||
legacy_sdkit_version_str = "2.0.15.17"
|
||||
legacy_sdkit_version_str = "2.0.15.18"
|
||||
legacy_diffusers_version_str = "0.21.4"
|
||||
|
||||
sdkit_version_str = version("sdkit")
|
||||
|
@@ -18,6 +18,7 @@ from .impl import (
|
||||
ping,
|
||||
load_model,
|
||||
unload_model,
|
||||
flush_model_changes,
|
||||
set_options,
|
||||
generate_images,
|
||||
filter_images,
|
||||
@@ -53,6 +54,7 @@ MODELS_TO_OVERRIDE = {
|
||||
"codeformer": "--codeformer-models-path",
|
||||
"embeddings": "--embeddings-dir",
|
||||
"controlnet": "--controlnet-dir",
|
||||
"text-encoder": "--text-encoder-dir",
|
||||
}
|
||||
|
||||
WEBUI_PATCHES = [
|
||||
|
@@ -27,6 +27,7 @@ webui_opts: dict = None
|
||||
curr_models = {
|
||||
"stable-diffusion": None,
|
||||
"vae": None,
|
||||
"text-encoder": None,
|
||||
}
|
||||
|
||||
|
||||
@@ -96,50 +97,51 @@ def ping(timeout=1):
|
||||
|
||||
|
||||
def load_model(context, model_type, **kwargs):
|
||||
from easydiffusion.app import ROOT_DIR, getConfig
|
||||
|
||||
config = getConfig()
|
||||
models_dir = config.get("models_dir", os.path.join(ROOT_DIR, "models"))
|
||||
|
||||
model_path = context.model_paths[model_type]
|
||||
|
||||
if model_type == "stable-diffusion":
|
||||
base_dir = os.path.join(models_dir, model_type)
|
||||
model_path = os.path.relpath(model_path, base_dir)
|
||||
|
||||
# print(f"load model: {model_type=} {model_path=} {curr_models=}")
|
||||
curr_models[model_type] = model_path
|
||||
|
||||
|
||||
def unload_model(context, model_type, **kwargs):
|
||||
# print(f"unload model: {model_type=} {curr_models=}")
|
||||
curr_models[model_type] = None
|
||||
|
||||
|
||||
def flush_model_changes(context):
|
||||
if webui_opts is None:
|
||||
print("Server not ready, can't set the model")
|
||||
return
|
||||
|
||||
if model_type == "stable-diffusion":
|
||||
model_name = os.path.basename(model_path)
|
||||
model_name = os.path.splitext(model_name)[0]
|
||||
print(f"setting sd model: {model_name}")
|
||||
if curr_models[model_type] != model_name:
|
||||
try:
|
||||
res = webui_post("/sdapi/v1/options", json={"sd_model_checkpoint": model_name})
|
||||
if res.status_code != 200:
|
||||
raise Exception(res.text)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"The engine failed to set the required options. Please check the logs in the command line window for more details."
|
||||
)
|
||||
modules = []
|
||||
for model_type in ("vae", "text-encoder"):
|
||||
if curr_models[model_type]:
|
||||
model_paths = curr_models[model_type]
|
||||
model_paths = [model_paths] if not isinstance(model_paths, list) else model_paths
|
||||
modules += model_paths
|
||||
|
||||
curr_models[model_type] = model_name
|
||||
elif model_type == "vae":
|
||||
if curr_models[model_type] != model_path:
|
||||
vae_model = [model_path] if model_path else []
|
||||
opts = {"sd_model_checkpoint": curr_models["stable-diffusion"], "forge_additional_modules": modules}
|
||||
|
||||
opts = {"sd_model_checkpoint": curr_models["stable-diffusion"], "forge_additional_modules": vae_model}
|
||||
print("setting opts 2", opts)
|
||||
print("Setting backend models", opts)
|
||||
|
||||
try:
|
||||
res = webui_post("/sdapi/v1/options", json=opts)
|
||||
if res.status_code != 200:
|
||||
raise Exception(res.text)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"The engine failed to set the required options. Please check the logs in the command line window for more details."
|
||||
)
|
||||
|
||||
curr_models[model_type] = model_path
|
||||
|
||||
|
||||
def unload_model(context, model_type, **kwargs):
|
||||
if model_type == "vae":
|
||||
context.model_paths[model_type] = None
|
||||
load_model(context, model_type)
|
||||
try:
|
||||
res = webui_post("/sdapi/v1/options", json=opts)
|
||||
print("got res", res.status_code)
|
||||
if res.status_code != 200:
|
||||
raise Exception(res.text)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"The engine failed to set the required options. Please check the logs in the command line window for more details."
|
||||
)
|
||||
|
||||
|
||||
def generate_images(
|
||||
@@ -346,7 +348,7 @@ def refresh_models():
|
||||
pass
|
||||
|
||||
try:
|
||||
for type in ("checkpoints", "vae"):
|
||||
for type in ("checkpoints", "vae-and-text-encoders"):
|
||||
t = Thread(target=make_refresh_call, args=(type,))
|
||||
t.start()
|
||||
except Exception as e:
|
||||
|
@@ -3,6 +3,7 @@ import shutil
|
||||
from glob import glob
|
||||
import traceback
|
||||
from typing import Union
|
||||
from os import path
|
||||
|
||||
from easydiffusion import app
|
||||
from easydiffusion.types import ModelsData
|
||||
@@ -22,6 +23,7 @@ KNOWN_MODEL_TYPES = [
|
||||
"codeformer",
|
||||
"embeddings",
|
||||
"controlnet",
|
||||
"text-encoder",
|
||||
]
|
||||
MODEL_EXTENSIONS = {
|
||||
"stable-diffusion": [".ckpt", ".safetensors", ".sft", ".gguf"],
|
||||
@@ -33,6 +35,7 @@ MODEL_EXTENSIONS = {
|
||||
"codeformer": [".pth"],
|
||||
"embeddings": [".pt", ".bin", ".safetensors", ".sft"],
|
||||
"controlnet": [".pth", ".safetensors", ".sft"],
|
||||
"text-encoder": [".safetensors", ".sft"],
|
||||
}
|
||||
DEFAULT_MODELS = {
|
||||
"stable-diffusion": [
|
||||
@@ -59,6 +62,7 @@ ALTERNATE_FOLDER_NAMES = { # for WebUI compatibility
|
||||
"realesrgan": "RealESRGAN",
|
||||
"lora": "Lora",
|
||||
"controlnet": "ControlNet",
|
||||
"text-encoder": "text_encoder",
|
||||
}
|
||||
|
||||
known_models = {}
|
||||
@@ -204,6 +208,9 @@ def reload_models_if_necessary(context: Context, models_data: ModelsData, models
|
||||
context.model_load_errors = {}
|
||||
context.model_load_errors[model_type] = str(e) # storing the entire Exception can lead to memory leaks
|
||||
|
||||
if hasattr(backend, "flush_model_changes"):
|
||||
backend.flush_model_changes(context)
|
||||
|
||||
|
||||
def resolve_model_paths(models_data: ModelsData):
|
||||
from easydiffusion.backend_manager import backend
|
||||
@@ -224,14 +231,25 @@ def resolve_model_paths(models_data: ModelsData):
|
||||
for model_type in model_paths:
|
||||
if model_type in skip_models: # doesn't use model paths
|
||||
continue
|
||||
if model_type == "codeformer" and model_paths[model_type]:
|
||||
download_if_necessary("codeformer", "codeformer.pth", "codeformer-0.1.0")
|
||||
elif model_type == "controlnet" and model_paths[model_type]:
|
||||
model_id = model_paths[model_type]
|
||||
model_info = get_model_info_from_db(model_type=model_type, model_id=model_id)
|
||||
if model_info:
|
||||
filename = model_info.get("url", "").split("/")[-1]
|
||||
download_if_necessary("controlnet", filename, model_id, skip_if_others_exist=False)
|
||||
|
||||
if model_type in ("vae", "codeformer", "controlnet", "text-encoder") and model_paths[model_type]:
|
||||
model_ids = model_paths[model_type]
|
||||
model_ids = model_ids if isinstance(model_ids, list) else [model_ids]
|
||||
|
||||
new_model_paths = []
|
||||
|
||||
for model_id in model_ids:
|
||||
log.info(f"Checking for {model_id=}")
|
||||
model_info = get_model_info_from_db(model_type=model_type, model_id=model_id)
|
||||
if model_info:
|
||||
filename = model_info.get("url", "").split("/")[-1]
|
||||
download_if_necessary(model_type, filename, model_id, skip_if_others_exist=False)
|
||||
|
||||
new_model_paths.append(path.splitext(filename)[0])
|
||||
else: # not in the model db, probably a regular file
|
||||
new_model_paths.append(model_id)
|
||||
|
||||
model_paths[model_type] = new_model_paths
|
||||
|
||||
model_paths[model_type] = resolve_model_to_use(model_paths[model_type], model_type=model_type)
|
||||
|
||||
@@ -256,17 +274,31 @@ def download_default_models_if_necessary():
|
||||
|
||||
|
||||
def download_if_necessary(model_type: str, file_name: str, model_id: str, skip_if_others_exist=True):
|
||||
model_dir = get_model_dirs(model_type)[0]
|
||||
model_path = os.path.join(model_dir, file_name)
|
||||
from easydiffusion.backend_manager import backend
|
||||
|
||||
expected_hash = get_model_info_from_db(model_type=model_type, model_id=model_id)["quick_hash"]
|
||||
|
||||
other_models_exist = any_model_exists(model_type) and skip_if_others_exist
|
||||
known_model_exists = os.path.exists(model_path)
|
||||
known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash
|
||||
|
||||
if known_model_is_corrupt or (not other_models_exist and not known_model_exists):
|
||||
print("> download", model_type, model_id)
|
||||
download_model(model_type, model_id, download_base_dir=app.MODELS_DIR, download_config_if_available=False)
|
||||
for model_dir in get_model_dirs(model_type):
|
||||
model_path = os.path.join(model_dir, file_name)
|
||||
|
||||
known_model_exists = os.path.exists(model_path)
|
||||
known_model_is_corrupt = known_model_exists and hash_file_quick(model_path) != expected_hash
|
||||
|
||||
needs_download = known_model_is_corrupt or (not other_models_exist and not known_model_exists)
|
||||
|
||||
log.info(f"{model_path=} {needs_download=}")
|
||||
if known_model_exists:
|
||||
log.info(f"{expected_hash=} {hash_file_quick(model_path)=}")
|
||||
log.info(f"{known_model_is_corrupt=} {other_models_exist=} {known_model_exists=}")
|
||||
|
||||
if not needs_download:
|
||||
return
|
||||
|
||||
print("> download", model_type, model_id)
|
||||
download_model(model_type, model_id, download_base_dir=app.MODELS_DIR, download_config_if_available=False)
|
||||
|
||||
backend.refresh_models()
|
||||
|
||||
|
||||
def migrate_legacy_model_location():
|
||||
@@ -363,7 +395,7 @@ def getModels(scan_for_malicious: bool = True):
|
||||
models = {
|
||||
"options": {
|
||||
"stable-diffusion": [],
|
||||
"vae": [],
|
||||
"vae": [{"ae": "ae (Flux VAE fp16)"}],
|
||||
"hypernetwork": [],
|
||||
"lora": [],
|
||||
"codeformer": [{"codeformer": "CodeFormer"}],
|
||||
@@ -383,6 +415,11 @@ def getModels(scan_for_malicious: bool = True):
|
||||
# {"control_v11e_sd15_shuffle": "Shuffle"},
|
||||
# {"control_v11f1e_sd15_tile": "Tile"},
|
||||
],
|
||||
"text-encoder": [
|
||||
{"t5xxl_fp16": "T5 XXL fp16"},
|
||||
{"clip_l": "CLIP L"},
|
||||
{"clip_g": "CLIP G"},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
@@ -466,6 +503,7 @@ def getModels(scan_for_malicious: bool = True):
|
||||
listModels(model_type="lora")
|
||||
listModels(model_type="embeddings", nameFilter=get_embedding_token)
|
||||
listModels(model_type="controlnet")
|
||||
listModels(model_type="text-encoder")
|
||||
|
||||
if scan_for_malicious and models_scanned > 0:
|
||||
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
|
||||
|
@@ -80,6 +80,7 @@ class RenderTaskData(TaskData):
|
||||
latent_upscaler_steps: int = 10
|
||||
use_stable_diffusion_model: Union[str, List[str]] = "sd-v1-4"
|
||||
use_vae_model: Union[str, List[str]] = None
|
||||
use_text_encoder_model: Union[str, List[str]] = None
|
||||
use_hypernetwork_model: Union[str, List[str]] = None
|
||||
use_lora_model: Union[str, List[str]] = None
|
||||
use_controlnet_model: Union[str, List[str]] = None
|
||||
@@ -211,6 +212,7 @@ def convert_legacy_render_req_to_new(old_req: dict):
|
||||
# move the model info
|
||||
model_paths["stable-diffusion"] = old_req.get("use_stable_diffusion_model")
|
||||
model_paths["vae"] = old_req.get("use_vae_model")
|
||||
model_paths["text-encoder"] = old_req.get("use_text_encoder_model")
|
||||
model_paths["hypernetwork"] = old_req.get("use_hypernetwork_model")
|
||||
model_paths["lora"] = old_req.get("use_lora_model")
|
||||
model_paths["controlnet"] = old_req.get("use_controlnet_model")
|
||||
|
@@ -302,6 +302,14 @@
|
||||
<input id="vae_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
|
||||
<a href="https://github.com/easydiffusion/easydiffusion/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about VAEs</span></i></a>
|
||||
</td></tr>
|
||||
<tr id="text_encoder_model_container" class="pl-5 gated-feature" data-feature-keys="backend_webui">
|
||||
<td>
|
||||
<label for="text_encoder_model">Text Encoder:</label>
|
||||
</td>
|
||||
<td>
|
||||
<div id="text_encoder_model" data-path=""></div>
|
||||
</td>
|
||||
</tr>
|
||||
<tr id="samplerSelection" class="pl-5"><td><label for="sampler_name">Sampler:</label></td><td>
|
||||
<select id="sampler_name" name="sampler_name">
|
||||
<option value="plms">PLMS</option>
|
||||
@@ -442,7 +450,7 @@
|
||||
<tr class="pl-5"><td><label for="guidance_scale_slider">Guidance Scale:</label></td><td> <input id="guidance_scale_slider" name="guidance_scale_slider" class="editor-slider" value="75" type="range" min="11" max="500"> <input id="guidance_scale" name="guidance_scale" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"></td></tr>
|
||||
<tr class="pl-5 displayNone warning-label" id="guidanceWarning"><td></td><td id="guidanceWarningText"></td></tr>
|
||||
<tr id="prompt_strength_container" class="pl-5"><td><label for="prompt_strength_slider">Prompt Strength:</label></td><td> <input id="prompt_strength_slider" name="prompt_strength_slider" class="editor-slider" value="80" type="range" min="0" max="99"> <input id="prompt_strength" name="prompt_strength" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"><br/></td></tr>
|
||||
<tr id="distilled_guidance_scale_container" class="pl-5 displayNone"><td><label for="distilled_guidance_scale_slider">Distilled Guidance:</label></td><td> <input id="distilled_guidance_scale_slider" name="distilled_guidance_scale_slider" class="editor-slider" value="35" type="range" min="11" max="500"> <input id="distilled_guidance_scale" name="distilled_guidance_scale" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"></td></tr>
|
||||
<tr id="distilled_guidance_scale_container" class="pl-5 gated-feature" data-feature-keys="backend_webui"><td><label for="distilled_guidance_scale_slider">Distilled Guidance:</label></td><td> <input id="distilled_guidance_scale_slider" name="distilled_guidance_scale_slider" class="editor-slider" value="35" type="range" min="11" max="500"> <input id="distilled_guidance_scale" name="distilled_guidance_scale" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" inputmode="decimal"></td></tr>
|
||||
<tr id="lora_model_container" class="pl-5 gated-feature" data-feature-keys="backend_ed_diffusers backend_webui">
|
||||
<td>
|
||||
<label for="lora_model">LoRA:</label>
|
||||
|
@@ -60,6 +60,7 @@ const SETTINGS_IDS_LIST = [
|
||||
"extract_lora_from_prompt",
|
||||
"embedding-card-size-selector",
|
||||
"lora_model",
|
||||
"text_encoder_model",
|
||||
"enable_vae_tiling",
|
||||
"controlnet_alpha",
|
||||
]
|
||||
|
@@ -394,6 +394,45 @@ const TASK_MAPPING = {
|
||||
return val
|
||||
},
|
||||
},
|
||||
use_text_encoder_model: {
|
||||
name: "Text Encoder model",
|
||||
setUI: (use_text_encoder_model) => {
|
||||
let modelPaths = []
|
||||
use_text_encoder_model = use_text_encoder_model === null ? "" : use_text_encoder_model
|
||||
use_text_encoder_model = Array.isArray(use_text_encoder_model) ? use_text_encoder_model : [use_text_encoder_model]
|
||||
use_text_encoder_model.forEach((m) => {
|
||||
if (m.includes("models\\text-encoder\\")) {
|
||||
m = m.split("models\\text-encoder\\")[1]
|
||||
} else if (m.includes("models\\\\text-encoder\\\\")) {
|
||||
m = m.split("models\\\\text-encoder\\\\")[1]
|
||||
} else if (m.includes("models/text-encoder/")) {
|
||||
m = m.split("models/text-encoder/")[1]
|
||||
}
|
||||
m = m.replaceAll("\\\\", "/")
|
||||
m = getModelPath(m, [".safetensors", ".sft"])
|
||||
modelPaths.push(m)
|
||||
})
|
||||
text_encoderModelField.modelNames = modelPaths
|
||||
},
|
||||
readUI: () => {
|
||||
return text_encoderModelField.modelNames
|
||||
},
|
||||
parse: (val) => {
|
||||
val = !val || val === "None" ? "" : val
|
||||
if (typeof val === "string" && val.includes(",")) {
|
||||
val = val.split(",")
|
||||
val = val.map((v) => v.trim())
|
||||
val = val.map((v) => v.replaceAll("\\", "\\\\"))
|
||||
val = val.map((v) => v.replaceAll('"', ""))
|
||||
val = val.map((v) => v.replaceAll("'", ""))
|
||||
val = val.map((v) => '"' + v + '"')
|
||||
val = "[" + val + "]"
|
||||
val = JSON.parse(val)
|
||||
}
|
||||
val = Array.isArray(val) ? val : [val]
|
||||
return val
|
||||
},
|
||||
},
|
||||
use_hypernetwork_model: {
|
||||
name: "Hypernetwork model",
|
||||
setUI: (use_hypernetwork_model) => {
|
||||
@@ -620,6 +659,7 @@ const TASK_TEXT_MAPPING = {
|
||||
hypernetwork_strength: "Hypernetwork Strength",
|
||||
use_lora_model: "LoRA model",
|
||||
lora_alpha: "LoRA Strength",
|
||||
use_text_encoder_model: "Text Encoder model",
|
||||
use_controlnet_model: "ControlNet model",
|
||||
control_filter_to_apply: "ControlNet Filter",
|
||||
control_alpha: "ControlNet Strength",
|
||||
|
@@ -54,6 +54,7 @@ const taskConfigSetup = {
|
||||
label: "Hypernetwork Strength",
|
||||
visible: ({ reqBody }) => !!reqBody?.use_hypernetwork_model,
|
||||
},
|
||||
use_text_encoder_model: { label: "Text Encoder", visible: ({ reqBody }) => !!reqBody?.use_text_encoder_model },
|
||||
use_lora_model: { label: "Lora Model", visible: ({ reqBody }) => !!reqBody?.use_lora_model },
|
||||
lora_alpha: { label: "Lora Strength", visible: ({ reqBody }) => !!reqBody?.use_lora_model },
|
||||
preserve_init_image_color_profile: "Preserve Color Profile",
|
||||
@@ -141,6 +142,7 @@ let tilingField = document.querySelector("#tiling")
|
||||
let controlnetModelField = new ModelDropdown(document.querySelector("#controlnet_model"), "controlnet", "None", false)
|
||||
let vaeModelField = new ModelDropdown(document.querySelector("#vae_model"), "vae", "None")
|
||||
let loraModelField = new MultiModelSelector(document.querySelector("#lora_model"), "lora", "LoRA", 0.5, 0.02)
|
||||
let textEncoderModelField = new MultiModelSelector(document.querySelector("#text_encoder_model"), "text-encoder", "Text Encoder", 0.5, 0.02, false)
|
||||
let hypernetworkModelField = new ModelDropdown(document.querySelector("#hypernetwork_model"), "hypernetwork", "None")
|
||||
let hypernetworkStrengthSlider = document.querySelector("#hypernetwork_strength_slider")
|
||||
let hypernetworkStrengthField = document.querySelector("#hypernetwork_strength")
|
||||
@@ -1396,6 +1398,7 @@ function getCurrentUserRequest() {
|
||||
newTask.reqBody.hypernetwork_strength = parseFloat(hypernetworkStrengthField.value)
|
||||
}
|
||||
if (testDiffusers.checked) {
|
||||
// lora
|
||||
let loraModelData = loraModelField.value
|
||||
let modelNames = loraModelData["modelNames"]
|
||||
let modelStrengths = loraModelData["modelWeights"]
|
||||
@@ -1408,6 +1411,16 @@ function getCurrentUserRequest() {
|
||||
newTask.reqBody.lora_alpha = modelStrengths
|
||||
}
|
||||
|
||||
// text encoder
|
||||
let textEncoderModelNames = textEncoderModelField.modelNames
|
||||
|
||||
if (textEncoderModelNames.length > 0) {
|
||||
textEncoderModelNames = textEncoderModelNames.length == 1 ? textEncoderModelNames[0] : textEncoderModelNames
|
||||
|
||||
newTask.reqBody.use_text_encoder_model = textEncoderModelNames
|
||||
}
|
||||
|
||||
// vae tiling
|
||||
if (tilingField.value !== "none") {
|
||||
newTask.reqBody.tiling = tilingField.value
|
||||
}
|
||||
@@ -1891,8 +1904,31 @@ document.addEventListener("refreshModels", function() {
|
||||
onControlnetModelChange()
|
||||
})
|
||||
|
||||
// tip for Flux
|
||||
// utilities for Flux and Chroma
|
||||
let sdModelField = document.querySelector("#stable_diffusion_model")
|
||||
|
||||
// function checkAndSetDependentModels() {
|
||||
// let sdModel = sdModelField.value.toLowerCase()
|
||||
// let isFlux = sdModel.includes("flux")
|
||||
// let isChroma = sdModel.includes("chroma")
|
||||
|
||||
// if (isFlux || isChroma) {
|
||||
// vaeModelField.value = "ae"
|
||||
|
||||
// if (isFlux) {
|
||||
// textEncoderModelField.modelNames = ["t5xxl_fp16", "clip_l"]
|
||||
// } else {
|
||||
// textEncoderModelField.modelNames = ["t5xxl_fp16"]
|
||||
// }
|
||||
// } else {
|
||||
// if (vaeModelField.value == "ae") {
|
||||
// vaeModelField.value = ""
|
||||
// }
|
||||
// textEncoderModelField.modelNames = []
|
||||
// }
|
||||
// }
|
||||
// sdModelField.addEventListener("change", checkAndSetDependentModels)
|
||||
|
||||
function checkGuidanceValue() {
|
||||
let guidance = parseFloat(guidanceScaleField.value)
|
||||
let guidanceWarning = document.querySelector("#guidanceWarning")
|
||||
@@ -1917,15 +1953,16 @@ sdModelField.addEventListener("change", checkGuidanceValue)
|
||||
guidanceScaleField.addEventListener("change", checkGuidanceValue)
|
||||
guidanceScaleSlider.addEventListener("change", checkGuidanceValue)
|
||||
|
||||
function checkGuidanceScaleVisibility() {
|
||||
let guidanceScaleContainer = document.querySelector("#distilled_guidance_scale_container")
|
||||
if (sdModelField.value.toLowerCase().includes("flux")) {
|
||||
guidanceScaleContainer.classList.remove("displayNone")
|
||||
} else {
|
||||
guidanceScaleContainer.classList.add("displayNone")
|
||||
}
|
||||
}
|
||||
sdModelField.addEventListener("change", checkGuidanceScaleVisibility)
|
||||
// disabling until we can detect flux models more reliably
|
||||
// function checkGuidanceScaleVisibility() {
|
||||
// let guidanceScaleContainer = document.querySelector("#distilled_guidance_scale_container")
|
||||
// if (sdModelField.value.toLowerCase().includes("flux")) {
|
||||
// guidanceScaleContainer.classList.remove("displayNone")
|
||||
// } else {
|
||||
// guidanceScaleContainer.classList.add("displayNone")
|
||||
// }
|
||||
// }
|
||||
// sdModelField.addEventListener("change", checkGuidanceScaleVisibility)
|
||||
|
||||
function checkFluxSampler() {
|
||||
let samplerWarning = document.querySelector("#fluxSamplerWarning")
|
||||
@@ -1980,8 +2017,9 @@ schedulerField.addEventListener("change", checkFluxSchedulerSteps)
|
||||
numInferenceStepsField.addEventListener("change", checkFluxSchedulerSteps)
|
||||
|
||||
document.addEventListener("refreshModels", function() {
|
||||
// checkAndSetDependentModels()
|
||||
checkGuidanceValue()
|
||||
checkGuidanceScaleVisibility()
|
||||
// checkGuidanceScaleVisibility()
|
||||
checkFluxSampler()
|
||||
checkFluxScheduler()
|
||||
checkFluxSchedulerSteps()
|
||||
|
@@ -10,6 +10,7 @@ class MultiModelSelector {
|
||||
root
|
||||
modelType
|
||||
modelNameFriendly
|
||||
showWeights
|
||||
defaultWeight
|
||||
weightStep
|
||||
|
||||
@@ -35,13 +36,13 @@ class MultiModelSelector {
|
||||
if (typeof modelData !== "object") {
|
||||
throw new Error("Multi-model selector expects an object containing modelNames and modelWeights as keys!")
|
||||
}
|
||||
if (!("modelNames" in modelData) || !("modelWeights" in modelData)) {
|
||||
if (!("modelNames" in modelData) || (this.showWeights && !("modelWeights" in modelData))) {
|
||||
throw new Error("modelNames or modelWeights not present in the data passed to the multi-model selector")
|
||||
}
|
||||
|
||||
let newModelNames = modelData["modelNames"]
|
||||
let newModelWeights = modelData["modelWeights"]
|
||||
if (newModelNames.length !== newModelWeights.length) {
|
||||
if (newModelWeights && newModelNames.length !== newModelWeights.length) {
|
||||
throw new Error("Need to pass an equal number of modelNames and modelWeights!")
|
||||
}
|
||||
|
||||
@@ -50,7 +51,7 @@ class MultiModelSelector {
|
||||
// the root of all this unholiness is because searchable-models automatically dispatches an update event
|
||||
// as soon as the value is updated via JS, which is against the DOM pattern of not dispatching an event automatically
|
||||
// unless the caller explicitly dispatches the event.
|
||||
this.modelWeights = newModelWeights
|
||||
this.modelWeights = newModelWeights || []
|
||||
this.modelNames = newModelNames
|
||||
}
|
||||
get disabled() {
|
||||
@@ -91,10 +92,11 @@ class MultiModelSelector {
|
||||
}
|
||||
}
|
||||
|
||||
constructor(root, modelType, modelNameFriendly = undefined, defaultWeight = 0.5, weightStep = 0.02) {
|
||||
constructor(root, modelType, modelNameFriendly = undefined, defaultWeight = 0.5, weightStep = 0.02, showWeights = true) {
|
||||
this.root = root
|
||||
this.modelType = modelType
|
||||
this.modelNameFriendly = modelNameFriendly || modelType
|
||||
this.showWeights = showWeights
|
||||
this.defaultWeight = defaultWeight
|
||||
this.weightStep = weightStep
|
||||
|
||||
@@ -135,10 +137,13 @@ class MultiModelSelector {
|
||||
|
||||
const modelElement = document.createElement("div")
|
||||
modelElement.className = "model_entry"
|
||||
modelElement.innerHTML = `
|
||||
<input id="${this.modelType}_${idx}" class="model_name model-filter" type="text" spellcheck="false" autocomplete="off" data-path="" />
|
||||
<input class="model_weight" type="number" step="${this.weightStep}" value="${this.defaultWeight}" pattern="^-?[0-9]*\.?[0-9]*$" onkeypress="preventNonNumericalInput(event)">
|
||||
`
|
||||
let html = `<input id="${this.modelType}_${idx}" class="model_name model-filter" type="text" spellcheck="false" autocomplete="off" data-path="" />`
|
||||
|
||||
if (this.showWeights) {
|
||||
html += `<input class="model_weight" type="number" step="${this.weightStep}" value="${this.defaultWeight}" pattern="^-?[0-9]*\.?[0-9]*$" onkeypress="preventNonNumericalInput(event)">`
|
||||
}
|
||||
modelElement.innerHTML = html
|
||||
|
||||
this.modelContainer.appendChild(modelElement)
|
||||
|
||||
let modelNameEl = modelElement.querySelector(".model_name")
|
||||
@@ -160,8 +165,8 @@ class MultiModelSelector {
|
||||
|
||||
modelNameEl.addEventListener("change", makeUpdateEvent("change"))
|
||||
modelNameEl.addEventListener("input", makeUpdateEvent("input"))
|
||||
modelWeightEl.addEventListener("change", makeUpdateEvent("change"))
|
||||
modelWeightEl.addEventListener("input", makeUpdateEvent("input"))
|
||||
modelWeightEl?.addEventListener("change", makeUpdateEvent("change"))
|
||||
modelWeightEl?.addEventListener("input", makeUpdateEvent("input"))
|
||||
|
||||
let removeBtn = document.createElement("button")
|
||||
removeBtn.className = "remove_model_btn"
|
||||
@@ -218,10 +223,14 @@ class MultiModelSelector {
|
||||
}
|
||||
|
||||
get modelWeights() {
|
||||
return this.getModelElements(true).map((e) => e.weight.value)
|
||||
return this.getModelElements(true).map((e) => e.weight?.value)
|
||||
}
|
||||
|
||||
set modelWeights(newModelWeights) {
|
||||
if (!this.showWeights) {
|
||||
return
|
||||
}
|
||||
|
||||
this.resizeEntryList(newModelWeights.length)
|
||||
|
||||
if (newModelWeights.length === 0) {
|
||||
|
Reference in New Issue
Block a user