diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index b742c17c..80c373da 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -95,7 +95,7 @@ if "%ERRORLEVEL%" EQU "0" ( set PYTHONNOUSERSITE=1 set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages - call python -m pip install --upgrade sdkit==1.0.48 -q || ( + call python -m pip install --upgrade sdkit==1.0.49 -q || ( echo "Error updating sdkit" ) ) @@ -106,7 +106,7 @@ if "%ERRORLEVEL%" EQU "0" ( set PYTHONNOUSERSITE=1 set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages - call python -m pip install sdkit==1.0.48 || ( + call python -m pip install sdkit==1.0.49 || ( echo "Error installing sdkit. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" pause exit /b diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index b62a2fc2..32ebdf84 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -103,7 +103,7 @@ if python ../scripts/check_modules.py sdkit sdkit.models ldm transformers numpy export PYTHONNOUSERSITE=1 export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" - python -m pip install --upgrade sdkit==1.0.48 -q + python -m pip install --upgrade sdkit==1.0.49 -q fi else echo "Installing sdkit: https://pypi.org/project/sdkit/" @@ -111,7 +111,7 @@ else export PYTHONNOUSERSITE=1 export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" - if python -m pip install sdkit==1.0.48 ; then + if python -m pip install sdkit==1.0.49 ; then echo "Installed." else fail "sdkit install failed" diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index 116edf33..a06c56cf 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -7,13 +7,14 @@ from easydiffusion.utils import log from sdkit import Context from sdkit.models import load_model, unload_model, scan_model -KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan"] +KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan", "lora"] MODEL_EXTENSIONS = { "stable-diffusion": [".ckpt", ".safetensors"], "vae": [".vae.pt", ".ckpt", ".safetensors"], "hypernetwork": [".pt", ".safetensors"], "gfpgan": [".pth"], "realesrgan": [".pth"], + "lora": [".ckpt", ".safetensors"], } DEFAULT_MODELS = { "stable-diffusion": [ # needed to support the legacy installations @@ -23,7 +24,7 @@ DEFAULT_MODELS = { "gfpgan": ["GFPGANv1.3"], "realesrgan": ["RealESRGAN_x4plus"], } -MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork"] +MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork", "lora"] known_models = {} @@ -102,6 +103,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData): "gfpgan": task_data.use_face_correction, "realesrgan": task_data.use_upscale, "nsfw_checker": True if task_data.block_nsfw else None, + "lora": task_data.use_lora_model, } models_to_reload = { model_type: path @@ -125,6 +127,7 @@ def resolve_model_paths(task_data: TaskData): ) task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae") task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork") + task_data.use_lora_model = resolve_model_to_use(task_data.use_lora_model, model_type="lora") if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, "gfpgan") @@ -184,11 +187,13 @@ def getModels(): "stable-diffusion": "sd-v1-4", "vae": "", "hypernetwork": "", + "lora": "", }, "options": { "stable-diffusion": ["sd-v1-4"], "vae": [], "hypernetwork": [], + "lora": [], }, } @@ -243,6 +248,7 @@ def getModels(): listModels(model_type="vae") listModels(model_type="hypernetwork") listModels(model_type="gfpgan") + listModels(model_type="lora") if models_scanned > 0: log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") diff --git a/ui/easydiffusion/renderer.py b/ui/easydiffusion/renderer.py index c5dc88b4..fd8b7f7a 100644 --- a/ui/easydiffusion/renderer.py +++ b/ui/easydiffusion/renderer.py @@ -10,7 +10,7 @@ from easydiffusion.utils import get_printable_request, save_images_to_disk, log from sdkit import Context from sdkit.generate import generate_images from sdkit.filter import apply_filters -from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc +from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, diffusers_latent_samples_to_images context = Context() # thread-local """ @@ -26,6 +26,11 @@ def init(device): context.temp_images = {} context.partial_x_samples = None + from easydiffusion import app + + app_config = app.getConfig() + context.test_diffusers = app_config.get("test_diffusers", False) + device_manager.device_init(context, device) @@ -57,7 +62,13 @@ def make_images_internal( ): images, user_stopped = generate_images_internal( - req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress, task_data.stream_image_progress_interval + req, + task_data, + data_queue, + task_temp_images, + step_callback, + task_data.stream_image_progress, + task_data.stream_image_progress_interval, ) filtered_images = filter_images(task_data, images, user_stopped) @@ -82,10 +93,18 @@ def generate_images_internal( ): context.temp_images.clear() - callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress, stream_image_progress_interval) + callback = make_step_callback( + req, + task_data, + data_queue, + task_temp_images, + step_callback, + stream_image_progress, + stream_image_progress_interval, + ) try: - if req.init_image is not None: + if req.init_image is not None and not context.test_diffusers: req.sampler_name = "ddim" images = generate_images(context, callback=callback, **req.dict()) @@ -94,10 +113,14 @@ def generate_images_internal( images = [] user_stopped = True if context.partial_x_samples is not None: - images = latent_samples_to_images(context, context.partial_x_samples) + if context.test_diffusers: + images = diffusers_latent_samples_to_images(context, context.partial_x_samples) + else: + images = latent_samples_to_images(context, context.partial_x_samples) finally: if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None: - del context.partial_x_samples + if not context.test_diffusers: + del context.partial_x_samples context.partial_x_samples = None return images, user_stopped @@ -145,7 +168,12 @@ def make_step_callback( def update_temp_img(x_samples, task_temp_images: list): partial_images = [] - images = latent_samples_to_images(context, x_samples) + + if context.test_diffusers: + images = diffusers_latent_samples_to_images(context, x_samples) + else: + images = latent_samples_to_images(context, x_samples) + if task_data.block_nsfw: images = apply_filters(context, "nsfw_checker", images) @@ -158,17 +186,21 @@ def make_step_callback( del images return partial_images - def on_image_step(x_samples, i): + def on_image_step(x_samples, i, *args): nonlocal last_callback_time - context.partial_x_samples = x_samples + if context.test_diffusers: + context.partial_x_samples = (x_samples, args[0]) + else: + context.partial_x_samples = x_samples + step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 last_callback_time = time.time() progress = {"step": i, "step_time": step_time, "total_steps": n_steps} if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0: - progress["output"] = update_temp_img(x_samples, task_temp_images) + progress["output"] = update_temp_img(context.partial_x_samples, task_temp_images) data_queue.put(json.dumps(progress)) diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index 1d05a1f0..e27f9c5b 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -29,10 +29,10 @@ NOCACHE_HEADERS = {"Cache-Control": "no-cache, no-store, must-revalidate", "Prag class NoCacheStaticFiles(StaticFiles): def __init__(self, directory: str): # follow_symlink is only available on fastapi >= 0.92.0 - if (os.path.islink(directory)): - super().__init__(directory = os.path.realpath(directory)) + if os.path.islink(directory): + super().__init__(directory=os.path.realpath(directory)) else: - super().__init__(directory = directory) + super().__init__(directory=directory) def is_not_modified(self, response_headers, request_headers) -> bool: if "content-type" in response_headers and ( @@ -51,11 +51,12 @@ class SetAppConfigRequest(BaseModel): ui_open_browser_on_start: bool = None listen_to_network: bool = None listen_port: int = None + test_diffusers: bool = False def init(): mimetypes.init() - mimetypes.add_type('text/css', '.css') + mimetypes.add_type("text/css", ".css") if os.path.isdir(app.CUSTOM_MODIFIERS_DIR): server_api.mount( @@ -132,6 +133,9 @@ def set_app_config_internal(req: SetAppConfigRequest): if "net" not in config: config["net"] = {} config["net"]["listen_port"] = int(req.listen_port) + + config["test_diffusers"] = req.test_diffusers + try: app.setConfig(config) diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index 8e7044f3..ebacc864 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -21,6 +21,7 @@ class GenerateImageRequest(BaseModel): sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" hypernetwork_strength: float = 0 + lora_alpha: float = 0 class TaskData(BaseModel): @@ -36,6 +37,7 @@ class TaskData(BaseModel): # use_stable_diffusion_config: str = "v1-inference" use_vae_model: str = None use_hypernetwork_model: str = None + use_lora_model: str = None show_only_filtered_image: bool = False block_nsfw: bool = False diff --git a/ui/index.html b/ui/index.html index 7112c022..8821b75e 100644 --- a/ui/index.html +++ b/ui/index.html @@ -162,7 +162,7 @@ - + Click to learn more about samplers @@ -217,6 +217,13 @@