Merge pull request #695 from cmdr2/refactor

v2.5 - move to sdkit
This commit is contained in:
cmdr2 2022-12-24 23:28:10 +05:30 committed by GitHub
commit 206f9b97bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 1306 additions and 2734 deletions

View File

@ -1,5 +1,22 @@
# What's new? # What's new?
## v2.5
### Major Changes
- **Nearly twice as fast** - significantly faster speed of image generation. We're now pretty close to automatic1111's speed. Code contributions are welcome to make our project even faster: https://github.com/easydiffusion/sdkit/#is-it-fast
- **Full support for Stable Diffusion 2.1** - supports loading v1.4 or v2.0 or v2.1 models seamlessly. No need to enable "Test SD2", and no need to add `sd2_` to your SD 2.0 model file names.
- **Memory optimized Stable Diffusion 2.1** - you can now use 768x768 models for SD 2.1, with the same low VRAM optimizations that we've always had for SD 1.4.
- **6 new samplers!** - explore the new samplers, some of which can generate great images in less than 10 inference steps!
- **Model Merging** - You can now merge two models (`.ckpt` or `.safetensors`) and output `.ckpt` or `.safetensors` models, optionally in `fp16` precision. Details: https://github.com/cmdr2/stable-diffusion-ui/wiki/Model-Merging
- **Fast loading/unloading of VAEs** - No longer needs to reload the entire Stable Diffusion model, each time you change the VAE
- **Database of known models** - automatically picks the right configuration for known models. E.g. we automatically detect and apply "v" parameterization (required for some SD 2.0 models), and "fp32" attention precision (required for some SD 2.1 models).
- **Color correction for img2img** - an option to preserve the color profile (histogram) of the initial image. This is especially useful if you're getting red-tinted images after inpainting/masking.
- **Three GPU Memory Usage Settings** - `High` (fastest, maximum VRAM usage), `Balanced` (default - almost as fast, significantly lower VRAM usage), `Low` (slowest, very low VRAM usage). The `Low` setting is applied automatically for GPUs with less than 4 GB of VRAM.
- **Save metadata as JSON** - You can now save the metadata files as either text or json files (choose in the Settings tab).
- **Major rewrite of the code** - Most of the codebase has been reorganized and rewritten, to make it more manageable and easier for new developers to contribute features. We've separated our core engine into a new project called `sdkit`, which allows anyone to easily integrate Stable Diffusion (and related modules like GFPGAN etc) into their programming projects (via a simple `pip install sdkit`): https://github.com/easydiffusion/sdkit/
- **Name change** - Last, and probably the least, the UI is now called "Easy Diffusion". It indicates the focus of this project - an easy way for people to play with Stable Diffusion.
Our focus continues to remain on an easy installation experience, and an easy user-interface. While still remaining pretty powerful, in terms of features and speed.
## v2.4 ## v2.4
### Major Changes ### Major Changes
- **Allow reordering the task queue** (by dragging and dropping tasks). Thanks @madrang - **Allow reordering the task queue** (by dragging and dropping tasks). Thanks @madrang

View File

@ -23,23 +23,20 @@ call conda --version
echo. echo.
@rem activate the environment @rem activate the legacy environment (if present) and set PYTHONPATH
if exist "installer_files\env" (
set PYTHONPATH=%cd%\installer_files\env\lib\site-packages
)
if exist "stable-diffusion\env" (
call conda activate .\stable-diffusion\env call conda activate .\stable-diffusion\env
set PYTHONPATH=%cd%\stable-diffusion\env\lib\site-packages
)
call where python call where python
call python --version call python --version
@rem set the PYTHONPATH
cd stable-diffusion
set SD_DIR=%cd%
cd env\lib\site-packages
set PYTHONPATH=%SD_DIR%;%cd%
cd ..\..\..
echo PYTHONPATH=%PYTHONPATH% echo PYTHONPATH=%PYTHONPATH%
cd ..
@rem done @rem done
echo. echo.

View File

@ -24,7 +24,7 @@ if exist "%INSTALL_ENV_DIR%" set PATH=%INSTALL_ENV_DIR%;%INSTALL_ENV_DIR%\Librar
set PACKAGES_TO_INSTALL= set PACKAGES_TO_INSTALL=
if not exist "%LEGACY_INSTALL_ENV_DIR%\etc\profile.d\conda.sh" ( if not exist "%LEGACY_INSTALL_ENV_DIR%\etc\profile.d\conda.sh" (
if not exist "%INSTALL_ENV_DIR%\etc\profile.d\conda.sh" set PACKAGES_TO_INSTALL=%PACKAGES_TO_INSTALL% conda if not exist "%INSTALL_ENV_DIR%\etc\profile.d\conda.sh" set PACKAGES_TO_INSTALL=%PACKAGES_TO_INSTALL% conda python=3.8.5
) )
call git --version >.tmp1 2>.tmp2 call git --version >.tmp1 2>.tmp2

View File

@ -39,7 +39,7 @@ if [ -e "$INSTALL_ENV_DIR" ]; then export PATH="$INSTALL_ENV_DIR/bin:$PATH"; fi
PACKAGES_TO_INSTALL="" PACKAGES_TO_INSTALL=""
if [ ! -e "$LEGACY_INSTALL_ENV_DIR/etc/profile.d/conda.sh" ] && [ ! -e "$INSTALL_ENV_DIR/etc/profile.d/conda.sh" ]; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL conda"; fi if [ ! -e "$LEGACY_INSTALL_ENV_DIR/etc/profile.d/conda.sh" ] && [ ! -e "$INSTALL_ENV_DIR/etc/profile.d/conda.sh" ]; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL conda python=3.8.5"; fi
if ! hash "git" &>/dev/null; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL git"; fi if ! hash "git" &>/dev/null; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL git"; fi
if "$MAMBA_ROOT_PREFIX/micromamba" --version &>/dev/null; then umamba_exists="T"; fi if "$MAMBA_ROOT_PREFIX/micromamba" --version &>/dev/null; then umamba_exists="T"; fi

13
scripts/check_modules.py Normal file
View File

@ -0,0 +1,13 @@
'''
This script checks if the given modules exist
'''
import sys
import pkgutil
modules = sys.argv[1:]
missing_modules = []
for m in modules:
if pkgutil.find_loader(m) is None:
print('module', m, 'not found')
exit(1)

View File

@ -26,21 +26,23 @@ if [ "$0" == "bash" ]; then
echo "" echo ""
# activate the environment # activate the legacy environment (if present) and set PYTHONPATH
if [ -e "installer_files/env" ]; then
export PYTHONPATH="$(pwd)/installer_files/env/lib/python3.8/site-packages"
fi
if [ -e "stable-diffusion/env" ]; then
CONDA_BASEPATH=$(conda info --base) CONDA_BASEPATH=$(conda info --base)
source "$CONDA_BASEPATH/etc/profile.d/conda.sh" # otherwise conda complains about 'shell not initialized' (needed when running in a script) source "$CONDA_BASEPATH/etc/profile.d/conda.sh" # otherwise conda complains about 'shell not initialized' (needed when running in a script)
conda activate ./stable-diffusion/env conda activate ./stable-diffusion/env
export PYTHONPATH="$(pwd)/stable-diffusion/env/lib/python3.8/site-packages"
fi
which python which python
python --version python --version
# set the PYTHONPATH
cd stable-diffusion
SD_PATH=`pwd`
export PYTHONPATH="$SD_PATH:$SD_PATH/env/lib/python3.8/site-packages"
echo "PYTHONPATH=$PYTHONPATH" echo "PYTHONPATH=$PYTHONPATH"
cd ..
# done # done

View File

@ -53,6 +53,7 @@ if "%update_branch%"=="" (
@xcopy sd-ui-files\ui ui /s /i /Y /q @xcopy sd-ui-files\ui ui /s /i /Y /q
@copy sd-ui-files\scripts\on_sd_start.bat scripts\ /Y @copy sd-ui-files\scripts\on_sd_start.bat scripts\ /Y
@copy sd-ui-files\scripts\bootstrap.bat scripts\ /Y @copy sd-ui-files\scripts\bootstrap.bat scripts\ /Y
@copy sd-ui-files\scripts\check_modules.py scripts\ /Y
@copy "sd-ui-files\scripts\Start Stable Diffusion UI.cmd" . /Y @copy "sd-ui-files\scripts\Start Stable Diffusion UI.cmd" . /Y
@copy "sd-ui-files\scripts\Developer Console.cmd" . /Y @copy "sd-ui-files\scripts\Developer Console.cmd" . /Y

View File

@ -37,6 +37,7 @@ rm -rf ui
cp -Rf sd-ui-files/ui . cp -Rf sd-ui-files/ui .
cp sd-ui-files/scripts/on_sd_start.sh scripts/ cp sd-ui-files/scripts/on_sd_start.sh scripts/
cp sd-ui-files/scripts/bootstrap.sh scripts/ cp sd-ui-files/scripts/bootstrap.sh scripts/
cp sd-ui-files/scripts/check_modules.py scripts/
cp sd-ui-files/scripts/start.sh . cp sd-ui-files/scripts/start.sh .
cp sd-ui-files/scripts/developer_console.sh . cp sd-ui-files/scripts/developer_console.sh .

View File

@ -5,11 +5,20 @@
@copy sd-ui-files\scripts\on_env_start.bat scripts\ /Y @copy sd-ui-files\scripts\on_env_start.bat scripts\ /Y
@copy sd-ui-files\scripts\bootstrap.bat scripts\ /Y @copy sd-ui-files\scripts\bootstrap.bat scripts\ /Y
@copy sd-ui-files\scripts\check_modules.py scripts\ /Y
if exist "%cd%\profile" ( if exist "%cd%\profile" (
set USERPROFILE=%cd%\profile set USERPROFILE=%cd%\profile
) )
@rem set the correct installer path (current vs legacy)
if exist "%cd%\installer_files\env" (
set INSTALL_ENV_DIR=%cd%\installer_files\env
)
if exist "%cd%\stable-diffusion\env" (
set INSTALL_ENV_DIR=%cd%\stable-diffusion\env
)
@mkdir tmp @mkdir tmp
@set TMP=%cd%\tmp @set TMP=%cd%\tmp
@set TEMP=%cd%\tmp @set TEMP=%cd%\tmp
@ -27,150 +36,92 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');" @call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');"
if NOT DEFINED test_sd2 set test_sd2=N @rem create the stable-diffusion folder, to work with legacy installations
if not exist "stable-diffusion" mkdir stable-diffusion
cd stable-diffusion
@>nul findstr /m "sd_git_cloned" scripts\install_status.txt @rem activate the old stable-diffusion env, if it exists
@if "%ERRORLEVEL%" EQU "0" ( if exist "env" (
@echo "Stable Diffusion's git repository was already installed. Updating.." call conda activate .\env
@cd stable-diffusion
@call git remote set-url origin https://github.com/easydiffusion/diffusion-kit.git
@call git reset --hard
@call git pull
if "%test_sd2%" == "N" (
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
)
if "%test_sd2%" == "Y" (
@call git -c advice.detachedHead=false checkout 733a1f6f9cae9b9a9b83294bf3281b123378cb1f
) )
@cd .. @rem disable the legacy src and ldm folder (otherwise this prevents installing gfpgan and realesrgan)
if exist src rename src src-old
if exist ldm rename ldm ldm-old
@rem install torch and torchvision
call python ..\scripts\check_modules.py torch torchvision
if "%ERRORLEVEL%" EQU "0" (
echo "torch and torchvision have already been installed."
) else ( ) else (
@echo. & echo "Downloading Stable Diffusion.." & echo. echo "Installing torch and torchvision.."
@call git clone https://github.com/easydiffusion/diffusion-kit.git stable-diffusion && ( @REM prevent from using packages from the user's home directory, to avoid conflicts
@echo sd_git_cloned >> scripts\install_status.txt set PYTHONNOUSERSITE=1
) || ( set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
@echo "Error downloading Stable Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause
@exit /b
)
@cd stable-diffusion call pip install --upgrade torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116 || (
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a echo "Error installing torch. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
@cd ..
)
@cd stable-diffusion
@>nul findstr /m "conda_sd_env_created" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
@echo "Packages necessary for Stable Diffusion were already installed"
@call conda activate .\env
) else (
@echo. & echo "Downloading packages necessary for Stable Diffusion.." & echo. & echo "***** This will take some time (depending on the speed of the Internet connection) and may appear to be stuck, but please be patient ***** .." & echo.
@rmdir /s /q .\env
@REM prevent conda from using packages from the user's home directory, to avoid conflicts
@set PYTHONNOUSERSITE=1
set USERPROFILE=%cd%\profile
set PYTHONPATH=%cd%;%cd%\env\lib\site-packages
@call conda env create --prefix env -f environment.yaml || (
@echo. & echo "Error installing the packages necessary for Stable Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause pause
exit /b exit /b
) )
)
@call conda activate .\env @rem install/upgrade sdkit
call python ..\scripts\check_modules.py sdkit sdkit.models ldm transformers numpy antlr4 gfpgan realesrgan
if "%ERRORLEVEL%" EQU "0" (
echo "sdkit is already installed."
for /f "tokens=*" %%a in ('python -c "import torch; import ldm; import transformers; import numpy; import antlr4; print(42)"') do if "%%a" NEQ "42" ( @REM prevent from using packages from the user's home directory, to avoid conflicts
@echo. & echo "Dependency test failed! Error installing the packages necessary for Stable Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo. set PYTHONNOUSERSITE=1
set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
call >nul pip install --upgrade sdkit || (
echo "Error updating sdkit"
)
) else (
echo "Installing sdkit: https://pypi.org/project/sdkit/"
@REM prevent from using packages from the user's home directory, to avoid conflicts
set PYTHONNOUSERSITE=1
set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
call pip install sdkit || (
echo "Error installing sdkit. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause pause
exit /b exit /b
) )
@echo conda_sd_env_created >> ..\scripts\install_status.txt
) )
@rem allow rolling back the sdkit-based changes @rem install rich
if exist "src-old" ( call python ..\scripts\check_modules.py rich
if not exist "src" ( if "%ERRORLEVEL%" EQU "0" (
rename "src-old" "src" echo "rich has already been installed."
) else (
echo "Installing rich.."
if exist "ldm-old" ( set PYTHONNOUSERSITE=1
rd /s /q "ldm-old" set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
)
call pip uninstall -y sdkit stable-diffusion-sdkit call pip install rich || (
echo "Error installing rich. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause
exit /b
) )
) )
set PATH=C:\Windows\System32;%PATH% set PATH=C:\Windows\System32;%PATH%
@>nul findstr /m "conda_sd_gfpgan_deps_installed" ..\scripts\install_status.txt call python ..\scripts\check_modules.py uvicorn fastapi
@if "%ERRORLEVEL%" EQU "0" (
@echo "Packages necessary for GFPGAN (Face Correction) were already installed"
) else (
@echo. & echo "Downloading packages necessary for GFPGAN (Face Correction).." & echo.
@set PYTHONNOUSERSITE=1
set USERPROFILE=%cd%\profile
set PYTHONPATH=%cd%;%cd%\env\lib\site-packages
for /f "tokens=*" %%a in ('python -c "from gfpgan import GFPGANer; print(42)"') do if "%%a" NEQ "42" (
@echo. & echo "Dependency test failed! Error installing the packages necessary for GFPGAN (Face Correction). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
@echo conda_sd_gfpgan_deps_installed >> ..\scripts\install_status.txt
)
@>nul findstr /m "conda_sd_esrgan_deps_installed" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
@echo "Packages necessary for ESRGAN (Resolution Upscaling) were already installed"
) else (
@echo. & echo "Downloading packages necessary for ESRGAN (Resolution Upscaling).." & echo.
@set PYTHONNOUSERSITE=1
set USERPROFILE=%cd%\profile
set PYTHONPATH=%cd%;%cd%\env\lib\site-packages
for /f "tokens=*" %%a in ('python -c "from basicsr.archs.rrdbnet_arch import RRDBNet; from realesrgan import RealESRGANer; print(42)"') do if "%%a" NEQ "42" (
@echo. & echo "Dependency test failed! Error installing the packages necessary for ESRGAN (Resolution Upscaling). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
@echo conda_sd_esrgan_deps_installed >> ..\scripts\install_status.txt
)
@>nul findstr /m "conda_sd_ui_deps_installed" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" ( @if "%ERRORLEVEL%" EQU "0" (
echo "Packages necessary for Stable Diffusion UI were already installed" echo "Packages necessary for Stable Diffusion UI were already installed"
) else ( ) else (
@echo. & echo "Downloading packages necessary for Stable Diffusion UI.." & echo. @echo. & echo "Downloading packages necessary for Stable Diffusion UI.." & echo.
@set PYTHONNOUSERSITE=1 set PYTHONNOUSERSITE=1
set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
set USERPROFILE=%cd%\profile @call conda install -c conda-forge -y uvicorn fastapi || (
set PYTHONPATH=%cd%;%cd%\env\lib\site-packages
@call conda install -c conda-forge -y --prefix env uvicorn fastapi || (
echo "Error installing the packages necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" echo "Error installing the packages necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause pause
exit /b exit /b
@ -185,26 +136,6 @@ call WHERE uvicorn > .tmp
exit /b exit /b
) )
@>nul 2>nul call python -m picklescan --help
@if "%ERRORLEVEL%" NEQ "0" (
@echo. & echo Picklescan not found. Installing
@call pip install picklescan || (
echo "Error installing the picklescan package necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause
exit /b
)
)
@>nul 2>nul call python -c "import safetensors"
@if "%ERRORLEVEL%" NEQ "0" (
@echo. & echo SafeTensors not found. Installing
@call pip install safetensors || (
echo "Error installing the safetensors package necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause
exit /b
)
)
@>nul findstr /m "conda_sd_ui_deps_installed" ..\scripts\install_status.txt @>nul findstr /m "conda_sd_ui_deps_installed" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" NEQ "0" ( @if "%ERRORLEVEL%" NEQ "0" (
@echo conda_sd_ui_deps_installed >> ..\scripts\install_status.txt @echo conda_sd_ui_deps_installed >> ..\scripts\install_status.txt
@ -212,12 +143,7 @@ call WHERE uvicorn > .tmp
if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion"
if not exist "..\models\vae" mkdir "..\models\vae" if not exist "..\models\vae" mkdir "..\models\vae"
if not exist "..\models\hypernetwork" mkdir "..\models\hypernetwork"
echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
echo. > "..\models\vae\Put your VAE files here.txt"
echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt"
@if exist "sd-v1-4.ckpt" ( @if exist "sd-v1-4.ckpt" (
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" ( for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (
@ -375,10 +301,6 @@ echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt"
) )
) )
if "%test_sd2%" == "Y" (
@call pip install open_clip_torch==2.0.2
)
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt @>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" NEQ "0" ( @if "%ERRORLEVEL%" NEQ "0" (
@echo sd_weights_downloaded >> ..\scripts\install_status.txt @echo sd_weights_downloaded >> ..\scripts\install_status.txt
@ -389,10 +311,8 @@ if "%test_sd2%" == "Y" (
@set SD_DIR=%cd% @set SD_DIR=%cd%
@cd env\lib\site-packages set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
@set PYTHONPATH=%SD_DIR%;%cd% echo PYTHONPATH=%PYTHONPATH%
@cd ..\..\..
@echo PYTHONPATH=%PYTHONPATH%
call where python call where python
call python --version call python --version
@ -401,17 +321,9 @@ call python --version
@set SD_UI_PATH=%cd%\ui @set SD_UI_PATH=%cd%\ui
@cd stable-diffusion @cd stable-diffusion
@rem
@rem Rewrite easy-install.pth. This fixes the installation if the user has relocated the SDUI installation
@rem
>env\Lib\site-packages\easy-install.pth echo %cd%\src\taming-transformers
>>env\Lib\site-packages\easy-install.pth echo %cd%\src\clip
>>env\Lib\site-packages\easy-install.pth echo %cd%\src\gfpgan
>>env\Lib\site-packages\easy-install.pth echo %cd%\src\realesrgan
@if NOT DEFINED SD_UI_BIND_PORT set SD_UI_BIND_PORT=9000 @if NOT DEFINED SD_UI_BIND_PORT set SD_UI_BIND_PORT=9000
@if NOT DEFINED SD_UI_BIND_IP set SD_UI_BIND_IP=0.0.0.0 @if NOT DEFINED SD_UI_BIND_IP set SD_UI_BIND_IP=0.0.0.0
@uvicorn server:app --app-dir "%SD_UI_PATH%" --port %SD_UI_BIND_PORT% --host %SD_UI_BIND_IP% @uvicorn main:server_api --app-dir "%SD_UI_PATH%" --port %SD_UI_BIND_PORT% --host %SD_UI_BIND_IP% --log-level error
@pause @pause

View File

@ -4,6 +4,7 @@ source ./scripts/functions.sh
cp sd-ui-files/scripts/on_env_start.sh scripts/ cp sd-ui-files/scripts/on_env_start.sh scripts/
cp sd-ui-files/scripts/bootstrap.sh scripts/ cp sd-ui-files/scripts/bootstrap.sh scripts/
cp sd-ui-files/scripts/check_modules.py scripts/
# activate the installer env # activate the installer env
CONDA_BASEPATH=$(conda info --base) CONDA_BASEPATH=$(conda info --base)
@ -21,125 +22,89 @@ python -c "import os; import shutil; frm = 'sd-ui-files/ui/hotfix/9c24e6cd9f499d
# Caution, this file will make your eyes and brain bleed. It's such an unholy mess. # Caution, this file will make your eyes and brain bleed. It's such an unholy mess.
# Note to self: Please rewrite this in Python. For the sake of your own sanity. # Note to self: Please rewrite this in Python. For the sake of your own sanity.
if [ "$test_sd2" == "" ]; then # set the correct installer path (current vs legacy)
export test_sd2="N" if [ -e "installer_files/env" ]; then
export INSTALL_ENV_DIR="$(pwd)/installer_files/env"
fi
if [ -e "stable-diffusion/env" ]; then
export INSTALL_ENV_DIR="$(pwd)/stable-diffusion/env"
fi fi
if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then # create the stable-diffusion folder, to work with legacy installations
echo "Stable Diffusion's git repository was already installed. Updating.." if [ ! -e "stable-diffusion" ]; then mkdir stable-diffusion; fi
cd stable-diffusion cd stable-diffusion
git remote set-url origin https://github.com/easydiffusion/diffusion-kit.git # activate the old stable-diffusion env, if it exists
if [ -e "env" ]; then
git reset --hard
git pull
if [ "$test_sd2" == "N" ]; then
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
elif [ "$test_sd2" == "Y" ]; then
git -c advice.detachedHead=false checkout 733a1f6f9cae9b9a9b83294bf3281b123378cb1f
fi
cd ..
else
printf "\n\nDownloading Stable Diffusion..\n\n"
if git clone https://github.com/easydiffusion/diffusion-kit.git stable-diffusion ; then
echo sd_git_cloned >> scripts/install_status.txt
else
fail "git clone of basujindal/stable-diffusion.git failed"
fi
cd stable-diffusion
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
cd ..
fi
cd stable-diffusion
if [ `grep -c conda_sd_env_created ../scripts/install_status.txt` -gt "0" ]; then
echo "Packages necessary for Stable Diffusion were already installed"
conda activate ./env || fail "conda activate failed" conda activate ./env || fail "conda activate failed"
fi
# disable the legacy src and ldm folder (otherwise this prevents installing gfpgan and realesrgan)
if [ -e "src" ]; then mv src src-old; fi
if [ -e "ldm" ]; then mv ldm ldm-old; fi
# install torch and torchvision
if python ../scripts/check_modules.py torch torchvision; then
echo "torch and torchvision have already been installed."
else else
printf "\n\nDownloading packages necessary for Stable Diffusion..\n" echo "Installing torch and torchvision.."
printf "\n\n***** This will take some time (depending on the speed of the Internet connection) and may appear to be stuck, but please be patient ***** ..\n\n"
# prevent conda from using packages from the user's home directory, to avoid conflicts
export PYTHONNOUSERSITE=1
export PYTHONPATH="$(pwd):$(pwd)/env/lib/site-packages"
if conda env create --prefix env --force -f environment.yaml ; then
echo "Installed. Testing.."
else
fail "'conda env create' failed"
fi
conda activate ./env || fail "conda activate failed"
out_test=`python -c "import torch; import ldm; import transformers; import numpy; import antlr4; print(42)"`
if [ "$out_test" != "42" ]; then
fail "Dependency test failed"
fi
echo conda_sd_env_created >> ../scripts/install_status.txt
fi
# allow rolling back the sdkit-based changes
if [ -e "src-old" ] && [ ! -e "src" ]; then
mv src-old src
if [ -e "ldm-old" ]; then rm -r ldm-old; fi
pip uninstall -y sdkit stable-diffusion-sdkit
fi
if [ `grep -c conda_sd_gfpgan_deps_installed ../scripts/install_status.txt` -gt "0" ]; then
echo "Packages necessary for GFPGAN (Face Correction) were already installed"
else
printf "\n\nDownloading packages necessary for GFPGAN (Face Correction)..\n"
export PYTHONNOUSERSITE=1 export PYTHONNOUSERSITE=1
export PYTHONPATH="$(pwd):$(pwd)/env/lib/site-packages" export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
out_test=`python -c "from gfpgan import GFPGANer; print(42)"` if pip install --upgrade torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116 ; then
if [ "$out_test" != "42" ]; then echo "Installed."
echo "EE The dependency check has failed. This usually means that some system libraries are missing."
echo "EE On Debian/Ubuntu systems, this are often these packages: libsm6 libxext6 libxrender-dev"
echo "EE Other Linux distributions might have different package names for these libraries."
fail "GFPGAN dependency test failed"
fi
echo conda_sd_gfpgan_deps_installed >> ../scripts/install_status.txt
fi
if [ `grep -c conda_sd_esrgan_deps_installed ../scripts/install_status.txt` -gt "0" ]; then
echo "Packages necessary for ESRGAN (Resolution Upscaling) were already installed"
else else
printf "\n\nDownloading packages necessary for ESRGAN (Resolution Upscaling)..\n" fail "torch install failed"
fi
fi
# install/upgrade sdkit
if python ../scripts/check_modules.py sdkit sdkit.models ldm transformers numpy antlr4 gfpgan realesrgan ; then
echo "sdkit is already installed."
export PYTHONNOUSERSITE=1 export PYTHONNOUSERSITE=1
export PYTHONPATH="$(pwd):$(pwd)/env/lib/site-packages" export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
out_test=`python -c "from basicsr.archs.rrdbnet_arch import RRDBNet; from realesrgan import RealESRGANer; print(42)"` pip install --upgrade sdkit > /dev/null
if [ "$out_test" != "42" ]; then else
fail "ESRGAN dependency test failed" echo "Installing sdkit: https://pypi.org/project/sdkit/"
export PYTHONNOUSERSITE=1
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
if pip install sdkit ; then
echo "Installed."
else
fail "sdkit install failed"
fi
fi fi
echo conda_sd_esrgan_deps_installed >> ../scripts/install_status.txt # install rich
if python ../scripts/check_modules.py rich; then
echo "rich has already been installed."
else
echo "Installing rich.."
export PYTHONNOUSERSITE=1
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
if pip install rich ; then
echo "Installed."
else
fail "Install failed for rich"
fi
fi fi
if [ `grep -c conda_sd_ui_deps_installed ../scripts/install_status.txt` -gt "0" ]; then if python ../scripts/check_modules.py uvicorn fastapi ; then
echo "Packages necessary for Stable Diffusion UI were already installed" echo "Packages necessary for Stable Diffusion UI were already installed"
else else
printf "\n\nDownloading packages necessary for Stable Diffusion UI..\n\n" printf "\n\nDownloading packages necessary for Stable Diffusion UI..\n\n"
export PYTHONNOUSERSITE=1 export PYTHONNOUSERSITE=1
export PYTHONPATH="$(pwd):$(pwd)/env/lib/site-packages" export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
if conda install -c conda-forge --prefix ./env -y uvicorn fastapi ; then if conda install -c conda-forge -y uvicorn fastapi ; then
echo "Installed. Testing.." echo "Installed. Testing.."
else else
fail "'conda install uvicorn' failed" fail "'conda install uvicorn' failed"
@ -148,32 +113,9 @@ else
if ! command -v uvicorn &> /dev/null; then if ! command -v uvicorn &> /dev/null; then
fail "UI packages not found!" fail "UI packages not found!"
fi fi
echo conda_sd_ui_deps_installed >> ../scripts/install_status.txt
fi fi
if python -m picklescan --help >/dev/null 2>&1; then
echo "Picklescan is already installed."
else
echo "Picklescan not found, installing."
pip install picklescan || fail "Picklescan installation failed."
fi
if python -c "import safetensors" --help >/dev/null 2>&1; then
echo "SafeTensors is already installed."
else
echo "SafeTensors not found, installing."
pip install safetensors || fail "SafeTensors installation failed."
fi
mkdir -p "../models/stable-diffusion"
mkdir -p "../models/vae" mkdir -p "../models/vae"
mkdir -p "../models/hypernetwork"
echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt"
echo "" > "../models/vae/Put your VAE files here.txt"
echo "" > "../models/hypernetwork/Put your hypernetwork files here.txt"
if [ -f "sd-v1-4.ckpt" ]; then if [ -f "sd-v1-4.ckpt" ]; then
model_size=`find "sd-v1-4.ckpt" -printf "%s"` model_size=`find "sd-v1-4.ckpt" -printf "%s"`
@ -314,10 +256,6 @@ if [ ! -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then
fi fi
fi fi
if [ "$test_sd2" == "Y" ]; then
pip install open_clip_torch==2.0.2
fi
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
echo sd_weights_downloaded >> ../scripts/install_status.txt echo sd_weights_downloaded >> ../scripts/install_status.txt
echo sd_install_complete >> ../scripts/install_status.txt echo sd_install_complete >> ../scripts/install_status.txt
@ -326,7 +264,8 @@ fi
printf "\n\nStable Diffusion is ready!\n\n" printf "\n\nStable Diffusion is ready!\n\n"
SD_PATH=`pwd` SD_PATH=`pwd`
export PYTHONPATH="$SD_PATH:$SD_PATH/env/lib/python3.8/site-packages"
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
echo "PYTHONPATH=$PYTHONPATH" echo "PYTHONPATH=$PYTHONPATH"
which python which python
@ -336,6 +275,6 @@ cd ..
export SD_UI_PATH=`pwd`/ui export SD_UI_PATH=`pwd`/ui
cd stable-diffusion cd stable-diffusion
uvicorn server:app --app-dir "$SD_UI_PATH" --port ${SD_UI_BIND_PORT:-9000} --host ${SD_UI_BIND_IP:-0.0.0.0} uvicorn main:server_api --app-dir "$SD_UI_PATH" --port ${SD_UI_BIND_PORT:-9000} --host ${SD_UI_BIND_IP:-0.0.0.0} --log-level error
read -p "Press any key to continue" read -p "Press any key to continue"

View File

165
ui/easydiffusion/app.py Normal file
View File

@ -0,0 +1,165 @@
import os
import socket
import sys
import json
import traceback
import logging
from rich.logging import RichHandler
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
from easydiffusion import task_manager
from easydiffusion.utils import log
# Remove all handlers associated with the root logger object.
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s'
logging.basicConfig(
level=logging.INFO,
format=LOG_FORMAT,
datefmt="%X",
handlers=[RichHandler(markup=True, rich_tracebacks=True, show_time=False, show_level=False)]
)
SD_DIR = os.getcwd()
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
sys.path.append(os.path.dirname(SD_UI_DIR))
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'))
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = {
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
'update_branch': 'main',
'ui': {
'open_browser_on_start': True,
},
}
def init():
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
update_render_threads()
def getConfig(default_val=APP_CONFIG_DEFAULTS):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return default_val
with open(config_json_path, 'r', encoding='utf-8') as f:
config = json.load(f)
if 'net' not in config:
config['net'] = {}
if os.getenv('SD_UI_BIND_PORT') is not None:
config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT'))
if os.getenv('SD_UI_BIND_IP') is not None:
config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0')
return config
except Exception as e:
log.warn(traceback.format_exc())
return default_val
def setConfig(config):
try: # config.json
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w', encoding='utf-8') as f:
json.dump(config, f)
except:
log.error(traceback.format_exc())
try: # config.bat
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
config_bat = []
if 'update_branch' in config:
config_bat.append(f"@set update_branch={config['update_branch']}")
config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
if len(config_bat) > 0:
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat))
except:
log.error(traceback.format_exc())
try: # config.sh
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
config_sh = ['#!/bin/bash']
if 'update_branch' in config:
config_sh.append(f"export update_branch={config['update_branch']}")
config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
if len(config_sh) > 1:
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh))
except:
log.error(traceback.format_exc())
def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = ckpt_model_name
config['model']['vae'] = vae_model_name
config['model']['hypernetwork'] = hypernetwork_model_name
if vae_model_name is None or vae_model_name == "":
del config['model']['vae']
if hypernetwork_model_name is None or hypernetwork_model_name == "":
del config['model']['hypernetwork']
config['vram_usage_level'] = vram_usage_level
setConfig(config)
def update_render_threads():
config = getConfig()
render_devices = config.get('render_devices', 'auto')
active_devices = task_manager.get_devices()['active'].keys()
log.debug(f'requesting for render_devices: {render_devices}')
task_manager.update_render_threads(render_devices, active_devices)
def getUIPlugins():
plugins = []
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
for file in os.listdir(plugins_dir):
if file.endswith('.plugin.js'):
plugins.append(f'/plugins/{dir_prefix}/{file}')
return plugins
def getIPConfig():
try:
ips = socket.gethostbyname_ex(socket.gethostname())
ips[2].append(ips[0])
return ips[2]
except Exception as e:
log.exception(e)
return []
def open_browser():
config = getConfig()
ui = config.get('ui', {})
net = config.get('net', {'listen_port':9000})
port = net.get('listen_port', 9000)
if ui.get('open_browser_on_start', True):
import webbrowser; webbrowser.open(f"http://localhost:{port}")

View File

@ -3,6 +3,15 @@ import torch
import traceback import traceback
import re import re
from easydiffusion.utils import log
'''
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
Otherwise the models will load at half-precision (i.e. float16).
Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs).
'''
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
mem_free_threshold = 0 mem_free_threshold = 0
@ -34,7 +43,7 @@ def get_device_delta(render_devices, active_devices):
if 'auto' in render_devices: if 'auto' in render_devices:
render_devices = auto_pick_devices(active_devices) render_devices = auto_pick_devices(active_devices)
if 'cpu' in render_devices: if 'cpu' in render_devices:
print('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!') log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
active_devices = set(active_devices) active_devices = set(active_devices)
render_devices = set(render_devices) render_devices = set(render_devices)
@ -53,7 +62,7 @@ def auto_pick_devices(currently_active_devices):
if device_count == 1: if device_count == 1:
return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu'] return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu']
print('Autoselecting GPU. Using most free memory.') log.debug('Autoselecting GPU. Using most free memory.')
devices = [] devices = []
for device in range(device_count): for device in range(device_count):
device = f'cuda:{device}' device = f'cuda:{device}'
@ -64,7 +73,7 @@ def auto_pick_devices(currently_active_devices):
mem_free /= float(10**9) mem_free /= float(10**9)
mem_total /= float(10**9) mem_total /= float(10**9)
device_name = torch.cuda.get_device_name(device) device_name = torch.cuda.get_device_name(device)
print(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb') log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb')
devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free}) devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free})
devices.sort(key=lambda x:x['mem_free'], reverse=True) devices.sort(key=lambda x:x['mem_free'], reverse=True)
@ -82,7 +91,7 @@ def auto_pick_devices(currently_active_devices):
devices = list(map(lambda x: x['device'], devices)) devices = list(map(lambda x: x['device'], devices))
return devices return devices
def device_init(thread_data, device): def device_init(context, device):
''' '''
This function assumes the 'device' has already been verified to be compatible. This function assumes the 'device' has already been verified to be compatible.
`get_device_delta()` has already filtered out incompatible devices. `get_device_delta()` has already filtered out incompatible devices.
@ -91,27 +100,45 @@ def device_init(thread_data, device):
validate_device_id(device, log_prefix='device_init') validate_device_id(device, log_prefix='device_init')
if device == 'cpu': if device == 'cpu':
thread_data.device = 'cpu' context.device = 'cpu'
thread_data.device_name = get_processor_name() context.device_name = get_processor_name()
print('Render device CPU available as', thread_data.device_name) context.half_precision = False
log.debug(f'Render device CPU available as {context.device_name}')
return return
thread_data.device_name = torch.cuda.get_device_name(device) context.device_name = torch.cuda.get_device_name(device)
thread_data.device = device context.device = device
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
device_name = thread_data.device_name.lower() if needs_to_force_full_precision(context):
thread_data.force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
if thread_data.force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', thread_data.device_name)
# Apply force_full_precision now before models are loaded. # Apply force_full_precision now before models are loaded.
thread_data.precision = 'full' context.half_precision = False
print(f'Setting {device} as active') log.info(f'Setting {device} as active, with precision: {"half" if context.half_precision else "full"}')
torch.cuda.device(device) torch.cuda.device(device)
return return
def needs_to_force_full_precision(context):
if 'FORCE_FULL_PRECISION' in os.environ:
return True
device_name = context.device_name.lower()
return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
def get_max_vram_usage_level(device):
if device != 'cpu':
_, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9)
if mem_total < 4.5:
return 'low'
elif mem_total < 6.5:
return 'balanced'
return 'high'
def validate_device_id(device, log_prefix=''): def validate_device_id(device, log_prefix=''):
def is_valid(): def is_valid():
if not isinstance(device, str): if not isinstance(device, str):
@ -132,7 +159,7 @@ def is_device_compatible(device):
try: try:
validate_device_id(device, log_prefix='is_device_compatible') validate_device_id(device, log_prefix='is_device_compatible')
except: except:
print(str(e)) log.error(str(e))
return False return False
if device == 'cpu': return True if device == 'cpu': return True
@ -141,10 +168,10 @@ def is_device_compatible(device):
_, mem_total = torch.cuda.mem_get_info(device) _, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9) mem_total /= float(10**9)
if mem_total < 3.0: if mem_total < 3.0:
print(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion') log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion')
return False return False
except RuntimeError as e: except RuntimeError as e:
print(str(e)) log.error(str(e))
return False return False
return True return True
@ -164,5 +191,5 @@ def get_processor_name():
if "model name" in line: if "model name" in line:
return re.sub(".*model name.*:", "", line, 1).strip() return re.sub(".*model name.*:", "", line, 1).strip()
except: except:
print(traceback.format_exc()) log.error(traceback.format_exc())
return "cpu" return "cpu"

View File

@ -0,0 +1,223 @@
import os
from easydiffusion import app, device_manager
from easydiffusion.types import TaskData
from easydiffusion.utils import log
from sdkit import Context
from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model
from sdkit.utils import hash_file_quick
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
MODEL_EXTENSIONS = {
'stable-diffusion': ['.ckpt', '.safetensors'],
'vae': ['.vae.pt', '.ckpt', '.safetensors'],
'hypernetwork': ['.pt', '.safetensors'],
'gfpgan': ['.pth'],
'realesrgan': ['.pth'],
}
DEFAULT_MODELS = {
'stable-diffusion': [ # needed to support the legacy installations
'custom-model', # only one custom model file was supported initially, creatively named 'custom-model'
'sd-v1-4', # Default fallback.
],
'gfpgan': ['GFPGANv1.3'],
'realesrgan': ['RealESRGAN_x4plus'],
}
VRAM_USAGE_LEVEL_TO_OPTIMIZATIONS = {
'balanced': {'KEEP_FS_AND_CS_IN_CPU', 'SET_ATTENTION_STEP_TO_4'},
'low': {'KEEP_ENTIRE_MODEL_IN_CPU'},
'high': {},
}
MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork']
known_models = {}
def init():
make_model_folders()
getModels() # run this once, to cache the picklescan results
def load_default_models(context: Context):
set_vram_optimizations(context)
# init default model paths
for model_type in MODELS_TO_LOAD_ON_START:
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
load_model(context, model_type)
def unload_all(context: Context):
for model_type in KNOWN_MODEL_TYPES:
unload_model(context, model_type)
def resolve_model_to_use(model_name:str=None, model_type:str=None):
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
default_models = DEFAULT_MODELS.get(model_type, [])
config = app.getConfig()
model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR]
if not model_name: # When None try user configured model.
# config = getConfig()
if 'model' in config and model_type in config['model']:
model_name = config['model'][model_type]
if model_name:
# Check models directory
models_dir_path = os.path.join(app.MODELS_DIR, model_type, model_name)
for model_extension in model_extensions:
if os.path.exists(models_dir_path + model_extension):
return models_dir_path + model_extension
if os.path.exists(model_name + model_extension):
return os.path.abspath(model_name + model_extension)
# Default locations
if model_name in default_models:
default_model_path = os.path.join(app.SD_DIR, model_name)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
return default_model_path + model_extension
# Can't find requested model, check the default paths.
for default_model in default_models:
for model_dir in model_dirs:
default_model_path = os.path.join(model_dir, default_model)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
if model_name is not None:
log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path + model_extension
return None
def reload_models_if_necessary(context: Context, task_data: TaskData):
model_paths_in_req = {
'stable-diffusion': task_data.use_stable_diffusion_model,
'vae': task_data.use_vae_model,
'hypernetwork': task_data.use_hypernetwork_model,
'gfpgan': task_data.use_face_correction,
'realesrgan': task_data.use_upscale,
}
models_to_reload = {model_type: path for model_type, path in model_paths_in_req.items() if context.model_paths.get(model_type) != path}
if set_vram_optimizations(context): # reload SD
models_to_reload['stable-diffusion'] = model_paths_in_req['stable-diffusion']
if 'stable-diffusion' in models_to_reload:
quick_hash = hash_file_quick(models_to_reload['stable-diffusion'])
known_model_info = get_model_info_from_db(quick_hash=quick_hash)
for model_type, model_path_in_req in models_to_reload.items():
context.model_paths[model_type] = model_path_in_req
action_fn = unload_model if context.model_paths[model_type] is None else load_model
action_fn(context, model_type, scan_model=False) # we've scanned them already
def resolve_model_paths(task_data: TaskData):
task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae')
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan')
def set_vram_optimizations(context: Context):
config = app.getConfig()
max_usage_level = device_manager.get_max_vram_usage_level(context.device)
vram_usage_level = config.get('vram_usage_level', 'balanced')
v = {'low': 0, 'balanced': 1, 'high': 2}
if v[vram_usage_level] > v[max_usage_level]:
log.error(f'Requested GPU Memory Usage level ({vram_usage_level}) is higher than what is ' + \
f'possible ({max_usage_level}) on this device ({context.device}). Using "{max_usage_level}" instead')
vram_usage_level = max_usage_level
vram_optimizations = VRAM_USAGE_LEVEL_TO_OPTIMIZATIONS[vram_usage_level]
if vram_optimizations != context.vram_optimizations:
context.vram_optimizations = vram_optimizations
return True
return False
def make_model_folders():
for model_type in KNOWN_MODEL_TYPES:
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
os.makedirs(model_dir_path, exist_ok=True)
help_file_name = f'Place your {model_type} model files here.txt'
help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}'
with open(os.path.join(model_dir_path, help_file_name), 'w', encoding='utf-8') as f:
f.write(help_file_contents)
def is_malicious_model(file_path):
try:
scan_result = scan_model(file_path)
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return True
else:
log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return False
except Exception as e:
log.error(f'error while scanning: {file_path}, error: {e}')
return False
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
'vae': '',
'hypernetwork': '',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
'vae': [],
'hypernetwork': [],
},
}
models_scanned = 0
def listModels(model_type):
nonlocal models_scanned
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
models_dir = os.path.join(app.MODELS_DIR, model_type)
if not os.path.exists(models_dir):
os.makedirs(models_dir)
for file in os.listdir(models_dir):
for model_extension in model_extensions:
if not file.endswith(model_extension):
continue
model_path = os.path.join(models_dir, file)
mtime = os.path.getmtime(model_path)
mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime:
models_scanned += 1
if is_malicious_model(model_path):
models['scan-error'] = file
return
known_models[model_path] = mtime
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates
models['options'][model_type].sort()
# custom models
listModels(model_type='stable-diffusion')
listModels(model_type='vae')
listModels(model_type='hypernetwork')
if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. Nothing infected[/]')
# legacy
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['options']['stable-diffusion'].append('custom-model')
return models

View File

@ -0,0 +1,124 @@
import queue
import time
import json
from easydiffusion import device_manager
from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop, GenerateImageRequest
from easydiffusion.utils import get_printable_request, save_images_to_disk, log
from sdkit import Context
from sdkit.generate import generate_images
from sdkit.filter import apply_filters
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc
context = Context() # thread-local
'''
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
'''
def init(device):
'''
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
'''
context.stop_processing = False
context.temp_images = {}
context.partial_x_samples = None
device_manager.device_init(context, device)
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
context.stop_processing = False
log.info(f'request: {get_printable_request(req)}')
log.info(f'task data: {task_data.dict()}')
images = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed))
res = res.json()
data_queue.put(json.dumps(res))
log.info('Task completed')
return res
def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
filtered_images = filter_images(task_data, images, user_stopped)
if task_data.save_to_disk_path is not None:
save_images_to_disk(images, filtered_images, req, task_data)
return filtered_images if task_data.show_only_filtered_image else images + filtered_images
def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
context.temp_images.clear()
callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
try:
images = generate_images(context, callback=callback, **req.dict())
user_stopped = False
except UserInitiatedStop:
images = []
user_stopped = True
if context.partial_x_samples is not None:
images = latent_samples_to_images(context, context.partial_x_samples)
context.partial_x_samples = None
finally:
gc(context)
return images, user_stopped
def filter_images(task_data: TaskData, images: list, user_stopped):
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
return images
filters_to_apply = []
if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan')
return apply_filters(context, filters_to_apply, images)
def construct_response(images: list, task_data: TaskData, base_seed: int):
return [
ResponseImage(
data=img_to_base64_str(img, task_data.output_format, task_data.output_quality),
seed=base_seed + i
) for i, img in enumerate(images)
]
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
last_callback_time = -1
def update_temp_img(x_samples, task_temp_images: list):
partial_images = []
images = latent_samples_to_images(context, x_samples)
for i, img in enumerate(images):
buf = img_to_buffer(img, output_format='JPEG')
context.temp_images[f"{task_data.request_id}/{i}"] = buf
task_temp_images[i] = buf
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
del images
return partial_images
def on_image_step(x_samples, i):
nonlocal last_callback_time
context.partial_x_samples = x_samples
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
if stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(x_samples, task_temp_images)
data_queue.put(json.dumps(progress))
step_callback()
if context.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return on_image_step

219
ui/easydiffusion/server.py Normal file
View File

@ -0,0 +1,219 @@
"""server.py: FastAPI SD-UI Web Host.
Notes:
async endpoints always run on the main thread. Without they run on the thread pool.
"""
import os
import traceback
import datetime
from typing import List, Union
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
from easydiffusion import app, model_manager, task_manager
from easydiffusion.types import TaskData, GenerateImageRequest
from easydiffusion.utils import log
log.info(f'started in {app.SD_DIR}')
log.info(f'started at {datetime.datetime.now():%x %X}')
server_api = FastAPI()
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
class NoCacheStaticFiles(StaticFiles):
def is_not_modified(self, response_headers, request_headers) -> bool:
if 'content-type' in response_headers and ('javascript' in response_headers['content-type'] or 'css' in response_headers['content-type']):
response_headers.update(NOCACHE_HEADERS)
return False
return super().is_not_modified(response_headers, request_headers)
class SetAppConfigRequest(BaseModel):
update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
model_vae: str = None
ui_open_browser_on_start: bool = None
listen_to_network: bool = None
listen_port: int = None
def init():
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media")
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
server_api.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}")
@server_api.post('/app_config')
async def set_app_config(req : SetAppConfigRequest):
return set_app_config_internal(req)
@server_api.get('/get/{key:path}')
def read_web_data(key:str=None):
return read_web_data_internal(key)
@server_api.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None):
return ping_internal(session_id)
@server_api.post('/render')
def render(req: dict):
return render_internal(req)
@server_api.get('/image/stream/{task_id:int}')
def stream(task_id:int):
return stream_internal(task_id)
@server_api.get('/image/stop')
def stop(task: int):
return stop_internal(task)
@server_api.get('/image/tmp/{task_id:int}/{img_id:int}')
def get_image(task_id: int, img_id: int):
return get_image_internal(task_id, img_id)
@server_api.get('/')
def read_root():
return FileResponse(os.path.join(app.SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
@server_api.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
# API implementations
def set_app_config_internal(req : SetAppConfigRequest):
config = app.getConfig()
if req.update_branch is not None:
config['update_branch'] = req.update_branch
if req.render_devices is not None:
update_render_devices_in_config(config, req.render_devices)
if req.ui_open_browser_on_start is not None:
if 'ui' not in config:
config['ui'] = {}
config['ui']['open_browser_on_start'] = req.ui_open_browser_on_start
if req.listen_to_network is not None:
if 'net' not in config:
config['net'] = {}
config['net']['listen_to_network'] = bool(req.listen_to_network)
if req.listen_port is not None:
if 'net' not in config:
config['net'] = {}
config['net']['listen_port'] = int(req.listen_port)
try:
app.setConfig(config)
if req.render_devices:
app.update_render_threads()
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def update_render_devices_in_config(config, render_devices):
if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'):
raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}')
if render_devices.startswith('cuda:'):
render_devices = render_devices.split(',')
config['render_devices'] = render_devices
def read_web_data_internal(key:str=None):
if not key: # /get without parameters, stable-diffusion easter egg.
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == 'app_config':
return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS)
elif key == 'system_info':
config = app.getConfig()
system_info = {
'devices': task_manager.get_devices(),
'hosts': app.getIPConfig(),
'default_output_dir': os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME),
}
system_info['devices']['config'] = config.get('render_devices', "auto")
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
elif key == 'models':
return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(app.SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
elif key == 'ui_plugins': return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS)
else:
raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
def ping_internal(session_id:str=None):
if task_manager.is_alive() <= 0: # Check that render threads are alive.
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
raise HTTPException(status_code=500, detail='Render thread is dead.')
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive
response = {'status': str(task_manager.current_state)}
if session_id:
session = task_manager.get_cached_session(session_id, update_ttl=True)
response['tasks'] = {id(t): t.status for t in session.tasks}
response['devices'] = task_manager.get_devices()
return JSONResponse(response, headers=NOCACHE_HEADERS)
def render_internal(req: dict):
try:
# separate out the request data into rendering and task-specific data
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
task_data: TaskData = TaskData.parse_obj(req)
render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision
app.save_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model, task_data.vram_usage_level)
# enqueue the task
new_task = task_manager.render(render_req, task_data)
response = {
'status': str(task_manager.current_state),
'queue': len(task_manager.tasks_queue),
'stream': f'/image/stream/{id(new_task)}',
'task': id(new_task)
}
return JSONResponse(response, headers=NOCACHE_HEADERS)
except ChildProcessError as e: # Render thread is dead
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def stream_internal(task_id:int):
#TODO Move to WebSockets ??
task = task_manager.get_cached_task(task_id, update_ttl=True)
if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#log.info(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
def stop_internal(task: int):
if not task:
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
task_manager.current_state_error = StopAsyncIteration('')
return {'OK'}
task_id = task
task = task_manager.get_cached_task(task_id, update_ttl=False)
if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration(f'Task {task_id} stop requested.')
return {'OK'}
def get_image_internal(task_id: int, img_id: int):
task = task_manager.get_cached_task(task_id, update_ttl=True)
if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
try:
img_data = task.temp_images[img_id]
img_data.seek(0)
return StreamingResponse(img_data, media_type='image/jpeg')
except KeyError as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@ -11,12 +11,13 @@ TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
import torch import torch
import queue, threading, time, weakref import queue, threading, time, weakref
from typing import Any, Generator, Hashable, Optional, Union from typing import Any, Hashable
from pydantic import BaseModel from easydiffusion import device_manager
from sd_internal import Request, Response, runtime, device_manager from easydiffusion.types import TaskData, GenerateImageRequest
from easydiffusion.utils import log
THREAD_NAME_PREFIX = 'Runtime-Render/' THREAD_NAME_PREFIX = ''
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
@ -36,12 +37,13 @@ class ServerStates:
class Unavailable(Symbol): pass class Unavailable(Symbol): pass
class RenderTask(): # Task with output queue and completion lock. class RenderTask(): # Task with output queue and completion lock.
def __init__(self, req: Request): def __init__(self, req: GenerateImageRequest, task_data: TaskData):
req.request_id = id(self) task_data.request_id = id(self)
self.request: Request = req # Initial Request self.render_request: GenerateImageRequest = req # Initial Request
self.task_data: TaskData = task_data
self.response: Any = None # Copy of the last reponse self.response: Any = None # Copy of the last reponse
self.render_device = None # Select the task affinity. (Not used to change active devices). self.render_device = None # Select the task affinity. (Not used to change active devices).
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) self.temp_images:list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
self.error: Exception = None self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
@ -69,54 +71,6 @@ class RenderTask(): # Task with output queue and completion lock.
def is_pending(self): def is_pending(self):
return bool(not self.response and not self.error) return bool(not self.response and not self.error)
# defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel):
session_id: str = "session"
prompt: str = ""
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = 42
prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False ##TODO Remove after UI and plugins transition.
render_device: str = None # Select the task affinity. (Not used to change active devices).
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
use_hypernetwork_model: str = None
hypernetwork_strength: float = None
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
stream_progress_updates: bool = False
stream_image_progress: bool = False
class FilterRequest(BaseModel):
session_id: str = "session"
model: str = None
name: str = ""
init_image: str = None # base64
width: int = 512
height: int = 512
save_to_disk_path: str = None
turbo: bool = True
render_device: str = None
use_full_precision: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
# Temporary cache to allow to query tasks results for a short time after they are completed. # Temporary cache to allow to query tasks results for a short time after they are completed.
class DataCache(): class DataCache():
def __init__(self): def __init__(self):
@ -139,11 +93,11 @@ class DataCache():
for key in to_delete: for key in to_delete:
(_, val) = self._base[key] (_, val) = self._base[key]
if isinstance(val, RenderTask): if isinstance(val, RenderTask):
print(f'RenderTask {key} expired. Data removed.') log.debug(f'RenderTask {key} expired. Data removed.')
elif isinstance(val, SessionState): elif isinstance(val, SessionState):
print(f'Session {key} expired. Data removed.') log.debug(f'Session {key} expired. Data removed.')
else: else:
print(f'Key {key} expired. Data removed.') log.debug(f'Key {key} expired. Data removed.')
del self._base[key] del self._base[key]
finally: finally:
self._lock.release() self._lock.release()
@ -177,8 +131,7 @@ class DataCache():
self._get_ttl_time(ttl), value self._get_ttl_time(ttl), value
) )
except Exception as e: except Exception as e:
print(str(e)) log.error(traceback.format_exc())
print(traceback.format_exc())
return False return False
else: else:
return True return True
@ -189,7 +142,7 @@ class DataCache():
try: try:
ttl, value = self._base.get(key, (None, None)) ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl): if ttl is not None and self._is_expired(ttl):
print(f'Session {key} expired. Discarding data.') log.debug(f'Session {key} expired. Discarding data.')
del self._base[key] del self._base[key]
return None return None
return value return value
@ -200,15 +153,9 @@ manager_lock = threading.RLock()
render_threads = [] render_threads = []
current_state = ServerStates.Init current_state = ServerStates.Init
current_state_error:Exception = None current_state_error:Exception = None
current_model_path = None
current_vae_path = None
current_hypernetwork_path = None
tasks_queue = [] tasks_queue = []
session_cache = DataCache() session_cache = DataCache()
task_cache = DataCache() task_cache = DataCache()
default_model_to_load = None
default_vae_to_load = None
default_hypernetwork_to_load = None
weak_thread_data = weakref.WeakKeyDictionary() weak_thread_data = weakref.WeakKeyDictionary()
idle_event: threading.Event = threading.Event() idle_event: threading.Event = threading.Event()
@ -236,40 +183,10 @@ class SessionState():
self._tasks_ids.pop(0) self._tasks_ids.pop(0)
return True return True
def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None):
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
if ckpt_file_path == None:
ckpt_file_path = default_model_to_load
if vae_file_path == None:
vae_file_path = default_vae_to_load
if hypernetwork_file_path == None:
hypernetwork_file_path = default_hypernetwork_to_load
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
return
current_state = ServerStates.LoadingModel
try:
from . import runtime
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
runtime.thread_data.ckpt_file = ckpt_file_path
runtime.thread_data.vae_file = vae_file_path
runtime.load_model_ckpt()
runtime.load_hypernetwork()
current_model_path = ckpt_file_path
current_vae_path = vae_file_path
current_hypernetwork_path = hypernetwork_file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
current_model_path = None
current_vae_path = None
current_state_error = e
current_state = ServerStates.Unavailable
print(traceback.format_exc())
def thread_get_next_task(): def thread_get_next_task():
from . import runtime from easydiffusion import renderer
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.') log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.')
return None return None
if len(tasks_queue) <= 0: if len(tasks_queue) <= 0:
manager_lock.release() manager_lock.release()
@ -277,7 +194,7 @@ def thread_get_next_task():
task = None task = None
try: # Select a render task. try: # Select a render task.
for queued_task in tasks_queue: for queued_task in tasks_queue:
if queued_task.render_device and runtime.thread_data.device != queued_task.render_device: if queued_task.render_device and renderer.context.device != queued_task.render_device:
# Is asking for a specific render device. # Is asking for a specific render device.
if is_alive(queued_task.render_device) > 0: if is_alive(queued_task.render_device) > 0:
continue # requested device alive, skip current one. continue # requested device alive, skip current one.
@ -286,7 +203,7 @@ def thread_get_next_task():
queued_task.error = Exception(queued_task.render_device + ' is not currently active.') queued_task.error = Exception(queued_task.render_device + ' is not currently active.')
task = queued_task task = queued_task
break break
if not queued_task.render_device and runtime.thread_data.device == 'cpu' and is_alive() > 1: if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1:
# not asking for any specific devices, cpu want to grab task but other render devices are alive. # not asking for any specific devices, cpu want to grab task but other render devices are alive.
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
task = queued_task task = queued_task
@ -298,31 +215,36 @@ def thread_get_next_task():
manager_lock.release() manager_lock.release()
def thread_render(device): def thread_render(device):
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path global current_state, current_state_error
from . import runtime
from easydiffusion import renderer, model_manager
try: try:
runtime.thread_init(device) renderer.init(device)
except Exception as e:
print(traceback.format_exc())
weak_thread_data[threading.current_thread()] = { weak_thread_data[threading.current_thread()] = {
'error': e 'device': renderer.context.device,
} 'device_name': renderer.context.device_name,
return
weak_thread_data[threading.current_thread()] = {
'device': runtime.thread_data.device,
'device_name': runtime.thread_data.device_name,
'alive': True 'alive': True
} }
if runtime.thread_data.device != 'cpu' or is_alive() == 1:
preload_model() current_state = ServerStates.LoadingModel
model_manager.load_default_models(renderer.context)
current_state = ServerStates.Online current_state = ServerStates.Online
except Exception as e:
log.error(traceback.format_exc())
weak_thread_data[threading.current_thread()] = {
'error': e,
'alive': False
}
return
while True: while True:
session_cache.clean() session_cache.clean()
task_cache.clean() task_cache.clean()
if not weak_thread_data[threading.current_thread()]['alive']: if not weak_thread_data[threading.current_thread()]['alive']:
print(f'Shutting down thread for device {runtime.thread_data.device}') log.info(f'Shutting down thread for device {renderer.context.device}')
runtime.unload_models() model_manager.unload_all(renderer.context)
runtime.unload_filters()
return return
if isinstance(current_state_error, SystemExit): if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable current_state = ServerStates.Unavailable
@ -333,7 +255,7 @@ def thread_render(device):
idle_event.wait(timeout=1) idle_event.wait(timeout=1)
continue continue
if task.error is not None: if task.error is not None:
print(task.error) log.error(task.error)
task.response = {"status": 'failed', "detail": str(task.error)} task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
continue continue
@ -342,51 +264,45 @@ def thread_render(device):
task.response = {"status": 'failed', "detail": str(task.error)} task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
continue continue
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}') log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try: try:
if runtime.is_hypernetwork_reload_necessary(task.request):
runtime.reload_hypernetwork()
current_hypernetwork_path = task.request.use_hypernetwork_model
if runtime.is_model_reload_necessary(task.request):
current_state = ServerStates.LoadingModel
runtime.reload_model()
current_model_path = task.request.use_stable_diffusion_model
current_vae_path = task.request.use_vae_model
def step_callback(): def step_callback():
global current_state_error global current_state_error
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.thread_data.stop_processing = True renderer.context.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration): if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error task.error = current_state_error
current_state_error = None current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}')
current_state = ServerStates.LoadingModel
model_manager.resolve_model_paths(task.task_data)
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
current_state = ServerStates.Rendering current_state = ServerStates.Rendering
task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback) task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
# Before looping back to the generator, mark cache as still alive. # Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL) task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL)
except Exception as e: except Exception as e:
task.error = e task.error = e
task.response = {"status": 'failed', "detail": str(task.error)} task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
print(traceback.format_exc()) log.error(traceback.format_exc())
continue continue
finally: finally:
# Task completed # Task completed
task.lock.release() task.lock.release()
task_cache.keep(id(task), TASK_TTL) task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration): if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!') log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!')
elif task.error is not None: elif task.error is not None:
print(f'Session {task.request.session_id} task {id(task)} failed!') log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
else: else:
print(f'Session {task.request.session_id} task {id(task)} completed by {runtime.thread_data.device_name}.') log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.')
current_state = ServerStates.Online current_state = ServerStates.Online
def get_cached_task(task_id:str, update_ttl:bool=False): def get_cached_task(task_id:str, update_ttl:bool=False):
@ -423,6 +339,7 @@ def get_devices():
'name': torch.cuda.get_device_name(device), 'name': torch.cuda.get_device_name(device),
'mem_free': mem_free, 'mem_free': mem_free,
'mem_total': mem_total, 'mem_total': mem_total,
'max_vram_usage_level': device_manager.get_max_vram_usage_level(device),
} }
# list the compatible devices # list the compatible devices
@ -472,7 +389,7 @@ def is_alive(device=None):
def start_render_thread(device): def start_render_thread(device):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED) if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED)
print('Start new Rendering Thread on device', device) log.info(f'Start new Rendering Thread on device: {device}')
try: try:
rthread = threading.Thread(target=thread_render, kwargs={'device': device}) rthread = threading.Thread(target=thread_render, kwargs={'device': device})
rthread.daemon = True rthread.daemon = True
@ -484,7 +401,7 @@ def start_render_thread(device):
timeout = DEVICE_START_TIMEOUT timeout = DEVICE_START_TIMEOUT
while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]: while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]:
if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]: if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]:
print(rthread, device, 'error:', weak_thread_data[rthread]['error']) log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
return False return False
if timeout <= 0: if timeout <= 0:
return False return False
@ -496,11 +413,11 @@ def stop_render_thread(device):
try: try:
device_manager.validate_device_id(device, log_prefix='stop_render_thread') device_manager.validate_device_id(device, log_prefix='stop_render_thread')
except: except:
print(traceback.format_exc()) log.error(traceback.format_exc())
return False return False
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED) if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED)
print('Stopping Rendering Thread on device', device) log.info(f'Stopping Rendering Thread on device: {device}')
try: try:
thread_to_remove = None thread_to_remove = None
@ -523,79 +440,44 @@ def stop_render_thread(device):
def update_render_threads(render_devices, active_devices): def update_render_threads(render_devices, active_devices):
devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices) devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices)
print('devices_to_start', devices_to_start) log.debug(f'devices_to_start: {devices_to_start}')
print('devices_to_stop', devices_to_stop) log.debug(f'devices_to_stop: {devices_to_stop}')
for device in devices_to_stop: for device in devices_to_stop:
if is_alive(device) <= 0: if is_alive(device) <= 0:
print(device, 'is not alive') log.debug(f'{device} is not alive')
continue continue
if not stop_render_thread(device): if not stop_render_thread(device):
print(device, 'could not stop render thread') log.warn(f'{device} could not stop render thread')
for device in devices_to_start: for device in devices_to_start:
if is_alive(device) >= 1: if is_alive(device) >= 1:
print(device, 'already registered.') log.debug(f'{device} already registered.')
continue continue
if not start_render_thread(device): if not start_render_thread(device):
print(device, 'failed to start.') log.warn(f'{device} failed to start.')
if is_alive() <= 0: # No running devices, probably invalid user config. if is_alive() <= 0: # No running devices, probably invalid user config.
raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json') raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json')
print('active devices', get_devices()['active']) log.debug(f"active devices: {get_devices()['active']}")
def shutdown_event(): # Signal render thread to close on shutdown def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error global current_state_error
current_state_error = SystemExit('Application shutting down.') current_state_error = SystemExit('Application shutting down.')
def render(req : ImageRequest): def render(render_req: GenerateImageRequest, task_data: TaskData):
current_thread_count = is_alive() current_thread_count = is_alive()
if current_thread_count <= 0: # Render thread is dead if current_thread_count <= 0: # Render thread is dead
raise ChildProcessError('Rendering thread has died.') raise ChildProcessError('Rendering thread has died.')
# Alive, check if task in cache # Alive, check if task in cache
session = get_cached_session(req.session_id, update_ttl=True) session = get_cached_session(task_data.session_id, update_ttl=True)
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
if current_thread_count < len(pending_tasks): if current_thread_count < len(pending_tasks):
raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
from . import runtime new_task = RenderTask(render_req, task_data)
r = Request()
r.session_id = req.session_id
r.prompt = req.prompt
r.negative_prompt = req.negative_prompt
r.init_image = req.init_image
r.mask = req.mask
r.num_outputs = req.num_outputs
r.num_inference_steps = req.num_inference_steps
r.guidance_scale = req.guidance_scale
r.width = req.width
r.height = req.height
r.seed = req.seed
r.prompt_strength = req.prompt_strength
r.sampler = req.sampler
# r.allow_nsfw = req.allow_nsfw
r.turbo = req.turbo
r.use_full_precision = req.use_full_precision
r.save_to_disk_path = req.save_to_disk_path
r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction
r.use_stable_diffusion_model = req.use_stable_diffusion_model
r.use_vae_model = req.use_vae_model
r.use_hypernetwork_model = req.use_hypernetwork_model
r.hypernetwork_strength = req.hypernetwork_strength
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.output_quality = req.output_quality
r.stream_progress_updates = True # the underlying implementation only supports streaming
r.stream_image_progress = req.stream_image_progress
if not req.stream_progress_updates:
r.stream_image_progress = False
new_task = RenderTask(r)
if session.put(new_task, TASK_TTL): if session.put(new_task, TASK_TTL):
# Use twice the normal timeout for adding user requests. # Use twice the normal timeout for adding user requests.
# Tries to force session.put to fail before tasks_queue.put would. # Tries to force session.put to fail before tasks_queue.put would.

87
ui/easydiffusion/types.py Normal file
View File

@ -0,0 +1,87 @@
from pydantic import BaseModel
from typing import Any
class GenerateImageRequest(BaseModel):
prompt: str = ""
negative_prompt: str = ""
seed: int = 42
width: int = 512
height: int = 512
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
init_image: Any = None
init_image_mask: Any = None
prompt_strength: float = 0.8
preserve_init_image_color_profile = False
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
hypernetwork_strength: float = 0
class TaskData(BaseModel):
request_id: str = None
session_id: str = "session"
save_to_disk_path: str = None
vram_usage_level: str = "balanced" # or "low" or "medium"
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_stable_diffusion_config: str = "v1-inference"
use_vae_model: str = None
use_hypernetwork_model: str = None
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
metadata_output_format: str = "txt" # or "json"
stream_image_progress: bool = False
class Image:
data: str # base64
seed: int
is_nsfw: bool
path_abs: str = None
def __init__(self, data, seed):
self.data = data
self.seed = seed
def json(self):
return {
"data": self.data,
"seed": self.seed,
"path_abs": self.path_abs,
}
class Response:
render_request: GenerateImageRequest
task_data: TaskData
images: list
def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list):
self.render_request = render_request
self.task_data = task_data
self.images = images
def json(self):
del self.render_request.init_image
del self.render_request.init_image_mask
res = {
"status": 'succeeded',
"render_request": self.render_request.dict(),
"task_data": self.task_data.dict(),
"output": [],
}
for image in self.images:
res["output"].append(image.json())
return res
class UserInitiatedStop(Exception):
pass

View File

@ -0,0 +1,8 @@
import logging
log = logging.getLogger('easydiffusion')
from .save_utils import (
save_images_to_disk,
get_printable_request,
)

View File

@ -0,0 +1,79 @@
import os
import time
import base64
import re
from easydiffusion.types import TaskData, GenerateImageRequest
from sdkit.utils import save_images, save_dicts
filename_regex = re.compile('[^a-zA-Z0-9]')
# keep in sync with `ui/media/js/dnd.js`
TASK_TEXT_MAPPING = {
'prompt': 'Prompt',
'width': 'Width',
'height': 'Height',
'seed': 'Seed',
'num_inference_steps': 'Steps',
'guidance_scale': 'Guidance Scale',
'prompt_strength': 'Prompt Strength',
'use_face_correction': 'Use Face Correction',
'use_upscale': 'Use Upscaling',
'sampler_name': 'Sampler',
'negative_prompt': 'Negative Prompt',
'use_stable_diffusion_model': 'Stable Diffusion model',
'use_hypernetwork_model': 'Hypernetwork model',
'hypernetwork_strength': 'Hypernetwork Strength'
}
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
metadata_entries = get_metadata_entries_for_request(req, task_data)
if task_data.show_only_filtered_image or filtered_images == images:
save_images(filtered_images, save_dir_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
save_dicts(metadata_entries, save_dir_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format)
else:
save_images(images, save_dir_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
save_images(filtered_images, save_dir_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality)
save_dicts(metadata_entries, save_dir_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format)
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData):
metadata = get_printable_request(req)
metadata.update({
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
'use_vae_model': task_data.use_vae_model,
'use_hypernetwork_model': task_data.use_hypernetwork_model,
'use_face_correction': task_data.use_face_correction,
'use_upscale': task_data.use_upscale,
})
# if text, format it in the text format expected by the UI
is_txt_format = (task_data.metadata_output_format.lower() == 'txt')
if is_txt_format:
metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING}
entries = [metadata.copy() for _ in range(req.num_outputs)]
for i, entry in enumerate(entries):
entry['Seed' if is_txt_format else 'seed'] = req.seed + i
return entries
def get_printable_request(req: GenerateImageRequest):
metadata = req.dict()
del metadata['init_image']
del metadata['init_image_mask']
return metadata
def make_filename_callback(req: GenerateImageRequest, suffix=None):
def make_filename(i):
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
name = f"{prompt_flattened}_{img_id}"
name = name if suffix is None else f'{name}_{suffix}'
return name
return make_filename

View File

@ -24,8 +24,8 @@
<div id="top-nav"> <div id="top-nav">
<div id="logo"> <div id="logo">
<h1> <h1>
Stable Diffusion UI Easy Diffusion
<small>v2.4.22 <span id="updateBranchLabel"></span></small> <small>v2.5.0 <span id="updateBranchLabel"></span></small>
</h1> </h1>
</div> </div>
<div id="server-status"> <div id="server-status">
@ -129,22 +129,32 @@
</select> </select>
<a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a> <a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a>
</td></tr> </td></tr>
<!-- <tr id="modelConfigSelection" class="pl-5"><td><label for="model_config">Model Config:</i></label></td><td>
<select id="model_config" name="model_config">
</select>
</td></tr> -->
<tr class="pl-5"><td><label for="vae_model">Custom VAE:</i></label></td><td> <tr class="pl-5"><td><label for="vae_model">Custom VAE:</i></label></td><td>
<select id="vae_model" name="vae_model"> <select id="vae_model" name="vae_model">
<!-- <option value="" selected>None</option> --> <!-- <option value="" selected>None</option> -->
</select> </select>
<a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about VAEs</span></i></a> <a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about VAEs</span></i></a>
</td></tr> </td></tr>
<tr id="samplerSelection" class="pl-5"><td><label for="sampler">Sampler:</label></td><td> <tr id="samplerSelection" class="pl-5"><td><label for="sampler_name">Sampler:</label></td><td>
<select id="sampler" name="sampler"> <select id="sampler_name" name="sampler_name">
<option value="plms">plms</option> <option value="plms">PLMS</option>
<option value="ddim">ddim</option> <option value="ddim">DDIM</option>
<option value="heun">heun</option> <option value="heun">Heun</option>
<option value="euler">euler</option> <option value="euler">Euler</option>
<option value="euler_a" selected>euler_a</option> <option value="euler_a" selected>Euler Ancestral</option>
<option value="dpm2">dpm2</option> <option value="dpm2">DPM2</option>
<option value="dpm2_a">dpm2_a</option> <option value="dpm2_a">DPM2 Ancestral</option>
<option value="lms">lms</option> <option value="lms">LMS</option>
<option value="dpm_solver_stability">DPM Solver (Stability AI)</option>
<option value="dpmpp_2s_a" selected>DPM++ 2s Ancestral</option>
<option value="dpmpp_2m">DPM++ 2m</option>
<option value="dpmpp_sde">DPM++ SDE</option>
<option value="dpm_fast">DPM Fast</option>
<option value="dpm_adaptive">DPM Adaptive</option>
</select> </select>
<a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/How-to-Use#samplers" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about samplers</span></i></a> <a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/How-to-Use#samplers" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about samplers</span></i></a>
</td></tr> </td></tr>
@ -220,6 +230,7 @@
<div><ul> <div><ul>
<li><b class="settings-subheader">Render Settings</b></li> <li><b class="settings-subheader">Render Settings</b></li>
<li class="pl-5"><input id="stream_image_progress" name="stream_image_progress" type="checkbox"> <label for="stream_image_progress">Show a live preview <small>(uses more VRAM, slower images)</small></label></li> <li class="pl-5"><input id="stream_image_progress" name="stream_image_progress" type="checkbox"> <label for="stream_image_progress">Show a live preview <small>(uses more VRAM, slower images)</small></label></li>
<li id="apply_color_correction_setting" class="pl-5"><input id="apply_color_correction" name="apply_color_correction" type="checkbox"> <label for="apply_color_correction">Preserve color profile <small>(helps during inpainting)</small></label></li>
<li class="pl-5"><input id="use_face_correction" name="use_face_correction" type="checkbox"> <label for="use_face_correction">Fix incorrect faces and eyes <small>(uses GFPGAN)</small></label></li> <li class="pl-5"><input id="use_face_correction" name="use_face_correction" type="checkbox"> <label for="use_face_correction">Fix incorrect faces and eyes <small>(uses GFPGAN)</small></label></li>
<li class="pl-5"> <li class="pl-5">
<input id="use_upscale" name="use_upscale" type="checkbox"> <label for="use_upscale">Upscale image by 4x with </label> <input id="use_upscale" name="use_upscale" type="checkbox"> <label for="use_upscale">Upscale image by 4x with </label>
@ -416,7 +427,6 @@
async function init() { async function init() {
await initSettings() await initSettings()
await getModels() await getModels()
await getDiskPath()
await getAppConfig() await getAppConfig()
await loadUIPlugins() await loadUIPlugins()
await loadModifiers() await loadModifiers()

10
ui/main.py Normal file
View File

@ -0,0 +1,10 @@
from easydiffusion import model_manager, app, server
from easydiffusion.server import server_api # required for uvicorn
# Init the app
model_manager.init()
app.init()
server.init()
# start the browser ui
app.open_browser()

View File

@ -15,7 +15,7 @@ const SETTINGS_IDS_LIST = [
"stable_diffusion_model", "stable_diffusion_model",
"vae_model", "vae_model",
"hypernetwork_model", "hypernetwork_model",
"sampler", "sampler_name",
"width", "width",
"height", "height",
"num_inference_steps", "num_inference_steps",
@ -36,10 +36,11 @@ const SETTINGS_IDS_LIST = [
"save_to_disk", "save_to_disk",
"diskPath", "diskPath",
"sound_toggle", "sound_toggle",
"turbo", "vram_usage_level",
"use_full_precision",
"confirm_dangerous_actions", "confirm_dangerous_actions",
"auto_save_settings" "metadata_output_format",
"auto_save_settings",
"apply_color_correction"
] ]
const IGNORE_BY_DEFAULT = [ const IGNORE_BY_DEFAULT = [
@ -277,7 +278,6 @@ function tryLoadOldSettings() {
"soundEnabled": "sound_toggle", "soundEnabled": "sound_toggle",
"saveToDisk": "save_to_disk", "saveToDisk": "save_to_disk",
"useCPU": "use_cpu", "useCPU": "use_cpu",
"useFullPrecision": "use_full_precision",
"useTurboMode": "turbo", "useTurboMode": "turbo",
"diskPath": "diskPath", "diskPath": "diskPath",
"useFaceCorrection": "use_face_correction", "useFaceCorrection": "use_face_correction",

View File

@ -25,6 +25,7 @@ function parseBoolean(stringValue) {
case "no": case "no":
case "off": case "off":
case "0": case "0":
case "none":
case null: case null:
case undefined: case undefined:
return false; return false;
@ -160,9 +161,9 @@ const TASK_MAPPING = {
readUI: () => (useUpscalingField.checked ? upscaleModelField.value : undefined), readUI: () => (useUpscalingField.checked ? upscaleModelField.value : undefined),
parse: (val) => val parse: (val) => val
}, },
sampler: { name: 'Sampler', sampler_name: { name: 'Sampler',
setUI: (sampler) => { setUI: (sampler_name) => {
samplerField.value = sampler samplerField.value = sampler_name
}, },
readUI: () => samplerField.value, readUI: () => samplerField.value,
parse: (val) => val parse: (val) => val
@ -171,7 +172,7 @@ const TASK_MAPPING = {
setUI: (use_stable_diffusion_model) => { setUI: (use_stable_diffusion_model) => {
const oldVal = stableDiffusionModelField.value const oldVal = stableDiffusionModelField.value
use_stable_diffusion_model = getModelPath(use_stable_diffusion_model, ['.ckpt']) use_stable_diffusion_model = getModelPath(use_stable_diffusion_model, ['.ckpt', '.safetensors'])
stableDiffusionModelField.value = use_stable_diffusion_model stableDiffusionModelField.value = use_stable_diffusion_model
if (!stableDiffusionModelField.value) { if (!stableDiffusionModelField.value) {
@ -184,6 +185,7 @@ const TASK_MAPPING = {
use_vae_model: { name: 'VAE model', use_vae_model: { name: 'VAE model',
setUI: (use_vae_model) => { setUI: (use_vae_model) => {
const oldVal = vaeModelField.value const oldVal = vaeModelField.value
use_vae_model = (use_vae_model === undefined || use_vae_model === null || use_vae_model === 'None' ? '' : use_vae_model)
if (use_vae_model !== '') { if (use_vae_model !== '') {
use_vae_model = getModelPath(use_vae_model, ['.vae.pt', '.ckpt']) use_vae_model = getModelPath(use_vae_model, ['.vae.pt', '.ckpt'])
@ -197,6 +199,7 @@ const TASK_MAPPING = {
use_hypernetwork_model: { name: 'Hypernetwork model', use_hypernetwork_model: { name: 'Hypernetwork model',
setUI: (use_hypernetwork_model) => { setUI: (use_hypernetwork_model) => {
const oldVal = hypernetworkModelField.value const oldVal = hypernetworkModelField.value
use_hypernetwork_model = (use_hypernetwork_model === undefined || use_hypernetwork_model === null || use_hypernetwork_model === 'None' ? '' : use_hypernetwork_model)
if (use_hypernetwork_model !== '') { if (use_hypernetwork_model !== '') {
use_hypernetwork_model = getModelPath(use_hypernetwork_model, ['.pt']) use_hypernetwork_model = getModelPath(use_hypernetwork_model, ['.pt'])
@ -239,13 +242,6 @@ const TASK_MAPPING = {
readUI: () => turboField.checked, readUI: () => turboField.checked,
parse: (val) => Boolean(val) parse: (val) => Boolean(val)
}, },
use_full_precision: { name: 'Use Full Precision',
setUI: (use_full_precision) => {
useFullPrecisionField.checked = use_full_precision
},
readUI: () => useFullPrecisionField.checked,
parse: (val) => Boolean(val)
},
stream_image_progress: { name: 'Stream Image Progress', stream_image_progress: { name: 'Stream Image Progress',
setUI: (stream_image_progress) => { setUI: (stream_image_progress) => {
@ -350,6 +346,7 @@ function getModelPath(filename, extensions)
} }
const TASK_TEXT_MAPPING = { const TASK_TEXT_MAPPING = {
prompt: 'Prompt',
width: 'Width', width: 'Width',
height: 'Height', height: 'Height',
seed: 'Seed', seed: 'Seed',
@ -358,7 +355,7 @@ const TASK_TEXT_MAPPING = {
prompt_strength: 'Prompt Strength', prompt_strength: 'Prompt Strength',
use_face_correction: 'Use Face Correction', use_face_correction: 'Use Face Correction',
use_upscale: 'Use Upscaling', use_upscale: 'Use Upscaling',
sampler: 'Sampler', sampler_name: 'Sampler',
negative_prompt: 'Negative Prompt', negative_prompt: 'Negative Prompt',
use_stable_diffusion_model: 'Stable Diffusion model', use_stable_diffusion_model: 'Stable Diffusion model',
use_hypernetwork_model: 'Hypernetwork model', use_hypernetwork_model: 'Hypernetwork model',
@ -410,6 +407,9 @@ async function parseContent(text) {
if (text.startsWith('{') && text.endsWith('}')) { if (text.startsWith('{') && text.endsWith('}')) {
try { try {
const task = JSON.parse(text) const task = JSON.parse(text)
if (!('reqBody' in task)) { // support the format saved to the disk, by the UI
task.reqBody = Object.assign({}, task)
}
restoreTaskToUI(task) restoreTaskToUI(task)
return true return true
} catch (e) { } catch (e) {
@ -477,7 +477,6 @@ document.addEventListener("dragover", dragOverHandler)
const TASK_REQ_NO_EXPORT = [ const TASK_REQ_NO_EXPORT = [
"use_cpu", "use_cpu",
"turbo", "turbo",
"use_full_precision",
"save_to_disk_path" "save_to_disk_path"
] ]
const resetSettings = document.getElementById('reset-image-settings') const resetSettings = document.getElementById('reset-image-settings')

View File

@ -728,7 +728,6 @@
"stream_image_progress": 'boolean', "stream_image_progress": 'boolean',
"show_only_filtered_image": 'boolean', "show_only_filtered_image": 'boolean',
"turbo": 'boolean', "turbo": 'boolean',
"use_full_precision": 'boolean',
"output_format": 'string', "output_format": 'string',
"output_quality": 'number', "output_quality": 'number',
} }
@ -744,7 +743,6 @@
"stream_image_progress": true, "stream_image_progress": true,
"show_only_filtered_image": true, "show_only_filtered_image": true,
"turbo": false, "turbo": false,
"use_full_precision": false,
"output_format": "png", "output_format": "png",
"output_quality": 75, "output_quality": 75,
} }

View File

@ -26,9 +26,11 @@ let initImagePreview = document.querySelector("#init_image_preview")
let initImageSizeBox = document.querySelector("#init_image_size_box") let initImageSizeBox = document.querySelector("#init_image_size_box")
let maskImageSelector = document.querySelector("#mask") let maskImageSelector = document.querySelector("#mask")
let maskImagePreview = document.querySelector("#mask_preview") let maskImagePreview = document.querySelector("#mask_preview")
let applyColorCorrectionField = document.querySelector('#apply_color_correction')
let colorCorrectionSetting = document.querySelector('#apply_color_correction_setting')
let promptStrengthSlider = document.querySelector('#prompt_strength_slider') let promptStrengthSlider = document.querySelector('#prompt_strength_slider')
let promptStrengthField = document.querySelector('#prompt_strength') let promptStrengthField = document.querySelector('#prompt_strength')
let samplerField = document.querySelector('#sampler') let samplerField = document.querySelector('#sampler_name')
let samplerSelectionContainer = document.querySelector("#samplerSelection") let samplerSelectionContainer = document.querySelector("#samplerSelection")
let useFaceCorrectionField = document.querySelector("#use_face_correction") let useFaceCorrectionField = document.querySelector("#use_face_correction")
let useUpscalingField = document.querySelector("#use_upscale") let useUpscalingField = document.querySelector("#use_upscale")
@ -610,7 +612,7 @@ function onTaskCompleted(task, reqBody, instance, outputContainer, stepUpdate) {
<b>Suggestions</b>: <b>Suggestions</b>:
<br/> <br/>
1. If you have set an initial image, please try reducing its dimension to ${MAX_INIT_IMAGE_DIMENSION}x${MAX_INIT_IMAGE_DIMENSION} or smaller.<br/> 1. If you have set an initial image, please try reducing its dimension to ${MAX_INIT_IMAGE_DIMENSION}x${MAX_INIT_IMAGE_DIMENSION} or smaller.<br/>
2. Try disabling the '<em>Turbo mode</em>' under '<em>Advanced Settings</em>'.<br/> 2. Try picking a lower level in the '<em>GPU Memory Usage</em>' setting (in the '<em>Settings</em>' tab).<br/>
3. Try generating a smaller image.<br/>` 3. Try generating a smaller image.<br/>`
} }
} else { } else {
@ -789,7 +791,8 @@ function createTask(task) {
let w = task.reqBody.width * h / task.reqBody.height >>0 let w = task.reqBody.width * h / task.reqBody.height >>0
taskConfig += `<div class="task-initimg" style="float:left;"><img style="width:${w}px;height:${h}px;" src="${task.reqBody.init_image}"><div class="task-fs-initimage"></div></div>` taskConfig += `<div class="task-initimg" style="float:left;"><img style="width:${w}px;height:${h}px;" src="${task.reqBody.init_image}"><div class="task-fs-initimage"></div></div>`
} }
taskConfig += `<b>Seed:</b> ${task.seed}, <b>Sampler:</b> ${task.reqBody.sampler}, <b>Inference Steps:</b> ${task.reqBody.num_inference_steps}, <b>Guidance Scale:</b> ${task.reqBody.guidance_scale}, <b>Model:</b> ${task.reqBody.use_stable_diffusion_model}` taskConfig += `<b>Seed:</b> ${task.seed}, <b>Sampler:</b> ${task.reqBody.sampler_name}, <b>Inference Steps:</b> ${task.reqBody.num_inference_steps}, <b>Guidance Scale:</b> ${task.reqBody.guidance_scale}, <b>Model:</b> ${task.reqBody.use_stable_diffusion_model}`
if (task.reqBody.use_vae_model.trim() !== '') { if (task.reqBody.use_vae_model.trim() !== '') {
taskConfig += `, <b>VAE:</b> ${task.reqBody.use_vae_model}` taskConfig += `, <b>VAE:</b> ${task.reqBody.use_vae_model}`
} }
@ -809,6 +812,9 @@ function createTask(task) {
taskConfig += `, <b>Hypernetwork:</b> ${task.reqBody.use_hypernetwork_model}` taskConfig += `, <b>Hypernetwork:</b> ${task.reqBody.use_hypernetwork_model}`
taskConfig += `, <b>Hypernetwork Strength:</b> ${task.reqBody.hypernetwork_strength}` taskConfig += `, <b>Hypernetwork Strength:</b> ${task.reqBody.hypernetwork_strength}`
} }
if (task.reqBody.preserve_init_image_color_profile) {
taskConfig += `, <b>Preserve Color Profile:</b> true`
}
let taskEntry = document.createElement('div') let taskEntry = document.createElement('div')
taskEntry.id = `imageTaskContainer-${Date.now()}` taskEntry.id = `imageTaskContainer-${Date.now()}`
@ -914,9 +920,8 @@ function getCurrentUserRequest() {
width: parseInt(widthField.value), width: parseInt(widthField.value),
height: parseInt(heightField.value), height: parseInt(heightField.value),
// allow_nsfw: allowNSFWField.checked, // allow_nsfw: allowNSFWField.checked,
turbo: turboField.checked, vram_usage_level: vramUsageLevelField.value,
//render_device: undefined, // Set device affinity. Prefer this device, but wont activate. //render_device: undefined, // Set device affinity. Prefer this device, but wont activate.
use_full_precision: useFullPrecisionField.checked,
use_stable_diffusion_model: stableDiffusionModelField.value, use_stable_diffusion_model: stableDiffusionModelField.value,
use_vae_model: vaeModelField.value, use_vae_model: vaeModelField.value,
stream_progress_updates: true, stream_progress_updates: true,
@ -924,6 +929,7 @@ function getCurrentUserRequest() {
show_only_filtered_image: showOnlyFilteredImageField.checked, show_only_filtered_image: showOnlyFilteredImageField.checked,
output_format: outputFormatField.value, output_format: outputFormatField.value,
output_quality: parseInt(outputQualityField.value), output_quality: parseInt(outputQualityField.value),
metadata_output_format: document.querySelector('#metadata_output_format').value,
original_prompt: promptField.value, original_prompt: promptField.value,
active_tags: (activeTags.map(x => x.name)) active_tags: (activeTags.map(x => x.name))
} }
@ -938,9 +944,10 @@ function getCurrentUserRequest() {
if (maskSetting.checked) { if (maskSetting.checked) {
newTask.reqBody.mask = imageInpainter.getImg() newTask.reqBody.mask = imageInpainter.getImg()
} }
newTask.reqBody.sampler = 'ddim' newTask.reqBody.preserve_init_image_color_profile = applyColorCorrectionField.checked
newTask.reqBody.sampler_name = 'ddim'
} else { } else {
newTask.reqBody.sampler = samplerField.value newTask.reqBody.sampler_name = samplerField.value
} }
if (saveToDiskField.checked && diskPathField.value.trim() !== '') { if (saveToDiskField.checked && diskPathField.value.trim() !== '') {
newTask.reqBody.save_to_disk_path = diskPathField.value.trim() newTask.reqBody.save_to_disk_path = diskPathField.value.trim()
@ -1349,6 +1356,7 @@ function img2imgLoad() {
promptStrengthContainer.style.display = 'table-row' promptStrengthContainer.style.display = 'table-row'
samplerSelectionContainer.style.display = "none" samplerSelectionContainer.style.display = "none"
initImagePreviewContainer.classList.add("has-image") initImagePreviewContainer.classList.add("has-image")
colorCorrectionSetting.style.display = ''
initImageSizeBox.textContent = initImagePreview.naturalWidth + " x " + initImagePreview.naturalHeight initImageSizeBox.textContent = initImagePreview.naturalWidth + " x " + initImagePreview.naturalHeight
imageEditor.setImage(this.src, initImagePreview.naturalWidth, initImagePreview.naturalHeight) imageEditor.setImage(this.src, initImagePreview.naturalWidth, initImagePreview.naturalHeight)
@ -1363,6 +1371,7 @@ function img2imgUnload() {
promptStrengthContainer.style.display = "none" promptStrengthContainer.style.display = "none"
samplerSelectionContainer.style.display = "" samplerSelectionContainer.style.display = ""
initImagePreviewContainer.classList.remove("has-image") initImagePreviewContainer.classList.remove("has-image")
colorCorrectionSetting.style.display = 'none'
imageEditor.setImage(null, parseInt(widthField.value), parseInt(heightField.value)) imageEditor.setImage(null, parseInt(widthField.value), parseInt(heightField.value))
} }

View File

@ -53,6 +53,23 @@ var PARAMETERS = [
return `<input id="${parameter.id}" name="${parameter.id}" size="30" disabled>` return `<input id="${parameter.id}" name="${parameter.id}" size="30" disabled>`
} }
}, },
{
id: "metadata_output_format",
type: ParameterType.select,
label: "Metadata format",
note: "will be saved to disk in this format",
default: "txt",
options: [
{
value: "txt",
label: "txt"
},
{
value: "json",
label: "json"
}
],
},
{ {
id: "sound_toggle", id: "sound_toggle",
type: ParameterType.checkbox, type: ParameterType.checkbox,
@ -77,12 +94,20 @@ var PARAMETERS = [
default: true, default: true,
}, },
{ {
id: "turbo", id: "vram_usage_level",
type: ParameterType.checkbox, type: ParameterType.select,
label: "Turbo Mode", label: "GPU Memory Usage",
note: "generates images faster, but uses an additional 1 GB of GPU memory", note: "Faster performance requires more GPU memory (VRAM)<br/><br/>" +
"<b>Balanced:</b> nearly as fast as High, much lower VRAM usage<br/>" +
"<b>High:</b> fastest, maximum GPU memory usage</br>" +
"<b>Low:</b> slowest, force-used for GPUs with 4 GB (or less) memory",
icon: "fa-forward", icon: "fa-forward",
default: true, default: "balanced",
options: [
{value: "balanced", label: "Balanced"},
{value: "high", label: "High"},
{value: "low", label: "Low"}
],
}, },
{ {
id: "use_cpu", id: "use_cpu",
@ -105,14 +130,6 @@ var PARAMETERS = [
note: "to process in parallel", note: "to process in parallel",
default: false, default: false,
}, },
{
id: "use_full_precision",
type: ParameterType.checkbox,
label: "Use Full Precision",
note: "for GPU-only. warning: this will consume more VRAM",
icon: "fa-crosshairs",
default: false,
},
{ {
id: "auto_save_settings", id: "auto_save_settings",
type: ParameterType.checkbox, type: ParameterType.checkbox,
@ -147,14 +164,6 @@ var PARAMETERS = [
return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">` return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">`
} }
}, },
{
id: "test_sd2",
type: ParameterType.checkbox,
label: "Test SD 2.0",
note: "Experimental! High memory usage! GPU-only! Not the final version! Please restart the program after changing this.",
icon: "fa-fire",
default: false,
},
{ {
id: "use_beta_channel", id: "use_beta_channel",
type: ParameterType.checkbox, type: ParameterType.checkbox,
@ -210,16 +219,14 @@ function initParameters() {
initParameters() initParameters()
let turboField = document.querySelector('#turbo') let vramUsageLevelField = document.querySelector('#vram_usage_level')
let useCPUField = document.querySelector('#use_cpu') let useCPUField = document.querySelector('#use_cpu')
let autoPickGPUsField = document.querySelector('#auto_pick_gpus') let autoPickGPUsField = document.querySelector('#auto_pick_gpus')
let useGPUsField = document.querySelector('#use_gpus') let useGPUsField = document.querySelector('#use_gpus')
let useFullPrecisionField = document.querySelector('#use_full_precision')
let saveToDiskField = document.querySelector('#save_to_disk') let saveToDiskField = document.querySelector('#save_to_disk')
let diskPathField = document.querySelector('#diskPath') let diskPathField = document.querySelector('#diskPath')
let listenToNetworkField = document.querySelector("#listen_to_network") let listenToNetworkField = document.querySelector("#listen_to_network")
let listenPortField = document.querySelector("#listen_port") let listenPortField = document.querySelector("#listen_port")
let testSD2Field = document.querySelector("#test_sd2")
let useBetaChannelField = document.querySelector("#use_beta_channel") let useBetaChannelField = document.querySelector("#use_beta_channel")
let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start") let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start")
let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions") let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions")
@ -256,12 +263,6 @@ async function getAppConfig() {
if (config.ui && config.ui.open_browser_on_start === false) { if (config.ui && config.ui.open_browser_on_start === false) {
uiOpenBrowserOnStartField.checked = false uiOpenBrowserOnStartField.checked = false
} }
if ('test_sd2' in config) {
testSD2Field.checked = config['test_sd2']
}
let testSD2SettingEntry = getParameterSettingsEntry('test_sd2')
testSD2SettingEntry.style.display = (config.update_branch === 'beta' ? '' : 'none')
if (config.net && config.net.listen_to_network === false) { if (config.net && config.net.listen_to_network === false) {
listenToNetworkField.checked = false listenToNetworkField.checked = false
} }
@ -327,20 +328,10 @@ autoPickGPUsField.addEventListener('click', function() {
gpuSettingEntry.style.display = (this.checked ? 'none' : '') gpuSettingEntry.style.display = (this.checked ? 'none' : '')
}) })
async function getDiskPath() { async function setDiskPath(defaultDiskPath) {
try {
var diskPath = getSetting("diskPath") var diskPath = getSetting("diskPath")
if (diskPath == '' || diskPath == undefined || diskPath == "undefined") { if (diskPath == '' || diskPath == undefined || diskPath == "undefined") {
let res = await fetch('/get/output_dir') setSetting("diskPath", defaultDiskPath)
if (res.status === 200) {
res = await res.json()
res = res.output_dir
setSetting("diskPath", res)
}
}
} catch (e) {
console.log('error fetching output dir path', e)
} }
} }
@ -415,6 +406,7 @@ async function getSystemInfo() {
setDeviceInfo(devices) setDeviceInfo(devices)
setHostInfo(res['hosts']) setHostInfo(res['hosts'])
setDiskPath(res['default_output_dir'])
} catch (e) { } catch (e) {
console.log('error fetching devices', e) console.log('error fetching devices', e)
} }
@ -435,8 +427,7 @@ saveSettingsBtn.addEventListener('click', function() {
'update_branch': updateBranch, 'update_branch': updateBranch,
'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked, 'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked,
'listen_to_network': listenToNetworkField.checked, 'listen_to_network': listenToNetworkField.checked,
'listen_port': listenPortField.value, 'listen_port': listenPortField.value
'test_sd2': testSD2Field.checked
}) })
saveSettingsBtn.classList.add('active') saveSettingsBtn.classList.add('active')
asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active')) asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active'))

View File

@ -1,119 +0,0 @@
import json
class Request:
request_id: str = None
session_id: str = "session"
prompt: str = ""
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = 42
prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False
precision: str = "autocast" # or "full"
save_to_disk_path: str = None
turbo: bool = True
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
use_hypernetwork_model: str = None
hypernetwork_strength: float = 1
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
stream_progress_updates: bool = False
stream_image_progress: bool = False
def json(self):
return {
"session_id": self.session_id,
"prompt": self.prompt,
"negative_prompt": self.negative_prompt,
"num_outputs": self.num_outputs,
"num_inference_steps": self.num_inference_steps,
"guidance_scale": self.guidance_scale,
"hypernetwork_strengtgh": self.guidance_scale,
"width": self.width,
"height": self.height,
"seed": self.seed,
"prompt_strength": self.prompt_strength,
"sampler": self.sampler,
"use_face_correction": self.use_face_correction,
"use_upscale": self.use_upscale,
"use_stable_diffusion_model": self.use_stable_diffusion_model,
"use_vae_model": self.use_vae_model,
"use_hypernetwork_model": self.use_hypernetwork_model,
"hypernetwork_strength": self.hypernetwork_strength,
"output_format": self.output_format,
"output_quality": self.output_quality,
}
def __str__(self):
return f'''
session_id: {self.session_id}
prompt: {self.prompt}
negative_prompt: {self.negative_prompt}
seed: {self.seed}
num_inference_steps: {self.num_inference_steps}
sampler: {self.sampler}
guidance_scale: {self.guidance_scale}
w: {self.width}
h: {self.height}
precision: {self.precision}
save_to_disk_path: {self.save_to_disk_path}
turbo: {self.turbo}
use_full_precision: {self.use_full_precision}
use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale}
use_stable_diffusion_model: {self.use_stable_diffusion_model}
use_vae_model: {self.use_vae_model}
use_hypernetwork_model: {self.use_hypernetwork_model}
hypernetwork_strength: {self.hypernetwork_strength}
show_only_filtered_image: {self.show_only_filtered_image}
output_format: {self.output_format}
output_quality: {self.output_quality}
stream_progress_updates: {self.stream_progress_updates}
stream_image_progress: {self.stream_image_progress}'''
class Image:
data: str # base64
seed: int
is_nsfw: bool
path_abs: str = None
def __init__(self, data, seed):
self.data = data
self.seed = seed
def json(self):
return {
"data": self.data,
"seed": self.seed,
"path_abs": self.path_abs,
}
class Response:
request: Request
images: list
def json(self):
res = {
"status": 'succeeded',
"request": self.request.json(),
"output": [],
}
for image in self.images:
res["output"].append(image.json())
return res

View File

@ -1,162 +0,0 @@
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
index 79058bc..a473411 100644
--- a/optimizedSD/ddpm.py
+++ b/optimizedSD/ddpm.py
@@ -564,12 +564,12 @@ class UNet(DDPM):
unconditional_guidance_scale=unconditional_guidance_scale,
callback=callback, img_callback=img_callback)
+ yield from samples
+
if(self.turbo):
self.model1.to("cpu")
self.model2.to("cpu")
- return samples
-
@torch.no_grad()
def plms_sampling(self, cond,b, img,
ddim_use_original_steps=False,
@@ -608,10 +608,10 @@ class UNet(DDPM):
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
- return img
+ yield from img_callback(img, len(iterator)-1)
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
@@ -740,13 +740,13 @@ class UNet(DDPM):
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
- if callback: callback(i)
- if img_callback: img_callback(x_dec, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x_dec, i)
if mask is not None:
- return x0 * mask + (1. - mask) * x_dec
+ x_dec = x0 * mask + (1. - mask) * x_dec
- return x_dec
+ yield from img_callback(x_dec, len(iterator)-1)
@torch.no_grad()
@@ -820,12 +820,12 @@ class UNet(DDPM):
d = to_d(x, sigma_hat, denoised)
- if callback: callback(i)
- if img_callback: img_callback(x, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x, i)
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
- return x
+ yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, img_callback=None):
@@ -852,14 +852,14 @@ class UNet(DDPM):
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
- if callback: callback(i)
- if img_callback: img_callback(x, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x, i)
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + torch.randn_like(x) * sigma_up
- return x
+ yield from img_callback(x, len(sigmas)-1)
@@ -892,8 +892,8 @@ class UNet(DDPM):
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d = to_d(x, sigma_hat, denoised)
- if callback: callback(i)
- if img_callback: img_callback(x, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x, i)
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
@@ -913,7 +913,7 @@ class UNet(DDPM):
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
- return x
+ yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
@@ -944,8 +944,8 @@ class UNet(DDPM):
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
- if callback: callback(i)
- if img_callback: img_callback(x, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x, i)
d = to_d(x, sigma_hat, denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
@@ -966,7 +966,7 @@ class UNet(DDPM):
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
- return x
+ yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
@@ -994,8 +994,8 @@ class UNet(DDPM):
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
- if callback: callback(i)
- if img_callback: img_callback(x, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x, i)
d = to_d(x, sigmas[i], denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
@@ -1016,7 +1016,7 @@ class UNet(DDPM):
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + torch.randn_like(x) * sigma_up
- return x
+ yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
@@ -1042,8 +1042,8 @@ class UNet(DDPM):
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
- if callback: callback(i)
- if img_callback: img_callback(x, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x, i)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
@@ -1054,4 +1054,4 @@ class UNet(DDPM):
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
- return x
+ yield from img_callback(x, len(sigmas)-1)

View File

@ -1,84 +0,0 @@
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
index 27ead0e..6215939 100644
--- a/ldm/models/diffusion/ddim.py
+++ b/ldm/models/diffusion/ddim.py
@@ -100,7 +100,7 @@ class DDIMSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
- samples, intermediates = self.ddim_sampling(conditioning, size,
+ samples = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
@@ -117,7 +117,8 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
- return samples, intermediates
+ # return samples, intermediates
+ yield from samples
@torch.no_grad()
def ddim_sampling(self, cond, shape,
@@ -168,14 +169,15 @@ class DDIMSampler(object):
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
- return img, intermediates
+ # return img, intermediates
+ yield from img_callback(pred_x0, len(iterator)-1)
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
index 7002a36..0951f39 100644
--- a/ldm/models/diffusion/plms.py
+++ b/ldm/models/diffusion/plms.py
@@ -96,7 +96,7 @@ class PLMSSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
- samples, intermediates = self.plms_sampling(conditioning, size,
+ samples = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
@@ -112,7 +112,8 @@ class PLMSSampler(object):
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
- return samples, intermediates
+ #return samples, intermediates
+ yield from samples
@torch.no_grad()
def plms_sampling(self, cond, shape,
@@ -165,14 +166,15 @@ class PLMSSampler(object):
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
- return img, intermediates
+ # return img, intermediates
+ yield from img_callback(pred_x0, len(iterator)-1)
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,

View File

@ -1,198 +0,0 @@
# this is basically a cut down version of https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/c9a2cfdf2a53d37c2de1908423e4f548088667ef/modules/hypernetworks/hypernetwork.py, mostly for feature parity
# I, c0bra5, don't really understand how deep learning works. I just know how to port stuff.
import inspect
import torch
import optimizedSD.splitAttention
from . import runtime
from einops import rearrange
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
loaded_hypernetwork = None
class HypernetworkModule(torch.nn.Module):
multiplier = 0.5
activation_dict = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
}
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
# Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# Add an activation func except last layer
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
# Add layer normalization
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout except last layer
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
self.fix_old_state_dict(state_dict)
self.load_state_dict(state_dict)
self.to(runtime.thread_data.device)
def fix_old_state_dict(self, state_dict):
changes = {
'linear1.bias': 'linear.0.bias',
'linear1.weight': 'linear.0.weight',
'linear2.bias': 'linear.1.bias',
'linear2.weight': 'linear.1.weight',
}
for fr, to in changes.items():
x = state_dict.get(fr, None)
if x is None:
continue
del state_dict[fr]
state_dict[to] = x
def forward(self, x: torch.Tensor):
return x + self.linear(x) * runtime.thread_data.hypernetwork_strength
def apply_hypernetwork(hypernetwork, context, layer=None):
hypernetwork_layers = hypernetwork.get(context.shape[2], None)
if hypernetwork_layers is None:
return context, context
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
context_k = hypernetwork_layers[0](context)
context_v = hypernetwork_layers[1](context)
return context_k, context_v
def get_kv(context, hypernetwork):
if hypernetwork is None:
return context, context
else:
return apply_hypernetwork(runtime.thread_data.hypernetwork, context)
# This might need updating as the optimisedSD code changes
# I think yall have a system for this (patch files in sd_internal) but idk how it works and no amount of searching gave me any clue
# just in case for attribution https://github.com/easydiffusion/diffusion-kit/blob/e8ea0cadd543056059cd951e76d4744de76327d2/optimizedSD/splitAttention.py#L171
def new_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
# default context
context = context if context is not None else x() if inspect.isfunction(x) else x
# hypernetwork!
context_k, context_v = get_kv(context, runtime.thread_data.hypernetwork)
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
limit = k.shape[0]
att_step = self.att_step
q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))
q_chunks.reverse()
k_chunks.reverse()
v_chunks.reverse()
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
del k, q, v
for i in range (0, limit, att_step):
q_buffer = q_chunks.pop()
k_buffer = k_chunks.pop()
v_buffer = v_chunks.pop()
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
del k_buffer, q_buffer
# attention, what we cannot get enough of, by chunks
sim_buffer = sim_buffer.softmax(dim=-1)
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
del v_buffer
sim[i:i+att_step,:,:] = sim_buffer
del sim_buffer
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)
def load_hypernetwork(path: str):
state_dict = torch.load(path, map_location='cpu')
layer_structure = state_dict.get('layer_structure', [1, 2, 1])
activation_func = state_dict.get('activation_func', None)
weight_init = state_dict.get('weight_initialization', 'Normal')
add_layer_norm = state_dict.get('is_layer_norm', False)
use_dropout = state_dict.get('use_dropout', False)
activate_output = state_dict.get('activate_output', True)
last_layer_dropout = state_dict.get('last_layer_dropout', False)
# this is a bit verbose so leaving it commented out for the poor soul who ever has to debug this
# print(f"layer_structure: {layer_structure}")
# print(f"activation_func: {activation_func}")
# print(f"weight_init: {weight_init}")
# print(f"add_layer_norm: {add_layer_norm}")
# print(f"use_dropout: {use_dropout}")
# print(f"activate_output: {activate_output}")
# print(f"last_layer_dropout: {last_layer_dropout}")
layers = {}
for size, sd in state_dict.items():
if type(size) == int:
layers[size] = (
HypernetworkModule(size, sd[0], layer_structure, activation_func, weight_init, add_layer_norm,
use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
HypernetworkModule(size, sd[1], layer_structure, activation_func, weight_init, add_layer_norm,
use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
)
print(f"hypernetwork loaded")
return layers
# overriding of original function
old_cross_attention_forward = optimizedSD.splitAttention.CrossAttention.forward
# hijacks the cross attention forward function to add hyper network support
def hijack_cross_attention():
print("hypernetwork functionality added to cross attention")
optimizedSD.splitAttention.CrossAttention.forward = new_cross_attention_forward
# there was a cop on board
def unhijack_cross_attention_forward():
print("hypernetwork functionality removed from cross attention")
optimizedSD.splitAttention.CrossAttention.forward = old_cross_attention_forward
hijack_cross_attention()

File diff suppressed because it is too large Load Diff

View File

@ -1,500 +0,0 @@
"""server.py: FastAPI SD-UI Web Host.
Notes:
async endpoints always run on the main thread. Without they run on the thread pool.
"""
import json
import traceback
import sys
import os
import socket
import picklescan.scanner
import rich
SD_DIR = os.getcwd()
print('started in ', SD_DIR)
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
sys.path.append(os.path.dirname(SD_UI_DIR))
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'))
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = {
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
'update_branch': 'main',
'ui': {
'open_browser_on_start': True,
},
}
APP_CONFIG_DEFAULT_MODELS = [
# needed to support the legacy installations
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
import logging
from typing import Any, Generator, Hashable, List, Optional, Union
from sd_internal import Request, Response, task_manager
app = FastAPI()
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
# don't show access log entries for URLs that start with the given prefix
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
class NoCacheStaticFiles(StaticFiles):
def is_not_modified(self, response_headers, request_headers) -> bool:
if 'content-type' in response_headers and ('javascript' in response_headers['content-type'] or 'css' in response_headers['content-type']):
response_headers.update(NOCACHE_HEADERS)
return False
return super().is_not_modified(response_headers, request_headers)
app.mount('/media', NoCacheStaticFiles(directory=os.path.join(SD_UI_DIR, 'media')), name="media")
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
app.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}")
def getConfig(default_val=APP_CONFIG_DEFAULTS):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return default_val
with open(config_json_path, 'r', encoding='utf-8') as f:
config = json.load(f)
if 'net' not in config:
config['net'] = {}
if os.getenv('SD_UI_BIND_PORT') is not None:
config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT'))
if os.getenv('SD_UI_BIND_IP') is not None:
config['net']['listen_to_network'] = ( os.getenv('SD_UI_BIND_IP') == '0.0.0.0' )
return config
except Exception as e:
print(str(e))
print(traceback.format_exc())
return default_val
def setConfig(config):
print( json.dumps(config) )
try: # config.json
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w', encoding='utf-8') as f:
json.dump(config, f)
except:
print(traceback.format_exc())
try: # config.bat
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
config_bat = []
if 'update_branch' in config:
config_bat.append(f"@set update_branch={config['update_branch']}")
config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}")
if len(config_bat) > 0:
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat))
except:
print(traceback.format_exc())
try: # config.sh
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
config_sh = ['#!/bin/bash']
if 'update_branch' in config:
config_sh.append(f"export update_branch={config['update_branch']}")
config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"")
if len(config_sh) > 1:
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh))
except:
print(traceback.format_exc())
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
config = getConfig()
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
if not model_name: # When None try user configured model.
# config = getConfig()
if 'model' in config and model_type in config['model']:
model_name = config['model'][model_type]
if model_name:
is_sd2 = config.get('test_sd2', False)
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
model_name = 'sd-v1-4'
# Check models directory
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
for model_extension in model_extensions:
if os.path.exists(models_dir_path + model_extension):
return models_dir_path
if os.path.exists(model_name + model_extension):
# Direct Path to file
model_name = os.path.abspath(model_name)
return model_name
# Default locations
if model_name in default_models:
default_model_path = os.path.join(SD_DIR, model_name)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
return default_model_path
# Can't find requested model, check the default paths.
for default_model in default_models:
for model_dir in model_dirs:
default_model_path = os.path.join(model_dir, default_model)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
if model_name is not None:
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path
raise Exception('No valid models found.')
def resolve_ckpt_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=APP_CONFIG_DEFAULT_MODELS)
def resolve_vae_to_use(model_name:str=None):
try:
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[])
except:
return None
def resolve_hypernetwork_to_use(model_name:str=None):
try:
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
except:
return None
class SetAppConfigRequest(BaseModel):
update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
model_vae: str = None
ui_open_browser_on_start: bool = None
listen_to_network: bool = None
listen_port: int = None
test_sd2: bool = None
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
config = getConfig()
if req.update_branch is not None:
config['update_branch'] = req.update_branch
if req.render_devices is not None:
update_render_devices_in_config(config, req.render_devices)
if req.ui_open_browser_on_start is not None:
if 'ui' not in config:
config['ui'] = {}
config['ui']['open_browser_on_start'] = req.ui_open_browser_on_start
if req.listen_to_network is not None:
if 'net' not in config:
config['net'] = {}
config['net']['listen_to_network'] = bool(req.listen_to_network)
if req.listen_port is not None:
if 'net' not in config:
config['net'] = {}
config['net']['listen_port'] = int(req.listen_port)
if req.test_sd2 is not None:
config['test_sd2'] = req.test_sd2
try:
setConfig(config)
if req.render_devices:
update_render_threads()
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def is_malicious_model(file_path):
try:
scan_result = picklescan.scanner.scan_file_path(file_path)
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return True
else:
rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return False
except Exception as e:
print('error while scanning', file_path, 'error:', e)
return False
known_models = {}
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
'vae': '',
'hypernetwork': '',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
'vae': [],
'hypernetwork': [],
},
}
def listModels(models_dirname, model_type, model_extensions):
models_dir = os.path.join(MODELS_DIR, models_dirname)
if not os.path.exists(models_dir):
os.makedirs(models_dir)
for file in os.listdir(models_dir):
for model_extension in model_extensions:
if not file.endswith(model_extension):
continue
model_path = os.path.join(models_dir, file)
mtime = os.path.getmtime(model_path)
mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime:
if is_malicious_model(model_path):
models['scan-error'] = file
return
known_models[model_path] = mtime
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates
models['options'][model_type].sort()
# custom models
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['options']['stable-diffusion'].append('custom-model')
return models
def getUIPlugins():
plugins = []
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
for file in os.listdir(plugins_dir):
if file.endswith('.plugin.js'):
plugins.append(f'/plugins/{dir_prefix}/{file}')
return plugins
def getIPConfig():
try:
ips = socket.gethostbyname_ex(socket.gethostname())
ips[2].append(ips[0])
return ips[2]
except Exception as e:
print(e)
print(traceback.format_exc())
return []
@app.get('/get/{key:path}')
def read_web_data(key:str=None):
if not key: # /get without parameters, stable-diffusion easter egg.
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == 'app_config':
config = getConfig(default_val=None)
if config is None:
config = APP_CONFIG_DEFAULTS
return JSONResponse(config, headers=NOCACHE_HEADERS)
elif key == 'system_info':
config = getConfig()
system_info = {
'devices': task_manager.get_devices(),
'hosts': getIPConfig(),
}
system_info['devices']['config'] = config.get('render_devices', "auto")
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
elif key == 'models':
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
elif key == 'ui_plugins': return JSONResponse(getUIPlugins(), headers=NOCACHE_HEADERS)
else:
raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
@app.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None):
if task_manager.is_alive() <= 0: # Check that render threads are alive.
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
raise HTTPException(status_code=500, detail='Render thread is dead.')
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive
response = {'status': str(task_manager.current_state)}
if session_id:
session = task_manager.get_cached_session(session_id, update_ttl=True)
response['tasks'] = {id(t): t.status for t in session.tasks}
response['devices'] = task_manager.get_devices()
return JSONResponse(response, headers=NOCACHE_HEADERS)
def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = ckpt_model_name
config['model']['vae'] = vae_model_name
config['model']['hypernetwork'] = hypernetwork_model_name
if vae_model_name is None or vae_model_name == "":
del config['model']['vae']
if hypernetwork_model_name is None or hypernetwork_model_name == "":
del config['model']['hypernetwork']
setConfig(config)
def update_render_devices_in_config(config, render_devices):
if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'):
raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}')
if render_devices.startswith('cuda:'):
render_devices = render_devices.split(',')
config['render_devices'] = render_devices
@app.post('/render')
def render(req : task_manager.ImageRequest):
try:
save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model)
req.use_vae_model = resolve_vae_to_use(req.use_vae_model)
req.use_hypernetwork_model = resolve_hypernetwork_to_use(req.use_hypernetwork_model)
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),
'queue': len(task_manager.tasks_queue),
'stream': f'/image/stream/{id(new_task)}',
'task': id(new_task)
}
return JSONResponse(response, headers=NOCACHE_HEADERS)
except ChildProcessError as e: # Render thread is dead
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
except Exception as e:
print(e)
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.get('/image/stream/{task_id:int}')
def stream(task_id:int):
#TODO Move to WebSockets ??
task = task_manager.get_cached_task(task_id, update_ttl=True)
if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#print(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
@app.get('/image/stop')
def stop(task: int):
if not task:
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
task_manager.current_state_error = StopAsyncIteration('')
return {'OK'}
task_id = task
task = task_manager.get_cached_task(task_id, update_ttl=False)
if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration(f'Task {task_id} stop requested.')
return {'OK'}
@app.get('/image/tmp/{task_id:int}/{img_id:int}')
def get_image(task_id: int, img_id: int):
task = task_manager.get_cached_task(task_id, update_ttl=True)
if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
try:
img_data = task.temp_images[img_id]
img_data.seek(0)
return StreamingResponse(img_data, media_type='image/jpeg')
except KeyError as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get('/')
def read_root():
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
# don't log certain requests
class LogSuppressFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
path = record.getMessage()
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
if path.find(prefix) != -1:
return False
return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
# Check models and prepare cache for UI open
getModels()
# Start the task_manager
task_manager.default_model_to_load = resolve_ckpt_to_use()
task_manager.default_vae_to_load = resolve_vae_to_use()
task_manager.default_hypernetwork_to_load = resolve_hypernetwork_to_use()
def update_render_threads():
config = getConfig()
render_devices = config.get('render_devices', 'auto')
active_devices = task_manager.get_devices()['active'].keys()
print('requesting for render_devices', render_devices)
task_manager.update_render_threads(render_devices, active_devices)
update_render_threads()
# start the browser ui
def open_browser():
config = getConfig()
ui = config.get('ui', {})
net = config.get('net', {'listen_port':9000})
port = net.get('listen_port', 9000)
if ui.get('open_browser_on_start', True):
import webbrowser; webbrowser.open(f"http://localhost:{port}")
open_browser()