forked from extern/easydiffusion
commit
118a4862ab
20
CHANGES.md
20
CHANGES.md
@ -5,7 +5,7 @@
|
||||
- **Nearly twice as fast** - significantly faster speed of image generation. We're now pretty close to automatic1111's speed. Code contributions are welcome to make our project even faster: https://github.com/easydiffusion/sdkit/#is-it-fast
|
||||
- **Full support for Stable Diffusion 2.1 (including CPU)** - supports loading v1.4 or v2.0 or v2.1 models seamlessly. No need to enable "Test SD2", and no need to add `sd2_` to your SD 2.0 model file names. Works on CPU as well.
|
||||
- **Memory optimized Stable Diffusion 2.1** - you can now use Stable Diffusion 2.1 models, with the same low VRAM optimizations that we've always had for SD 1.4. Please note, the SD 2.0 and 2.1 models require more GPU and System RAM, as compared to the SD 1.4 and 1.5 models.
|
||||
- **6 new samplers!** - explore the new samplers, some of which can generate great images in less than 10 inference steps!
|
||||
- **11 new samplers!** - explore the new samplers, some of which can generate great images in less than 10 inference steps! We've added the Karras and UniPC samplers.
|
||||
- **Model Merging** - You can now merge two models (`.ckpt` or `.safetensors`) and output `.ckpt` or `.safetensors` models, optionally in `fp16` precision. Details: https://github.com/cmdr2/stable-diffusion-ui/wiki/Model-Merging
|
||||
- **Fast loading/unloading of VAEs** - No longer needs to reload the entire Stable Diffusion model, each time you change the VAE
|
||||
- **Database of known models** - automatically picks the right configuration for known models. E.g. we automatically detect and apply "v" parameterization (required for some SD 2.0 models), and "fp32" attention precision (required for some SD 2.1 models).
|
||||
@ -19,6 +19,24 @@
|
||||
Our focus continues to remain on an easy installation experience, and an easy user-interface. While still remaining pretty powerful, in terms of features and speed.
|
||||
|
||||
### Detailed changelog
|
||||
* 2.5.22 - 28 Feb 2023 - Minor styling changes to UI buttons, and the models dropdown.
|
||||
* 2.5.22 - 28 Feb 2023 - Lots of UI-related bug fixes. Thanks @patriceac.
|
||||
* 2.5.21 - 22 Feb 2023 - An option to control the size of the image thumbnails. You can use the `Display options` in the top-right corner to change this. Thanks @JeLuf.
|
||||
* 2.5.20 - 20 Feb 2023 - Support saving images in WEBP format (which consumes less disk space, with similar quality). Thanks @ogmaresca.
|
||||
* 2.5.20 - 18 Feb 2023 - A setting to block NSFW images from being generated. You can enable this setting in the Settings tab.
|
||||
* 2.5.19 - 17 Feb 2023 - Initial support for server-side plugins. Currently supports overriding the `get_cond_and_uncond()` function.
|
||||
* 2.5.18 - 17 Feb 2023 - 5 new samplers! UniPC samplers, some of which produce images in less than 15 steps. Thanks @Schorny.
|
||||
* 2.5.16 - 13 Feb 2023 - Searchable dropdown for models. This is useful if you have a LOT of models. You can type part of the model name, to auto-search through your models. Thanks @patriceac for the feature, and @AssassinJN for help in UI tweaks!
|
||||
* 2.5.16 - 13 Feb 2023 - Lots of fixes and improvements to the installer. First round of changes to add Mac support. Thanks @JeLuf.
|
||||
* 2.5.16 - 13 Feb 2023 - UI bug fixes for the inpainter editor. Thanks @patriceac.
|
||||
* 2.5.16 - 13 Feb 2023 - Fix broken task reorder. Thanks @JeLuf.
|
||||
* 2.5.16 - 13 Feb 2023 - Remove a task if all the images inside it have been removed. Thanks @AssassinJN.
|
||||
* 2.5.16 - 10 Feb 2023 - Embed metadata into the JPG/PNG images, if selected in the "Settings" tab (under "Metadata format"). Thanks @patriceac.
|
||||
* 2.5.16 - 10 Feb 2023 - Sort models alphabetically in the models dropdown. Thanks @ogmaresca.
|
||||
* 2.5.16 - 10 Feb 2023 - Support multiple GFPGAN models. Download new GFPGAN models into the `models/gfpgan` folder, and refresh the UI to use it. Thanks @JeLuf.
|
||||
* 2.5.16 - 10 Feb 2023 - Allow a server to enforce a fixed directory path to save images. This is useful if the server is exposed to a lot of users. This can be set in the `config.json` file as `force_save_path: "/path/to/fixed/save/dir"`. E.g. `force_save_path: "D:/user_images"`. Thanks @JeLuf.
|
||||
* 2.5.16 - 10 Feb 2023 - The "Make Images" button now shows the correct amount of images it'll create when using operators like `{}` or `|`. For e.g. if the prompt is `Photo of a {woman, man}`, then the button will say `Make 2 Images`. Thanks @JeLuf.
|
||||
* 2.5.16 - 10 Feb 2023 - A bunch of UI-related bug fixes. Thanks @patriceac.
|
||||
* 2.5.15 - 8 Feb 2023 - Allow using 'balanced' VRAM usage mode on GPUs with 4 GB or less of VRAM. This mode used to be called 'Turbo' in the previous version.
|
||||
* 2.5.14 - 8 Feb 2023 - Fix broken auto-save settings. We renamed `sampler` to `sampler_name`, which caused old settings to fail.
|
||||
* 2.5.14 - 6 Feb 2023 - Simplify the UI for merging models, and some other minor UI tweaks. Better error reporting if a model failed to load.
|
||||
|
@ -1,8 +1,27 @@
|
||||
@echo off
|
||||
|
||||
cd /d %~dp0
|
||||
echo Install dir: %~dp0
|
||||
|
||||
set PATH=C:\Windows\System32;%PATH%
|
||||
|
||||
if exist "on_sd_start.bat" (
|
||||
echo ================================================================================
|
||||
echo.
|
||||
echo !!!! WARNING !!!!
|
||||
echo.
|
||||
echo It looks like you're trying to run the installation script from a source code
|
||||
echo download. This will not work.
|
||||
echo.
|
||||
echo Recommended: Please close this window and download the installer from
|
||||
echo https://stable-diffusion-ui.github.io/docs/installation/
|
||||
echo.
|
||||
echo ================================================================================
|
||||
echo.
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
|
||||
@rem set legacy installer's PATH, if it exists
|
||||
if exist "installer" set PATH=%cd%\installer;%cd%\installer\Library\bin;%cd%\installer\Scripts;%cd%\installer\Library\usr\bin;%PATH%
|
||||
|
||||
|
@ -25,6 +25,15 @@ case "${OS_ARCH}" in
|
||||
*) echo "Unknown system architecture: $OS_ARCH! This script runs only on x86_64 or arm64" && exit
|
||||
esac
|
||||
|
||||
if ! which curl; then fail "'curl' not found. Please install curl."; fi
|
||||
if ! which tar; then fail "'tar' not found. Please install tar."; fi
|
||||
if ! which bzip2; then fail "'bzip2' not found. Please install bzip2."; fi
|
||||
|
||||
if pwd | grep ' '; then fail "The installation directory's path contains a space character. Conda will fail to install. Please change the directory."; fi
|
||||
if [ -f /proc/cpuinfo ]; then
|
||||
if ! cat /proc/cpuinfo | grep avx | uniq; then fail "Your CPU doesn't support AVX."; fi
|
||||
fi
|
||||
|
||||
# https://mamba.readthedocs.io/en/latest/installation.html
|
||||
if [ "$OS_NAME" == "linux" ] && [ "$OS_ARCH" == "arm64" ]; then OS_ARCH="aarch64"; fi
|
||||
|
||||
@ -52,7 +61,7 @@ if [ "$PACKAGES_TO_INSTALL" != "" ]; then
|
||||
echo "Downloading micromamba from $MICROMAMBA_DOWNLOAD_URL to $MAMBA_ROOT_PREFIX/micromamba"
|
||||
|
||||
mkdir -p "$MAMBA_ROOT_PREFIX"
|
||||
curl -L "$MICROMAMBA_DOWNLOAD_URL" | tar -xvj bin/micromamba -O > "$MAMBA_ROOT_PREFIX/micromamba"
|
||||
curl -L "$MICROMAMBA_DOWNLOAD_URL" | tar -xvj -O bin/micromamba > "$MAMBA_ROOT_PREFIX/micromamba"
|
||||
|
||||
if [ "$?" != "0" ]; then
|
||||
echo
|
||||
|
@ -28,5 +28,12 @@ EOF
|
||||
|
||||
}
|
||||
|
||||
filesize() {
|
||||
case "$(uname -s)" in
|
||||
Linux*) stat -c "%s" $1;;
|
||||
Darwin*) stat -f "%z" $1;;
|
||||
*) echo "Unknown OS: $OS_NAME! This script runs only on Linux or Mac" && exit
|
||||
esac
|
||||
}
|
||||
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
@echo off
|
||||
|
||||
@echo. & echo "Stable Diffusion UI - v2" & echo.
|
||||
@echo. & echo "Easy Diffusion - v2" & echo.
|
||||
|
||||
set PATH=C:\Windows\System32;%PATH%
|
||||
|
||||
@ -28,7 +28,7 @@ if "%update_branch%"=="" (
|
||||
|
||||
@>nul findstr /m "sd_ui_git_cloned" scripts\install_status.txt
|
||||
@if "%ERRORLEVEL%" EQU "0" (
|
||||
@echo "Stable Diffusion UI's git repository was already installed. Updating from %update_branch%.."
|
||||
@echo "Easy Diffusion's git repository was already installed. Updating from %update_branch%.."
|
||||
|
||||
@cd sd-ui-files
|
||||
|
||||
@ -38,13 +38,13 @@ if "%update_branch%"=="" (
|
||||
|
||||
@cd ..
|
||||
) else (
|
||||
@echo. & echo "Downloading Stable Diffusion UI.." & echo.
|
||||
@echo. & echo "Downloading Easy Diffusion..." & echo.
|
||||
@echo "Using the %update_branch% channel" & echo.
|
||||
|
||||
@call git clone -b "%update_branch%" https://github.com/cmdr2/stable-diffusion-ui.git sd-ui-files && (
|
||||
@echo sd_ui_git_cloned >> scripts\install_status.txt
|
||||
) || (
|
||||
@echo "Error downloading Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
|
||||
@echo "Error downloading Easy Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
|
||||
pause
|
||||
@exit /b
|
||||
)
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
source ./scripts/functions.sh
|
||||
|
||||
printf "\n\nStable Diffusion UI\n\n"
|
||||
printf "\n\nEasy Diffusion\n\n"
|
||||
|
||||
if [ -f "scripts/config.sh" ]; then
|
||||
source scripts/config.sh
|
||||
@ -13,7 +13,7 @@ if [ "$update_branch" == "" ]; then
|
||||
fi
|
||||
|
||||
if [ -f "scripts/install_status.txt" ] && [ `grep -c sd_ui_git_cloned scripts/install_status.txt` -gt "0" ]; then
|
||||
echo "Stable Diffusion UI's git repository was already installed. Updating from $update_branch.."
|
||||
echo "Easy Diffusion's git repository was already installed. Updating from $update_branch.."
|
||||
|
||||
cd sd-ui-files
|
||||
|
||||
@ -23,7 +23,7 @@ if [ -f "scripts/install_status.txt" ] && [ `grep -c sd_ui_git_cloned scripts/in
|
||||
|
||||
cd ..
|
||||
else
|
||||
printf "\n\nDownloading Stable Diffusion UI..\n\n"
|
||||
printf "\n\nDownloading Easy Diffusion..\n\n"
|
||||
printf "Using the $update_branch channel\n\n"
|
||||
|
||||
if git clone -b "$update_branch" https://github.com/cmdr2/stable-diffusion-ui.git sd-ui-files ; then
|
||||
@ -40,7 +40,6 @@ cp sd-ui-files/scripts/bootstrap.sh scripts/
|
||||
cp sd-ui-files/scripts/check_modules.py scripts/
|
||||
cp sd-ui-files/scripts/start.sh .
|
||||
cp sd-ui-files/scripts/developer_console.sh .
|
||||
cp sd-ui-files/scripts/functions.sh scripts/
|
||||
|
||||
./scripts/on_sd_start.sh
|
||||
|
||||
read -p "Press any key to continue"
|
||||
exec ./scripts/on_sd_start.sh
|
||||
|
@ -26,7 +26,7 @@ if exist "%cd%\stable-diffusion\env" (
|
||||
@rem activate the installer env
|
||||
call conda activate
|
||||
@if "%ERRORLEVEL%" NEQ "0" (
|
||||
@echo. & echo "Error activating conda for Stable Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
|
||||
@echo. & echo "Error activating conda for Easy Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
@ -61,6 +61,9 @@ if exist "GFPGANv1.3.pth" move GFPGANv1.3.pth ..\models\gfpgan\
|
||||
if exist "RealESRGAN_x4plus.pth" move RealESRGAN_x4plus.pth ..\models\realesrgan\
|
||||
if exist "RealESRGAN_x4plus_anime_6B.pth" move RealESRGAN_x4plus_anime_6B.pth ..\models\realesrgan\
|
||||
|
||||
if not exist "%INSTALL_ENV_DIR%\DLLs\libssl-1_1-x64.dll" copy "%INSTALL_ENV_DIR%\Library\bin\libssl-1_1-x64.dll" "%INSTALL_ENV_DIR%\DLLs\"
|
||||
if not exist "%INSTALL_ENV_DIR%\DLLs\libcrypto-1_1-x64.dll" copy "%INSTALL_ENV_DIR%\Library\bin\libcrypto-1_1-x64.dll" "%INSTALL_ENV_DIR%\DLLs\"
|
||||
|
||||
@rem install torch and torchvision
|
||||
call python ..\scripts\check_modules.py torch torchvision
|
||||
if "%ERRORLEVEL%" EQU "0" (
|
||||
@ -92,7 +95,7 @@ if "%ERRORLEVEL%" EQU "0" (
|
||||
set PYTHONNOUSERSITE=1
|
||||
set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
|
||||
|
||||
call python -m pip install --upgrade sdkit==1.0.35 -q || (
|
||||
call python -m pip install --upgrade sdkit==1.0.43 -q || (
|
||||
echo "Error updating sdkit"
|
||||
)
|
||||
)
|
||||
@ -103,7 +106,7 @@ if "%ERRORLEVEL%" EQU "0" (
|
||||
set PYTHONNOUSERSITE=1
|
||||
set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
|
||||
|
||||
call python -m pip install sdkit==1.0.35 || (
|
||||
call python -m pip install sdkit==1.0.43 || (
|
||||
echo "Error installing sdkit. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
|
||||
pause
|
||||
exit /b
|
||||
@ -113,7 +116,7 @@ if "%ERRORLEVEL%" EQU "0" (
|
||||
call python -c "from importlib.metadata import version; print('sdkit version:', version('sdkit'))"
|
||||
|
||||
@rem upgrade stable-diffusion-sdkit
|
||||
call python -m pip install --upgrade stable-diffusion-sdkit==2.1.1 -q || (
|
||||
call python -m pip install --upgrade stable-diffusion-sdkit==2.1.3 -q || (
|
||||
echo "Error updating stable-diffusion-sdkit"
|
||||
)
|
||||
call python -c "from importlib.metadata import version; print('stable-diffusion version:', version('stable-diffusion-sdkit'))"
|
||||
@ -139,15 +142,15 @@ set PATH=C:\Windows\System32;%PATH%
|
||||
|
||||
call python ..\scripts\check_modules.py uvicorn fastapi
|
||||
@if "%ERRORLEVEL%" EQU "0" (
|
||||
echo "Packages necessary for Stable Diffusion UI were already installed"
|
||||
echo "Packages necessary for Easy Diffusion were already installed"
|
||||
) else (
|
||||
@echo. & echo "Downloading packages necessary for Stable Diffusion UI.." & echo.
|
||||
@echo. & echo "Downloading packages necessary for Easy Diffusion..." & echo.
|
||||
|
||||
set PYTHONNOUSERSITE=1
|
||||
set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
|
||||
|
||||
@call conda install -c conda-forge -y uvicorn fastapi || (
|
||||
echo "Error installing the packages necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
|
||||
echo "Error installing the packages necessary for Easy Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
|
||||
pause
|
||||
exit /b
|
||||
)
|
||||
@ -328,7 +331,7 @@ call WHERE uvicorn > .tmp
|
||||
@echo sd_install_complete >> ..\scripts\install_status.txt
|
||||
)
|
||||
|
||||
@echo. & echo "Stable Diffusion is ready!" & echo.
|
||||
@echo. & echo "Easy Diffusion installation complete! Starting the server!" & echo.
|
||||
|
||||
@set SD_DIR=%cd%
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
source ./scripts/functions.sh
|
||||
|
||||
cp sd-ui-files/scripts/functions.sh scripts/
|
||||
cp sd-ui-files/scripts/on_env_start.sh scripts/
|
||||
cp sd-ui-files/scripts/bootstrap.sh scripts/
|
||||
cp sd-ui-files/scripts/check_modules.py scripts/
|
||||
|
||||
source ./scripts/functions.sh
|
||||
|
||||
# activate the installer env
|
||||
CONDA_BASEPATH=$(conda info --base)
|
||||
source "$CONDA_BASEPATH/etc/profile.d/conda.sh" # avoids the 'shell not initialized' error
|
||||
@ -80,7 +81,7 @@ if python ../scripts/check_modules.py sdkit sdkit.models ldm transformers numpy
|
||||
export PYTHONNOUSERSITE=1
|
||||
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
|
||||
|
||||
python -m pip install --upgrade sdkit==1.0.35 -q
|
||||
python -m pip install --upgrade sdkit==1.0.43 -q
|
||||
fi
|
||||
else
|
||||
echo "Installing sdkit: https://pypi.org/project/sdkit/"
|
||||
@ -88,7 +89,7 @@ else
|
||||
export PYTHONNOUSERSITE=1
|
||||
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
|
||||
|
||||
if python -m pip install sdkit==1.0.35 ; then
|
||||
if python -m pip install sdkit==1.0.43 ; then
|
||||
echo "Installed."
|
||||
else
|
||||
fail "sdkit install failed"
|
||||
@ -98,7 +99,7 @@ fi
|
||||
python -c "from importlib.metadata import version; print('sdkit version:', version('sdkit'))"
|
||||
|
||||
# upgrade stable-diffusion-sdkit
|
||||
python -m pip install --upgrade stable-diffusion-sdkit==2.1.1 -q
|
||||
python -m pip install --upgrade stable-diffusion-sdkit==2.1.3 -q
|
||||
python -c "from importlib.metadata import version; print('stable-diffusion version:', version('stable-diffusion-sdkit'))"
|
||||
|
||||
# install rich
|
||||
@ -118,9 +119,9 @@ else
|
||||
fi
|
||||
|
||||
if python ../scripts/check_modules.py uvicorn fastapi ; then
|
||||
echo "Packages necessary for Stable Diffusion UI were already installed"
|
||||
echo "Packages necessary for Easy Diffusion were already installed"
|
||||
else
|
||||
printf "\n\nDownloading packages necessary for Stable Diffusion UI..\n\n"
|
||||
printf "\n\nDownloading packages necessary for Easy Diffusion..\n\n"
|
||||
|
||||
export PYTHONNOUSERSITE=1
|
||||
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
|
||||
@ -137,7 +138,7 @@ else
|
||||
fi
|
||||
|
||||
if [ -f "../models/stable-diffusion/sd-v1-4.ckpt" ]; then
|
||||
model_size=`find "../models/stable-diffusion/sd-v1-4.ckpt" -printf "%s"`
|
||||
model_size=`filesize "../models/stable-diffusion/sd-v1-4.ckpt"`
|
||||
|
||||
if [ "$model_size" -eq "4265380512" ] || [ "$model_size" -eq "7703807346" ] || [ "$model_size" -eq "7703810927" ]; then
|
||||
echo "Data files (weights) necessary for Stable Diffusion were already downloaded"
|
||||
@ -153,7 +154,7 @@ if [ ! -f "../models/stable-diffusion/sd-v1-4.ckpt" ]; then
|
||||
curl -L -k https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt > ../models/stable-diffusion/sd-v1-4.ckpt
|
||||
|
||||
if [ -f "../models/stable-diffusion/sd-v1-4.ckpt" ]; then
|
||||
model_size=`find "../models/stable-diffusion/sd-v1-4.ckpt" -printf "%s"`
|
||||
model_size=`filesize "../models/stable-diffusion/sd-v1-4.ckpt"`
|
||||
if [ ! "$model_size" == "4265380512" ]; then
|
||||
fail "The downloaded model file was invalid! Bytes downloaded: $model_size"
|
||||
fi
|
||||
@ -164,7 +165,7 @@ fi
|
||||
|
||||
|
||||
if [ -f "../models/gfpgan/GFPGANv1.3.pth" ]; then
|
||||
model_size=`find "../models/gfpgan/GFPGANv1.3.pth" -printf "%s"`
|
||||
model_size=`filesize "../models/gfpgan/GFPGANv1.3.pth"`
|
||||
|
||||
if [ "$model_size" -eq "348632874" ]; then
|
||||
echo "Data files (weights) necessary for GFPGAN (Face Correction) were already downloaded"
|
||||
@ -180,7 +181,7 @@ if [ ! -f "../models/gfpgan/GFPGANv1.3.pth" ]; then
|
||||
curl -L -k https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth > ../models/gfpgan/GFPGANv1.3.pth
|
||||
|
||||
if [ -f "../models/gfpgan/GFPGANv1.3.pth" ]; then
|
||||
model_size=`find "../models/gfpgan/GFPGANv1.3.pth" -printf "%s"`
|
||||
model_size=`filesize "../models/gfpgan/GFPGANv1.3.pth"`
|
||||
if [ ! "$model_size" -eq "348632874" ]; then
|
||||
fail "The downloaded GFPGAN model file was invalid! Bytes downloaded: $model_size"
|
||||
fi
|
||||
@ -191,7 +192,7 @@ fi
|
||||
|
||||
|
||||
if [ -f "../models/realesrgan/RealESRGAN_x4plus.pth" ]; then
|
||||
model_size=`find "../models/realesrgan/RealESRGAN_x4plus.pth" -printf "%s"`
|
||||
model_size=`filesize "../models/realesrgan/RealESRGAN_x4plus.pth"`
|
||||
|
||||
if [ "$model_size" -eq "67040989" ]; then
|
||||
echo "Data files (weights) necessary for ESRGAN (Resolution Upscaling) x4plus were already downloaded"
|
||||
@ -207,7 +208,7 @@ if [ ! -f "../models/realesrgan/RealESRGAN_x4plus.pth" ]; then
|
||||
curl -L -k https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth > ../models/realesrgan/RealESRGAN_x4plus.pth
|
||||
|
||||
if [ -f "../models/realesrgan/RealESRGAN_x4plus.pth" ]; then
|
||||
model_size=`find "../models/realesrgan/RealESRGAN_x4plus.pth" -printf "%s"`
|
||||
model_size=`filesize "../models/realesrgan/RealESRGAN_x4plus.pth"`
|
||||
if [ ! "$model_size" -eq "67040989" ]; then
|
||||
fail "The downloaded ESRGAN x4plus model file was invalid! Bytes downloaded: $model_size"
|
||||
fi
|
||||
@ -218,7 +219,7 @@ fi
|
||||
|
||||
|
||||
if [ -f "../models/realesrgan/RealESRGAN_x4plus_anime_6B.pth" ]; then
|
||||
model_size=`find "../models/realesrgan/RealESRGAN_x4plus_anime_6B.pth" -printf "%s"`
|
||||
model_size=`filesize "../models/realesrgan/RealESRGAN_x4plus_anime_6B.pth"`
|
||||
|
||||
if [ "$model_size" -eq "17938799" ]; then
|
||||
echo "Data files (weights) necessary for ESRGAN (Resolution Upscaling) x4plus_anime were already downloaded"
|
||||
@ -234,7 +235,7 @@ if [ ! -f "../models/realesrgan/RealESRGAN_x4plus_anime_6B.pth" ]; then
|
||||
curl -L -k https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth > ../models/realesrgan/RealESRGAN_x4plus_anime_6B.pth
|
||||
|
||||
if [ -f "../models/realesrgan/RealESRGAN_x4plus_anime_6B.pth" ]; then
|
||||
model_size=`find "../models/realesrgan/RealESRGAN_x4plus_anime_6B.pth" -printf "%s"`
|
||||
model_size=`filesize "../models/realesrgan/RealESRGAN_x4plus_anime_6B.pth"`
|
||||
if [ ! "$model_size" -eq "17938799" ]; then
|
||||
fail "The downloaded ESRGAN x4plus_anime model file was invalid! Bytes downloaded: $model_size"
|
||||
fi
|
||||
@ -245,7 +246,7 @@ fi
|
||||
|
||||
|
||||
if [ -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then
|
||||
model_size=`find ../models/vae/vae-ft-mse-840000-ema-pruned.ckpt -printf "%s"`
|
||||
model_size=`filesize "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt"`
|
||||
|
||||
if [ "$model_size" -eq "334695179" ]; then
|
||||
echo "Data files (weights) necessary for the default VAE (sd-vae-ft-mse-original) were already downloaded"
|
||||
@ -261,7 +262,7 @@ if [ ! -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then
|
||||
curl -L -k https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt > ../models/vae/vae-ft-mse-840000-ema-pruned.ckpt
|
||||
|
||||
if [ -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then
|
||||
model_size=`find ../models/vae/vae-ft-mse-840000-ema-pruned.ckpt -printf "%s"`
|
||||
model_size=`filesize "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt"`
|
||||
if [ ! "$model_size" -eq "334695179" ]; then
|
||||
printf "\n\nError: The downloaded default VAE (sd-vae-ft-mse-original) file was invalid! Bytes downloaded: $model_size\n\n"
|
||||
printf "\n\nError downloading the data files (weights) for the default VAE (sd-vae-ft-mse-original). Sorry about that, please try to:\n 1. Run this installer again.\n 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting\n 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB\n 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues\nThanks!\n\n"
|
||||
@ -280,7 +281,7 @@ if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
|
||||
echo sd_install_complete >> ../scripts/install_status.txt
|
||||
fi
|
||||
|
||||
printf "\n\nStable Diffusion is ready!\n\n"
|
||||
printf "\n\nEasy Diffusion installation complete, starting the server!\n\n"
|
||||
|
||||
SD_PATH=`pwd`
|
||||
|
||||
|
@ -2,6 +2,24 @@
|
||||
|
||||
cd "$(dirname "${BASH_SOURCE[0]}")"
|
||||
|
||||
if [ -f "on_sd_start.bat" ]; then
|
||||
echo ================================================================================
|
||||
echo
|
||||
echo !!!! WARNING !!!!
|
||||
echo
|
||||
echo It looks like you\'re trying to run the installation script from a source code
|
||||
echo download. This will not work.
|
||||
echo
|
||||
echo Recommended: Please close this window and download the installer from
|
||||
echo https://stable-diffusion-ui.github.io/docs/installation/
|
||||
echo
|
||||
echo ================================================================================
|
||||
echo
|
||||
read
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
# set legacy installer's PATH, if it exists
|
||||
if [ -e "installer" ]; then export PATH="$(pwd)/installer/bin:$PATH"; fi
|
||||
|
||||
|
@ -4,9 +4,10 @@ import sys
|
||||
import json
|
||||
import traceback
|
||||
import logging
|
||||
import shlex
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
|
||||
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
|
||||
|
||||
from easydiffusion import task_manager
|
||||
from easydiffusion.utils import log
|
||||
@ -15,138 +16,205 @@ from easydiffusion.utils import log
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s'
|
||||
LOG_FORMAT = "%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s"
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format=LOG_FORMAT,
|
||||
datefmt="%X",
|
||||
handlers=[RichHandler(markup=True, rich_tracebacks=False, show_time=False, show_level=False)],
|
||||
level=logging.INFO,
|
||||
format=LOG_FORMAT,
|
||||
datefmt="%X",
|
||||
handlers=[RichHandler(markup=True, rich_tracebacks=False, show_time=False, show_level=False)],
|
||||
)
|
||||
|
||||
SD_DIR = os.getcwd()
|
||||
|
||||
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
|
||||
SD_UI_DIR = os.getenv("SD_UI_PATH", None)
|
||||
|
||||
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts"))
|
||||
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models"))
|
||||
|
||||
USER_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins"))
|
||||
CORE_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins"))
|
||||
|
||||
USER_UI_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "ui")
|
||||
CORE_UI_PLUGINS_DIR = os.path.join(CORE_PLUGINS_DIR, "ui")
|
||||
USER_SERVER_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "server")
|
||||
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, "core"), (USER_UI_PLUGINS_DIR, "user"))
|
||||
|
||||
sys.path.append(os.path.dirname(SD_UI_DIR))
|
||||
sys.path.append(USER_SERVER_PLUGINS_DIR)
|
||||
|
||||
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
|
||||
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
|
||||
|
||||
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'))
|
||||
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
|
||||
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
|
||||
|
||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||
PRESERVE_CONFIG_VARS = ["FORCE_FULL_PRECISION"]
|
||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||
APP_CONFIG_DEFAULTS = {
|
||||
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
|
||||
'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
|
||||
'update_branch': 'main',
|
||||
'ui': {
|
||||
'open_browser_on_start': True,
|
||||
"render_devices": "auto", # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
|
||||
"update_branch": "main",
|
||||
"ui": {
|
||||
"open_browser_on_start": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def init():
|
||||
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
||||
os.makedirs(USER_SERVER_PLUGINS_DIR, exist_ok=True)
|
||||
|
||||
load_server_plugins()
|
||||
|
||||
update_render_threads()
|
||||
|
||||
|
||||
def getConfig(default_val=APP_CONFIG_DEFAULTS):
|
||||
try:
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
config_json_path = os.path.join(CONFIG_DIR, "config.json")
|
||||
if not os.path.exists(config_json_path):
|
||||
return default_val
|
||||
with open(config_json_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
if 'net' not in config:
|
||||
config['net'] = {}
|
||||
if os.getenv('SD_UI_BIND_PORT') is not None:
|
||||
config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT'))
|
||||
if os.getenv('SD_UI_BIND_IP') is not None:
|
||||
config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0')
|
||||
return config
|
||||
config = default_val
|
||||
else:
|
||||
with open(config_json_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
if "net" not in config:
|
||||
config["net"] = {}
|
||||
if os.getenv("SD_UI_BIND_PORT") is not None:
|
||||
config["net"]["listen_port"] = int(os.getenv("SD_UI_BIND_PORT"))
|
||||
else:
|
||||
config["net"]["listen_port"] = 9000
|
||||
if os.getenv("SD_UI_BIND_IP") is not None:
|
||||
config["net"]["listen_to_network"] = os.getenv("SD_UI_BIND_IP") == "0.0.0.0"
|
||||
else:
|
||||
config["net"]["listen_to_network"] = True
|
||||
return config
|
||||
except Exception as e:
|
||||
log.warn(traceback.format_exc())
|
||||
return default_val
|
||||
|
||||
|
||||
def setConfig(config):
|
||||
try: # config.json
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
with open(config_json_path, 'w', encoding='utf-8') as f:
|
||||
try: # config.json
|
||||
config_json_path = os.path.join(CONFIG_DIR, "config.json")
|
||||
with open(config_json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f)
|
||||
except:
|
||||
log.error(traceback.format_exc())
|
||||
|
||||
try: # config.bat
|
||||
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
|
||||
try: # config.bat
|
||||
config_bat_path = os.path.join(CONFIG_DIR, "config.bat")
|
||||
config_bat = []
|
||||
|
||||
if 'update_branch' in config:
|
||||
if "update_branch" in config:
|
||||
config_bat.append(f"@set update_branch={config['update_branch']}")
|
||||
|
||||
config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}")
|
||||
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
|
||||
bind_ip = "0.0.0.0" if config["net"]["listen_to_network"] else "127.0.0.1"
|
||||
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
|
||||
|
||||
# Preserve these variables if they are set
|
||||
for var in PRESERVE_CONFIG_VARS:
|
||||
if os.getenv(var) is not None:
|
||||
config_bat.append(f"@set {var}={os.getenv(var)}")
|
||||
|
||||
if len(config_bat) > 0:
|
||||
with open(config_bat_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\r\n'.join(config_bat))
|
||||
with open(config_bat_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(config_bat))
|
||||
except:
|
||||
log.error(traceback.format_exc())
|
||||
|
||||
try: # config.sh
|
||||
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
|
||||
config_sh = ['#!/bin/bash']
|
||||
try: # config.sh
|
||||
config_sh_path = os.path.join(CONFIG_DIR, "config.sh")
|
||||
config_sh = ["#!/bin/bash"]
|
||||
|
||||
if 'update_branch' in config:
|
||||
if "update_branch" in config:
|
||||
config_sh.append(f"export update_branch={config['update_branch']}")
|
||||
|
||||
config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}")
|
||||
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
|
||||
bind_ip = "0.0.0.0" if config["net"]["listen_to_network"] else "127.0.0.1"
|
||||
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
|
||||
|
||||
# Preserve these variables if they are set
|
||||
for var in PRESERVE_CONFIG_VARS:
|
||||
if os.getenv(var) is not None:
|
||||
config_bat.append(f'export {var}="{shlex.quote(os.getenv(var))}"')
|
||||
|
||||
if len(config_sh) > 1:
|
||||
with open(config_sh_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(config_sh))
|
||||
with open(config_sh_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(config_sh))
|
||||
except:
|
||||
log.error(traceback.format_exc())
|
||||
|
||||
|
||||
def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level):
|
||||
config = getConfig()
|
||||
if 'model' not in config:
|
||||
config['model'] = {}
|
||||
if "model" not in config:
|
||||
config["model"] = {}
|
||||
|
||||
config['model']['stable-diffusion'] = ckpt_model_name
|
||||
config['model']['vae'] = vae_model_name
|
||||
config['model']['hypernetwork'] = hypernetwork_model_name
|
||||
config["model"]["stable-diffusion"] = ckpt_model_name
|
||||
config["model"]["vae"] = vae_model_name
|
||||
config["model"]["hypernetwork"] = hypernetwork_model_name
|
||||
|
||||
if vae_model_name is None or vae_model_name == "":
|
||||
del config['model']['vae']
|
||||
del config["model"]["vae"]
|
||||
if hypernetwork_model_name is None or hypernetwork_model_name == "":
|
||||
del config['model']['hypernetwork']
|
||||
del config["model"]["hypernetwork"]
|
||||
|
||||
config['vram_usage_level'] = vram_usage_level
|
||||
config["vram_usage_level"] = vram_usage_level
|
||||
|
||||
setConfig(config)
|
||||
|
||||
|
||||
def update_render_threads():
|
||||
config = getConfig()
|
||||
render_devices = config.get('render_devices', 'auto')
|
||||
active_devices = task_manager.get_devices()['active'].keys()
|
||||
render_devices = config.get("render_devices", "auto")
|
||||
active_devices = task_manager.get_devices()["active"].keys()
|
||||
|
||||
log.debug(f'requesting for render_devices: {render_devices}')
|
||||
log.debug(f"requesting for render_devices: {render_devices}")
|
||||
task_manager.update_render_threads(render_devices, active_devices)
|
||||
|
||||
|
||||
def getUIPlugins():
|
||||
plugins = []
|
||||
|
||||
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
|
||||
for file in os.listdir(plugins_dir):
|
||||
if file.endswith('.plugin.js'):
|
||||
plugins.append(f'/plugins/{dir_prefix}/{file}')
|
||||
if file.endswith(".plugin.js"):
|
||||
plugins.append(f"/plugins/{dir_prefix}/{file}")
|
||||
|
||||
return plugins
|
||||
|
||||
|
||||
def load_server_plugins():
|
||||
if not os.path.exists(USER_SERVER_PLUGINS_DIR):
|
||||
return
|
||||
|
||||
import importlib
|
||||
|
||||
def load_plugin(file):
|
||||
mod_path = file.replace(".py", "")
|
||||
return importlib.import_module(mod_path)
|
||||
|
||||
def apply_plugin(file, plugin):
|
||||
if hasattr(plugin, "get_cond_and_uncond"):
|
||||
import sdkit.generate.image_generator
|
||||
|
||||
sdkit.generate.image_generator.get_cond_and_uncond = plugin.get_cond_and_uncond
|
||||
log.info(f"Overridden get_cond_and_uncond with the one in the server plugin: {file}")
|
||||
|
||||
for file in os.listdir(USER_SERVER_PLUGINS_DIR):
|
||||
file_path = os.path.join(USER_SERVER_PLUGINS_DIR, file)
|
||||
if (not os.path.isdir(file_path) and not file_path.endswith("_plugin.py")) or (
|
||||
os.path.isdir(file_path) and not file_path.endswith("_plugin")
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
log.info(f"Loading server plugin: {file}")
|
||||
mod = load_plugin(file)
|
||||
|
||||
log.info(f"Applying server plugin: {file}")
|
||||
apply_plugin(file, mod)
|
||||
except:
|
||||
log.warn(f"Error while loading a server plugin")
|
||||
log.warn(traceback.format_exc())
|
||||
|
||||
|
||||
def getIPConfig():
|
||||
try:
|
||||
ips = socket.gethostbyname_ex(socket.gethostname())
|
||||
@ -156,10 +224,13 @@ def getIPConfig():
|
||||
log.exception(e)
|
||||
return []
|
||||
|
||||
|
||||
def open_browser():
|
||||
config = getConfig()
|
||||
ui = config.get('ui', {})
|
||||
net = config.get('net', {'listen_port':9000})
|
||||
port = net.get('listen_port', 9000)
|
||||
if ui.get('open_browser_on_start', True):
|
||||
import webbrowser; webbrowser.open(f"http://localhost:{port}")
|
||||
ui = config.get("ui", {})
|
||||
net = config.get("net", {"listen_port": 9000})
|
||||
port = net.get("listen_port", 9000)
|
||||
if ui.get("open_browser_on_start", True):
|
||||
import webbrowser
|
||||
|
||||
webbrowser.open(f"http://localhost:{port}")
|
||||
|
@ -5,45 +5,54 @@ import re
|
||||
|
||||
from easydiffusion.utils import log
|
||||
|
||||
'''
|
||||
"""
|
||||
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
|
||||
Otherwise the models will load at half-precision (i.e. float16).
|
||||
|
||||
Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs).
|
||||
'''
|
||||
"""
|
||||
|
||||
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
|
||||
COMPARABLE_GPU_PERCENTILE = (
|
||||
0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
|
||||
)
|
||||
|
||||
mem_free_threshold = 0
|
||||
|
||||
|
||||
def get_device_delta(render_devices, active_devices):
|
||||
'''
|
||||
"""
|
||||
render_devices: 'cpu', or 'auto' or ['cuda:N'...]
|
||||
active_devices: ['cpu', 'cuda:N'...]
|
||||
'''
|
||||
"""
|
||||
|
||||
if render_devices in ('cpu', 'auto'):
|
||||
if render_devices in ("cpu", "auto"):
|
||||
render_devices = [render_devices]
|
||||
elif render_devices is not None:
|
||||
if isinstance(render_devices, str):
|
||||
render_devices = [render_devices]
|
||||
if isinstance(render_devices, list) and len(render_devices) > 0:
|
||||
render_devices = list(filter(lambda x: x.startswith('cuda:'), render_devices))
|
||||
render_devices = list(filter(lambda x: x.startswith("cuda:"), render_devices))
|
||||
if len(render_devices) == 0:
|
||||
raise Exception('Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}')
|
||||
raise Exception(
|
||||
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
|
||||
)
|
||||
|
||||
render_devices = list(filter(lambda x: is_device_compatible(x), render_devices))
|
||||
if len(render_devices) == 0:
|
||||
raise Exception('Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion')
|
||||
raise Exception(
|
||||
"Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion"
|
||||
)
|
||||
else:
|
||||
raise Exception('Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}')
|
||||
raise Exception(
|
||||
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
|
||||
)
|
||||
else:
|
||||
render_devices = ['auto']
|
||||
render_devices = ["auto"]
|
||||
|
||||
if 'auto' in render_devices:
|
||||
if "auto" in render_devices:
|
||||
render_devices = auto_pick_devices(active_devices)
|
||||
if 'cpu' in render_devices:
|
||||
log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
|
||||
if "cpu" in render_devices:
|
||||
log.warn("WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!")
|
||||
|
||||
active_devices = set(active_devices)
|
||||
render_devices = set(render_devices)
|
||||
@ -53,19 +62,21 @@ def get_device_delta(render_devices, active_devices):
|
||||
|
||||
return devices_to_start, devices_to_stop
|
||||
|
||||
|
||||
def auto_pick_devices(currently_active_devices):
|
||||
global mem_free_threshold
|
||||
|
||||
if not torch.cuda.is_available(): return ['cpu']
|
||||
if not torch.cuda.is_available():
|
||||
return ["cpu"]
|
||||
|
||||
device_count = torch.cuda.device_count()
|
||||
if device_count == 1:
|
||||
return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu']
|
||||
return ["cuda:0"] if is_device_compatible("cuda:0") else ["cpu"]
|
||||
|
||||
log.debug('Autoselecting GPU. Using most free memory.')
|
||||
log.debug("Autoselecting GPU. Using most free memory.")
|
||||
devices = []
|
||||
for device in range(device_count):
|
||||
device = f'cuda:{device}'
|
||||
device = f"cuda:{device}"
|
||||
if not is_device_compatible(device):
|
||||
continue
|
||||
|
||||
@ -73,11 +84,13 @@ def auto_pick_devices(currently_active_devices):
|
||||
mem_free /= float(10**9)
|
||||
mem_total /= float(10**9)
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb')
|
||||
devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free})
|
||||
log.debug(
|
||||
f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
|
||||
)
|
||||
devices.append({"device": device, "device_name": device_name, "mem_free": mem_free})
|
||||
|
||||
devices.sort(key=lambda x:x['mem_free'], reverse=True)
|
||||
max_mem_free = devices[0]['mem_free']
|
||||
devices.sort(key=lambda x: x["mem_free"], reverse=True)
|
||||
max_mem_free = devices[0]["mem_free"]
|
||||
curr_mem_free_threshold = COMPARABLE_GPU_PERCENTILE * max_mem_free
|
||||
mem_free_threshold = max(curr_mem_free_threshold, mem_free_threshold)
|
||||
|
||||
@ -87,23 +100,26 @@ def auto_pick_devices(currently_active_devices):
|
||||
# always be very low (since their VRAM contains the model).
|
||||
# These already-running devices probably aren't terrible, since they were picked in the past.
|
||||
# Worst case, the user can restart the program and that'll get rid of them.
|
||||
devices = list(filter((lambda x: x['mem_free'] > mem_free_threshold or x['device'] in currently_active_devices), devices))
|
||||
devices = list(map(lambda x: x['device'], devices))
|
||||
devices = list(
|
||||
filter((lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices), devices)
|
||||
)
|
||||
devices = list(map(lambda x: x["device"], devices))
|
||||
return devices
|
||||
|
||||
|
||||
def device_init(context, device):
|
||||
'''
|
||||
"""
|
||||
This function assumes the 'device' has already been verified to be compatible.
|
||||
`get_device_delta()` has already filtered out incompatible devices.
|
||||
'''
|
||||
"""
|
||||
|
||||
validate_device_id(device, log_prefix='device_init')
|
||||
validate_device_id(device, log_prefix="device_init")
|
||||
|
||||
if device == 'cpu':
|
||||
context.device = 'cpu'
|
||||
if device == "cpu":
|
||||
context.device = "cpu"
|
||||
context.device_name = get_processor_name()
|
||||
context.half_precision = False
|
||||
log.debug(f'Render device CPU available as {context.device_name}')
|
||||
log.debug(f"Render device CPU available as {context.device_name}")
|
||||
return
|
||||
|
||||
context.device_name = torch.cuda.get_device_name(device)
|
||||
@ -111,7 +127,7 @@ def device_init(context, device):
|
||||
|
||||
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
|
||||
if needs_to_force_full_precision(context):
|
||||
log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
|
||||
log.warn(f"forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}")
|
||||
# Apply force_full_precision now before models are loaded.
|
||||
context.half_precision = False
|
||||
|
||||
@ -120,72 +136,93 @@ def device_init(context, device):
|
||||
|
||||
return
|
||||
|
||||
|
||||
def needs_to_force_full_precision(context):
|
||||
if 'FORCE_FULL_PRECISION' in os.environ:
|
||||
if "FORCE_FULL_PRECISION" in os.environ:
|
||||
return True
|
||||
|
||||
device_name = context.device_name.lower()
|
||||
return (('nvidia' in device_name or 'geforce' in device_name or 'quadro' in device_name) and (' 1660' in device_name or ' 1650' in device_name or ' t400' in device_name or ' t500' in device_name or ' t550' in device_name or ' t600' in device_name or ' t1000' in device_name or ' t1200' in device_name or ' t2000' in device_name)) or ('tesla k40m' in device_name)
|
||||
return (
|
||||
("nvidia" in device_name or "geforce" in device_name or "quadro" in device_name)
|
||||
and (
|
||||
" 1660" in device_name
|
||||
or " 1650" in device_name
|
||||
or " t400" in device_name
|
||||
or " t550" in device_name
|
||||
or " t600" in device_name
|
||||
or " t1000" in device_name
|
||||
or " t1200" in device_name
|
||||
or " t2000" in device_name
|
||||
)
|
||||
) or ("tesla k40m" in device_name)
|
||||
|
||||
|
||||
def get_max_vram_usage_level(device):
|
||||
if device != 'cpu':
|
||||
if device != "cpu":
|
||||
_, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_total /= float(10**9)
|
||||
|
||||
if mem_total < 4.5:
|
||||
return 'low'
|
||||
return "low"
|
||||
elif mem_total < 6.5:
|
||||
return 'balanced'
|
||||
return "balanced"
|
||||
|
||||
return 'high'
|
||||
return "high"
|
||||
|
||||
def validate_device_id(device, log_prefix=''):
|
||||
|
||||
def validate_device_id(device, log_prefix=""):
|
||||
def is_valid():
|
||||
if not isinstance(device, str):
|
||||
return False
|
||||
if device == 'cpu':
|
||||
if device == "cpu":
|
||||
return True
|
||||
if not device.startswith('cuda:') or not device[5:].isnumeric():
|
||||
if not device.startswith("cuda:") or not device[5:].isnumeric():
|
||||
return False
|
||||
return True
|
||||
|
||||
if not is_valid():
|
||||
raise EnvironmentError(f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}")
|
||||
raise EnvironmentError(
|
||||
f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}"
|
||||
)
|
||||
|
||||
|
||||
def is_device_compatible(device):
|
||||
'''
|
||||
"""
|
||||
Returns True/False, and prints any compatibility errors
|
||||
'''
|
||||
# static variable "history".
|
||||
is_device_compatible.history = getattr(is_device_compatible, 'history', {})
|
||||
"""
|
||||
# static variable "history".
|
||||
is_device_compatible.history = getattr(is_device_compatible, "history", {})
|
||||
try:
|
||||
validate_device_id(device, log_prefix='is_device_compatible')
|
||||
validate_device_id(device, log_prefix="is_device_compatible")
|
||||
except:
|
||||
log.error(str(e))
|
||||
return False
|
||||
|
||||
if device == 'cpu': return True
|
||||
if device == "cpu":
|
||||
return True
|
||||
# Memory check
|
||||
try:
|
||||
_, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_total /= float(10**9)
|
||||
if mem_total < 3.0:
|
||||
if is_device_compatible.history.get(device) == None:
|
||||
log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion')
|
||||
is_device_compatible.history[device] = 1
|
||||
log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion")
|
||||
is_device_compatible.history[device] = 1
|
||||
return False
|
||||
except RuntimeError as e:
|
||||
log.error(str(e))
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_processor_name():
|
||||
try:
|
||||
import platform, subprocess
|
||||
|
||||
if platform.system() == "Windows":
|
||||
return platform.processor()
|
||||
elif platform.system() == "Darwin":
|
||||
os.environ['PATH'] = os.environ['PATH'] + os.pathsep + '/usr/sbin'
|
||||
os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
|
||||
command = "sysctl -n machdep.cpu.brand_string"
|
||||
return subprocess.check_output(command).strip()
|
||||
elif platform.system() == "Linux":
|
||||
|
@ -1,36 +1,37 @@
|
||||
import os
|
||||
|
||||
from easydiffusion import app, device_manager
|
||||
from easydiffusion import app
|
||||
from easydiffusion.types import TaskData
|
||||
from easydiffusion.utils import log
|
||||
|
||||
from sdkit import Context
|
||||
from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model
|
||||
from sdkit.utils import hash_file_quick
|
||||
from sdkit.models import load_model, unload_model, scan_model
|
||||
|
||||
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
||||
KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan"]
|
||||
MODEL_EXTENSIONS = {
|
||||
'stable-diffusion': ['.ckpt', '.safetensors'],
|
||||
'vae': ['.vae.pt', '.ckpt', '.safetensors'],
|
||||
'hypernetwork': ['.pt', '.safetensors'],
|
||||
'gfpgan': ['.pth'],
|
||||
'realesrgan': ['.pth'],
|
||||
"stable-diffusion": [".ckpt", ".safetensors"],
|
||||
"vae": [".vae.pt", ".ckpt", ".safetensors"],
|
||||
"hypernetwork": [".pt", ".safetensors"],
|
||||
"gfpgan": [".pth"],
|
||||
"realesrgan": [".pth"],
|
||||
}
|
||||
DEFAULT_MODELS = {
|
||||
'stable-diffusion': [ # needed to support the legacy installations
|
||||
'custom-model', # only one custom model file was supported initially, creatively named 'custom-model'
|
||||
'sd-v1-4', # Default fallback.
|
||||
"stable-diffusion": [ # needed to support the legacy installations
|
||||
"custom-model", # only one custom model file was supported initially, creatively named 'custom-model'
|
||||
"sd-v1-4", # Default fallback.
|
||||
],
|
||||
'gfpgan': ['GFPGANv1.3'],
|
||||
'realesrgan': ['RealESRGAN_x4plus'],
|
||||
"gfpgan": ["GFPGANv1.3"],
|
||||
"realesrgan": ["RealESRGAN_x4plus"],
|
||||
}
|
||||
MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork']
|
||||
MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork"]
|
||||
|
||||
known_models = {}
|
||||
|
||||
|
||||
def init():
|
||||
make_model_folders()
|
||||
getModels() # run this once, to cache the picklescan results
|
||||
getModels() # run this once, to cache the picklescan results
|
||||
|
||||
|
||||
def load_default_models(context: Context):
|
||||
set_vram_optimizations(context)
|
||||
@ -39,27 +40,28 @@ def load_default_models(context: Context):
|
||||
for model_type in MODELS_TO_LOAD_ON_START:
|
||||
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
|
||||
try:
|
||||
load_model(context, model_type)
|
||||
load_model(context, model_type)
|
||||
except Exception as e:
|
||||
log.error(f'[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]')
|
||||
log.error(f'[red]Error: {e}[/red]')
|
||||
log.error(f'[red]Consider removing the model from the model folder.[red]')
|
||||
log.error(f"[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]")
|
||||
log.error(f"[red]Error: {e}[/red]")
|
||||
log.error(f"[red]Consider removing the model from the model folder.[red]")
|
||||
|
||||
|
||||
def unload_all(context: Context):
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
unload_model(context, model_type)
|
||||
|
||||
def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
||||
|
||||
def resolve_model_to_use(model_name: str = None, model_type: str = None):
|
||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||
default_models = DEFAULT_MODELS.get(model_type, [])
|
||||
config = app.getConfig()
|
||||
|
||||
model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR]
|
||||
if not model_name: # When None try user configured model.
|
||||
if not model_name: # When None try user configured model.
|
||||
# config = getConfig()
|
||||
if 'model' in config and model_type in config['model']:
|
||||
model_name = config['model'][model_type]
|
||||
if "model" in config and model_type in config["model"]:
|
||||
model_name = config["model"][model_type]
|
||||
|
||||
if model_name:
|
||||
# Check models directory
|
||||
@ -84,41 +86,55 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
||||
for model_extension in model_extensions:
|
||||
if os.path.exists(default_model_path + model_extension):
|
||||
if model_name is not None:
|
||||
log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
|
||||
log.warn(
|
||||
f"Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}"
|
||||
)
|
||||
return default_model_path + model_extension
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||
model_paths_in_req = {
|
||||
'stable-diffusion': task_data.use_stable_diffusion_model,
|
||||
'vae': task_data.use_vae_model,
|
||||
'hypernetwork': task_data.use_hypernetwork_model,
|
||||
'gfpgan': task_data.use_face_correction,
|
||||
'realesrgan': task_data.use_upscale,
|
||||
"stable-diffusion": task_data.use_stable_diffusion_model,
|
||||
"vae": task_data.use_vae_model,
|
||||
"hypernetwork": task_data.use_hypernetwork_model,
|
||||
"gfpgan": task_data.use_face_correction,
|
||||
"realesrgan": task_data.use_upscale,
|
||||
"nsfw_checker": True if task_data.block_nsfw else None,
|
||||
}
|
||||
models_to_reload = {
|
||||
model_type: path
|
||||
for model_type, path in model_paths_in_req.items()
|
||||
if context.model_paths.get(model_type) != path
|
||||
}
|
||||
models_to_reload = {model_type: path for model_type, path in model_paths_in_req.items() if context.model_paths.get(model_type) != path}
|
||||
|
||||
if set_vram_optimizations(context): # reload SD
|
||||
models_to_reload['stable-diffusion'] = model_paths_in_req['stable-diffusion']
|
||||
if set_vram_optimizations(context): # reload SD
|
||||
models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"]
|
||||
|
||||
for model_type, model_path_in_req in models_to_reload.items():
|
||||
context.model_paths[model_type] = model_path_in_req
|
||||
|
||||
action_fn = unload_model if context.model_paths[model_type] is None else load_model
|
||||
action_fn(context, model_type, scan_model=False) # we've scanned them already
|
||||
action_fn(context, model_type, scan_model=False) # we've scanned them already
|
||||
|
||||
|
||||
def resolve_model_paths(task_data: TaskData):
|
||||
task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
|
||||
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae')
|
||||
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
|
||||
task_data.use_stable_diffusion_model = resolve_model_to_use(
|
||||
task_data.use_stable_diffusion_model, model_type="stable-diffusion"
|
||||
)
|
||||
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae")
|
||||
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork")
|
||||
|
||||
if task_data.use_face_correction:
|
||||
task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, "gfpgan")
|
||||
if task_data.use_upscale:
|
||||
task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, "realesrgan")
|
||||
|
||||
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
|
||||
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan')
|
||||
|
||||
def set_vram_optimizations(context: Context):
|
||||
config = app.getConfig()
|
||||
vram_usage_level = config.get('vram_usage_level', 'balanced')
|
||||
vram_usage_level = config.get("vram_usage_level", "balanced")
|
||||
|
||||
if vram_usage_level != context.vram_usage_level:
|
||||
context.vram_usage_level = vram_usage_level
|
||||
@ -126,42 +142,53 @@ def set_vram_optimizations(context: Context):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def make_model_folders():
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
||||
|
||||
os.makedirs(model_dir_path, exist_ok=True)
|
||||
|
||||
help_file_name = f'Place your {model_type} model files here.txt'
|
||||
help_file_name = f"Place your {model_type} model files here.txt"
|
||||
help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}'
|
||||
|
||||
with open(os.path.join(model_dir_path, help_file_name), 'w', encoding='utf-8') as f:
|
||||
with open(os.path.join(model_dir_path, help_file_name), "w", encoding="utf-8") as f:
|
||||
f.write(help_file_contents)
|
||||
|
||||
|
||||
def is_malicious_model(file_path):
|
||||
try:
|
||||
if file_path.endswith(".safetensors"):
|
||||
return False
|
||||
scan_result = scan_model(file_path)
|
||||
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
|
||||
log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
||||
log.warn(
|
||||
":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]"
|
||||
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
|
||||
)
|
||||
return True
|
||||
else:
|
||||
log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
||||
log.debug(
|
||||
"Scan %s: [green]%d scanned, %d issue, %d infected.[/green]"
|
||||
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
log.error(f'error while scanning: {file_path}, error: {e}')
|
||||
log.error(f"error while scanning: {file_path}, error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def getModels():
|
||||
models = {
|
||||
'active': {
|
||||
'stable-diffusion': 'sd-v1-4',
|
||||
'vae': '',
|
||||
'hypernetwork': '',
|
||||
"active": {
|
||||
"stable-diffusion": "sd-v1-4",
|
||||
"vae": "",
|
||||
"hypernetwork": "",
|
||||
},
|
||||
'options': {
|
||||
'stable-diffusion': ['sd-v1-4'],
|
||||
'vae': [],
|
||||
'hypernetwork': [],
|
||||
"options": {
|
||||
"stable-diffusion": ["sd-v1-4"],
|
||||
"vae": [],
|
||||
"hypernetwork": [],
|
||||
},
|
||||
}
|
||||
|
||||
@ -171,13 +198,16 @@ def getModels():
|
||||
"Raised when picklescan reports a problem with a model"
|
||||
pass
|
||||
|
||||
def scan_directory(directory, suffixes):
|
||||
def scan_directory(directory, suffixes, directoriesFirst: bool = True):
|
||||
nonlocal models_scanned
|
||||
tree = []
|
||||
for entry in os.scandir(directory):
|
||||
for entry in sorted(
|
||||
os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower())
|
||||
):
|
||||
if entry.is_file():
|
||||
matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes))
|
||||
if len(matching_suffix) == 0: continue
|
||||
if len(matching_suffix) == 0:
|
||||
continue
|
||||
matching_suffix = matching_suffix[0]
|
||||
|
||||
mtime = entry.stat().st_mtime
|
||||
@ -187,11 +217,12 @@ def getModels():
|
||||
if is_malicious_model(entry.path):
|
||||
raise MaliciousModelException(entry.path)
|
||||
known_models[entry.path] = mtime
|
||||
tree.append(entry.name[:-len(matching_suffix)])
|
||||
tree.append(entry.name[: -len(matching_suffix)])
|
||||
elif entry.is_dir():
|
||||
scan=scan_directory(entry.path, suffixes)
|
||||
scan = scan_directory(entry.path, suffixes, directoriesFirst=False)
|
||||
|
||||
if len(scan) != 0:
|
||||
tree.append( (entry.name, scan ) )
|
||||
tree.append((entry.name, scan))
|
||||
return tree
|
||||
|
||||
def listModels(model_type):
|
||||
@ -203,20 +234,22 @@ def getModels():
|
||||
os.makedirs(models_dir)
|
||||
|
||||
try:
|
||||
models['options'][model_type] = scan_directory(models_dir, model_extensions)
|
||||
models["options"][model_type] = scan_directory(models_dir, model_extensions)
|
||||
except MaliciousModelException as e:
|
||||
models['scan-error'] = e
|
||||
models["scan-error"] = e
|
||||
|
||||
# custom models
|
||||
listModels(model_type='stable-diffusion')
|
||||
listModels(model_type='vae')
|
||||
listModels(model_type='hypernetwork')
|
||||
listModels(model_type="stable-diffusion")
|
||||
listModels(model_type="vae")
|
||||
listModels(model_type="hypernetwork")
|
||||
listModels(model_type="gfpgan")
|
||||
|
||||
if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. Nothing infected[/]')
|
||||
if models_scanned > 0:
|
||||
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
|
||||
|
||||
# legacy
|
||||
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
|
||||
custom_weight_path = os.path.join(app.SD_DIR, "custom-model.ckpt")
|
||||
if os.path.exists(custom_weight_path):
|
||||
models['options']['stable-diffusion'].append('custom-model')
|
||||
models["options"]["stable-diffusion"].append("custom-model")
|
||||
|
||||
return models
|
||||
|
@ -12,22 +12,26 @@ from sdkit.generate import generate_images
|
||||
from sdkit.filter import apply_filters
|
||||
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc
|
||||
|
||||
context = Context() # thread-local
|
||||
'''
|
||||
context = Context() # thread-local
|
||||
"""
|
||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def init(device):
|
||||
'''
|
||||
"""
|
||||
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
|
||||
'''
|
||||
"""
|
||||
context.stop_processing = False
|
||||
context.temp_images = {}
|
||||
context.partial_x_samples = None
|
||||
|
||||
device_manager.device_init(context, device)
|
||||
|
||||
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
|
||||
def make_images(
|
||||
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
|
||||
):
|
||||
context.stop_processing = False
|
||||
print_task_info(req, task_data)
|
||||
|
||||
@ -36,18 +40,25 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
|
||||
res = Response(req, task_data, images=construct_response(images, seeds, task_data, base_seed=req.seed))
|
||||
res = res.json()
|
||||
data_queue.put(json.dumps(res))
|
||||
log.info('Task completed')
|
||||
log.info("Task completed")
|
||||
|
||||
return res
|
||||
|
||||
def print_task_info(req: GenerateImageRequest, task_data: TaskData):
|
||||
req_str = pprint.pformat(get_printable_request(req)).replace("[","\[")
|
||||
task_str = pprint.pformat(task_data.dict()).replace("[","\[")
|
||||
log.info(f'request: {req_str}')
|
||||
log.info(f'task data: {task_str}')
|
||||
|
||||
def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
||||
def print_task_info(req: GenerateImageRequest, task_data: TaskData):
|
||||
req_str = pprint.pformat(get_printable_request(req)).replace("[", "\[")
|
||||
task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
|
||||
log.info(f"request: {req_str}")
|
||||
log.info(f"task data: {task_str}")
|
||||
|
||||
|
||||
def make_images_internal(
|
||||
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
|
||||
):
|
||||
|
||||
images, user_stopped = generate_images_internal(
|
||||
req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress, task_data.stream_image_progress_interval
|
||||
)
|
||||
filtered_images = filter_images(task_data, images, user_stopped)
|
||||
|
||||
if task_data.save_to_disk_path is not None:
|
||||
@ -59,13 +70,23 @@ def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_qu
|
||||
else:
|
||||
return images + filtered_images, seeds + seeds
|
||||
|
||||
def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
|
||||
def generate_images_internal(
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
stream_image_progress: bool,
|
||||
stream_image_progress_interval: int,
|
||||
):
|
||||
context.temp_images.clear()
|
||||
|
||||
callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||
callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress, stream_image_progress_interval)
|
||||
|
||||
try:
|
||||
if req.init_image is not None: req.sampler_name = 'ddim'
|
||||
if req.init_image is not None:
|
||||
req.sampler_name = "ddim"
|
||||
|
||||
images = generate_images(context, callback=callback, **req.dict())
|
||||
user_stopped = False
|
||||
@ -75,31 +96,50 @@ def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, dat
|
||||
if context.partial_x_samples is not None:
|
||||
images = latent_samples_to_images(context, context.partial_x_samples)
|
||||
finally:
|
||||
if hasattr(context, 'partial_x_samples') and context.partial_x_samples is not None:
|
||||
if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None:
|
||||
del context.partial_x_samples
|
||||
context.partial_x_samples = None
|
||||
|
||||
return images, user_stopped
|
||||
|
||||
|
||||
def filter_images(task_data: TaskData, images: list, user_stopped):
|
||||
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
|
||||
if user_stopped:
|
||||
return images
|
||||
|
||||
filters_to_apply = []
|
||||
if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
|
||||
if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan')
|
||||
if task_data.block_nsfw:
|
||||
filters_to_apply.append("nsfw_checker")
|
||||
if task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower():
|
||||
filters_to_apply.append("gfpgan")
|
||||
if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower():
|
||||
filters_to_apply.append("realesrgan")
|
||||
|
||||
if len(filters_to_apply) == 0:
|
||||
return images
|
||||
|
||||
return apply_filters(context, filters_to_apply, images, scale=task_data.upscale_amount)
|
||||
|
||||
|
||||
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
|
||||
return [
|
||||
ResponseImage(
|
||||
data=img_to_base64_str(img, task_data.output_format, task_data.output_quality),
|
||||
seed=seed,
|
||||
) for img, seed in zip(images, seeds)
|
||||
)
|
||||
for img, seed in zip(images, seeds)
|
||||
]
|
||||
|
||||
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
|
||||
def make_step_callback(
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
stream_image_progress: bool,
|
||||
stream_image_progress_interval: int,
|
||||
):
|
||||
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
||||
last_callback_time = -1
|
||||
|
||||
@ -107,11 +147,11 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
|
||||
partial_images = []
|
||||
images = latent_samples_to_images(context, x_samples)
|
||||
for i, img in enumerate(images):
|
||||
buf = img_to_buffer(img, output_format='JPEG')
|
||||
buf = img_to_buffer(img, output_format="JPEG")
|
||||
|
||||
context.temp_images[f"{task_data.request_id}/{i}"] = buf
|
||||
task_temp_images[i] = buf
|
||||
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
|
||||
partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"})
|
||||
del images
|
||||
return partial_images
|
||||
|
||||
@ -124,8 +164,8 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
|
||||
|
||||
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
|
||||
|
||||
if stream_image_progress and i % 5 == 0:
|
||||
progress['output'] = update_temp_img(x_samples, task_temp_images)
|
||||
if stream_image_progress and stream_image_progress_interval > 0 and i % stream_image_progress_interval == 0:
|
||||
progress["output"] = update_temp_img(x_samples, task_temp_images)
|
||||
|
||||
data_queue.put(json.dumps(progress))
|
||||
|
||||
|
@ -16,21 +16,25 @@ from easydiffusion import app, model_manager, task_manager
|
||||
from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest
|
||||
from easydiffusion.utils import log
|
||||
|
||||
log.info(f'started in {app.SD_DIR}')
|
||||
log.info(f'started at {datetime.datetime.now():%x %X}')
|
||||
log.info(f"started in {app.SD_DIR}")
|
||||
log.info(f"started at {datetime.datetime.now():%x %X}")
|
||||
|
||||
server_api = FastAPI()
|
||||
|
||||
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||
NOCACHE_HEADERS = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||
|
||||
|
||||
class NoCacheStaticFiles(StaticFiles):
|
||||
def is_not_modified(self, response_headers, request_headers) -> bool:
|
||||
if 'content-type' in response_headers and ('javascript' in response_headers['content-type'] or 'css' in response_headers['content-type']):
|
||||
if "content-type" in response_headers and (
|
||||
"javascript" in response_headers["content-type"] or "css" in response_headers["content-type"]
|
||||
):
|
||||
response_headers.update(NOCACHE_HEADERS)
|
||||
return False
|
||||
|
||||
return super().is_not_modified(response_headers, request_headers)
|
||||
|
||||
|
||||
class SetAppConfigRequest(BaseModel):
|
||||
update_branch: str = None
|
||||
render_devices: Union[List[str], List[int], str, int] = None
|
||||
@ -39,203 +43,243 @@ class SetAppConfigRequest(BaseModel):
|
||||
listen_to_network: bool = None
|
||||
listen_port: int = None
|
||||
|
||||
|
||||
def init():
|
||||
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media")
|
||||
server_api.mount("/media", NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")), name="media")
|
||||
|
||||
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
|
||||
server_api.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}")
|
||||
server_api.mount(
|
||||
f"/plugins/{dir_prefix}", NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}"
|
||||
)
|
||||
|
||||
@server_api.post('/app_config')
|
||||
async def set_app_config(req : SetAppConfigRequest):
|
||||
@server_api.post("/app_config")
|
||||
async def set_app_config(req: SetAppConfigRequest):
|
||||
return set_app_config_internal(req)
|
||||
|
||||
@server_api.get('/get/{key:path}')
|
||||
def read_web_data(key:str=None):
|
||||
@server_api.get("/get/{key:path}")
|
||||
def read_web_data(key: str = None):
|
||||
return read_web_data_internal(key)
|
||||
|
||||
@server_api.get('/ping') # Get server and optionally session status.
|
||||
def ping(session_id:str=None):
|
||||
@server_api.get("/ping") # Get server and optionally session status.
|
||||
def ping(session_id: str = None):
|
||||
return ping_internal(session_id)
|
||||
|
||||
@server_api.post('/render')
|
||||
@server_api.post("/render")
|
||||
def render(req: dict):
|
||||
return render_internal(req)
|
||||
|
||||
@server_api.post('/model/merge')
|
||||
@server_api.post("/model/merge")
|
||||
def model_merge(req: dict):
|
||||
print(req)
|
||||
return model_merge_internal(req)
|
||||
|
||||
@server_api.get('/image/stream/{task_id:int}')
|
||||
def stream(task_id:int):
|
||||
@server_api.get("/image/stream/{task_id:int}")
|
||||
def stream(task_id: int):
|
||||
return stream_internal(task_id)
|
||||
|
||||
@server_api.get('/image/stop')
|
||||
@server_api.get("/image/stop")
|
||||
def stop(task: int):
|
||||
return stop_internal(task)
|
||||
|
||||
@server_api.get('/image/tmp/{task_id:int}/{img_id:int}')
|
||||
@server_api.get("/image/tmp/{task_id:int}/{img_id:int}")
|
||||
def get_image(task_id: int, img_id: int):
|
||||
return get_image_internal(task_id, img_id)
|
||||
|
||||
@server_api.get('/')
|
||||
@server_api.get("/")
|
||||
def read_root():
|
||||
return FileResponse(os.path.join(app.SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
|
||||
return FileResponse(os.path.join(app.SD_UI_DIR, "index.html"), headers=NOCACHE_HEADERS)
|
||||
|
||||
@server_api.on_event("shutdown")
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
task_manager.current_state_error = SystemExit('Application shutting down.')
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
task_manager.current_state_error = SystemExit("Application shutting down.")
|
||||
|
||||
|
||||
# API implementations
|
||||
def set_app_config_internal(req : SetAppConfigRequest):
|
||||
def set_app_config_internal(req: SetAppConfigRequest):
|
||||
config = app.getConfig()
|
||||
if req.update_branch is not None:
|
||||
config['update_branch'] = req.update_branch
|
||||
config["update_branch"] = req.update_branch
|
||||
if req.render_devices is not None:
|
||||
update_render_devices_in_config(config, req.render_devices)
|
||||
if req.ui_open_browser_on_start is not None:
|
||||
if 'ui' not in config:
|
||||
config['ui'] = {}
|
||||
config['ui']['open_browser_on_start'] = req.ui_open_browser_on_start
|
||||
if "ui" not in config:
|
||||
config["ui"] = {}
|
||||
config["ui"]["open_browser_on_start"] = req.ui_open_browser_on_start
|
||||
if req.listen_to_network is not None:
|
||||
if 'net' not in config:
|
||||
config['net'] = {}
|
||||
config['net']['listen_to_network'] = bool(req.listen_to_network)
|
||||
if "net" not in config:
|
||||
config["net"] = {}
|
||||
config["net"]["listen_to_network"] = bool(req.listen_to_network)
|
||||
if req.listen_port is not None:
|
||||
if 'net' not in config:
|
||||
config['net'] = {}
|
||||
config['net']['listen_port'] = int(req.listen_port)
|
||||
if "net" not in config:
|
||||
config["net"] = {}
|
||||
config["net"]["listen_port"] = int(req.listen_port)
|
||||
try:
|
||||
app.setConfig(config)
|
||||
|
||||
if req.render_devices:
|
||||
app.update_render_threads()
|
||||
|
||||
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
|
||||
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def update_render_devices_in_config(config, render_devices):
|
||||
if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'):
|
||||
raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}')
|
||||
if render_devices not in ("cpu", "auto") and not render_devices.startswith("cuda:"):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid render device requested: {render_devices}")
|
||||
|
||||
if render_devices.startswith('cuda:'):
|
||||
render_devices = render_devices.split(',')
|
||||
if render_devices.startswith("cuda:"):
|
||||
render_devices = render_devices.split(",")
|
||||
|
||||
config['render_devices'] = render_devices
|
||||
config["render_devices"] = render_devices
|
||||
|
||||
def read_web_data_internal(key:str=None):
|
||||
if not key: # /get without parameters, stable-diffusion easter egg.
|
||||
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
||||
elif key == 'app_config':
|
||||
|
||||
def read_web_data_internal(key: str = None):
|
||||
if not key: # /get without parameters, stable-diffusion easter egg.
|
||||
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
||||
elif key == "app_config":
|
||||
return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS)
|
||||
elif key == 'system_info':
|
||||
elif key == "system_info":
|
||||
config = app.getConfig()
|
||||
system_info = {
|
||||
'devices': task_manager.get_devices(),
|
||||
'hosts': app.getIPConfig(),
|
||||
'default_output_dir': os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME),
|
||||
}
|
||||
system_info['devices']['config'] = config.get('render_devices', "auto")
|
||||
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
|
||||
elif key == 'models':
|
||||
return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS)
|
||||
elif key == 'modifiers': return FileResponse(os.path.join(app.SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
|
||||
elif key == 'ui_plugins': return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
|
||||
|
||||
def ping_internal(session_id:str=None):
|
||||
if task_manager.is_alive() <= 0: # Check that render threads are alive.
|
||||
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
raise HTTPException(status_code=500, detail='Render thread is dead.')
|
||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
output_dir = config.get("force_save_path", os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME))
|
||||
|
||||
system_info = {
|
||||
"devices": task_manager.get_devices(),
|
||||
"hosts": app.getIPConfig(),
|
||||
"default_output_dir": output_dir,
|
||||
"enforce_output_dir": ("force_save_path" in config),
|
||||
}
|
||||
system_info["devices"]["config"] = config.get("render_devices", "auto")
|
||||
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
|
||||
elif key == "models":
|
||||
return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS)
|
||||
elif key == "modifiers":
|
||||
return FileResponse(os.path.join(app.SD_UI_DIR, "modifiers.json"), headers=NOCACHE_HEADERS)
|
||||
elif key == "ui_plugins":
|
||||
return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"Request for unknown {key}") # HTTP404 Not Found
|
||||
|
||||
|
||||
def ping_internal(session_id: str = None):
|
||||
if task_manager.is_alive() <= 0: # Check that render threads are alive.
|
||||
if task_manager.current_state_error:
|
||||
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
raise HTTPException(status_code=500, detail="Render thread is dead.")
|
||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration):
|
||||
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
# Alive
|
||||
response = {'status': str(task_manager.current_state)}
|
||||
response = {"status": str(task_manager.current_state)}
|
||||
if session_id:
|
||||
session = task_manager.get_cached_session(session_id, update_ttl=True)
|
||||
response['tasks'] = {id(t): t.status for t in session.tasks}
|
||||
response['devices'] = task_manager.get_devices()
|
||||
response["tasks"] = {id(t): t.status for t in session.tasks}
|
||||
response["devices"] = task_manager.get_devices()
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
|
||||
|
||||
def render_internal(req: dict):
|
||||
try:
|
||||
# separate out the request data into rendering and task-specific data
|
||||
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
|
||||
task_data: TaskData = TaskData.parse_obj(req)
|
||||
|
||||
render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision
|
||||
# Overwrite user specified save path
|
||||
config = app.getConfig()
|
||||
if "force_save_path" in config:
|
||||
task_data.save_to_disk_path = config["force_save_path"]
|
||||
|
||||
app.save_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model, task_data.vram_usage_level)
|
||||
render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision
|
||||
|
||||
app.save_to_config(
|
||||
task_data.use_stable_diffusion_model,
|
||||
task_data.use_vae_model,
|
||||
task_data.use_hypernetwork_model,
|
||||
task_data.vram_usage_level,
|
||||
)
|
||||
|
||||
# enqueue the task
|
||||
new_task = task_manager.render(render_req, task_data)
|
||||
response = {
|
||||
'status': str(task_manager.current_state),
|
||||
'queue': len(task_manager.tasks_queue),
|
||||
'stream': f'/image/stream/{id(new_task)}',
|
||||
'task': id(new_task)
|
||||
"status": str(task_manager.current_state),
|
||||
"queue": len(task_manager.tasks_queue),
|
||||
"stream": f"/image/stream/{id(new_task)}",
|
||||
"task": id(new_task),
|
||||
}
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
except ChildProcessError as e: # Render thread is dead
|
||||
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
|
||||
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
|
||||
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
|
||||
except ChildProcessError as e: # Render thread is dead
|
||||
raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error
|
||||
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
|
||||
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def model_merge_internal(req: dict):
|
||||
try:
|
||||
from sdkit.train import merge_models
|
||||
from easydiffusion.utils.save_utils import filename_regex
|
||||
|
||||
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
||||
|
||||
merge_models(model_manager.resolve_model_to_use(mergeReq.model0,'stable-diffusion'),
|
||||
model_manager.resolve_model_to_use(mergeReq.model1,'stable-diffusion'),
|
||||
mergeReq.ratio,
|
||||
os.path.join(app.MODELS_DIR, 'stable-diffusion', filename_regex.sub('_', mergeReq.out_path)),
|
||||
mergeReq.use_fp16
|
||||
|
||||
merge_models(
|
||||
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
|
||||
model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
|
||||
mergeReq.ratio,
|
||||
os.path.join(app.MODELS_DIR, "stable-diffusion", filename_regex.sub("_", mergeReq.out_path)),
|
||||
mergeReq.use_fp16,
|
||||
)
|
||||
return JSONResponse({'status':'OK'}, headers=NOCACHE_HEADERS)
|
||||
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def stream_internal(task_id:int):
|
||||
#TODO Move to WebSockets ??
|
||||
|
||||
def stream_internal(task_id: int):
|
||||
# TODO Move to WebSockets ??
|
||||
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
||||
if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound
|
||||
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail=f"Request {task_id} not found.") # HTTP404 NotFound
|
||||
# if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
||||
if task.buffer_queue.empty() and not task.lock.locked():
|
||||
if task.response:
|
||||
#log.info(f'Session {session_id} sending cached response')
|
||||
# log.info(f'Session {session_id} sending cached response')
|
||||
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
|
||||
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
|
||||
#log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
|
||||
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
|
||||
raise HTTPException(status_code=425, detail="Too Early, task not started yet.") # HTTP425 Too Early
|
||||
# log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
|
||||
return StreamingResponse(task.read_buffer_generator(), media_type="application/json")
|
||||
|
||||
|
||||
def stop_internal(task: int):
|
||||
if not task:
|
||||
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
|
||||
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
|
||||
task_manager.current_state_error = StopAsyncIteration('')
|
||||
return {'OK'}
|
||||
if (
|
||||
task_manager.current_state == task_manager.ServerStates.Online
|
||||
or task_manager.current_state == task_manager.ServerStates.Unavailable
|
||||
):
|
||||
raise HTTPException(status_code=409, detail="Not currently running any tasks.") # HTTP409 Conflict
|
||||
task_manager.current_state_error = StopAsyncIteration("")
|
||||
return {"OK"}
|
||||
task_id = task
|
||||
task = task_manager.get_cached_task(task_id, update_ttl=False)
|
||||
if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found
|
||||
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict
|
||||
task.error = StopAsyncIteration(f'Task {task_id} stop requested.')
|
||||
return {'OK'}
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} was not found.") # HTTP404 Not Found
|
||||
if isinstance(task.error, StopAsyncIteration):
|
||||
raise HTTPException(status_code=409, detail=f"Task {task_id} is already stopped.") # HTTP409 Conflict
|
||||
task.error = StopAsyncIteration(f"Task {task_id} stop requested.")
|
||||
return {"OK"}
|
||||
|
||||
|
||||
def get_image_internal(task_id: int, img_id: int):
|
||||
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
||||
if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound
|
||||
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
|
||||
if not task:
|
||||
raise HTTPException(status_code=410, detail=f"Task {task_id} could not be found.") # HTTP404 NotFound
|
||||
if not task.temp_images[img_id]:
|
||||
raise HTTPException(status_code=425, detail="Too Early, task data is not available yet.") # HTTP425 Too Early
|
||||
try:
|
||||
img_data = task.temp_images[img_id]
|
||||
img_data.seek(0)
|
||||
return StreamingResponse(img_data, media_type='image/jpeg')
|
||||
return StreamingResponse(img_data, media_type="image/jpeg")
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
@ -7,7 +7,7 @@ Notes:
|
||||
import json
|
||||
import traceback
|
||||
|
||||
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
||||
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
||||
|
||||
import torch
|
||||
import queue, threading, time, weakref
|
||||
@ -19,71 +19,98 @@ from easydiffusion.utils import log
|
||||
|
||||
from sdkit.utils import gc
|
||||
|
||||
THREAD_NAME_PREFIX = ''
|
||||
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
||||
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
||||
THREAD_NAME_PREFIX = ""
|
||||
ERR_LOCK_FAILED = " failed to acquire lock within timeout."
|
||||
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
||||
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
|
||||
|
||||
DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init.
|
||||
DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init.
|
||||
|
||||
|
||||
class SymbolClass(type): # Print nicely formatted Symbol names.
|
||||
def __repr__(self):
|
||||
return self.__qualname__
|
||||
|
||||
def __str__(self):
|
||||
return self.__name__
|
||||
|
||||
|
||||
class Symbol(metaclass=SymbolClass):
|
||||
pass
|
||||
|
||||
class SymbolClass(type): # Print nicely formatted Symbol names.
|
||||
def __repr__(self): return self.__qualname__
|
||||
def __str__(self): return self.__name__
|
||||
class Symbol(metaclass=SymbolClass): pass
|
||||
|
||||
class ServerStates:
|
||||
class Init(Symbol): pass
|
||||
class LoadingModel(Symbol): pass
|
||||
class Online(Symbol): pass
|
||||
class Rendering(Symbol): pass
|
||||
class Unavailable(Symbol): pass
|
||||
class Init(Symbol):
|
||||
pass
|
||||
|
||||
class RenderTask(): # Task with output queue and completion lock.
|
||||
class LoadingModel(Symbol):
|
||||
pass
|
||||
|
||||
class Online(Symbol):
|
||||
pass
|
||||
|
||||
class Rendering(Symbol):
|
||||
pass
|
||||
|
||||
class Unavailable(Symbol):
|
||||
pass
|
||||
|
||||
|
||||
class RenderTask: # Task with output queue and completion lock.
|
||||
def __init__(self, req: GenerateImageRequest, task_data: TaskData):
|
||||
task_data.request_id = id(self)
|
||||
self.render_request: GenerateImageRequest = req # Initial Request
|
||||
self.task_data: TaskData = task_data
|
||||
self.response: Any = None # Copy of the last reponse
|
||||
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
||||
self.temp_images:list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
|
||||
self.response: Any = None # Copy of the last reponse
|
||||
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
||||
self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
|
||||
self.error: Exception = None
|
||||
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
|
||||
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
|
||||
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
|
||||
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
|
||||
|
||||
async def read_buffer_generator(self):
|
||||
try:
|
||||
while not self.buffer_queue.empty():
|
||||
res = self.buffer_queue.get(block=False)
|
||||
self.buffer_queue.task_done()
|
||||
yield res
|
||||
except queue.Empty as e: yield
|
||||
except queue.Empty as e:
|
||||
yield
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if self.lock.locked():
|
||||
return 'running'
|
||||
return "running"
|
||||
if isinstance(self.error, StopAsyncIteration):
|
||||
return 'stopped'
|
||||
return "stopped"
|
||||
if self.error:
|
||||
return 'error'
|
||||
return "error"
|
||||
if not self.buffer_queue.empty():
|
||||
return 'buffer'
|
||||
return "buffer"
|
||||
if self.response:
|
||||
return 'completed'
|
||||
return 'pending'
|
||||
return "completed"
|
||||
return "pending"
|
||||
|
||||
@property
|
||||
def is_pending(self):
|
||||
return bool(not self.response and not self.error)
|
||||
|
||||
|
||||
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||
class DataCache():
|
||||
class DataCache:
|
||||
def __init__(self):
|
||||
self._base = dict()
|
||||
self._lock: threading.Lock = threading.Lock()
|
||||
|
||||
def _get_ttl_time(self, ttl: int) -> int:
|
||||
return int(time.time()) + ttl
|
||||
|
||||
def _is_expired(self, timestamp: int) -> bool:
|
||||
return int(time.time()) >= timestamp
|
||||
|
||||
def clean(self) -> None:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clean' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("DataCache.clean" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
# Create a list of expired keys to delete
|
||||
to_delete = []
|
||||
@ -95,20 +122,26 @@ class DataCache():
|
||||
for key in to_delete:
|
||||
(_, val) = self._base[key]
|
||||
if isinstance(val, RenderTask):
|
||||
log.debug(f'RenderTask {key} expired. Data removed.')
|
||||
log.debug(f"RenderTask {key} expired. Data removed.")
|
||||
elif isinstance(val, SessionState):
|
||||
log.debug(f'Session {key} expired. Data removed.')
|
||||
log.debug(f"Session {key} expired. Data removed.")
|
||||
else:
|
||||
log.debug(f'Key {key} expired. Data removed.')
|
||||
log.debug(f"Key {key} expired. Data removed.")
|
||||
del self._base[key]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def clear(self) -> None:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clear' + ERR_LOCK_FAILED)
|
||||
try: self._base.clear()
|
||||
finally: self._lock.release()
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("DataCache.clear" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
self._base.clear()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def delete(self, key: Hashable) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.delete' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("DataCache.delete" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
if key not in self._base:
|
||||
return False
|
||||
@ -116,8 +149,10 @@ class DataCache():
|
||||
return True
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def keep(self, key: Hashable, ttl: int) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.keep' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("DataCache.keep" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
if key in self._base:
|
||||
_, value = self._base.get(key)
|
||||
@ -126,12 +161,12 @@ class DataCache():
|
||||
return False
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.put' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("DataCache.put" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
self._base[key] = (
|
||||
self._get_ttl_time(ttl), value
|
||||
)
|
||||
self._base[key] = (self._get_ttl_time(ttl), value)
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
return False
|
||||
@ -139,35 +174,41 @@ class DataCache():
|
||||
return True
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def tryGet(self, key: Hashable) -> Any:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.tryGet' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("DataCache.tryGet" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
ttl, value = self._base.get(key, (None, None))
|
||||
if ttl is not None and self._is_expired(ttl):
|
||||
log.debug(f'Session {key} expired. Discarding data.')
|
||||
log.debug(f"Session {key} expired. Discarding data.")
|
||||
del self._base[key]
|
||||
return None
|
||||
return value
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
|
||||
manager_lock = threading.RLock()
|
||||
render_threads = []
|
||||
current_state = ServerStates.Init
|
||||
current_state_error:Exception = None
|
||||
current_state_error: Exception = None
|
||||
tasks_queue = []
|
||||
session_cache = DataCache()
|
||||
task_cache = DataCache()
|
||||
weak_thread_data = weakref.WeakKeyDictionary()
|
||||
idle_event: threading.Event = threading.Event()
|
||||
|
||||
class SessionState():
|
||||
|
||||
class SessionState:
|
||||
def __init__(self, id: str):
|
||||
self._id = id
|
||||
self._tasks_ids = []
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def tasks(self):
|
||||
tasks = []
|
||||
@ -176,6 +217,7 @@ class SessionState():
|
||||
if task:
|
||||
tasks.append(task)
|
||||
return tasks
|
||||
|
||||
def put(self, task, ttl=TASK_TTL):
|
||||
task_id = id(task)
|
||||
self._tasks_ids.append(task_id)
|
||||
@ -185,10 +227,12 @@ class SessionState():
|
||||
self._tasks_ids.pop(0)
|
||||
return True
|
||||
|
||||
|
||||
def thread_get_next_task():
|
||||
from easydiffusion import renderer
|
||||
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.')
|
||||
log.warn(f"Render thread on device: {renderer.context.device} failed to acquire manager lock.")
|
||||
return None
|
||||
if len(tasks_queue) <= 0:
|
||||
manager_lock.release()
|
||||
@ -202,10 +246,10 @@ def thread_get_next_task():
|
||||
continue # requested device alive, skip current one.
|
||||
else:
|
||||
# Requested device is not active, return error to UI.
|
||||
queued_task.error = Exception(queued_task.render_device + ' is not currently active.')
|
||||
queued_task.error = Exception(queued_task.render_device + " is not currently active.")
|
||||
task = queued_task
|
||||
break
|
||||
if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1:
|
||||
if not queued_task.render_device and renderer.context.device == "cpu" and is_alive() > 1:
|
||||
# not asking for any specific devices, cpu want to grab task but other render devices are alive.
|
||||
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
|
||||
task = queued_task
|
||||
@ -216,17 +260,19 @@ def thread_get_next_task():
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
|
||||
def thread_render(device):
|
||||
global current_state, current_state_error
|
||||
|
||||
from easydiffusion import renderer, model_manager
|
||||
|
||||
try:
|
||||
renderer.init(device)
|
||||
|
||||
weak_thread_data[threading.current_thread()] = {
|
||||
'device': renderer.context.device,
|
||||
'device_name': renderer.context.device_name,
|
||||
'alive': True
|
||||
"device": renderer.context.device,
|
||||
"device_name": renderer.context.device_name,
|
||||
"alive": True,
|
||||
}
|
||||
|
||||
current_state = ServerStates.LoadingModel
|
||||
@ -235,17 +281,14 @@ def thread_render(device):
|
||||
current_state = ServerStates.Online
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
weak_thread_data[threading.current_thread()] = {
|
||||
'error': e,
|
||||
'alive': False
|
||||
}
|
||||
weak_thread_data[threading.current_thread()] = {"error": e, "alive": False}
|
||||
return
|
||||
|
||||
while True:
|
||||
session_cache.clean()
|
||||
task_cache.clean()
|
||||
if not weak_thread_data[threading.current_thread()]['alive']:
|
||||
log.info(f'Shutting down thread for device {renderer.context.device}')
|
||||
if not weak_thread_data[threading.current_thread()]["alive"]:
|
||||
log.info(f"Shutting down thread for device {renderer.context.device}")
|
||||
model_manager.unload_all(renderer.context)
|
||||
return
|
||||
if isinstance(current_state_error, SystemExit):
|
||||
@ -258,39 +301,47 @@ def thread_render(device):
|
||||
continue
|
||||
if task.error is not None:
|
||||
log.error(task.error)
|
||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||
task.response = {"status": "failed", "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
continue
|
||||
if current_state_error:
|
||||
task.error = current_state_error
|
||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||
task.response = {"status": "failed", "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
continue
|
||||
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}')
|
||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||
log.info(f"Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}")
|
||||
if not task.lock.acquire(blocking=False):
|
||||
raise Exception("Got locked task from queue.")
|
||||
try:
|
||||
|
||||
def step_callback():
|
||||
global current_state_error
|
||||
|
||||
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
||||
if (
|
||||
isinstance(current_state_error, SystemExit)
|
||||
or isinstance(current_state_error, StopAsyncIteration)
|
||||
or isinstance(task.error, StopAsyncIteration)
|
||||
):
|
||||
renderer.context.stop_processing = True
|
||||
if isinstance(current_state_error, StopAsyncIteration):
|
||||
task.error = current_state_error
|
||||
current_state_error = None
|
||||
log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}')
|
||||
log.info(f"Session {task.task_data.session_id} sent cancel signal for task {id(task)}")
|
||||
|
||||
current_state = ServerStates.LoadingModel
|
||||
model_manager.resolve_model_paths(task.task_data)
|
||||
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
|
||||
|
||||
current_state = ServerStates.Rendering
|
||||
task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
||||
task.response = renderer.make_images(
|
||||
task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback
|
||||
)
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||
except Exception as e:
|
||||
task.error = str(e)
|
||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||
task.response = {"status": "failed", "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
log.error(traceback.format_exc())
|
||||
finally:
|
||||
@ -299,21 +350,25 @@ def thread_render(device):
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||
if isinstance(task.error, StopAsyncIteration):
|
||||
log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!')
|
||||
log.info(f"Session {task.task_data.session_id} task {id(task)} cancelled!")
|
||||
elif task.error is not None:
|
||||
log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
|
||||
log.info(f"Session {task.task_data.session_id} task {id(task)} failed!")
|
||||
else:
|
||||
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.')
|
||||
log.info(
|
||||
f"Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}."
|
||||
)
|
||||
current_state = ServerStates.Online
|
||||
|
||||
def get_cached_task(task_id:str, update_ttl:bool=False):
|
||||
|
||||
def get_cached_task(task_id: str, update_ttl: bool = False):
|
||||
# By calling keep before tryGet, wont discard if was expired.
|
||||
if update_ttl and not task_cache.keep(task_id, TASK_TTL):
|
||||
# Failed to keep task, already gone.
|
||||
return None
|
||||
return task_cache.tryGet(task_id)
|
||||
|
||||
def get_cached_session(session_id:str, update_ttl:bool=False):
|
||||
|
||||
def get_cached_session(session_id: str, update_ttl: bool = False):
|
||||
if update_ttl:
|
||||
session_cache.keep(session_id, TASK_TTL)
|
||||
session = session_cache.tryGet(session_id)
|
||||
@ -322,64 +377,68 @@ def get_cached_session(session_id:str, update_ttl:bool=False):
|
||||
session_cache.put(session_id, session, TASK_TTL)
|
||||
return session
|
||||
|
||||
|
||||
def get_devices():
|
||||
devices = {
|
||||
'all': {},
|
||||
'active': {},
|
||||
"all": {},
|
||||
"active": {},
|
||||
}
|
||||
|
||||
def get_device_info(device):
|
||||
if device == 'cpu':
|
||||
return {'name': device_manager.get_processor_name()}
|
||||
|
||||
if device == "cpu":
|
||||
return {"name": device_manager.get_processor_name()}
|
||||
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_free /= float(10**9)
|
||||
mem_total /= float(10**9)
|
||||
|
||||
return {
|
||||
'name': torch.cuda.get_device_name(device),
|
||||
'mem_free': mem_free,
|
||||
'mem_total': mem_total,
|
||||
'max_vram_usage_level': device_manager.get_max_vram_usage_level(device),
|
||||
"name": torch.cuda.get_device_name(device),
|
||||
"mem_free": mem_free,
|
||||
"mem_total": mem_total,
|
||||
"max_vram_usage_level": device_manager.get_max_vram_usage_level(device),
|
||||
}
|
||||
|
||||
# list the compatible devices
|
||||
gpu_count = torch.cuda.device_count()
|
||||
for device in range(gpu_count):
|
||||
device = f'cuda:{device}'
|
||||
device = f"cuda:{device}"
|
||||
if not device_manager.is_device_compatible(device):
|
||||
continue
|
||||
|
||||
devices['all'].update({device: get_device_info(device)})
|
||||
devices["all"].update({device: get_device_info(device)})
|
||||
|
||||
devices['all'].update({'cpu': get_device_info('cpu')})
|
||||
devices["all"].update({"cpu": get_device_info("cpu")})
|
||||
|
||||
# list the activated devices
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('get_devices' + ERR_LOCK_FAILED)
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("get_devices" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
for rthread in render_threads:
|
||||
if not rthread.is_alive():
|
||||
continue
|
||||
weak_data = weak_thread_data.get(rthread)
|
||||
if not weak_data or not 'device' in weak_data or not 'device_name' in weak_data:
|
||||
if not weak_data or not "device" in weak_data or not "device_name" in weak_data:
|
||||
continue
|
||||
device = weak_data['device']
|
||||
devices['active'].update({device: get_device_info(device)})
|
||||
device = weak_data["device"]
|
||||
devices["active"].update({device: get_device_info(device)})
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
return devices
|
||||
|
||||
|
||||
def is_alive(device=None):
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED)
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("is_alive" + ERR_LOCK_FAILED)
|
||||
nbr_alive = 0
|
||||
try:
|
||||
for rthread in render_threads:
|
||||
if device is not None:
|
||||
weak_data = weak_thread_data.get(rthread)
|
||||
if weak_data is None or not 'device' in weak_data or weak_data['device'] is None:
|
||||
if weak_data is None or not "device" in weak_data or weak_data["device"] is None:
|
||||
continue
|
||||
thread_device = weak_data['device']
|
||||
thread_device = weak_data["device"]
|
||||
if thread_device != device:
|
||||
continue
|
||||
if rthread.is_alive():
|
||||
@ -388,11 +447,13 @@ def is_alive(device=None):
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
|
||||
def start_render_thread(device):
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED)
|
||||
log.info(f'Start new Rendering Thread on device: {device}')
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("start_render_thread" + ERR_LOCK_FAILED)
|
||||
log.info(f"Start new Rendering Thread on device: {device}")
|
||||
try:
|
||||
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
|
||||
rthread = threading.Thread(target=thread_render, kwargs={"device": device})
|
||||
rthread.daemon = True
|
||||
rthread.name = THREAD_NAME_PREFIX + device
|
||||
rthread.start()
|
||||
@ -400,8 +461,8 @@ def start_render_thread(device):
|
||||
finally:
|
||||
manager_lock.release()
|
||||
timeout = DEVICE_START_TIMEOUT
|
||||
while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]:
|
||||
if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]:
|
||||
while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]:
|
||||
if rthread in weak_thread_data and "error" in weak_thread_data[rthread]:
|
||||
log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
|
||||
return False
|
||||
if timeout <= 0:
|
||||
@ -410,25 +471,27 @@ def start_render_thread(device):
|
||||
time.sleep(1)
|
||||
return True
|
||||
|
||||
|
||||
def stop_render_thread(device):
|
||||
try:
|
||||
device_manager.validate_device_id(device, log_prefix='stop_render_thread')
|
||||
device_manager.validate_device_id(device, log_prefix="stop_render_thread")
|
||||
except:
|
||||
log.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED)
|
||||
log.info(f'Stopping Rendering Thread on device: {device}')
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("stop_render_thread" + ERR_LOCK_FAILED)
|
||||
log.info(f"Stopping Rendering Thread on device: {device}")
|
||||
|
||||
try:
|
||||
thread_to_remove = None
|
||||
for rthread in render_threads:
|
||||
weak_data = weak_thread_data.get(rthread)
|
||||
if weak_data is None or not 'device' in weak_data or weak_data['device'] is None:
|
||||
if weak_data is None or not "device" in weak_data or weak_data["device"] is None:
|
||||
continue
|
||||
thread_device = weak_data['device']
|
||||
thread_device = weak_data["device"]
|
||||
if thread_device == device:
|
||||
weak_data['alive'] = False
|
||||
weak_data["alive"] = False
|
||||
thread_to_remove = rthread
|
||||
break
|
||||
if thread_to_remove is not None:
|
||||
@ -439,44 +502,51 @@ def stop_render_thread(device):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def update_render_threads(render_devices, active_devices):
|
||||
devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices)
|
||||
log.debug(f'devices_to_start: {devices_to_start}')
|
||||
log.debug(f'devices_to_stop: {devices_to_stop}')
|
||||
log.debug(f"devices_to_start: {devices_to_start}")
|
||||
log.debug(f"devices_to_stop: {devices_to_stop}")
|
||||
|
||||
for device in devices_to_stop:
|
||||
if is_alive(device) <= 0:
|
||||
log.debug(f'{device} is not alive')
|
||||
log.debug(f"{device} is not alive")
|
||||
continue
|
||||
if not stop_render_thread(device):
|
||||
log.warn(f'{device} could not stop render thread')
|
||||
log.warn(f"{device} could not stop render thread")
|
||||
|
||||
for device in devices_to_start:
|
||||
if is_alive(device) >= 1:
|
||||
log.debug(f'{device} already registered.')
|
||||
log.debug(f"{device} already registered.")
|
||||
continue
|
||||
if not start_render_thread(device):
|
||||
log.warn(f'{device} failed to start.')
|
||||
log.warn(f"{device} failed to start.")
|
||||
|
||||
if is_alive() <= 0: # No running devices, probably invalid user config.
|
||||
raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json')
|
||||
if is_alive() <= 0: # No running devices, probably invalid user config.
|
||||
raise EnvironmentError(
|
||||
'ERROR: No active render devices! Please verify the "render_devices" value in config.json'
|
||||
)
|
||||
|
||||
log.debug(f"active devices: {get_devices()['active']}")
|
||||
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
global current_state_error
|
||||
current_state_error = SystemExit('Application shutting down.')
|
||||
current_state_error = SystemExit("Application shutting down.")
|
||||
|
||||
|
||||
def render(render_req: GenerateImageRequest, task_data: TaskData):
|
||||
current_thread_count = is_alive()
|
||||
if current_thread_count <= 0: # Render thread is dead
|
||||
raise ChildProcessError('Rendering thread has died.')
|
||||
raise ChildProcessError("Rendering thread has died.")
|
||||
|
||||
# Alive, check if task in cache
|
||||
session = get_cached_session(task_data.session_id, update_ttl=True)
|
||||
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
|
||||
if current_thread_count < len(pending_tasks):
|
||||
raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
|
||||
raise ConnectionRefusedError(
|
||||
f"Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}."
|
||||
)
|
||||
|
||||
new_task = RenderTask(render_req, task_data)
|
||||
if session.put(new_task, TASK_TTL):
|
||||
@ -489,4 +559,4 @@ def render(render_req: GenerateImageRequest, task_data: TaskData):
|
||||
return new_task
|
||||
finally:
|
||||
manager_lock.release()
|
||||
raise RuntimeError('Failed to add task to cache.')
|
||||
raise RuntimeError("Failed to add task to cache.")
|
||||
|
@ -1,6 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
|
||||
|
||||
class GenerateImageRequest(BaseModel):
|
||||
prompt: str = ""
|
||||
negative_prompt: str = ""
|
||||
@ -18,28 +19,32 @@ class GenerateImageRequest(BaseModel):
|
||||
prompt_strength: float = 0.8
|
||||
preserve_init_image_color_profile = False
|
||||
|
||||
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
||||
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
||||
hypernetwork_strength: float = 0
|
||||
|
||||
|
||||
class TaskData(BaseModel):
|
||||
request_id: str = None
|
||||
session_id: str = "session"
|
||||
save_to_disk_path: str = None
|
||||
vram_usage_level: str = "balanced" # or "low" or "medium"
|
||||
vram_usage_level: str = "balanced" # or "low" or "medium"
|
||||
|
||||
use_face_correction: str = None # or "GFPGANv1.3"
|
||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||
upscale_amount: int = 4 # or 2
|
||||
use_face_correction: str = None # or "GFPGANv1.3"
|
||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||
upscale_amount: int = 4 # or 2
|
||||
use_stable_diffusion_model: str = "sd-v1-4"
|
||||
# use_stable_diffusion_config: str = "v1-inference"
|
||||
use_vae_model: str = None
|
||||
use_hypernetwork_model: str = None
|
||||
|
||||
show_only_filtered_image: bool = False
|
||||
output_format: str = "jpeg" # or "png"
|
||||
block_nsfw: bool = False
|
||||
output_format: str = "jpeg" # or "png" or "webp"
|
||||
output_quality: int = 75
|
||||
metadata_output_format: str = "txt" # or "json"
|
||||
metadata_output_format: str = "txt" # or "json"
|
||||
stream_image_progress: bool = False
|
||||
stream_image_progress_interval: int = 5
|
||||
|
||||
|
||||
class MergeRequest(BaseModel):
|
||||
model0: str = None
|
||||
@ -48,8 +53,9 @@ class MergeRequest(BaseModel):
|
||||
out_path: str = "mix"
|
||||
use_fp16 = True
|
||||
|
||||
|
||||
class Image:
|
||||
data: str # base64
|
||||
data: str # base64
|
||||
seed: int
|
||||
is_nsfw: bool
|
||||
path_abs: str = None
|
||||
@ -65,6 +71,7 @@ class Image:
|
||||
"path_abs": self.path_abs,
|
||||
}
|
||||
|
||||
|
||||
class Response:
|
||||
render_request: GenerateImageRequest
|
||||
task_data: TaskData
|
||||
@ -80,7 +87,7 @@ class Response:
|
||||
del self.render_request.init_image_mask
|
||||
|
||||
res = {
|
||||
"status": 'succeeded',
|
||||
"status": "succeeded",
|
||||
"render_request": self.render_request.dict(),
|
||||
"task_data": self.task_data.dict(),
|
||||
"output": [],
|
||||
@ -91,5 +98,6 @@ class Response:
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class UserInitiatedStop(Exception):
|
||||
pass
|
||||
|
@ -1,8 +1,8 @@
|
||||
import logging
|
||||
|
||||
log = logging.getLogger('easydiffusion')
|
||||
log = logging.getLogger("easydiffusion")
|
||||
|
||||
from .save_utils import (
|
||||
save_images_to_disk,
|
||||
get_printable_request,
|
||||
)
|
||||
)
|
||||
|
@ -7,82 +7,126 @@ from easydiffusion.types import TaskData, GenerateImageRequest
|
||||
|
||||
from sdkit.utils import save_images, save_dicts
|
||||
|
||||
filename_regex = re.compile('[^a-zA-Z0-9._-]')
|
||||
filename_regex = re.compile("[^a-zA-Z0-9._-]")
|
||||
|
||||
# keep in sync with `ui/media/js/dnd.js`
|
||||
TASK_TEXT_MAPPING = {
|
||||
'prompt': 'Prompt',
|
||||
'width': 'Width',
|
||||
'height': 'Height',
|
||||
'seed': 'Seed',
|
||||
'num_inference_steps': 'Steps',
|
||||
'guidance_scale': 'Guidance Scale',
|
||||
'prompt_strength': 'Prompt Strength',
|
||||
'use_face_correction': 'Use Face Correction',
|
||||
'use_upscale': 'Use Upscaling',
|
||||
'upscale_amount': 'Upscale By',
|
||||
'sampler_name': 'Sampler',
|
||||
'negative_prompt': 'Negative Prompt',
|
||||
'use_stable_diffusion_model': 'Stable Diffusion model',
|
||||
'use_hypernetwork_model': 'Hypernetwork model',
|
||||
'hypernetwork_strength': 'Hypernetwork Strength'
|
||||
"prompt": "Prompt",
|
||||
"width": "Width",
|
||||
"height": "Height",
|
||||
"seed": "Seed",
|
||||
"num_inference_steps": "Steps",
|
||||
"guidance_scale": "Guidance Scale",
|
||||
"prompt_strength": "Prompt Strength",
|
||||
"use_face_correction": "Use Face Correction",
|
||||
"use_upscale": "Use Upscaling",
|
||||
"upscale_amount": "Upscale By",
|
||||
"sampler_name": "Sampler",
|
||||
"negative_prompt": "Negative Prompt",
|
||||
"use_stable_diffusion_model": "Stable Diffusion model",
|
||||
"use_vae_model": "VAE model",
|
||||
"use_hypernetwork_model": "Hypernetwork model",
|
||||
"hypernetwork_strength": "Hypernetwork Strength",
|
||||
}
|
||||
|
||||
|
||||
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
|
||||
now = time.time()
|
||||
save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
||||
save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub("_", task_data.session_id))
|
||||
metadata_entries = get_metadata_entries_for_request(req, task_data)
|
||||
make_filename = make_filename_callback(req, now=now)
|
||||
|
||||
if task_data.show_only_filtered_image or filtered_images is images:
|
||||
save_images(filtered_images, save_dir_path, file_name=make_filename, output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||
save_dicts(metadata_entries, save_dir_path, file_name=make_filename, output_format=task_data.metadata_output_format)
|
||||
save_images(
|
||||
filtered_images,
|
||||
save_dir_path,
|
||||
file_name=make_filename,
|
||||
output_format=task_data.output_format,
|
||||
output_quality=task_data.output_quality,
|
||||
)
|
||||
if task_data.metadata_output_format.lower() in ["json", "txt", "embed"]:
|
||||
save_dicts(
|
||||
metadata_entries,
|
||||
save_dir_path,
|
||||
file_name=make_filename,
|
||||
output_format=task_data.metadata_output_format,
|
||||
file_format=task_data.output_format,
|
||||
)
|
||||
else:
|
||||
make_filter_filename = make_filename_callback(req, now=now, suffix='filtered')
|
||||
make_filter_filename = make_filename_callback(req, now=now, suffix="filtered")
|
||||
|
||||
save_images(
|
||||
images,
|
||||
save_dir_path,
|
||||
file_name=make_filename,
|
||||
output_format=task_data.output_format,
|
||||
output_quality=task_data.output_quality,
|
||||
)
|
||||
save_images(
|
||||
filtered_images,
|
||||
save_dir_path,
|
||||
file_name=make_filter_filename,
|
||||
output_format=task_data.output_format,
|
||||
output_quality=task_data.output_quality,
|
||||
)
|
||||
if task_data.metadata_output_format.lower() in ["json", "txt", "embed"]:
|
||||
save_dicts(
|
||||
metadata_entries,
|
||||
save_dir_path,
|
||||
file_name=make_filter_filename,
|
||||
output_format=task_data.metadata_output_format,
|
||||
file_format=task_data.output_format,
|
||||
)
|
||||
|
||||
save_images(images, save_dir_path, file_name=make_filename, output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||
save_images(filtered_images, save_dir_path, file_name=make_filter_filename, output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||
save_dicts(metadata_entries, save_dir_path, file_name=make_filter_filename, output_format=task_data.metadata_output_format)
|
||||
|
||||
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData):
|
||||
metadata = get_printable_request(req)
|
||||
metadata.update({
|
||||
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
|
||||
'use_vae_model': task_data.use_vae_model,
|
||||
'use_hypernetwork_model': task_data.use_hypernetwork_model,
|
||||
'use_face_correction': task_data.use_face_correction,
|
||||
'use_upscale': task_data.use_upscale,
|
||||
})
|
||||
if metadata['use_upscale'] is not None:
|
||||
metadata['upscale_amount'] = task_data.upscale_amount
|
||||
metadata.update(
|
||||
{
|
||||
"use_stable_diffusion_model": task_data.use_stable_diffusion_model,
|
||||
"use_vae_model": task_data.use_vae_model,
|
||||
"use_hypernetwork_model": task_data.use_hypernetwork_model,
|
||||
"use_face_correction": task_data.use_face_correction,
|
||||
"use_upscale": task_data.use_upscale,
|
||||
}
|
||||
)
|
||||
if metadata["use_upscale"] is not None:
|
||||
metadata["upscale_amount"] = task_data.upscale_amount
|
||||
if task_data.use_hypernetwork_model is None:
|
||||
del metadata["hypernetwork_strength"]
|
||||
|
||||
# if text, format it in the text format expected by the UI
|
||||
is_txt_format = (task_data.metadata_output_format.lower() == 'txt')
|
||||
is_txt_format = task_data.metadata_output_format.lower() == "txt"
|
||||
if is_txt_format:
|
||||
metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING}
|
||||
|
||||
entries = [metadata.copy() for _ in range(req.num_outputs)]
|
||||
for i, entry in enumerate(entries):
|
||||
entry['Seed' if is_txt_format else 'seed'] = req.seed + i
|
||||
entry["Seed" if is_txt_format else "seed"] = req.seed + i
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def get_printable_request(req: GenerateImageRequest):
|
||||
metadata = req.dict()
|
||||
del metadata['init_image']
|
||||
del metadata['init_image_mask']
|
||||
del metadata["init_image"]
|
||||
del metadata["init_image_mask"]
|
||||
if req.init_image is None:
|
||||
del metadata["prompt_strength"]
|
||||
return metadata
|
||||
|
||||
|
||||
def make_filename_callback(req: GenerateImageRequest, suffix=None, now=None):
|
||||
if now is None:
|
||||
now = time.time()
|
||||
def make_filename(i):
|
||||
img_id = base64.b64encode(int(now+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
||||
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
||||
|
||||
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
|
||||
def make_filename(i):
|
||||
img_id = base64.b64encode(int(now + i).to_bytes(8, "big")).decode() # Generate unique ID based on time.
|
||||
img_id = img_id.translate({43: None, 47: None, 61: None})[-8:] # Remove + / = and keep last 8 chars.
|
||||
|
||||
prompt_flattened = filename_regex.sub("_", req.prompt)[:50]
|
||||
name = f"{prompt_flattened}_{img_id}"
|
||||
name = name if suffix is None else f'{name}_{suffix}'
|
||||
name = name if suffix is None else f"{name}_{suffix}"
|
||||
return name
|
||||
|
||||
return make_filename
|
||||
|
@ -14,6 +14,7 @@
|
||||
<link rel="stylesheet" href="/media/css/modifier-thumbnails.css">
|
||||
<link rel="stylesheet" href="/media/css/fontawesome-all.min.css">
|
||||
<link rel="stylesheet" href="/media/css/image-editor.css">
|
||||
<link rel="stylesheet" href="/media/css/searchable-models.css">
|
||||
<link rel="manifest" href="/media/manifest.webmanifest">
|
||||
<script src="/media/js/jquery-3.6.1.min.js"></script>
|
||||
<script src="/media/js/jquery-confirm.min.js"></script>
|
||||
@ -25,7 +26,7 @@
|
||||
<div id="logo">
|
||||
<h1>
|
||||
Easy Diffusion
|
||||
<small>v2.5.15 <span id="updateBranchLabel"></span></small>
|
||||
<small>v2.5.22 <span id="updateBranchLabel"></span></small>
|
||||
</h1>
|
||||
</div>
|
||||
<div id="server-status">
|
||||
@ -50,7 +51,7 @@
|
||||
<div id="editor">
|
||||
<div id="editor-inputs">
|
||||
<div id="editor-inputs-prompt" class="row">
|
||||
<label for="prompt"><b>Enter Prompt</b></label> <small>or</small> <button id="promptsFromFileBtn">Load from a file</button>
|
||||
<label for="prompt"><b>Enter Prompt</b></label> <small>or</small> <button id="promptsFromFileBtn" class="tertiaryButton">Load from a file</button>
|
||||
<textarea id="prompt" class="col-free">a photograph of an astronaut riding a horse</textarea>
|
||||
<input id="prompt_from_file" name="prompt_from_file" type="file" /> <!-- hidden -->
|
||||
<label for="negative_prompt" class="collapsible" id="negative_prompt_handle">
|
||||
@ -69,7 +70,7 @@
|
||||
<div id="init_image_preview_container" class="image_preview_container">
|
||||
<div id="init_image_wrapper">
|
||||
<img id="init_image_preview" src="" />
|
||||
<span id="init_image_size_box"></span>
|
||||
<span id="init_image_size_box" class="img_bottom_label"></span>
|
||||
<button class="init_image_clear image_clear_btn"><i class="fa-solid fa-xmark"></i></button>
|
||||
</div>
|
||||
<div id="init_image_buttons">
|
||||
@ -125,10 +126,9 @@
|
||||
<tr><b class="settings-subheader">Image Settings</b></tr>
|
||||
<tr class="pl-5"><td><label for="seed">Seed:</label></td><td><input id="seed" name="seed" size="10" value="0" onkeypress="preventNonNumericalInput(event)"> <input id="random_seed" name="random_seed" type="checkbox" checked><label for="random_seed">Random</label></td></tr>
|
||||
<tr class="pl-5"><td><label for="num_outputs_total">Number of Images:</label></td><td><input id="num_outputs_total" name="num_outputs_total" value="1" size="1" onkeypress="preventNonNumericalInput(event)"> <label><small>(total)</small></label> <input id="num_outputs_parallel" name="num_outputs_parallel" value="1" size="1" onkeypress="preventNonNumericalInput(event)"> <label for="num_outputs_parallel"><small>(in parallel)</small></label></td></tr>
|
||||
<tr class="pl-5"><td><label for="stable_diffusion_model">Model:</label></td><td>
|
||||
<select id="stable_diffusion_model" name="stable_diffusion_model">
|
||||
<!-- <option value="sd-v1-4" selected>sd-v1-4</option> -->
|
||||
</select>
|
||||
<tr class="pl-5"><td><label for="stable_diffusion_model">Model:</label></td><td class="model-input">
|
||||
<input id="stable_diffusion_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
|
||||
<button id="reload-models" class="secondaryButton reloadModels"><i class='fa-solid fa-rotate'></i></button>
|
||||
<a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a>
|
||||
</td></tr>
|
||||
<!-- <tr id="modelConfigSelection" class="pl-5"><td><label for="model_config">Model Config:</i></label></td><td>
|
||||
@ -136,9 +136,7 @@
|
||||
</select>
|
||||
</td></tr> -->
|
||||
<tr class="pl-5"><td><label for="vae_model">Custom VAE:</i></label></td><td>
|
||||
<select id="vae_model" name="vae_model">
|
||||
<!-- <option value="" selected>None</option> -->
|
||||
</select>
|
||||
<input id="vae_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
|
||||
<a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about VAEs</span></i></a>
|
||||
</td></tr>
|
||||
<tr id="samplerSelection" class="pl-5"><td><label for="sampler_name">Sampler:</label></td><td>
|
||||
@ -157,6 +155,11 @@
|
||||
<option value="dpmpp_sde">DPM++ SDE</option>
|
||||
<option value="dpm_fast">DPM Fast</option>
|
||||
<option value="dpm_adaptive">DPM Adaptive</option>
|
||||
<option value="unipc_snr">UniPC SNR</option>
|
||||
<option value="unipc_tu">UniPC TU</option>
|
||||
<option value="unipc_snr_2">UniPC SNR 2</option>
|
||||
<option value="unipc_tu_2">UniPC TC 2</option>
|
||||
<option value="unipc_tq">UniPC TQ</option>
|
||||
</select>
|
||||
<a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/How-to-Use#samplers" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about samplers</span></i></a>
|
||||
</td></tr>
|
||||
@ -210,9 +213,7 @@
|
||||
<tr class="pl-5"><td><label for="guidance_scale_slider">Guidance Scale:</label></td><td> <input id="guidance_scale_slider" name="guidance_scale_slider" class="editor-slider" value="75" type="range" min="11" max="500"> <input id="guidance_scale" name="guidance_scale" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)"></td></tr>
|
||||
<tr id="prompt_strength_container" class="pl-5"><td><label for="prompt_strength_slider">Prompt Strength:</label></td><td> <input id="prompt_strength_slider" name="prompt_strength_slider" class="editor-slider" value="80" type="range" min="0" max="99"> <input id="prompt_strength" name="prompt_strength" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)"><br/></td></tr>
|
||||
<tr class="pl-5"><td><label for="hypernetwork_model">Hypernetwork:</i></label></td><td>
|
||||
<select id="hypernetwork_model" name="hypernetwork_model">
|
||||
<!-- <option value="" selected>None</option> -->
|
||||
</select>
|
||||
<input id="hypernetwork_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
|
||||
</td></tr>
|
||||
<tr id="hypernetwork_strength_container" class="pl-5">
|
||||
<td><label for="hypernetwork_strength_slider">Hypernetwork Strength:</label></td>
|
||||
@ -222,9 +223,10 @@
|
||||
<select id="output_format" name="output_format">
|
||||
<option value="jpeg" selected>jpeg</option>
|
||||
<option value="png">png</option>
|
||||
<option value="webp">webp</option>
|
||||
</select>
|
||||
</td></tr>
|
||||
<tr class="pl-5" id="output_quality_row"><td><label for="output_quality">JPEG Quality:</label></td><td>
|
||||
<tr class="pl-5" id="output_quality_row"><td><label for="output_quality">Image Quality:</label></td><td>
|
||||
<input id="output_quality_slider" name="output_quality" class="editor-slider" value="75" type="range" min="10" max="95"> <input id="output_quality" name="output_quality" size="4" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)">
|
||||
</td></tr>
|
||||
</table></div>
|
||||
@ -232,7 +234,7 @@
|
||||
<div><ul>
|
||||
<li><b class="settings-subheader">Render Settings</b></li>
|
||||
<li class="pl-5"><input id="stream_image_progress" name="stream_image_progress" type="checkbox"> <label for="stream_image_progress">Show a live preview <small>(uses more VRAM, slower images)</small></label></li>
|
||||
<li class="pl-5"><input id="use_face_correction" name="use_face_correction" type="checkbox"> <label for="use_face_correction">Fix incorrect faces and eyes <small>(uses GFPGAN)</small></label></li>
|
||||
<li class="pl-5"><input id="use_face_correction" name="use_face_correction" type="checkbox"> <label for="use_face_correction">Fix incorrect faces and eyes</label> <div style="display:inline-block;"><input id="gfpgan_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" /></div></li>
|
||||
<li class="pl-5">
|
||||
<input id="use_upscale" name="use_upscale" type="checkbox"> <label for="use_upscale">Scale up by</label>
|
||||
<select id="upscale_amount" name="upscale_amount">
|
||||
@ -281,8 +283,34 @@
|
||||
and selecting the desired modifiers.<br/><br/>
|
||||
Click "Image Settings" for additional settings like seed, image size, number of images to generate etc.<br/><br/>Enjoy! :)
|
||||
</div>
|
||||
<div id="preview-tools">
|
||||
<button id="clear-all-previews" class="secondaryButton"><i class="fa-solid fa-trash-can"></i> Clear All</button>
|
||||
|
||||
<div id="preview-content">
|
||||
<div id="preview-tools">
|
||||
<button id="clear-all-previews" class="secondaryButton"><i class="fa-solid fa-trash-can icon"></i> Clear All</button>
|
||||
<button id="save-all-images" class="tertiaryButton"><i class="fa-solid fa-download icon"></i> Download All Images</button>
|
||||
<div class="display-settings">
|
||||
<button id="auto_scroll_btn" class="tertiaryButton">
|
||||
<i class="fa-solid fa-arrows-up-to-line icon"></i>
|
||||
<input id="auto_scroll" name="auto_scroll" type="checkbox" style="display: none">
|
||||
<span class="simple-tooltip left">
|
||||
Scroll to generated image (<span class="state">OFF</span>)
|
||||
</span>
|
||||
</button>
|
||||
<button class="dropdown tertiaryButton">
|
||||
<i class="fa-solid fa-magnifying-glass-plus icon dropbtn"></i>
|
||||
<span class="simple-tooltip left">
|
||||
Image Size
|
||||
</span>
|
||||
</button>
|
||||
<div class="dropdown-content">
|
||||
<div class="dropdown-item">
|
||||
<input id="thumbnail_size" name="thumbnail_size" class="editor-slider" type="range" value="70" min="5" max="200" oninput="sliderUpdate(event)">
|
||||
<input id="thumbnail_size-input" name="thumbnail_size-input" size="3" value="70" pattern="^[0-9.]+$" onkeypress="preventNonNumericalInput(event)" oninput="sliderUpdate(event)"> %
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="clearfix" style="clear: both;"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@ -425,6 +453,7 @@
|
||||
<script src="media/js/image-modifiers.js"></script>
|
||||
<script src="media/js/auto-save.js"></script>
|
||||
|
||||
<script src="media/js/searchable-models.js"></script>
|
||||
<script src="media/js/main.js"></script>
|
||||
<script src="media/js/themes.js"></script>
|
||||
<script src="media/js/dnd.js"></script>
|
||||
|
@ -31,7 +31,7 @@
|
||||
}
|
||||
|
||||
.editor-options-container > * > *.active {
|
||||
border: 2px solid #3584e4;
|
||||
border: 1px solid #3584e4;
|
||||
}
|
||||
|
||||
.image_editor_opacity .editor-options-container > * > *:not(.active) {
|
||||
@ -160,6 +160,7 @@
|
||||
padding: var(--popup-padding);
|
||||
min-height: calc(100vh - (2 * var(--popup-margin)));
|
||||
max-width: none;
|
||||
min-width: fit-content;
|
||||
}
|
||||
|
||||
.image-editor-popup h1 {
|
||||
|
@ -123,6 +123,9 @@ code {
|
||||
.imgPreviewItemClearBtn {
|
||||
opacity: 0;
|
||||
}
|
||||
.imgContainer .img_bottom_label {
|
||||
opacity: 0;
|
||||
}
|
||||
.imgPreviewItemClearBtn:hover {
|
||||
background: rgb(177, 27, 0);
|
||||
}
|
||||
@ -132,6 +135,9 @@ code {
|
||||
.imgContainer:hover > .imgPreviewItemClearBtn {
|
||||
opacity: 1;
|
||||
}
|
||||
.imgContainer:hover > .img_bottom_label {
|
||||
opacity: 60%;
|
||||
}
|
||||
.imgItemInfo * {
|
||||
margin-bottom: 7px;
|
||||
}
|
||||
@ -193,7 +199,7 @@ code {
|
||||
flex: 0 0 70px;
|
||||
background: var(--accent-color);
|
||||
border: var(--primary-button-border);
|
||||
color: rgb(255, 221, 255);
|
||||
color: var(--accent-text-color);
|
||||
width: 100%;
|
||||
height: 30pt;
|
||||
}
|
||||
@ -402,10 +408,8 @@ div.img-preview img {
|
||||
display: none;
|
||||
position: absolute;
|
||||
z-index: 2;
|
||||
width: max-content;
|
||||
|
||||
background: var(--background-color4);
|
||||
border: 2px solid var(--background-color2);
|
||||
border-radius: 7px;
|
||||
padding: 5px;
|
||||
margin-bottom: 15px;
|
||||
box-shadow: 0 20px 28px 0 rgba(0, 0, 0, 0.15), 0 6px 20px 0 rgba(0, 0, 0, 0.15);
|
||||
@ -413,6 +417,36 @@ div.img-preview img {
|
||||
.dropdown:hover .dropdown-content {
|
||||
display: block;
|
||||
}
|
||||
.dropdown:hover + .dropdown-content {
|
||||
display: block;
|
||||
}
|
||||
.dropdown-content:hover {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.display-settings {
|
||||
float: right;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.display-settings .dropdown-content {
|
||||
right: 0px;
|
||||
top: 12pt;
|
||||
}
|
||||
|
||||
.dropdown-item {
|
||||
padding: 4px;
|
||||
background: var(--background-color4);
|
||||
border: 2px solid var(--background-color2);
|
||||
}
|
||||
|
||||
.dropdown-item:first-child {
|
||||
border-radius: 7px 7px 0px 0px;
|
||||
}
|
||||
|
||||
.dropdown-item:last-child {
|
||||
border-radius: 0px 0px 7px 7px;
|
||||
}
|
||||
|
||||
.imageTaskContainer {
|
||||
border: 1px solid var(--background-color2);
|
||||
@ -468,6 +502,7 @@ div.img-preview img {
|
||||
background: var(--accent-color);
|
||||
border: var(--primary-button-border);
|
||||
color: rgb(255, 221, 255);
|
||||
padding: 3pt 6pt;
|
||||
}
|
||||
.secondaryButton {
|
||||
background: rgb(132, 8, 0);
|
||||
@ -479,17 +514,26 @@ div.img-preview img {
|
||||
.secondaryButton:hover {
|
||||
background: rgb(177, 27, 0);
|
||||
}
|
||||
.useSettings {
|
||||
background: var(--accent-color);
|
||||
border: 1px solid var(--accent-color);
|
||||
color: rgb(255, 221, 255);
|
||||
.tertiaryButton {
|
||||
background: var(--tertiary-background-color);
|
||||
color: var(--tertiary-color);
|
||||
border: 1px solid var(--tertiary-border-color);
|
||||
padding: 3pt 6pt;
|
||||
border-radius: 5px;
|
||||
}
|
||||
.tertiaryButton:hover {
|
||||
background: hsl(var(--accent-hue), 100%, calc(var(--accent-lightness) + 6%));
|
||||
color: var(--accent-text-color);
|
||||
}
|
||||
.tertiaryButton.pressed {
|
||||
border-style: inset;
|
||||
background: hsl(var(--accent-hue), 100%, calc(var(--accent-lightness) + 6%));
|
||||
color: var(--accent-text-color);
|
||||
}
|
||||
.useSettings {
|
||||
margin-right: 6pt;
|
||||
float: right;
|
||||
}
|
||||
.useSettings:hover {
|
||||
background: hsl(var(--accent-hue), 100%, calc(var(--accent-lightness) + 6%));
|
||||
}
|
||||
.stopTask {
|
||||
float: right;
|
||||
}
|
||||
@ -577,6 +621,9 @@ div.img-preview img {
|
||||
} */
|
||||
|
||||
#init_image_size_box {
|
||||
border-radius: 6px 0px;
|
||||
}
|
||||
.img_bottom_label {
|
||||
position: absolute;
|
||||
right: 0px;
|
||||
bottom: 0px;
|
||||
@ -586,7 +633,6 @@ div.img-preview img {
|
||||
text-shadow: 0px 0px 4px black;
|
||||
opacity: 60%;
|
||||
font-size: 12px;
|
||||
border-radius: 6px 0px;
|
||||
}
|
||||
|
||||
#editor-settings {
|
||||
@ -603,7 +649,6 @@ div.img-preview img {
|
||||
}
|
||||
|
||||
#editor-settings-entries ul {
|
||||
margin: 0px;
|
||||
padding: 0px;
|
||||
}
|
||||
|
||||
@ -750,6 +795,13 @@ input::file-selector-button {
|
||||
right: calc(var(--input-border-size) + var(--input-switch-padding));
|
||||
opacity: 1;
|
||||
}
|
||||
.model-filter {
|
||||
width: 90%;
|
||||
padding-right: 20px;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
/* Small screens */
|
||||
@media screen and (max-width: 1265px) {
|
||||
@ -787,12 +839,6 @@ input::file-selector-button {
|
||||
width: 100%;
|
||||
object-fit: contain;
|
||||
}
|
||||
.dropdown-content {
|
||||
width: auto !important;
|
||||
transform: none !important;
|
||||
left: 0px;
|
||||
right: 0px;
|
||||
}
|
||||
#editor {
|
||||
padding: 16px 8px;
|
||||
}
|
||||
@ -825,6 +871,12 @@ input::file-selector-button {
|
||||
.simple-tooltip {
|
||||
display: none;
|
||||
}
|
||||
#preview-tools button {
|
||||
font-size: 0px;
|
||||
}
|
||||
#preview-tools button .icon {
|
||||
font-size: 12pt;
|
||||
}
|
||||
}
|
||||
|
||||
@media screen and (max-width: 500px) {
|
||||
@ -857,7 +909,7 @@ input::file-selector-button {
|
||||
#promptsFromFileBtn {
|
||||
font-size: 9pt;
|
||||
display: inline;
|
||||
background-color: var(--accent-color);
|
||||
padding: 2pt;
|
||||
}
|
||||
|
||||
.section-button {
|
||||
@ -890,18 +942,19 @@ input::file-selector-button {
|
||||
|
||||
/* SIMPLE TOOTIP */
|
||||
.simple-tooltip {
|
||||
border-radius: 3px;
|
||||
font-weight: bold;
|
||||
font-size: 12px;
|
||||
border-radius: 3px;
|
||||
font-weight: bold;
|
||||
font-size: 12px;
|
||||
background-color: var(--background-color3);
|
||||
|
||||
visibility: hidden;
|
||||
opacity: 0;
|
||||
position: absolute;
|
||||
width: max-content;
|
||||
max-width: 300px;
|
||||
padding: 8px 12px;
|
||||
transition: 0.3s all;
|
||||
opacity: 0;
|
||||
position: absolute;
|
||||
width: max-content;
|
||||
max-width: 300px;
|
||||
padding: 8px 12px;
|
||||
transition: 0.3s all;
|
||||
z-index: 1000;
|
||||
|
||||
pointer-events: none;
|
||||
}
|
||||
@ -1203,3 +1256,7 @@ body.wait-pause {
|
||||
.jconfirm.jconfirm-modern .jconfirm-box {
|
||||
background-color: var(--background-color1);
|
||||
}
|
||||
|
||||
.displayNone {
|
||||
display:none !important;
|
||||
}
|
||||
|
99
ui/media/css/searchable-models.css
Normal file
99
ui/media/css/searchable-models.css
Normal file
@ -0,0 +1,99 @@
|
||||
.model-list {
|
||||
position: absolute;
|
||||
margin-block-start: 2px;
|
||||
display: none;
|
||||
padding-inline-start: 0;
|
||||
max-height: 200px;
|
||||
overflow: auto;
|
||||
background: var(--input-background-color);
|
||||
border: var(--input-border-size) solid var(--input-border-color);
|
||||
border-radius: var(--input-border-radius);
|
||||
color: var(--input-text-color);
|
||||
z-index: 1;
|
||||
line-height: normal;
|
||||
}
|
||||
|
||||
.model-list ul {
|
||||
padding-right: 20px;
|
||||
padding-inline-start: 0;
|
||||
margin-top: 3pt;
|
||||
}
|
||||
|
||||
.model-list li {
|
||||
padding-top: 3px;
|
||||
padding-bottom: 3px;
|
||||
}
|
||||
|
||||
.model-list .icon {
|
||||
padding-right: 3pt;
|
||||
}
|
||||
|
||||
.model-result {
|
||||
list-style: none;
|
||||
}
|
||||
|
||||
.model-no-result {
|
||||
color: var(--text-color);
|
||||
list-style: none;
|
||||
padding: 3px 6px 3px 6px;
|
||||
font-size: 9pt;
|
||||
font-style: italic;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.model-list li.model-folder {
|
||||
color: var(--text-color);
|
||||
list-style: none;
|
||||
padding: 6px 6px 6px 6px;
|
||||
font-size: 9pt;
|
||||
font-weight: bold;
|
||||
border-top: 1px solid var(--background-color1);
|
||||
}
|
||||
|
||||
.model-list li.model-file {
|
||||
color: var(--input-text-color);
|
||||
list-style: none;
|
||||
padding-left: 12px;
|
||||
padding-right:20px;
|
||||
font-size: 10pt;
|
||||
font-weight: normal;
|
||||
transition: none;
|
||||
transition:property: none;
|
||||
cursor: default;
|
||||
}
|
||||
|
||||
.model-list li.model-file.in-root-folder {
|
||||
padding-left: 6px;
|
||||
}
|
||||
|
||||
.model-list li.model-file.selected {
|
||||
background: grey;
|
||||
}
|
||||
|
||||
.model-selector {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.model-selector-arrow {
|
||||
position: absolute;
|
||||
width: 17px;
|
||||
margin: 5px -17px;
|
||||
padding-top: 3px;
|
||||
cursor: pointer;
|
||||
font-size: 8pt;
|
||||
transition: none;
|
||||
}
|
||||
|
||||
.model-input {
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.reloadModels {
|
||||
background: var(--background-color2);
|
||||
border: none;
|
||||
padding: 0px 0px;
|
||||
}
|
||||
|
||||
#reload-models.secondaryButton:hover {
|
||||
background: var(--background-color2);
|
||||
}
|
@ -27,9 +27,13 @@
|
||||
--input-border-size: 1px;
|
||||
--accent-color: hsl(var(--accent-hue), 100%, var(--accent-lightness));
|
||||
--accent-color-hover: hsl(var(--accent-hue), 100%, var(--accent-lightness-hover));
|
||||
--accent-text-color: rgb(255, 221, 255);
|
||||
--primary-button-border: none;
|
||||
--input-switch-padding: 1px;
|
||||
--input-height: 18px;
|
||||
--tertiary-background-color: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) + (2 * var(--value-step))));
|
||||
--tertiary-border-color: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) + (3 * var(--value-step))));
|
||||
--tertiary-color: var(--input-text-color)
|
||||
|
||||
/* Main theme color, hex color fallback. */
|
||||
--theme-color-fallback: #673AB6;
|
||||
@ -48,6 +52,11 @@
|
||||
--input-border-color: grey;
|
||||
|
||||
--theme-color-fallback: #aaaaaa;
|
||||
|
||||
--tertiary-background-color: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) + (16.8 * var(--value-step))));
|
||||
--tertiary-border-color: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) + (12 * var(--value-step))));
|
||||
|
||||
--accent-text-color: white;
|
||||
}
|
||||
|
||||
.theme-discord {
|
||||
@ -64,6 +73,10 @@
|
||||
--input-border-color: var(--input-background-color);
|
||||
|
||||
--theme-color-fallback: #202225;
|
||||
|
||||
--tertiary-background-color: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) + (3.5 * var(--value-step))));
|
||||
--tertiary-border-color: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) + (4.5 * var(--value-step))));
|
||||
--accent-text-color: white;
|
||||
}
|
||||
|
||||
.theme-cool-blue {
|
||||
@ -81,6 +94,10 @@
|
||||
--accent-hue: 212;
|
||||
|
||||
--theme-color-fallback: #0056b8;
|
||||
|
||||
--tertiary-background-color: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) + (3.5 * var(--value-step))));
|
||||
--tertiary-border-color: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) + (4.5 * var(--value-step))));
|
||||
--accent-text-color: #f7fbff;
|
||||
}
|
||||
|
||||
|
||||
@ -97,6 +114,9 @@
|
||||
--input-background-color: var(--background-color3);
|
||||
|
||||
--theme-color-fallback: #5300b8;
|
||||
|
||||
--tertiary-background-color: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) + (3.5 * var(--value-step))));
|
||||
--tertiary-border-color: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) + (4.5 * var(--value-step))));
|
||||
}
|
||||
|
||||
.theme-super-dark {
|
||||
@ -131,6 +151,9 @@
|
||||
--input-background-color: hsl(222, var(--main-saturation), calc(var(--value-base) - (2 * var(--value-step))));
|
||||
--input-text-color: #FF0000;
|
||||
--input-border-color: #005E05;
|
||||
|
||||
--tertiary-color: white;
|
||||
--accent-text-color: #f7fbff;
|
||||
}
|
||||
|
||||
|
||||
|
@ -27,8 +27,10 @@ const SETTINGS_IDS_LIST = [
|
||||
"negative_prompt",
|
||||
"stream_image_progress",
|
||||
"use_face_correction",
|
||||
"gfpgan_model",
|
||||
"use_upscale",
|
||||
"upscale_amount",
|
||||
"block_nsfw",
|
||||
"show_only_filtered_image",
|
||||
"upscale_model",
|
||||
"preview-image",
|
||||
@ -42,7 +44,9 @@ const SETTINGS_IDS_LIST = [
|
||||
"metadata_output_format",
|
||||
"auto_save_settings",
|
||||
"apply_color_correction",
|
||||
"process_order_toggle"
|
||||
"process_order_toggle",
|
||||
"thumbnail_size",
|
||||
"auto_scroll"
|
||||
]
|
||||
|
||||
const IGNORE_BY_DEFAULT = [
|
||||
@ -92,6 +96,9 @@ async function initSettings() {
|
||||
}
|
||||
|
||||
function getSetting(element) {
|
||||
if (element.dataset && 'path' in element.dataset) {
|
||||
return element.dataset.path
|
||||
}
|
||||
if (typeof element === "string" || element instanceof String) {
|
||||
element = SETTINGS[element].element
|
||||
}
|
||||
@ -101,6 +108,10 @@ function getSetting(element) {
|
||||
return element.value
|
||||
}
|
||||
function setSetting(element, value) {
|
||||
if (element.dataset && 'path' in element.dataset) {
|
||||
element.dataset.path = value
|
||||
return // no need to dispatch any event here because the models are not loaded yet
|
||||
}
|
||||
if (typeof element === "string" || element instanceof String) {
|
||||
element = SETTINGS[element].element
|
||||
}
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
const EXT_REGEX = /(?:\.([^.]+))?$/
|
||||
const TEXT_EXTENSIONS = ['txt', 'json']
|
||||
const IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'bmp', 'tiff', 'tif', 'tga']
|
||||
const IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'bmp', 'tiff', 'tif', 'tga', 'webp']
|
||||
|
||||
function parseBoolean(stringValue) {
|
||||
if (typeof stringValue === 'boolean') {
|
||||
@ -154,10 +154,21 @@ const TASK_MAPPING = {
|
||||
|
||||
use_face_correction: { name: 'Use Face Correction',
|
||||
setUI: (use_face_correction) => {
|
||||
useFaceCorrectionField.checked = parseBoolean(use_face_correction)
|
||||
const oldVal = gfpganModelField.value
|
||||
gfpganModelField.value = getModelPath(use_face_correction, ['.pth'])
|
||||
if (gfpganModelField.value) { // Is a valid value for the field.
|
||||
useFaceCorrectionField.checked = true
|
||||
gfpganModelField.disabled = false
|
||||
} else { // Not a valid value, restore the old value and disable the filter.
|
||||
gfpganModelField.disabled = true
|
||||
gfpganModelField.value = oldVal
|
||||
useFaceCorrectionField.checked = false
|
||||
}
|
||||
|
||||
//useFaceCorrectionField.checked = parseBoolean(use_face_correction)
|
||||
},
|
||||
readUI: () => useFaceCorrectionField.checked,
|
||||
parse: (val) => parseBoolean(val)
|
||||
readUI: () => (useFaceCorrectionField.checked ? gfpganModelField.value : undefined),
|
||||
parse: (val) => val
|
||||
},
|
||||
use_upscale: { name: 'Use Upscaling',
|
||||
setUI: (use_upscale) => {
|
||||
@ -324,6 +335,7 @@ function restoreTaskToUI(task, fieldsToSkip) {
|
||||
// properly reset checkboxes
|
||||
if (!('use_face_correction' in task.reqBody)) {
|
||||
useFaceCorrectionField.checked = false
|
||||
gfpganModelField.disabled = true
|
||||
}
|
||||
if (!('use_upscale' in task.reqBody)) {
|
||||
useUpscalingField.checked = false
|
||||
@ -345,6 +357,7 @@ function restoreTaskToUI(task, fieldsToSkip) {
|
||||
initImagePreview.addEventListener('load', function() {
|
||||
if (Boolean(task.reqBody.mask)) {
|
||||
imageInpainter.setImg(task.reqBody.mask)
|
||||
maskSetting.checked = true
|
||||
}
|
||||
}, { once: true })
|
||||
initImagePreview.src = task.reqBody.init_image
|
||||
@ -363,12 +376,19 @@ function readUI() {
|
||||
}
|
||||
function getModelPath(filename, extensions)
|
||||
{
|
||||
let pathIdx = filename.lastIndexOf('/') // Linux, Mac paths
|
||||
if (pathIdx < 0) {
|
||||
pathIdx = filename.lastIndexOf('\\') // Windows paths.
|
||||
if (typeof filename !== "string") {
|
||||
return
|
||||
}
|
||||
|
||||
let pathIdx
|
||||
if (filename.includes('/models/stable-diffusion/')) {
|
||||
pathIdx = filename.indexOf('/models/stable-diffusion/') + 25 // Linux, Mac paths
|
||||
}
|
||||
else if (filename.includes('\\models\\stable-diffusion\\')) {
|
||||
pathIdx = filename.indexOf('\\models\\stable-diffusion\\') + 25 // Linux, Mac paths
|
||||
}
|
||||
if (pathIdx >= 0) {
|
||||
filename = filename.slice(pathIdx + 1)
|
||||
filename = filename.slice(pathIdx)
|
||||
}
|
||||
extensions.forEach(ext => {
|
||||
if (filename.endsWith(ext)) {
|
||||
@ -513,7 +533,7 @@ function dragOverHandler(ev) {
|
||||
ev.dataTransfer.dropEffect = "copy"
|
||||
|
||||
let img = new Image()
|
||||
img.src = location.host + '/media/images/favicon-32x32.png'
|
||||
img.src = '//' + location.host + '/media/images/favicon-32x32.png'
|
||||
ev.dataTransfer.setDragImage(img, 16, 16)
|
||||
}
|
||||
|
||||
|
@ -741,6 +741,7 @@
|
||||
"stream_progress_updates": true,
|
||||
"stream_image_progress": true,
|
||||
"show_only_filtered_image": true,
|
||||
"block_nsfw": false,
|
||||
"output_format": "png",
|
||||
"output_quality": 75,
|
||||
}
|
||||
|
@ -244,8 +244,8 @@ var IMAGE_EDITOR_SECTIONS = [
|
||||
var sub_element = document.createElement("div")
|
||||
sub_element.style.background = `var(--background-color3)`
|
||||
sub_element.style.filter = `blur(${blur_amount}px)`
|
||||
sub_element.style.width = `${size - 4}px`
|
||||
sub_element.style.height = `${size - 4}px`
|
||||
sub_element.style.width = `${size - 2}px`
|
||||
sub_element.style.height = `${size - 2}px`
|
||||
sub_element.style['border-radius'] = `${size}px`
|
||||
element.style.background = "none"
|
||||
element.appendChild(sub_element)
|
||||
|
@ -16,7 +16,7 @@ const modifierThumbnailPath = 'media/modifier-thumbnails'
|
||||
const activeCardClass = 'modifier-card-active'
|
||||
const CUSTOM_MODIFIERS_KEY = "customModifiers"
|
||||
|
||||
function createModifierCard(name, previews) {
|
||||
function createModifierCard(name, previews, removeBy) {
|
||||
const modifierCard = document.createElement('div')
|
||||
modifierCard.className = 'modifier-card'
|
||||
modifierCard.innerHTML = `
|
||||
@ -44,10 +44,10 @@ function createModifierCard(name, previews) {
|
||||
}
|
||||
|
||||
const maxLabelLength = 30
|
||||
const nameWithoutBy = name.replace('by ', '')
|
||||
const cardLabel = removeBy ? name.replace('by ', '') : name
|
||||
|
||||
if(nameWithoutBy.length <= maxLabelLength) {
|
||||
label.querySelector('p').innerText = nameWithoutBy
|
||||
if(cardLabel.length <= maxLabelLength) {
|
||||
label.querySelector('p').innerText = cardLabel
|
||||
} else {
|
||||
const tooltipText = document.createElement('span')
|
||||
tooltipText.className = 'tooltip-text'
|
||||
@ -56,13 +56,14 @@ function createModifierCard(name, previews) {
|
||||
label.classList.add('tooltip')
|
||||
label.appendChild(tooltipText)
|
||||
|
||||
label.querySelector('p').innerText = nameWithoutBy.substring(0, maxLabelLength) + '...'
|
||||
label.querySelector('p').innerText = cardLabel.substring(0, maxLabelLength) + '...'
|
||||
}
|
||||
label.querySelector('p').dataset.fullName = name // preserve the full name
|
||||
|
||||
return modifierCard
|
||||
}
|
||||
|
||||
function createModifierGroup(modifierGroup, initiallyExpanded) {
|
||||
function createModifierGroup(modifierGroup, initiallyExpanded, removeBy) {
|
||||
const title = modifierGroup.category
|
||||
const modifiers = modifierGroup.modifiers
|
||||
|
||||
@ -79,9 +80,9 @@ function createModifierGroup(modifierGroup, initiallyExpanded) {
|
||||
|
||||
modifiers.forEach(modObj => {
|
||||
const modifierName = modObj.modifier
|
||||
const modifierPreviews = modObj?.previews?.map(preview => `${modifierThumbnailPath}/${preview.path}`)
|
||||
const modifierPreviews = modObj?.previews?.map(preview => `${IMAGE_REGEX.test(preview.image) ? preview.image : modifierThumbnailPath + '/' + preview.path}`)
|
||||
|
||||
const modifierCard = createModifierCard(modifierName, modifierPreviews)
|
||||
const modifierCard = createModifierCard(modifierName, modifierPreviews, removeBy)
|
||||
|
||||
if(typeof modifierCard == 'object') {
|
||||
modifiersEl.appendChild(modifierCard)
|
||||
@ -114,6 +115,7 @@ function createModifierGroup(modifierGroup, initiallyExpanded) {
|
||||
modifiersEl.appendChild(brk)
|
||||
|
||||
let e = document.createElement('div')
|
||||
e.className = 'modifier-category'
|
||||
e.appendChild(titleEl)
|
||||
e.appendChild(modifiersEl)
|
||||
|
||||
@ -137,7 +139,7 @@ async function loadModifiers() {
|
||||
res.reverse()
|
||||
|
||||
res.forEach((modifierGroup, idx) => {
|
||||
createModifierGroup(modifierGroup, idx === res.length - 1)
|
||||
createModifierGroup(modifierGroup, idx === res.length - 1, modifierGroup === 'Artist' ? true : false) // only remove "By " for artists
|
||||
})
|
||||
|
||||
createCollapsibles(editorModifierEntries)
|
||||
@ -153,7 +155,7 @@ async function loadModifiers() {
|
||||
function refreshModifiersState(newTags) {
|
||||
// clear existing modifiers
|
||||
document.querySelector('#editor-modifiers').querySelectorAll('.modifier-card').forEach(modifierCard => {
|
||||
const modifierName = modifierCard.querySelector('.modifier-card-label').innerText
|
||||
const modifierName = modifierCard.querySelector('.modifier-card-label p').dataset.fullName // pick the full modifier name
|
||||
if (activeTags.map(x => x.name).includes(modifierName)) {
|
||||
modifierCard.classList.remove(activeCardClass)
|
||||
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '+'
|
||||
@ -165,13 +167,16 @@ function refreshModifiersState(newTags) {
|
||||
newTags.forEach(tag => {
|
||||
let found = false
|
||||
document.querySelector('#editor-modifiers').querySelectorAll('.modifier-card').forEach(modifierCard => {
|
||||
const modifierName = modifierCard.querySelector('.modifier-card-label').innerText
|
||||
if (tag == modifierName) {
|
||||
const modifierName = modifierCard.querySelector('.modifier-card-label p').dataset.fullName
|
||||
const shortModifierName = modifierCard.querySelector('.modifier-card-label p').innerText
|
||||
if (trimModifiers(tag) == trimModifiers(modifierName)) {
|
||||
// add modifier to active array
|
||||
if (!activeTags.map(x => x.name).includes(tag)) { // only add each tag once even if several custom modifier cards share the same tag
|
||||
const imageModifierCard = modifierCard.cloneNode(true)
|
||||
imageModifierCard.querySelector('.modifier-card-label p').innerText = shortModifierName
|
||||
activeTags.push({
|
||||
'name': modifierName,
|
||||
'element': modifierCard.cloneNode(true),
|
||||
'element': imageModifierCard,
|
||||
'originElement': modifierCard
|
||||
})
|
||||
}
|
||||
@ -181,7 +186,7 @@ function refreshModifiersState(newTags) {
|
||||
}
|
||||
})
|
||||
if (found == false) { // custom tag went missing, create one here
|
||||
let modifierCard = createModifierCard(tag, undefined) // create a modifier card for the missing tag, no image
|
||||
let modifierCard = createModifierCard(tag, undefined, false) // create a modifier card for the missing tag, no image
|
||||
|
||||
modifierCard.addEventListener('click', () => {
|
||||
if (activeTags.map(x => x.name).includes(tag)) {
|
||||
|
@ -33,18 +33,23 @@ let promptStrengthField = document.querySelector('#prompt_strength')
|
||||
let samplerField = document.querySelector('#sampler_name')
|
||||
let samplerSelectionContainer = document.querySelector("#samplerSelection")
|
||||
let useFaceCorrectionField = document.querySelector("#use_face_correction")
|
||||
let gfpganModelField = new ModelDropdown(document.querySelector("#gfpgan_model"), 'gfpgan')
|
||||
let useUpscalingField = document.querySelector("#use_upscale")
|
||||
let upscaleModelField = document.querySelector("#upscale_model")
|
||||
let upscaleAmountField = document.querySelector("#upscale_amount")
|
||||
let stableDiffusionModelField = document.querySelector('#stable_diffusion_model')
|
||||
let vaeModelField = document.querySelector('#vae_model')
|
||||
let hypernetworkModelField = document.querySelector('#hypernetwork_model')
|
||||
let stableDiffusionModelField = new ModelDropdown(document.querySelector('#stable_diffusion_model'), 'stable-diffusion')
|
||||
let vaeModelField = new ModelDropdown(document.querySelector('#vae_model'), 'vae', 'None')
|
||||
let hypernetworkModelField = new ModelDropdown(document.querySelector('#hypernetwork_model'), 'hypernetwork', 'None')
|
||||
let hypernetworkStrengthSlider = document.querySelector('#hypernetwork_strength_slider')
|
||||
let hypernetworkStrengthField = document.querySelector('#hypernetwork_strength')
|
||||
let outputFormatField = document.querySelector('#output_format')
|
||||
let blockNSFWField = document.querySelector('#block_nsfw')
|
||||
let showOnlyFilteredImageField = document.querySelector("#show_only_filtered_image")
|
||||
let updateBranchLabel = document.querySelector("#updateBranchLabel")
|
||||
let streamImageProgressField = document.querySelector("#stream_image_progress")
|
||||
let thumbnailSizeField = document.querySelector("#thumbnail_size-input")
|
||||
let autoscrollBtn = document.querySelector("#auto_scroll_btn")
|
||||
let autoScroll = document.querySelector("#auto_scroll")
|
||||
|
||||
let makeImageBtn = document.querySelector('#makeImage')
|
||||
let stopImageBtn = document.querySelector('#stopImage')
|
||||
@ -60,12 +65,14 @@ let promptStrengthContainer = document.querySelector('#prompt_strength_container
|
||||
let initialText = document.querySelector("#initial-text")
|
||||
let previewTools = document.querySelector("#preview-tools")
|
||||
let clearAllPreviewsBtn = document.querySelector("#clear-all-previews")
|
||||
let saveAllImagesBtn = document.querySelector("#save-all-images")
|
||||
|
||||
let maskSetting = document.querySelector('#enable_mask')
|
||||
|
||||
const processOrder = document.querySelector('#process_order_toggle')
|
||||
|
||||
let imagePreview = document.querySelector("#preview")
|
||||
let imagePreviewContent = document.querySelector("#preview-content")
|
||||
imagePreview.addEventListener('drop', function(ev) {
|
||||
const data = ev.dataTransfer?.getData("text/plain");
|
||||
if (!data) {
|
||||
@ -77,7 +84,7 @@ imagePreview.addEventListener('drop', function(ev) {
|
||||
}
|
||||
ev.preventDefault()
|
||||
let moveTarget = ev.target
|
||||
while (moveTarget && typeof moveTarget === 'object' && moveTarget.parentNode !== imagePreview) {
|
||||
while (moveTarget && typeof moveTarget === 'object' && moveTarget.parentNode !== imagePreviewContent) {
|
||||
moveTarget = moveTarget.parentNode
|
||||
}
|
||||
if (moveTarget === initialText || moveTarget === previewTools) {
|
||||
@ -87,17 +94,17 @@ imagePreview.addEventListener('drop', function(ev) {
|
||||
return
|
||||
}
|
||||
if (moveTarget) {
|
||||
const childs = Array.from(imagePreview.children)
|
||||
const childs = Array.from(imagePreviewContent.children)
|
||||
if (moveTarget.nextSibling && childs.indexOf(movedTask) < childs.indexOf(moveTarget)) {
|
||||
// Move after the target if lower than current position.
|
||||
moveTarget = moveTarget.nextSibling
|
||||
}
|
||||
}
|
||||
const newNode = imagePreview.insertBefore(movedTask, moveTarget || previewTools.nextSibling)
|
||||
const newNode = imagePreviewContent.insertBefore(movedTask, moveTarget || previewTools.nextSibling)
|
||||
if (newNode === movedTask) {
|
||||
return
|
||||
}
|
||||
imagePreview.removeChild(movedTask)
|
||||
imagePreviewContent.removeChild(movedTask)
|
||||
const task = htmlTaskMap.get(movedTask)
|
||||
if (task) {
|
||||
htmlTaskMap.delete(movedTask)
|
||||
@ -264,9 +271,26 @@ function showImages(reqBody, res, outputContainer, livePreview) {
|
||||
<span class="imgSeedLabel"></span>
|
||||
</div>
|
||||
<button class="imgPreviewItemClearBtn image_clear_btn"><i class="fa-solid fa-xmark"></i></button>
|
||||
<span class="img_bottom_label"></span>
|
||||
</div>
|
||||
`
|
||||
outputContainer.appendChild(imageItemElem)
|
||||
const imageRemoveBtn = imageItemElem.querySelector('.imgPreviewItemClearBtn')
|
||||
let parentTaskContainer = imageRemoveBtn.closest('.imageTaskContainer')
|
||||
imageRemoveBtn.addEventListener('click', (e) => {
|
||||
shiftOrConfirm(e, "Remove the image from the results?", () => {
|
||||
imageItemElem.style.display = 'none'
|
||||
let allHidden = true;
|
||||
let children = parentTaskContainer.querySelectorAll('.imgItem');
|
||||
for(let x = 0; x < children.length; x++) {
|
||||
let child = children[x];
|
||||
if(child.style.display != "none") {
|
||||
allHidden = false;
|
||||
}
|
||||
}
|
||||
if(allHidden === true) {parentTaskContainer.classList.add("displayNone")}
|
||||
})
|
||||
})
|
||||
}
|
||||
const imageElem = imageItemElem.querySelector('img')
|
||||
imageElem.src = imageData
|
||||
@ -276,12 +300,11 @@ function showImages(reqBody, res, outputContainer, livePreview) {
|
||||
imageElem.setAttribute('data-steps', imageInferenceSteps)
|
||||
imageElem.setAttribute('data-guidance', imageGuidanceScale)
|
||||
|
||||
const imageRemoveBtn = imageItemElem.querySelector('.imgPreviewItemClearBtn')
|
||||
imageRemoveBtn.addEventListener('click', (e) => {
|
||||
console.log(e)
|
||||
shiftOrConfirm(e, "Remove the image from the results?", () => { imageItemElem.style.display = 'none' })
|
||||
imageElem.addEventListener('load', function() {
|
||||
imageItemElem.querySelector('.img_bottom_label').innerText = `${this.naturalWidth} x ${this.naturalHeight}`
|
||||
})
|
||||
|
||||
|
||||
const imageInfo = imageItemElem.querySelector('.imgItemInfo')
|
||||
imageInfo.style.visibility = (livePreview ? 'hidden' : 'visible')
|
||||
|
||||
@ -413,7 +436,7 @@ function onUpscaleClick(req, img) {
|
||||
|
||||
function onFixFacesClick(req, img) {
|
||||
enqueueImageVariationTask(req, img, {
|
||||
use_face_correction: 'GFPGANv1.3'
|
||||
use_face_correction: gfpganModelField.value
|
||||
})
|
||||
}
|
||||
|
||||
@ -706,12 +729,18 @@ async function onTaskStart(task) {
|
||||
if (task.batchCount > 1) {
|
||||
// Each output render batch needs it's own task reqBody instance to avoid altering the other runs after they are completed.
|
||||
newTaskReqBody = Object.assign({}, task.reqBody)
|
||||
if (task.batchesDone == task.batchCount-1) {
|
||||
// Last batch of the task
|
||||
// If the number of parallel jobs is no factor of the total number of images, the last batch must create less than "parallel jobs count" images
|
||||
// E.g. with numOutputsTotal = 6 and num_outputs = 5, the last batch shall only generate 1 image.
|
||||
newTaskReqBody.num_outputs = task.numOutputsTotal - task.reqBody.num_outputs * (task.batchCount-1)
|
||||
}
|
||||
}
|
||||
|
||||
const startSeed = task.seed || newTaskReqBody.seed
|
||||
const genSeeds = Boolean(typeof newTaskReqBody.seed !== 'number' || (newTaskReqBody.seed === task.seed && task.numOutputsTotal > 1))
|
||||
if (genSeeds) {
|
||||
newTaskReqBody.seed = parseInt(startSeed) + (task.batchesDone * newTaskReqBody.num_outputs)
|
||||
newTaskReqBody.seed = parseInt(startSeed) + (task.batchesDone * task.reqBody.num_outputs)
|
||||
}
|
||||
|
||||
// Update the seed *before* starting the processing so it's retained if user stops the task
|
||||
@ -774,7 +803,10 @@ function createInitImageHover(taskEntry) {
|
||||
img.src = taskEntry.querySelector('div.task-initimg > img').src
|
||||
$tooltip.append(img)
|
||||
$tooltip.append(`<div class="top-right"><button>Use as Input</button></div>`)
|
||||
$tooltip.find('button').on('click', (e) => { onUseAsInputClick(null,img) } )
|
||||
$tooltip.find('button').on('click', (e) => {
|
||||
e.stopPropagation()
|
||||
onUseAsInputClick(null,img)
|
||||
})
|
||||
}
|
||||
|
||||
let startX, startY;
|
||||
@ -839,7 +871,7 @@ function createTask(task) {
|
||||
<i class="drag-handle fa-solid fa-grip"></i>
|
||||
<div class="taskStatusLabel">Enqueued</div>
|
||||
<button class="secondaryButton stopTask"><i class="fa-solid fa-trash-can"></i> Remove</button>
|
||||
<button class="secondaryButton useSettings"><i class="fa-solid fa-redo"></i> Use these settings</button>
|
||||
<button class="tertiaryButton useSettings"><i class="fa-solid fa-redo"></i> Use these settings</button>
|
||||
<div class="preview-prompt"></div>
|
||||
<div class="taskConfig">${taskConfig}</div>
|
||||
<div class="outputMsg"></div>
|
||||
@ -906,7 +938,7 @@ function createTask(task) {
|
||||
})
|
||||
|
||||
task.isProcessing = true
|
||||
taskEntry = imagePreview.insertBefore(taskEntry, previewTools.nextSibling)
|
||||
taskEntry = imagePreviewContent.insertBefore(taskEntry, previewTools.nextSibling)
|
||||
htmlTaskMap.set(taskEntry, task)
|
||||
|
||||
task.previewPrompt.innerText = task.reqBody.prompt
|
||||
@ -929,6 +961,7 @@ function getCurrentUserRequest() {
|
||||
|
||||
reqBody: {
|
||||
seed,
|
||||
used_random_seed: randomSeedField.checked,
|
||||
negative_prompt: negativePromptField.value.trim(),
|
||||
num_outputs: numOutputsParallel,
|
||||
num_inference_steps: parseInt(numInferenceStepsField.value),
|
||||
@ -943,9 +976,10 @@ function getCurrentUserRequest() {
|
||||
stream_progress_updates: true,
|
||||
stream_image_progress: (numOutputsTotal > 50 ? false : streamImageProgressField.checked),
|
||||
show_only_filtered_image: showOnlyFilteredImageField.checked,
|
||||
block_nsfw: blockNSFWField.checked,
|
||||
output_format: outputFormatField.value,
|
||||
output_quality: parseInt(outputQualityField.value),
|
||||
metadata_output_format: document.querySelector('#metadata_output_format').value,
|
||||
metadata_output_format: metadataOutputFormatField.value,
|
||||
original_prompt: promptField.value,
|
||||
active_tags: (activeTags.map(x => x.name)),
|
||||
inactive_tags: (activeTags.filter(tag => tag.inactive === true).map(x => x.name))
|
||||
@ -970,7 +1004,7 @@ function getCurrentUserRequest() {
|
||||
newTask.reqBody.save_to_disk_path = diskPathField.value.trim()
|
||||
}
|
||||
if (useFaceCorrectionField.checked) {
|
||||
newTask.reqBody.use_face_correction = 'GFPGANv1.3'
|
||||
newTask.reqBody.use_face_correction = gfpganModelField.value
|
||||
}
|
||||
if (useUpscalingField.checked) {
|
||||
newTask.reqBody.use_upscale = upscaleModelField.value
|
||||
@ -1015,6 +1049,8 @@ function getPrompts(prompts) {
|
||||
promptsToMake = applyPermuteOperator(promptsToMake)
|
||||
promptsToMake = applySetOperator(promptsToMake)
|
||||
|
||||
PLUGINS['GET_PROMPTS_HOOK'].forEach(fn => { promptsToMake = fn(promptsToMake) })
|
||||
|
||||
return promptsToMake
|
||||
}
|
||||
|
||||
@ -1099,7 +1135,7 @@ function createFileName(prompt, seed, steps, guidance, outputFormat) {
|
||||
// fileName += `${tagString}`
|
||||
|
||||
// add the file extension
|
||||
fileName += '.' + (outputFormat === 'png' ? 'png' : 'jpeg')
|
||||
fileName += '.' + outputFormat
|
||||
|
||||
return fileName
|
||||
}
|
||||
@ -1134,6 +1170,20 @@ clearAllPreviewsBtn.addEventListener('click', (e) => { shiftOrConfirm(e, "Clear
|
||||
taskEntries.forEach(removeTask)
|
||||
})})
|
||||
|
||||
saveAllImagesBtn.addEventListener('click', (e) => {
|
||||
let i = 0
|
||||
document.querySelectorAll(".imageTaskContainer").forEach(container => {
|
||||
let req = htmlTaskMap.get(container)
|
||||
container.querySelectorAll(".imgContainer img").forEach(img => {
|
||||
if (img.closest('.imgItem').style.display === 'none') {
|
||||
return
|
||||
}
|
||||
setTimeout(() => {onDownloadImageClick(req, img)}, i*200)
|
||||
i = i+1
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
stopImageBtn.addEventListener('click', (e) => { shiftOrConfirm(e, "Stop all the tasks?", async function(e) {
|
||||
await stopAllTasks()
|
||||
})})
|
||||
@ -1142,7 +1192,7 @@ widthField.addEventListener('change', onDimensionChange)
|
||||
heightField.addEventListener('change', onDimensionChange)
|
||||
|
||||
function renameMakeImageButton() {
|
||||
let totalImages = Math.max(parseInt(numOutputsTotalField.value), parseInt(numOutputsParallelField.value))
|
||||
let totalImages = Math.max(parseInt(numOutputsTotalField.value), parseInt(numOutputsParallelField.value)) * getPrompts().length
|
||||
let imageLabel = 'Image'
|
||||
if (totalImages > 1) {
|
||||
imageLabel = totalImages + ' Images'
|
||||
@ -1168,6 +1218,12 @@ function onDimensionChange() {
|
||||
}
|
||||
|
||||
diskPathField.disabled = !saveToDiskField.checked
|
||||
metadataOutputFormatField.disabled = !saveToDiskField.checked
|
||||
|
||||
gfpganModelField.disabled = !useFaceCorrectionField.checked
|
||||
useFaceCorrectionField.addEventListener('change', function(e) {
|
||||
gfpganModelField.disabled = !this.checked
|
||||
})
|
||||
|
||||
upscaleModelField.disabled = !useUpscalingField.checked
|
||||
upscaleAmountField.disabled = !useUpscalingField.checked
|
||||
@ -1254,7 +1310,7 @@ function updateHypernetworkStrengthContainer() {
|
||||
hypernetworkModelField.addEventListener('change', updateHypernetworkStrengthContainer)
|
||||
updateHypernetworkStrengthContainer()
|
||||
|
||||
/********************* JPEG Quality **********************/
|
||||
/********************* JPEG/WEBP Quality **********************/
|
||||
function updateOutputQuality() {
|
||||
outputQualityField.value = 0 | outputQualitySlider.value
|
||||
outputQualityField.dispatchEvent(new Event("change"))
|
||||
@ -1276,77 +1332,43 @@ outputQualityField.addEventListener('input', debounce(updateOutputQualitySlider,
|
||||
updateOutputQuality()
|
||||
|
||||
outputFormatField.addEventListener('change', e => {
|
||||
if (outputFormatField.value == 'jpeg') {
|
||||
outputQualityRow.style.display='table-row'
|
||||
} else {
|
||||
if (outputFormatField.value === 'png') {
|
||||
outputQualityRow.style.display='none'
|
||||
} else {
|
||||
outputQualityRow.style.display='table-row'
|
||||
}
|
||||
})
|
||||
|
||||
async function getModels() {
|
||||
try {
|
||||
const sd_model_setting_key = "stable_diffusion_model"
|
||||
const vae_model_setting_key = "vae_model"
|
||||
const hypernetwork_model_key = "hypernetwork_model"
|
||||
const selectedSDModel = SETTINGS[sd_model_setting_key].value
|
||||
const selectedVaeModel = SETTINGS[vae_model_setting_key].value
|
||||
const selectedHypernetworkModel = SETTINGS[hypernetwork_model_key].value
|
||||
|
||||
const models = await SD.getModels()
|
||||
const modelsOptions = models['options']
|
||||
if ("scan-error" in models) {
|
||||
// let previewPane = document.getElementById('tab-content-wrapper')
|
||||
let previewPane = document.getElementById('preview')
|
||||
previewPane.style.background="red"
|
||||
previewPane.style.textAlign="center"
|
||||
previewPane.innerHTML = '<H1>🔥Malware alert!🔥</H1><h2>The file <i>' + models['scan-error'] + '</i> in your <tt>models/stable-diffusion</tt> folder is probably malware infected.</h2><h2>Please delete this file from the folder before proceeding!</h2>After deleting the file, reload this page.<br><br><button onClick="window.location.reload();">Reload Page</button>'
|
||||
makeImageBtn.disabled = true
|
||||
}
|
||||
|
||||
const stableDiffusionOptions = modelsOptions['stable-diffusion']
|
||||
const vaeOptions = modelsOptions['vae']
|
||||
const hypernetworkOptions = modelsOptions['hypernetwork']
|
||||
|
||||
vaeOptions.unshift('') // add a None option
|
||||
hypernetworkOptions.unshift('') // add a None option
|
||||
|
||||
function createModelOptions(modelField, selectedModel, path="") {
|
||||
return function fn(modelName) {
|
||||
if (typeof(modelName) == 'string') {
|
||||
const modelOption = document.createElement('option')
|
||||
modelOption.value = path + modelName
|
||||
modelOption.innerHTML = modelName !== '' ? (path != "" ? " "+modelName : modelName) : 'None'
|
||||
|
||||
if (path + modelName === selectedModel) {
|
||||
modelOption.selected = true
|
||||
}
|
||||
modelField.appendChild(modelOption)
|
||||
} else {
|
||||
const modelGroup = document.createElement('optgroup')
|
||||
modelGroup.label = path + modelName[0]
|
||||
modelField.appendChild(modelGroup)
|
||||
modelName[1].forEach( createModelOptions(modelField, selectedModel, path + modelName[0] + "/" ) )
|
||||
/********************* Zoom Slider **********************/
|
||||
thumbnailSizeField.addEventListener('change', () => {
|
||||
(function (s) {
|
||||
for (var j =0; j < document.styleSheets.length; j++) {
|
||||
let cssSheet = document.styleSheets[j]
|
||||
for (var i = 0; i < cssSheet.cssRules.length; i++) {
|
||||
var rule = cssSheet.cssRules[i];
|
||||
if (rule.selectorText == "div.img-preview img") {
|
||||
rule.style['max-height'] = s+'vh';
|
||||
rule.style['max-width'] = s+'vw';
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
})(thumbnailSizeField.value)
|
||||
})
|
||||
|
||||
stableDiffusionOptions.forEach(createModelOptions(stableDiffusionModelField, selectedSDModel))
|
||||
vaeOptions.forEach(createModelOptions(vaeModelField, selectedVaeModel))
|
||||
hypernetworkOptions.forEach(createModelOptions(hypernetworkModelField, selectedHypernetworkModel))
|
||||
|
||||
stableDiffusionModelField.dispatchEvent(new Event('change'))
|
||||
vaeModelField.dispatchEvent(new Event('change'))
|
||||
hypernetworkModelField.dispatchEvent(new Event('change'))
|
||||
|
||||
// TODO: set default for model here too
|
||||
SETTINGS[sd_model_setting_key].default = stableDiffusionOptions[0]
|
||||
if (getSetting(sd_model_setting_key) == '' || SETTINGS[sd_model_setting_key].value == '') {
|
||||
setSetting(sd_model_setting_key, stableDiffusionOptions[0])
|
||||
}
|
||||
} catch (e) {
|
||||
console.log('get models error', e)
|
||||
function onAutoScrollUpdate() {
|
||||
if (autoScroll.checked) {
|
||||
autoscrollBtn.classList.add('pressed')
|
||||
} else {
|
||||
autoscrollBtn.classList.remove('pressed')
|
||||
}
|
||||
autoscrollBtn.querySelector(".state").innerHTML = (autoScroll.checked ? "ON" : "OFF")
|
||||
}
|
||||
autoscrollBtn.addEventListener('click', function() {
|
||||
autoScroll.checked = !autoScroll.checked
|
||||
autoScroll.dispatchEvent(new Event("change"))
|
||||
onAutoScrollUpdate()
|
||||
})
|
||||
autoScroll.addEventListener('change', onAutoScrollUpdate)
|
||||
|
||||
function checkRandomSeed() {
|
||||
if (randomSeedField.checked) {
|
||||
@ -1490,6 +1512,9 @@ function resumeClient() {
|
||||
})
|
||||
}
|
||||
|
||||
promptField.addEventListener("input", debounce( renameMakeImageButton, 1000) )
|
||||
|
||||
|
||||
pauseBtn.addEventListener("click", function () {
|
||||
pauseClient = true
|
||||
pauseBtn.style.display="none"
|
||||
@ -1522,3 +1547,7 @@ window.addEventListener("beforeunload", function(e) {
|
||||
|
||||
createCollapsibles()
|
||||
prettifyInputs(document);
|
||||
|
||||
// set the textbox as focused on start
|
||||
promptField.focus()
|
||||
promptField.selectionStart = promptField.value.length
|
||||
|
@ -7,6 +7,7 @@
|
||||
checkbox: "checkbox",
|
||||
select: "select",
|
||||
select_multiple: "select_multiple",
|
||||
slider: "slider",
|
||||
custom: "custom",
|
||||
};
|
||||
|
||||
@ -60,6 +61,10 @@ var PARAMETERS = [
|
||||
note: "will be saved to disk in this format",
|
||||
default: "txt",
|
||||
options: [
|
||||
{
|
||||
value: "none",
|
||||
label: "none"
|
||||
},
|
||||
{
|
||||
value: "txt",
|
||||
label: "txt"
|
||||
@ -67,9 +72,21 @@ var PARAMETERS = [
|
||||
{
|
||||
value: "json",
|
||||
label: "json"
|
||||
},
|
||||
{
|
||||
value: "embed",
|
||||
label: "embed"
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "block_nsfw",
|
||||
type: ParameterType.checkbox,
|
||||
label: "Block NSFW images",
|
||||
note: "blurs out NSFW images",
|
||||
icon: "fa-land-mine-on",
|
||||
default: false,
|
||||
},
|
||||
{
|
||||
id: "sound_toggle",
|
||||
type: ParameterType.checkbox,
|
||||
@ -183,6 +200,18 @@ function getParameterSettingsEntry(id) {
|
||||
return parameter[0].settingsEntry
|
||||
}
|
||||
|
||||
function sliderUpdate(event) {
|
||||
if (event.srcElement.id.endsWith('-input')) {
|
||||
let slider = document.getElementById(event.srcElement.id.slice(0,-6))
|
||||
slider.value = event.srcElement.value
|
||||
slider.dispatchEvent(new Event("change"))
|
||||
} else {
|
||||
let field = document.getElementById(event.srcElement.id+'-input')
|
||||
field.value = event.srcElement.value
|
||||
field.dispatchEvent(new Event("change"))
|
||||
}
|
||||
}
|
||||
|
||||
function getParameterElement(parameter) {
|
||||
switch (parameter.type) {
|
||||
case ParameterType.checkbox:
|
||||
@ -193,6 +222,8 @@ function getParameterElement(parameter) {
|
||||
var options = (parameter.options || []).map(option => `<option value="${option.value}">${option.label}</option>`).join("")
|
||||
var multiple = (parameter.type == ParameterType.select_multiple ? 'multiple' : '')
|
||||
return `<select id="${parameter.id}" name="${parameter.id}" ${multiple}>${options}</select>`
|
||||
case ParameterType.slider:
|
||||
return `<input id="${parameter.id}" name="${parameter.id}" class="editor-slider" type="range" value="${parameter.default}" min="${parameter.slider_min}" max="${parameter.slider_max}" oninput="sliderUpdate(event)"> <input id="${parameter.id}-input" name="${parameter.id}-input" size="4" value="${parameter.default}" pattern="^[0-9\.]+$" onkeypress="preventNonNumericalInput(event)" oninput="sliderUpdate(event)"> ${parameter.slider_unit}`
|
||||
case ParameterType.custom:
|
||||
return parameter.render(parameter)
|
||||
default:
|
||||
@ -226,6 +257,7 @@ let autoPickGPUsField = document.querySelector('#auto_pick_gpus')
|
||||
let useGPUsField = document.querySelector('#use_gpus')
|
||||
let saveToDiskField = document.querySelector('#save_to_disk')
|
||||
let diskPathField = document.querySelector('#diskPath')
|
||||
let metadataOutputFormatField = document.querySelector('#metadata_output_format')
|
||||
let listenToNetworkField = document.querySelector("#listen_to_network")
|
||||
let listenPortField = document.querySelector("#listen_port")
|
||||
let useBetaChannelField = document.querySelector("#use_beta_channel")
|
||||
@ -279,6 +311,7 @@ async function getAppConfig() {
|
||||
|
||||
saveToDiskField.addEventListener('change', function(e) {
|
||||
diskPathField.disabled = !this.checked
|
||||
metadataOutputFormatField.disabled = !this.checked
|
||||
})
|
||||
|
||||
function getCurrentRenderDeviceSelection() {
|
||||
@ -329,9 +362,9 @@ autoPickGPUsField.addEventListener('click', function() {
|
||||
gpuSettingEntry.style.display = (this.checked ? 'none' : '')
|
||||
})
|
||||
|
||||
async function setDiskPath(defaultDiskPath) {
|
||||
async function setDiskPath(defaultDiskPath, force=false) {
|
||||
var diskPath = getSetting("diskPath")
|
||||
if (diskPath == '' || diskPath == undefined || diskPath == "undefined") {
|
||||
if (force || diskPath == '' || diskPath == undefined || diskPath == "undefined") {
|
||||
setSetting("diskPath", defaultDiskPath)
|
||||
}
|
||||
}
|
||||
@ -407,7 +440,17 @@ async function getSystemInfo() {
|
||||
|
||||
setDeviceInfo(devices)
|
||||
setHostInfo(res['hosts'])
|
||||
setDiskPath(res['default_output_dir'])
|
||||
let force = false
|
||||
if (res['enforce_output_dir'] !== undefined) {
|
||||
force = res['enforce_output_dir']
|
||||
if (force == true) {
|
||||
saveToDiskField.checked = true
|
||||
metadataOutputFormatField.disabled = false
|
||||
}
|
||||
saveToDiskField.disabled = force
|
||||
diskPathField.disabled = force
|
||||
}
|
||||
setDiskPath(res['default_output_dir'], force)
|
||||
} catch (e) {
|
||||
console.log('error fetching devices', e)
|
||||
}
|
||||
@ -433,3 +476,4 @@ saveSettingsBtn.addEventListener('click', function() {
|
||||
saveSettingsBtn.classList.add('active')
|
||||
asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active'))
|
||||
})
|
||||
|
||||
|
@ -25,11 +25,13 @@ const PLUGINS = {
|
||||
* })
|
||||
*/
|
||||
IMAGE_INFO_BUTTONS: [],
|
||||
GET_PROMPTS_HOOK: [],
|
||||
MODIFIERS_LOAD: [],
|
||||
TASK_CREATE: [],
|
||||
OUTPUTS_FORMATS: new ServiceContainer(
|
||||
function png() { return (reqBody) => new SD.RenderTask(reqBody) }
|
||||
, function jpeg() { return (reqBody) => new SD.RenderTask(reqBody) }
|
||||
, function webp() { return (reqBody) => new SD.RenderTask(reqBody) }
|
||||
),
|
||||
}
|
||||
PLUGINS.OUTPUTS_FORMATS.register = function(...args) {
|
||||
|
687
ui/media/js/searchable-models.js
Normal file
687
ui/media/js/searchable-models.js
Normal file
@ -0,0 +1,687 @@
|
||||
"use strict"
|
||||
|
||||
let modelsCache
|
||||
let modelsOptions
|
||||
|
||||
/*
|
||||
*** SEARCHABLE MODELS ***
|
||||
Creates searchable dropdowns for SD, VAE, or HN models.
|
||||
Also adds a reload models button (placed next to SD models, reloads everything including VAE and HN models).
|
||||
More reload buttons may be added at strategic UI locations as needed.
|
||||
Merely calling getModels() makes all the magic happen behind the scene to refresh the dropdowns.
|
||||
|
||||
HOW TO CREATE A MODEL DROPDOWN:
|
||||
1) Create an input element. Make sure to add a data-path property, as this is how model dropdowns are identified in auto-save.js.
|
||||
<input id="stable_diffusion_model" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
|
||||
|
||||
2) Just declare one of these for your own dropdown (remember to change the element id, e.g. #stable_diffusion_models to your own input's id).
|
||||
let stableDiffusionModelField = new ModelDropdown(document.querySelector('#stable_diffusion_model'), 'stable-diffusion')
|
||||
let vaeModelField = new ModelDropdown(document.querySelector('#vae_model'), 'vae', 'None')
|
||||
let hypernetworkModelField = new ModelDropdown(document.querySelector('#hypernetwork_model'), 'hypernetwork', 'None')
|
||||
|
||||
3) Model dropdowns will be refreshed automatically when the reload models button is invoked.
|
||||
*/
|
||||
class ModelDropdown
|
||||
{
|
||||
modelFilter //= document.querySelector("#model-filter")
|
||||
modelFilterArrow //= document.querySelector("#model-filter-arrow")
|
||||
modelList //= document.querySelector("#model-list")
|
||||
modelResult //= document.querySelector("#model-result")
|
||||
modelNoResult //= document.querySelector("#model-no-result")
|
||||
|
||||
currentSelection //= { elem: undefined, value: '', path: ''}
|
||||
highlightedModelEntry //= undefined
|
||||
activeModel //= undefined
|
||||
|
||||
inputModels //= undefined
|
||||
modelKey //= undefined
|
||||
flatModelList //= []
|
||||
noneEntry //= ''
|
||||
modelFilterInitialized //= undefined
|
||||
|
||||
/* MIMIC A REGULAR INPUT FIELD */
|
||||
get parentElement() {
|
||||
return this.modelFilter.parentElement
|
||||
}
|
||||
get parentNode() {
|
||||
return this.modelFilter.parentNode
|
||||
}
|
||||
get value() {
|
||||
return this.modelFilter.dataset.path
|
||||
}
|
||||
set value(path) {
|
||||
this.modelFilter.dataset.path = path
|
||||
this.selectEntry(path)
|
||||
}
|
||||
get disabled() {
|
||||
return this.modelFilter.disabled
|
||||
}
|
||||
set disabled(state) {
|
||||
this.modelFilter.disabled = state
|
||||
if (this.modelFilterArrow) {
|
||||
this.modelFilterArrow.style.color = state ? 'dimgray' : ''
|
||||
}
|
||||
}
|
||||
get modelElements() {
|
||||
return this.modelList.querySelectorAll('.model-file')
|
||||
}
|
||||
addEventListener(type, listener, options) {
|
||||
return this.modelFilter.addEventListener(type, listener, options)
|
||||
}
|
||||
dispatchEvent(event) {
|
||||
return this.modelFilter.dispatchEvent(event)
|
||||
}
|
||||
appendChild(option) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
// remember 'this' - http://blog.niftysnippets.org/2008/04/you-must-remember-this.html
|
||||
bind(f, obj) {
|
||||
return function() {
|
||||
return f.apply(obj, arguments)
|
||||
}
|
||||
}
|
||||
|
||||
/* SEARCHABLE INPUT */
|
||||
constructor (input, modelKey, noneEntry = '') {
|
||||
this.modelFilter = input
|
||||
this.noneEntry = noneEntry
|
||||
this.modelKey = modelKey
|
||||
|
||||
if (modelsOptions !== undefined) { // reuse models from cache (only useful for plugins, which are loaded after models)
|
||||
this.inputModels = modelsOptions[this.modelKey]
|
||||
this.populateModels()
|
||||
}
|
||||
document.addEventListener("refreshModels", this.bind(function(e) {
|
||||
// reload the models
|
||||
this.inputModels = modelsOptions[this.modelKey]
|
||||
this.populateModels()
|
||||
}, this))
|
||||
}
|
||||
|
||||
saveCurrentSelection(elem, value, path) {
|
||||
this.currentSelection.elem = elem
|
||||
this.currentSelection.value = value
|
||||
this.currentSelection.path = path
|
||||
this.modelFilter.dataset.path = path
|
||||
this.modelFilter.value = value
|
||||
this.modelFilter.dispatchEvent(new Event('change'))
|
||||
}
|
||||
|
||||
processClick(e) {
|
||||
e.preventDefault()
|
||||
if (e.srcElement.classList.contains('model-file') || e.srcElement.classList.contains('fa-file')) {
|
||||
const elem = e.srcElement.classList.contains('model-file') ? e.srcElement : e.srcElement.parentElement
|
||||
this.saveCurrentSelection(elem, elem.innerText, elem.dataset.path)
|
||||
this.hideModelList()
|
||||
this.modelFilter.focus()
|
||||
this.modelFilter.select()
|
||||
}
|
||||
}
|
||||
|
||||
getPreviousVisibleSibling(elem) {
|
||||
const modelElements = Array.from(this.modelElements)
|
||||
const index = modelElements.indexOf(elem)
|
||||
if (index <= 0) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return modelElements.slice(0, index).reverse().find(e => e.style.display === 'list-item')
|
||||
}
|
||||
|
||||
getLastVisibleChild(elem) {
|
||||
let lastElementChild = elem.lastElementChild
|
||||
if (lastElementChild.style.display == 'list-item') return lastElementChild
|
||||
return this.getPreviousVisibleSibling(lastElementChild)
|
||||
}
|
||||
|
||||
getNextVisibleSibling(elem) {
|
||||
const modelElements = Array.from(this.modelElements)
|
||||
const index = modelElements.indexOf(elem)
|
||||
return modelElements.slice(index + 1).find(e => e.style.display === 'list-item')
|
||||
}
|
||||
|
||||
getFirstVisibleChild(elem) {
|
||||
let firstElementChild = elem.firstElementChild
|
||||
if (firstElementChild.style.display == 'list-item') return firstElementChild
|
||||
return this.getNextVisibleSibling(firstElementChild)
|
||||
}
|
||||
|
||||
selectModelEntry(elem) {
|
||||
if (elem) {
|
||||
if (this.highlightedModelEntry !== undefined) {
|
||||
this.highlightedModelEntry.classList.remove('selected')
|
||||
}
|
||||
this.saveCurrentSelection(elem, elem.innerText, elem.dataset.path)
|
||||
elem.classList.add('selected')
|
||||
elem.scrollIntoView({block: 'nearest'})
|
||||
this.highlightedModelEntry = elem
|
||||
}
|
||||
}
|
||||
|
||||
selectPreviousFile() {
|
||||
const elem = this.getPreviousVisibleSibling(this.highlightedModelEntry)
|
||||
if (elem) {
|
||||
this.selectModelEntry(elem)
|
||||
}
|
||||
else
|
||||
{
|
||||
//this.highlightedModelEntry.parentElement.parentElement.scrollIntoView({block: 'nearest'})
|
||||
this.highlightedModelEntry.closest('.model-list').scrollTop = 0
|
||||
}
|
||||
this.modelFilter.select()
|
||||
}
|
||||
|
||||
selectNextFile() {
|
||||
this.selectModelEntry(this.getNextVisibleSibling(this.highlightedModelEntry))
|
||||
this.modelFilter.select()
|
||||
}
|
||||
|
||||
selectFirstFile() {
|
||||
this.selectModelEntry(this.modelList.querySelector('.model-file'))
|
||||
this.highlightedModelEntry.scrollIntoView({block: 'nearest'})
|
||||
this.modelFilter.select()
|
||||
}
|
||||
|
||||
selectLastFile() {
|
||||
const elems = this.modelList.querySelectorAll('.model-file:last-child')
|
||||
this.selectModelEntry(elems[elems.length -1])
|
||||
this.modelFilter.select()
|
||||
}
|
||||
|
||||
resetSelection() {
|
||||
this.hideModelList()
|
||||
this.showAllEntries()
|
||||
this.modelFilter.value = this.currentSelection.value
|
||||
this.modelFilter.focus()
|
||||
this.modelFilter.select()
|
||||
}
|
||||
|
||||
validEntrySelected() {
|
||||
return (this.modelNoResult.style.display === 'none')
|
||||
}
|
||||
|
||||
processKey(e) {
|
||||
switch (e.key) {
|
||||
case 'Escape':
|
||||
e.preventDefault()
|
||||
this.resetSelection()
|
||||
break
|
||||
case 'Enter':
|
||||
e.preventDefault()
|
||||
if (this.validEntrySelected()) {
|
||||
if (this.modelList.style.display != 'block') {
|
||||
this.showModelList()
|
||||
}
|
||||
else
|
||||
{
|
||||
this.saveCurrentSelection(this.highlightedModelEntry, this.highlightedModelEntry.innerText, this.highlightedModelEntry.dataset.path)
|
||||
this.hideModelList()
|
||||
this.showAllEntries()
|
||||
}
|
||||
this.modelFilter.focus()
|
||||
}
|
||||
else
|
||||
{
|
||||
this.resetSelection()
|
||||
}
|
||||
break
|
||||
case 'ArrowUp':
|
||||
e.preventDefault()
|
||||
if (this.validEntrySelected()) {
|
||||
this.selectPreviousFile()
|
||||
}
|
||||
break
|
||||
case 'ArrowDown':
|
||||
e.preventDefault()
|
||||
if (this.validEntrySelected()) {
|
||||
this.selectNextFile()
|
||||
}
|
||||
break
|
||||
case 'ArrowLeft':
|
||||
if (this.modelList.style.display != 'block') {
|
||||
e.preventDefault()
|
||||
}
|
||||
break
|
||||
case 'ArrowRight':
|
||||
if (this.modelList.style.display != 'block') {
|
||||
e.preventDefault()
|
||||
}
|
||||
break
|
||||
case 'PageUp':
|
||||
e.preventDefault()
|
||||
if (this.validEntrySelected()) {
|
||||
this.selectPreviousFile()
|
||||
this.selectPreviousFile()
|
||||
this.selectPreviousFile()
|
||||
this.selectPreviousFile()
|
||||
this.selectPreviousFile()
|
||||
this.selectPreviousFile()
|
||||
this.selectPreviousFile()
|
||||
this.selectPreviousFile()
|
||||
}
|
||||
break
|
||||
case 'PageDown':
|
||||
e.preventDefault()
|
||||
if (this.validEntrySelected()) {
|
||||
this.selectNextFile()
|
||||
this.selectNextFile()
|
||||
this.selectNextFile()
|
||||
this.selectNextFile()
|
||||
this.selectNextFile()
|
||||
this.selectNextFile()
|
||||
this.selectNextFile()
|
||||
this.selectNextFile()
|
||||
}
|
||||
break
|
||||
case 'Home':
|
||||
//if (this.modelList.style.display != 'block') {
|
||||
e.preventDefault()
|
||||
if (this.validEntrySelected()) {
|
||||
this.selectFirstFile()
|
||||
}
|
||||
//}
|
||||
break
|
||||
case 'End':
|
||||
//if (this.modelList.style.display != 'block') {
|
||||
e.preventDefault()
|
||||
if (this.validEntrySelected()) {
|
||||
this.selectLastFile()
|
||||
}
|
||||
//}
|
||||
break
|
||||
default:
|
||||
//console.log(e.key)
|
||||
}
|
||||
}
|
||||
|
||||
modelListFocus() {
|
||||
this.selectEntry()
|
||||
this.showAllEntries()
|
||||
}
|
||||
|
||||
showModelList() {
|
||||
this.modelList.style.display = 'block'
|
||||
this.selectEntry()
|
||||
this.showAllEntries()
|
||||
//this.modelFilter.value = ''
|
||||
this.modelFilter.select() // preselect the entire string so user can just start typing.
|
||||
this.modelFilter.focus()
|
||||
this.modelFilter.style.cursor = 'auto'
|
||||
}
|
||||
|
||||
hideModelList() {
|
||||
this.modelList.style.display = 'none'
|
||||
this.modelFilter.value = this.currentSelection.value
|
||||
this.modelFilter.style.cursor = ''
|
||||
}
|
||||
|
||||
toggleModelList(e) {
|
||||
e.preventDefault()
|
||||
if (!this.modelFilter.disabled) {
|
||||
if (this.modelList.style.display != 'block') {
|
||||
this.showModelList()
|
||||
}
|
||||
else
|
||||
{
|
||||
this.hideModelList()
|
||||
this.modelFilter.select()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
selectEntry(path) {
|
||||
if (path !== undefined) {
|
||||
const entries = this.modelElements;
|
||||
|
||||
for (const elem of entries) {
|
||||
if (elem.dataset.path == path) {
|
||||
this.saveCurrentSelection(elem, elem.innerText, elem.dataset.path)
|
||||
this.highlightedModelEntry = elem
|
||||
elem.scrollIntoView({block: 'nearest'})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (this.currentSelection.elem !== undefined) {
|
||||
// select the previous element
|
||||
if (this.highlightedModelEntry !== undefined && this.highlightedModelEntry != this.currentSelection.elem) {
|
||||
this.highlightedModelEntry.classList.remove('selected')
|
||||
}
|
||||
this.currentSelection.elem.classList.add('selected')
|
||||
this.highlightedModelEntry = this.currentSelection.elem
|
||||
this.currentSelection.elem.scrollIntoView({block: 'nearest'})
|
||||
}
|
||||
else
|
||||
{
|
||||
this.selectFirstFile()
|
||||
}
|
||||
}
|
||||
|
||||
highlightModelAtPosition(e) {
|
||||
let elem = document.elementFromPoint(e.clientX, e.clientY)
|
||||
|
||||
if (elem.classList.contains('model-file')) {
|
||||
this.highlightModel(elem)
|
||||
}
|
||||
}
|
||||
|
||||
highlightModel(elem) {
|
||||
if (elem.classList.contains('model-file')) {
|
||||
if (this.highlightedModelEntry !== undefined && this.highlightedModelEntry != elem) {
|
||||
this.highlightedModelEntry.classList.remove('selected')
|
||||
}
|
||||
elem.classList.add('selected')
|
||||
this.highlightedModelEntry = elem
|
||||
}
|
||||
}
|
||||
|
||||
showAllEntries() {
|
||||
this.modelList.querySelectorAll('li').forEach(function(li) {
|
||||
if (li.id !== 'model-no-result') {
|
||||
li.style.display = 'list-item'
|
||||
}
|
||||
})
|
||||
this.modelNoResult.style.display = 'none'
|
||||
}
|
||||
|
||||
filterList(e) {
|
||||
const filter = this.modelFilter.value.toLowerCase()
|
||||
let found = false
|
||||
let showAllChildren = false
|
||||
|
||||
this.modelList.querySelectorAll('li').forEach(function(li) {
|
||||
if (li.classList.contains('model-folder')) {
|
||||
showAllChildren = false
|
||||
}
|
||||
if (filter == '') {
|
||||
li.style.display = 'list-item'
|
||||
found = true
|
||||
} else if (showAllChildren || li.textContent.toLowerCase().match(filter)) {
|
||||
li.style.display = 'list-item'
|
||||
if (li.classList.contains('model-folder') && li.firstChild.textContent.toLowerCase().match(filter)) {
|
||||
showAllChildren = true
|
||||
}
|
||||
found = true
|
||||
} else {
|
||||
li.style.display = 'none'
|
||||
}
|
||||
})
|
||||
|
||||
if (found) {
|
||||
this.modelResult.style.display = 'list-item'
|
||||
this.modelNoResult.style.display = 'none'
|
||||
const elem = this.getNextVisibleSibling(this.modelList.querySelector('.model-file'))
|
||||
this.highlightModel(elem)
|
||||
elem.scrollIntoView({block: 'nearest'})
|
||||
}
|
||||
else
|
||||
{
|
||||
this.modelResult.style.display = 'none'
|
||||
this.modelNoResult.style.display = 'list-item'
|
||||
}
|
||||
this.modelList.style.display = 'block'
|
||||
}
|
||||
|
||||
/* MODEL LOADER */
|
||||
getElementDimensions(element) {
|
||||
// Clone the element
|
||||
const clone = element.cloneNode(true)
|
||||
|
||||
// Copy the styles of the original element to the cloned element
|
||||
const originalStyles = window.getComputedStyle(element)
|
||||
for (let i = 0; i < originalStyles.length; i++) {
|
||||
const property = originalStyles[i]
|
||||
clone.style[property] = originalStyles.getPropertyValue(property)
|
||||
}
|
||||
|
||||
// Set its visibility to hidden and display to inline-block
|
||||
clone.style.visibility = "hidden"
|
||||
clone.style.display = "inline-block"
|
||||
|
||||
// Put the cloned element next to the original element
|
||||
element.parentNode.insertBefore(clone, element.nextSibling)
|
||||
|
||||
// Get its width and height
|
||||
const width = clone.offsetWidth
|
||||
const height = clone.offsetHeight
|
||||
|
||||
// Remove it from the DOM
|
||||
clone.remove()
|
||||
|
||||
// Return its width and height
|
||||
return { width, height }
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Array<string>} models
|
||||
*/
|
||||
sortStringArray(models) {
|
||||
models.sort((a, b) => a.localeCompare(b, undefined, { sensitivity: 'base' }))
|
||||
}
|
||||
|
||||
populateModels() {
|
||||
this.activeModel = this.modelFilter.dataset.path
|
||||
|
||||
this.currentSelection = { elem: undefined, value: '', path: ''}
|
||||
this.highlightedModelEntry = undefined
|
||||
this.flatModelList = []
|
||||
|
||||
if(this.modelList !== undefined) {
|
||||
this.modelList.remove()
|
||||
this.modelFilterArrow.remove()
|
||||
}
|
||||
this.createDropdown()
|
||||
}
|
||||
|
||||
createDropdown() {
|
||||
// create dropdown entries
|
||||
let rootModelList = this.createRootModelList(this.inputModels)
|
||||
this.modelFilter.insertAdjacentElement('afterend', rootModelList)
|
||||
this.modelFilter.insertAdjacentElement(
|
||||
'afterend',
|
||||
this.createElement(
|
||||
'i',
|
||||
{ id: `${this.modelFilter.id}-model-filter-arrow` },
|
||||
['model-selector-arrow', 'fa-solid', 'fa-angle-down'],
|
||||
),
|
||||
)
|
||||
this.modelFilter.classList.add('model-selector')
|
||||
this.modelFilterArrow = document.querySelector(`#${this.modelFilter.id}-model-filter-arrow`)
|
||||
if (this.modelFilterArrow) {
|
||||
this.modelFilterArrow.style.color = this.modelFilter.disabled ? 'dimgray' : ''
|
||||
}
|
||||
this.modelList = document.querySelector(`#${this.modelFilter.id}-model-list`)
|
||||
this.modelResult = document.querySelector(`#${this.modelFilter.id}-model-result`)
|
||||
this.modelNoResult = document.querySelector(`#${this.modelFilter.id}-model-no-result`)
|
||||
|
||||
if (this.modelFilterInitialized !== true) {
|
||||
this.modelFilter.addEventListener('input', this.bind(this.filterList, this))
|
||||
this.modelFilter.addEventListener('focus', this.bind(this.modelListFocus, this))
|
||||
this.modelFilter.addEventListener('blur', this.bind(this.hideModelList, this))
|
||||
this.modelFilter.addEventListener('click', this.bind(this.showModelList, this))
|
||||
this.modelFilter.addEventListener('keydown', this.bind(this.processKey, this))
|
||||
|
||||
this.modelFilterInitialized = true
|
||||
}
|
||||
this.modelFilterArrow.addEventListener('mousedown', this.bind(this.toggleModelList, this))
|
||||
this.modelList.addEventListener('mousemove', this.bind(this.highlightModelAtPosition, this))
|
||||
this.modelList.addEventListener('mousedown', this.bind(this.processClick, this))
|
||||
|
||||
let mf = this.modelFilter
|
||||
this.modelFilter.addEventListener('focus', function() {
|
||||
let modelFilterStyle = window.getComputedStyle(mf)
|
||||
rootModelList.style.minWidth = modelFilterStyle.width
|
||||
})
|
||||
|
||||
this.selectEntry(this.activeModel)
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param {string} tag
|
||||
* @param {object} attributes
|
||||
* @param {Array<string>} classes
|
||||
* @returns {HTMLElement}
|
||||
*/
|
||||
createElement(tagName, attributes, classes, text, icon) {
|
||||
const element = document.createElement(tagName)
|
||||
if (attributes) {
|
||||
Object.entries(attributes).forEach(([key, value]) => {
|
||||
element.setAttribute(key, value)
|
||||
})
|
||||
}
|
||||
if (classes) {
|
||||
classes.forEach(className => element.classList.add(className))
|
||||
}
|
||||
if (icon) {
|
||||
let iconEl = document.createElement('i')
|
||||
iconEl.className = icon + ' icon'
|
||||
element.appendChild(iconEl)
|
||||
}
|
||||
if (text) {
|
||||
element.appendChild(document.createTextNode(text))
|
||||
}
|
||||
return element
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {Array<string | object} modelTree
|
||||
* @param {string} folderName
|
||||
* @param {boolean} isRootFolder
|
||||
* @returns {HTMLElement}
|
||||
*/
|
||||
createModelNodeList(folderName, modelTree, isRootFolder) {
|
||||
const listElement = this.createElement('ul')
|
||||
|
||||
const foldersMap = new Map()
|
||||
const modelsMap = new Map()
|
||||
|
||||
modelTree.forEach(model => {
|
||||
if (Array.isArray(model)) {
|
||||
const [childFolderName, childModels] = model
|
||||
foldersMap.set(
|
||||
childFolderName,
|
||||
this.createModelNodeList(
|
||||
`${folderName || ''}/${childFolderName}`,
|
||||
childModels,
|
||||
false,
|
||||
),
|
||||
)
|
||||
} else {
|
||||
const classes = ['model-file']
|
||||
if (isRootFolder) {
|
||||
classes.push('in-root-folder')
|
||||
}
|
||||
// Remove the leading slash from the model path
|
||||
const fullPath = folderName ? `${folderName.substring(1)}/${model}` : model
|
||||
modelsMap.set(
|
||||
model,
|
||||
this.createElement('li', { 'data-path': fullPath }, classes, model, 'fa-regular fa-file'),
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
const childFolderNames = Array.from(foldersMap.keys())
|
||||
this.sortStringArray(childFolderNames)
|
||||
const folderElements = childFolderNames.map(name => foldersMap.get(name))
|
||||
|
||||
const modelNames = Array.from(modelsMap.keys())
|
||||
this.sortStringArray(modelNames)
|
||||
const modelElements = modelNames.map(name => modelsMap.get(name))
|
||||
|
||||
if (modelElements.length && folderName) {
|
||||
listElement.appendChild(this.createElement('li', undefined, ['model-folder'], folderName.substring(1), 'fa-solid fa-folder-open'))
|
||||
}
|
||||
|
||||
// const allModelElements = isRootFolder ? [...folderElements, ...modelElements] : [...modelElements, ...folderElements]
|
||||
const allModelElements = [...modelElements, ...folderElements]
|
||||
allModelElements.forEach(e => listElement.appendChild(e))
|
||||
return listElement
|
||||
}
|
||||
|
||||
/**
|
||||
* @param {object} modelTree
|
||||
* @returns {HTMLElement}
|
||||
*/
|
||||
createRootModelList(modelTree) {
|
||||
const rootList = this.createElement(
|
||||
'ul',
|
||||
{ id: `${this.modelFilter.id}-model-list` },
|
||||
['model-list'],
|
||||
)
|
||||
rootList.appendChild(
|
||||
this.createElement(
|
||||
'li',
|
||||
{ id: `${this.modelFilter.id}-model-no-result` },
|
||||
['model-no-result'],
|
||||
'No result'
|
||||
),
|
||||
)
|
||||
|
||||
if (this.noneEntry) {
|
||||
rootList.appendChild(
|
||||
this.createElement(
|
||||
'li',
|
||||
{ 'data-path': '' },
|
||||
['model-file', 'in-root-folder'],
|
||||
this.noneEntry,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
if (modelTree.length > 0) {
|
||||
const containerListItem = this.createElement(
|
||||
'li',
|
||||
{ id: `${this.modelFilter.id}-model-result` },
|
||||
['model-result'],
|
||||
)
|
||||
//console.log(containerListItem)
|
||||
containerListItem.appendChild(this.createModelNodeList(undefined, modelTree, true))
|
||||
rootList.appendChild(containerListItem)
|
||||
}
|
||||
|
||||
return rootList
|
||||
}
|
||||
}
|
||||
|
||||
/* (RE)LOAD THE MODELS */
|
||||
async function getModels() {
|
||||
try {
|
||||
modelsCache = await SD.getModels()
|
||||
modelsOptions = modelsCache['options']
|
||||
if ("scan-error" in modelsCache) {
|
||||
// let previewPane = document.getElementById('tab-content-wrapper')
|
||||
let previewPane = document.getElementById('preview')
|
||||
previewPane.style.background="red"
|
||||
previewPane.style.textAlign="center"
|
||||
previewPane.innerHTML = '<H1>🔥Malware alert!🔥</H1><h2>The file <i>' + modelsCache['scan-error'] + '</i> in your <tt>models/stable-diffusion</tt> folder is probably malware infected.</h2><h2>Please delete this file from the folder before proceeding!</h2>After deleting the file, reload this page.<br><br><button onClick="window.location.reload();">Reload Page</button>'
|
||||
makeImageBtn.disabled = true
|
||||
}
|
||||
|
||||
/* This code should no longer be needed. Commenting out for now, will cleanup later.
|
||||
const sd_model_setting_key = "stable_diffusion_model"
|
||||
const vae_model_setting_key = "vae_model"
|
||||
const hypernetwork_model_key = "hypernetwork_model"
|
||||
|
||||
const stableDiffusionOptions = modelsOptions['stable-diffusion']
|
||||
const vaeOptions = modelsOptions['vae']
|
||||
const hypernetworkOptions = modelsOptions['hypernetwork']
|
||||
|
||||
// TODO: set default for model here too
|
||||
SETTINGS[sd_model_setting_key].default = stableDiffusionOptions[0]
|
||||
if (getSetting(sd_model_setting_key) == '' || SETTINGS[sd_model_setting_key].value == '') {
|
||||
setSetting(sd_model_setting_key, stableDiffusionOptions[0])
|
||||
}
|
||||
*/
|
||||
|
||||
// notify ModelDropdown objects to refresh
|
||||
document.dispatchEvent(new Event('refreshModels'))
|
||||
} catch (e) {
|
||||
console.log('get models error', e)
|
||||
}
|
||||
}
|
||||
|
||||
// reload models button
|
||||
document.querySelector('#reload-models').addEventListener('click', getModels)
|
@ -20,19 +20,6 @@ function getNextSibling(elem, selector) {
|
||||
}
|
||||
}
|
||||
|
||||
function findClosestAncestor(element, selector) {
|
||||
if (!element || !element.parentNode) {
|
||||
// reached the top of the DOM tree, return null
|
||||
return null;
|
||||
} else if (element.parentNode.matches(selector)) {
|
||||
// found an ancestor that matches the selector, return it
|
||||
return element.parentNode;
|
||||
} else {
|
||||
// continue searching upwards
|
||||
return findClosestAncestor(element.parentNode, selector);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* Panel Stuff */
|
||||
|
||||
@ -522,6 +509,9 @@ function makeQuerablePromise(promise) {
|
||||
/* inserts custom html to allow prettifying of inputs */
|
||||
function prettifyInputs(root_element) {
|
||||
root_element.querySelectorAll(`input[type="checkbox"]`).forEach(element => {
|
||||
if (element.style.display === "none") {
|
||||
return
|
||||
}
|
||||
var parent = element.parentNode;
|
||||
if (!parent.classList.contains("input-toggle")) {
|
||||
var wrapper = document.createElement("div");
|
||||
|
@ -1,27 +1,7 @@
|
||||
(function () {
|
||||
"use strict"
|
||||
|
||||
var styleSheet = document.createElement("style");
|
||||
styleSheet.textContent = `
|
||||
.auto-scroll {
|
||||
float: right;
|
||||
}
|
||||
`;
|
||||
document.head.appendChild(styleSheet);
|
||||
|
||||
const autoScrollControl = document.createElement('div');
|
||||
autoScrollControl.innerHTML = `<input id="auto_scroll" name="auto_scroll" type="checkbox">
|
||||
<label for="auto_scroll">Scroll to generated image</label>`
|
||||
autoScrollControl.className = "auto-scroll"
|
||||
clearAllPreviewsBtn.parentNode.insertBefore(autoScrollControl, clearAllPreviewsBtn.nextSibling)
|
||||
prettifyInputs(document);
|
||||
let autoScroll = document.querySelector("#auto_scroll")
|
||||
|
||||
// save/restore the toggle state
|
||||
autoScroll.addEventListener('click', (e) => {
|
||||
localStorage.setItem('auto_scroll', autoScroll.checked)
|
||||
})
|
||||
autoScroll.checked = localStorage.getItem('auto_scroll') == "true"
|
||||
|
||||
// observe for changes in the preview pane
|
||||
var observer = new MutationObserver(function (mutations) {
|
||||
@ -39,7 +19,10 @@
|
||||
|
||||
function Autoscroll(target) {
|
||||
if (autoScroll.checked && target !== null) {
|
||||
target.parentElement.parentElement.parentElement.scrollIntoView();
|
||||
const img = target.querySelector('img')
|
||||
img.addEventListener('load', function() {
|
||||
img.closest('.imageTaskContainer').scrollIntoView()
|
||||
}, { once: true })
|
||||
}
|
||||
}
|
||||
})()
|
||||
|
@ -134,7 +134,7 @@
|
||||
/////////////////////// Tab implementation
|
||||
document.querySelector('.tab-container')?.insertAdjacentHTML('beforeend', `
|
||||
<span id="tab-merge" class="tab">
|
||||
<span><i class="fa fa-code-merge icon"></i> Merge models <small>(beta)</small></span>
|
||||
<span><i class="fa fa-code-merge icon"></i> Merge models</span>
|
||||
</span>
|
||||
`)
|
||||
|
||||
@ -241,13 +241,9 @@
|
||||
<div class="merge-container panel-box">
|
||||
<div class="merge-input">
|
||||
<p><label for="#mergeModelA">Select Model A:</label></p>
|
||||
<select id="mergeModelA">
|
||||
<option>A</option>
|
||||
</select>
|
||||
<input id="mergeModelA" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
|
||||
<p><label for="#mergeModelB">Select Model B:</label></p>
|
||||
<select id="mergeModelB">
|
||||
<option>A</option>
|
||||
</select>
|
||||
<input id="mergeModelB" type="text" spellcheck="false" autocomplete="off" class="model-filter" data-path="" />
|
||||
<br/><br/>
|
||||
<p id="merge-warning"><small><b>Important:</b> Please merge models of similar type.<br/>For e.g. <code>SD 1.4</code> models with only <code>SD 1.4/1.5</code> models,<br/><code>SD 2.0</code> with <code>SD 2.0</code>-type, and <code>SD 2.1</code> with <code>SD 2.1</code>-type models.</small></p>
|
||||
<br/>
|
||||
@ -338,19 +334,10 @@
|
||||
linkTabContents(tabSettingsSingle)
|
||||
linkTabContents(tabSettingsBatch)
|
||||
|
||||
/////////////////////// Event Listener
|
||||
document.addEventListener('tabClick', (e) => {
|
||||
if (e.detail.name == 'merge') {
|
||||
console.log('Activate')
|
||||
let modelList = stableDiffusionModelField.cloneNode(true)
|
||||
modelList.id = "mergeModelA"
|
||||
document.querySelector("#mergeModelA").replaceWith(modelList)
|
||||
modelList = stableDiffusionModelField.cloneNode(true)
|
||||
modelList.id = "mergeModelB"
|
||||
document.querySelector("#mergeModelB").replaceWith(modelList)
|
||||
updateChart()
|
||||
}
|
||||
})
|
||||
console.log('Activate')
|
||||
let mergeModelAField = new ModelDropdown(document.querySelector('#mergeModelA'), 'stable-diffusion')
|
||||
let mergeModelBField = new ModelDropdown(document.querySelector('#mergeModelB'), 'stable-diffusion')
|
||||
updateChart()
|
||||
|
||||
// slider
|
||||
const singleMergeRatioField = document.querySelector('#single-merge-ratio')
|
||||
|
@ -38,9 +38,9 @@
|
||||
i.parentElement.classList.add('modifier-toggle-inactive')
|
||||
}
|
||||
// refresh activeTags
|
||||
let modifierName = i.parentElement.getElementsByClassName('modifier-card-label')[0].getElementsByTagName("p")[0].innerText
|
||||
let modifierName = i.parentElement.getElementsByClassName('modifier-card-label')[0].getElementsByTagName("p")[0].dataset.fullName
|
||||
activeTags = activeTags.map(obj => {
|
||||
if (obj.name === modifierName) {
|
||||
if (trimModifiers(obj.name) === trimModifiers(modifierName)) {
|
||||
return {...obj, inactive: (obj.element.classList.contains('modifier-toggle-inactive'))};
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user