diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index b742c17c..80c373da 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -95,7 +95,7 @@ if "%ERRORLEVEL%" EQU "0" ( set PYTHONNOUSERSITE=1 set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages - call python -m pip install --upgrade sdkit==1.0.48 -q || ( + call python -m pip install --upgrade sdkit==1.0.49 -q || ( echo "Error updating sdkit" ) ) @@ -106,7 +106,7 @@ if "%ERRORLEVEL%" EQU "0" ( set PYTHONNOUSERSITE=1 set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages - call python -m pip install sdkit==1.0.48 || ( + call python -m pip install sdkit==1.0.49 || ( echo "Error installing sdkit. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" pause exit /b diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index b62a2fc2..32ebdf84 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -103,7 +103,7 @@ if python ../scripts/check_modules.py sdkit sdkit.models ldm transformers numpy export PYTHONNOUSERSITE=1 export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" - python -m pip install --upgrade sdkit==1.0.48 -q + python -m pip install --upgrade sdkit==1.0.49 -q fi else echo "Installing sdkit: https://pypi.org/project/sdkit/" @@ -111,7 +111,7 @@ else export PYTHONNOUSERSITE=1 export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" - if python -m pip install sdkit==1.0.48 ; then + if python -m pip install sdkit==1.0.49 ; then echo "Installed." else fail "sdkit install failed" diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 116edf33..a06c56cf 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -7,13 +7,14 @@ from easydiffusion.utils import log from sdkit import Context from sdkit.models import load_model, unload_model, scan_model -KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan"] +KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan", "lora"] MODEL_EXTENSIONS = { "stable-diffusion": [".ckpt", ".safetensors"], "vae": [".vae.pt", ".ckpt", ".safetensors"], "hypernetwork": [".pt", ".safetensors"], "gfpgan": [".pth"], "realesrgan": [".pth"], + "lora": [".ckpt", ".safetensors"], } DEFAULT_MODELS = { "stable-diffusion": [ # needed to support the legacy installations @@ -23,7 +24,7 @@ DEFAULT_MODELS = { "gfpgan": ["GFPGANv1.3"], "realesrgan": ["RealESRGAN_x4plus"], } -MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork"] +MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork", "lora"] known_models = {} @@ -102,6 +103,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData): "gfpgan": task_data.use_face_correction, "realesrgan": task_data.use_upscale, "nsfw_checker": True if task_data.block_nsfw else None, + "lora": task_data.use_lora_model, } models_to_reload = { model_type: path @@ -125,6 +127,7 @@ def resolve_model_paths(task_data: TaskData): ) task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae") task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork") + task_data.use_lora_model = resolve_model_to_use(task_data.use_lora_model, model_type="lora") if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, "gfpgan") @@ -184,11 +187,13 @@ def getModels(): "stable-diffusion": "sd-v1-4", "vae": "", "hypernetwork": "", + "lora": "", }, "options": { "stable-diffusion": ["sd-v1-4"], "vae": [], "hypernetwork": [], + "lora": [], }, } @@ -243,6 +248,7 @@ def getModels(): listModels(model_type="vae") listModels(model_type="hypernetwork") listModels(model_type="gfpgan") + listModels(model_type="lora") if models_scanned > 0: log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") diff --git a/ui/easydiffusion/renderer.py b/ui/easydiffusion/renderer.py index c5dc88b4..fd8b7f7a 100644 --- a/ui/easydiffusion/renderer.py +++ b/ui/easydiffusion/renderer.py @@ -10,7 +10,7 @@ from easydiffusion.utils import get_printable_request, save_images_to_disk, log from sdkit import Context from sdkit.generate import generate_images from sdkit.filter import apply_filters -from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc +from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, diffusers_latent_samples_to_images context = Context() # thread-local """ @@ -26,6 +26,11 @@ def init(device): context.temp_images = {} context.partial_x_samples = None + from easydiffusion import app + + app_config = app.getConfig() + context.test_diffusers = app_config.get("test_diffusers", False) + device_manager.device_init(context, device) @@ -57,7 +62,13 @@ def make_images_internal( ): images, user_stopped = generate_images_internal( - req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress, task_data.stream_image_progress_interval + req, + task_data, + data_queue, + task_temp_images, + step_callback, + task_data.stream_image_progress, + task_data.stream_image_progress_interval, ) filtered_images = filter_images(task_data, images, user_stopped) @@ -82,10 +93,18 @@ def generate_images_internal( ): context.temp_images.clear() - callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress, stream_image_progress_interval) + callback = make_step_callback( + req, + task_data, + data_queue, + task_temp_images, + step_callback, + stream_image_progress, + stream_image_progress_interval, + ) try: - if req.init_image is not None: + if req.init_image is not None and not context.test_diffusers: req.sampler_name = "ddim" images = generate_images(context, callback=callback, **req.dict()) @@ -94,10 +113,14 @@ def generate_images_internal( images = [] user_stopped = True if context.partial_x_samples is not None: - images = latent_samples_to_images(context, context.partial_x_samples) + if context.test_diffusers: + images = diffusers_latent_samples_to_images(context, context.partial_x_samples) + else: + images = latent_samples_to_images(context, context.partial_x_samples) finally: if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None: - del context.partial_x_samples + if not context.test_diffusers: + del context.partial_x_samples context.partial_x_samples = None return images, user_stopped @@ -145,7 +168,12 @@ def make_step_callback( def update_temp_img(x_samples, task_temp_images: list): partial_images = [] - images = latent_samples_to_images(context, x_samples) + + if context.test_diffusers: + images = diffusers_latent_samples_to_images(context, x_samples) + else: + images = latent_samples_to_images(context, x_samples) + if task_data.block_nsfw: images = apply_filters(context, "nsfw_checker", images) @@ -158,17 +186,21 @@ def make_step_callback( del images return partial_images - def on_image_step(x_samples, i): + def on_image_step(x_samples, i, *args): nonlocal last_callback_time - context.partial_x_samples = x_samples + if context.test_diffusers: + context.partial_x_samples = (x_samples, args[0]) + else: + context.partial_x_samples = x_samples + step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 last_callback_time = time.time() progress = {"step": i, "step_time": step_time, "total_steps": n_steps} if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0: - progress["output"] = update_temp_img(x_samples, task_temp_images) + progress["output"] = update_temp_img(context.partial_x_samples, task_temp_images) data_queue.put(json.dumps(progress)) diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index 1d05a1f0..e27f9c5b 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -29,10 +29,10 @@ NOCACHE_HEADERS = {"Cache-Control": "no-cache, no-store, must-revalidate", "Prag class NoCacheStaticFiles(StaticFiles): def __init__(self, directory: str): # follow_symlink is only available on fastapi >= 0.92.0 - if (os.path.islink(directory)): - super().__init__(directory = os.path.realpath(directory)) + if os.path.islink(directory): + super().__init__(directory=os.path.realpath(directory)) else: - super().__init__(directory = directory) + super().__init__(directory=directory) def is_not_modified(self, response_headers, request_headers) -> bool: if "content-type" in response_headers and ( @@ -51,11 +51,12 @@ class SetAppConfigRequest(BaseModel): ui_open_browser_on_start: bool = None listen_to_network: bool = None listen_port: int = None + test_diffusers: bool = False def init(): mimetypes.init() - mimetypes.add_type('text/css', '.css') + mimetypes.add_type("text/css", ".css") if os.path.isdir(app.CUSTOM_MODIFIERS_DIR): server_api.mount( @@ -132,6 +133,9 @@ def set_app_config_internal(req: SetAppConfigRequest): if "net" not in config: config["net"] = {} config["net"]["listen_port"] = int(req.listen_port) + + config["test_diffusers"] = req.test_diffusers + try: app.setConfig(config) diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index 8e7044f3..ebacc864 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -21,6 +21,7 @@ class GenerateImageRequest(BaseModel): sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" hypernetwork_strength: float = 0 + lora_alpha: float = 0 class TaskData(BaseModel): @@ -36,6 +37,7 @@ class TaskData(BaseModel): # use_stable_diffusion_config: str = "v1-inference" use_vae_model: str = None use_hypernetwork_model: str = None + use_lora_model: str = None show_only_filtered_image: bool = False block_nsfw: bool = False diff --git a/ui/index.html b/ui/index.html index 7112c022..8821b75e 100644 --- a/ui/index.html +++ b/ui/index.html @@ -162,7 +162,7 @@ - + Click to learn more about samplers @@ -217,6 +217,13 @@
+ + + + + +
+ diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js index 4179263e..f791ef74 100644 --- a/ui/media/js/auto-save.js +++ b/ui/media/js/auto-save.js @@ -15,6 +15,7 @@ const SETTINGS_IDS_LIST = [ "stable_diffusion_model", "vae_model", "hypernetwork_model", + "lora_model", "sampler_name", "width", "height", @@ -22,6 +23,7 @@ const SETTINGS_IDS_LIST = [ "guidance_scale", "prompt_strength", "hypernetwork_strength", + "lora_alpha", "output_format", "output_quality", "negative_prompt", diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 55c42af6..1b70a678 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -46,6 +46,9 @@ let vaeModelField = new ModelDropdown(document.querySelector('#vae_model'), 'vae let hypernetworkModelField = new ModelDropdown(document.querySelector('#hypernetwork_model'), 'hypernetwork', 'None') let hypernetworkStrengthSlider = document.querySelector('#hypernetwork_strength_slider') let hypernetworkStrengthField = document.querySelector('#hypernetwork_strength') +let loraModelField = new ModelDropdown(document.querySelector('#lora_model'), 'lora', 'None') +let loraAlphaSlider = document.querySelector('#lora_alpha_slider') +let loraAlphaField = document.querySelector('#lora_alpha') let outputFormatField = document.querySelector('#output_format') let blockNSFWField = document.querySelector('#block_nsfw') let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image") @@ -931,6 +934,9 @@ function createTask(task) { taskConfig += `, Hypernetwork: ${task.reqBody.use_hypernetwork_model}` taskConfig += `, Hypernetwork Strength: ${task.reqBody.hypernetwork_strength}` } + if (task.reqBody.use_lora_model) { + taskConfig += `, LoRA: ${task.reqBody.use_lora_model}` + } if (task.reqBody.preserve_init_image_color_profile) { taskConfig += `, Preserve Color Profile: true` } @@ -1041,9 +1047,11 @@ function getCurrentUserRequest() { height: parseInt(heightField.value), // allow_nsfw: allowNSFWField.checked, vram_usage_level: vramUsageLevelField.value, + sampler_name: samplerField.value, //render_device: undefined, // Set device affinity. Prefer this device, but wont activate. use_stable_diffusion_model: stableDiffusionModelField.value, use_vae_model: vaeModelField.value, + use_lora_model: loraModelField.value, stream_progress_updates: true, stream_image_progress: (numOutputsTotal > 50 ? false : streamImageProgressField.checked), show_only_filtered_image: showOnlyFilteredImageField.checked, @@ -1067,9 +1075,9 @@ function getCurrentUserRequest() { newTask.reqBody.mask = imageInpainter.getImg() } newTask.reqBody.preserve_init_image_color_profile = applyColorCorrectionField.checked - newTask.reqBody.sampler_name = 'ddim' - } else { - newTask.reqBody.sampler_name = samplerField.value + if (!testDiffusers.checked) { + newTask.reqBody.sampler_name = 'ddim' + } } if (saveToDiskField.checked && diskPathField.value.trim() !== '') { newTask.reqBody.save_to_disk_path = diskPathField.value.trim() @@ -1458,6 +1466,34 @@ function updateHypernetworkStrengthContainer() { hypernetworkModelField.addEventListener('change', updateHypernetworkStrengthContainer) updateHypernetworkStrengthContainer() +/********************* LoRA alpha **********************/ +function updateLoraAlpha() { + loraAlphaField.value = loraAlphaSlider.value / 100 + loraAlphaField.dispatchEvent(new Event("change")) +} + +function updateLoraAlphaSlider() { + if (loraAlphaField.value < 0) { + loraAlphaField.value = 0 + } else if (loraAlphaField.value > 0.99) { + loraAlphaField.value = 0.99 + } + + loraAlphaSlider.value = loraAlphaField.value * 100 + loraAlphaSlider.dispatchEvent(new Event("change")) +} + +loraAlphaSlider.addEventListener('input', updateLoraAlpha) +loraAlphaField.addEventListener('input', updateLoraAlphaSlider) +updateLoraAlpha() + +// function updateLoraAlphaContainer() { +// document.querySelector("#lora_alpha_container").style.display = (loraModelField.value === "" ? 'none' : '') +// } +// loraModelField.addEventListener('change', updateLoraAlphaContainer) +// updateLoraAlphaContainer() +document.querySelector("#lora_alpha_container").style.display = 'none' + /********************* JPEG/WEBP Quality **********************/ function updateOutputQuality() { outputQualityField.value = 0 | outputQualitySlider.value @@ -1550,7 +1586,9 @@ loadImg2ImgFromFile() function img2imgLoad() { promptStrengthContainer.style.display = 'table-row' - samplerSelectionContainer.style.display = "none" + if (!testDiffusers.checked) { + samplerSelectionContainer.style.display = "none" + } initImagePreviewContainer.classList.add("has-image") colorCorrectionSetting.style.display = '' @@ -1565,7 +1603,9 @@ function img2imgUnload() { maskSetting.checked = false promptStrengthContainer.style.display = "none" - samplerSelectionContainer.style.display = "" + if (!testDiffusers.checked) { + samplerSelectionContainer.style.display = "" + } initImagePreviewContainer.classList.remove("has-image") colorCorrectionSetting.style.display = 'none' imageEditor.setImage(null, parseInt(widthField.value), parseInt(heightField.value)) diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index 0e4b9b22..cb0c7a4c 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -190,6 +190,14 @@ var PARAMETERS = [ icon: "fa-fire", default: false, }, + { + id: "test_diffusers", + type: ParameterType.checkbox, + label: "Test Diffusers", + note: "Experimental! Can have bugs! Use upcoming features (like LoRA) in our new engine. Please press Save, then restart the program after changing this.", + icon: "fa-bolt", + default: false, + }, ]; function getParameterSettingsEntry(id) { @@ -263,6 +271,7 @@ let listenPortField = document.querySelector("#listen_port") let useBetaChannelField = document.querySelector("#use_beta_channel") let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start") let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions") +let testDiffusers = document.querySelector("#test_diffusers") let saveSettingsBtn = document.querySelector('#save-system-settings-btn') @@ -302,6 +311,10 @@ async function getAppConfig() { if (config.net && config.net.listen_port !== undefined) { listenPortField.value = config.net.listen_port } + if (config.test_diffusers !== undefined) { + testDiffusers.checked = config.test_diffusers + document.querySelector("#lora_model_container").style.display = (testDiffusers.checked ? '' : 'none') + } console.log('get config status response', config) } catch (e) { @@ -471,7 +484,8 @@ saveSettingsBtn.addEventListener('click', function() { 'update_branch': updateBranch, 'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked, 'listen_to_network': listenToNetworkField.checked, - 'listen_port': listenPortField.value + 'listen_port': listenPortField.value, + 'test_diffusers': testDiffusers.checked }) saveSettingsBtn.classList.add('active') asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active'))