Merge pull request #695 from cmdr2/refactor

v2.5 - move to sdkit
This commit is contained in:
cmdr2 2022-12-24 23:28:10 +05:30 committed by GitHub
commit 206f9b97bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 1306 additions and 2734 deletions

View File

@ -1,5 +1,22 @@
# What's new?
## v2.5
### Major Changes
- **Nearly twice as fast** - significantly faster speed of image generation. We're now pretty close to automatic1111's speed. Code contributions are welcome to make our project even faster: https://github.com/easydiffusion/sdkit/#is-it-fast
- **Full support for Stable Diffusion 2.1** - supports loading v1.4 or v2.0 or v2.1 models seamlessly. No need to enable "Test SD2", and no need to add `sd2_` to your SD 2.0 model file names.
- **Memory optimized Stable Diffusion 2.1** - you can now use 768x768 models for SD 2.1, with the same low VRAM optimizations that we've always had for SD 1.4.
- **6 new samplers!** - explore the new samplers, some of which can generate great images in less than 10 inference steps!
- **Model Merging** - You can now merge two models (`.ckpt` or `.safetensors`) and output `.ckpt` or `.safetensors` models, optionally in `fp16` precision. Details: https://github.com/cmdr2/stable-diffusion-ui/wiki/Model-Merging
- **Fast loading/unloading of VAEs** - No longer needs to reload the entire Stable Diffusion model, each time you change the VAE
- **Database of known models** - automatically picks the right configuration for known models. E.g. we automatically detect and apply "v" parameterization (required for some SD 2.0 models), and "fp32" attention precision (required for some SD 2.1 models).
- **Color correction for img2img** - an option to preserve the color profile (histogram) of the initial image. This is especially useful if you're getting red-tinted images after inpainting/masking.
- **Three GPU Memory Usage Settings** - `High` (fastest, maximum VRAM usage), `Balanced` (default - almost as fast, significantly lower VRAM usage), `Low` (slowest, very low VRAM usage). The `Low` setting is applied automatically for GPUs with less than 4 GB of VRAM.
- **Save metadata as JSON** - You can now save the metadata files as either text or json files (choose in the Settings tab).
- **Major rewrite of the code** - Most of the codebase has been reorganized and rewritten, to make it more manageable and easier for new developers to contribute features. We've separated our core engine into a new project called `sdkit`, which allows anyone to easily integrate Stable Diffusion (and related modules like GFPGAN etc) into their programming projects (via a simple `pip install sdkit`): https://github.com/easydiffusion/sdkit/
- **Name change** - Last, and probably the least, the UI is now called "Easy Diffusion". It indicates the focus of this project - an easy way for people to play with Stable Diffusion.
Our focus continues to remain on an easy installation experience, and an easy user-interface. While still remaining pretty powerful, in terms of features and speed.
## v2.4
### Major Changes
- **Allow reordering the task queue** (by dragging and dropping tasks). Thanks @madrang

View File

@ -23,23 +23,20 @@ call conda --version
echo.
@rem activate the environment
call conda activate .\stable-diffusion\env
@rem activate the legacy environment (if present) and set PYTHONPATH
if exist "installer_files\env" (
set PYTHONPATH=%cd%\installer_files\env\lib\site-packages
)
if exist "stable-diffusion\env" (
call conda activate .\stable-diffusion\env
set PYTHONPATH=%cd%\stable-diffusion\env\lib\site-packages
)
call where python
call python --version
@rem set the PYTHONPATH
cd stable-diffusion
set SD_DIR=%cd%
cd env\lib\site-packages
set PYTHONPATH=%SD_DIR%;%cd%
cd ..\..\..
echo PYTHONPATH=%PYTHONPATH%
cd ..
@rem done
echo.

View File

@ -24,7 +24,7 @@ if exist "%INSTALL_ENV_DIR%" set PATH=%INSTALL_ENV_DIR%;%INSTALL_ENV_DIR%\Librar
set PACKAGES_TO_INSTALL=
if not exist "%LEGACY_INSTALL_ENV_DIR%\etc\profile.d\conda.sh" (
if not exist "%INSTALL_ENV_DIR%\etc\profile.d\conda.sh" set PACKAGES_TO_INSTALL=%PACKAGES_TO_INSTALL% conda
if not exist "%INSTALL_ENV_DIR%\etc\profile.d\conda.sh" set PACKAGES_TO_INSTALL=%PACKAGES_TO_INSTALL% conda python=3.8.5
)
call git --version >.tmp1 2>.tmp2

View File

@ -39,7 +39,7 @@ if [ -e "$INSTALL_ENV_DIR" ]; then export PATH="$INSTALL_ENV_DIR/bin:$PATH"; fi
PACKAGES_TO_INSTALL=""
if [ ! -e "$LEGACY_INSTALL_ENV_DIR/etc/profile.d/conda.sh" ] && [ ! -e "$INSTALL_ENV_DIR/etc/profile.d/conda.sh" ]; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL conda"; fi
if [ ! -e "$LEGACY_INSTALL_ENV_DIR/etc/profile.d/conda.sh" ] && [ ! -e "$INSTALL_ENV_DIR/etc/profile.d/conda.sh" ]; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL conda python=3.8.5"; fi
if ! hash "git" &>/dev/null; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL git"; fi
if "$MAMBA_ROOT_PREFIX/micromamba" --version &>/dev/null; then umamba_exists="T"; fi

13
scripts/check_modules.py Normal file
View File

@ -0,0 +1,13 @@
'''
This script checks if the given modules exist
'''
import sys
import pkgutil
modules = sys.argv[1:]
missing_modules = []
for m in modules:
if pkgutil.find_loader(m) is None:
print('module', m, 'not found')
exit(1)

View File

@ -26,21 +26,23 @@ if [ "$0" == "bash" ]; then
echo ""
# activate the environment
# activate the legacy environment (if present) and set PYTHONPATH
if [ -e "installer_files/env" ]; then
export PYTHONPATH="$(pwd)/installer_files/env/lib/python3.8/site-packages"
fi
if [ -e "stable-diffusion/env" ]; then
CONDA_BASEPATH=$(conda info --base)
source "$CONDA_BASEPATH/etc/profile.d/conda.sh" # otherwise conda complains about 'shell not initialized' (needed when running in a script)
conda activate ./stable-diffusion/env
export PYTHONPATH="$(pwd)/stable-diffusion/env/lib/python3.8/site-packages"
fi
which python
python --version
# set the PYTHONPATH
cd stable-diffusion
SD_PATH=`pwd`
export PYTHONPATH="$SD_PATH:$SD_PATH/env/lib/python3.8/site-packages"
echo "PYTHONPATH=$PYTHONPATH"
cd ..
# done

View File

@ -53,6 +53,7 @@ if "%update_branch%"=="" (
@xcopy sd-ui-files\ui ui /s /i /Y /q
@copy sd-ui-files\scripts\on_sd_start.bat scripts\ /Y
@copy sd-ui-files\scripts\bootstrap.bat scripts\ /Y
@copy sd-ui-files\scripts\check_modules.py scripts\ /Y
@copy "sd-ui-files\scripts\Start Stable Diffusion UI.cmd" . /Y
@copy "sd-ui-files\scripts\Developer Console.cmd" . /Y

View File

@ -37,6 +37,7 @@ rm -rf ui
cp -Rf sd-ui-files/ui .
cp sd-ui-files/scripts/on_sd_start.sh scripts/
cp sd-ui-files/scripts/bootstrap.sh scripts/
cp sd-ui-files/scripts/check_modules.py scripts/
cp sd-ui-files/scripts/start.sh .
cp sd-ui-files/scripts/developer_console.sh .

View File

@ -5,11 +5,20 @@
@copy sd-ui-files\scripts\on_env_start.bat scripts\ /Y
@copy sd-ui-files\scripts\bootstrap.bat scripts\ /Y
@copy sd-ui-files\scripts\check_modules.py scripts\ /Y
if exist "%cd%\profile" (
set USERPROFILE=%cd%\profile
)
@rem set the correct installer path (current vs legacy)
if exist "%cd%\installer_files\env" (
set INSTALL_ENV_DIR=%cd%\installer_files\env
)
if exist "%cd%\stable-diffusion\env" (
set INSTALL_ENV_DIR=%cd%\stable-diffusion\env
)
@mkdir tmp
@set TMP=%cd%\tmp
@set TEMP=%cd%\tmp
@ -27,150 +36,92 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');"
if NOT DEFINED test_sd2 set test_sd2=N
@rem create the stable-diffusion folder, to work with legacy installations
if not exist "stable-diffusion" mkdir stable-diffusion
cd stable-diffusion
@>nul findstr /m "sd_git_cloned" scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
@echo "Stable Diffusion's git repository was already installed. Updating.."
@cd stable-diffusion
@call git remote set-url origin https://github.com/easydiffusion/diffusion-kit.git
@call git reset --hard
@call git pull
if "%test_sd2%" == "N" (
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
)
if "%test_sd2%" == "Y" (
@call git -c advice.detachedHead=false checkout 733a1f6f9cae9b9a9b83294bf3281b123378cb1f
)
@cd ..
) else (
@echo. & echo "Downloading Stable Diffusion.." & echo.
@call git clone https://github.com/easydiffusion/diffusion-kit.git stable-diffusion && (
@echo sd_git_cloned >> scripts\install_status.txt
) || (
@echo "Error downloading Stable Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause
@exit /b
)
@cd stable-diffusion
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
@cd ..
@rem activate the old stable-diffusion env, if it exists
if exist "env" (
call conda activate .\env
)
@cd stable-diffusion
@rem disable the legacy src and ldm folder (otherwise this prevents installing gfpgan and realesrgan)
if exist src rename src src-old
if exist ldm rename ldm ldm-old
@>nul findstr /m "conda_sd_env_created" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
@echo "Packages necessary for Stable Diffusion were already installed"
@call conda activate .\env
@rem install torch and torchvision
call python ..\scripts\check_modules.py torch torchvision
if "%ERRORLEVEL%" EQU "0" (
echo "torch and torchvision have already been installed."
) else (
@echo. & echo "Downloading packages necessary for Stable Diffusion.." & echo. & echo "***** This will take some time (depending on the speed of the Internet connection) and may appear to be stuck, but please be patient ***** .." & echo.
echo "Installing torch and torchvision.."
@rmdir /s /q .\env
@REM prevent from using packages from the user's home directory, to avoid conflicts
set PYTHONNOUSERSITE=1
set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
@REM prevent conda from using packages from the user's home directory, to avoid conflicts
@set PYTHONNOUSERSITE=1
set USERPROFILE=%cd%\profile
set PYTHONPATH=%cd%;%cd%\env\lib\site-packages
@call conda env create --prefix env -f environment.yaml || (
@echo. & echo "Error installing the packages necessary for Stable Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
call pip install --upgrade torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116 || (
echo "Error installing torch. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause
exit /b
)
)
@call conda activate .\env
@rem install/upgrade sdkit
call python ..\scripts\check_modules.py sdkit sdkit.models ldm transformers numpy antlr4 gfpgan realesrgan
if "%ERRORLEVEL%" EQU "0" (
echo "sdkit is already installed."
for /f "tokens=*" %%a in ('python -c "import torch; import ldm; import transformers; import numpy; import antlr4; print(42)"') do if "%%a" NEQ "42" (
@echo. & echo "Dependency test failed! Error installing the packages necessary for Stable Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
@REM prevent from using packages from the user's home directory, to avoid conflicts
set PYTHONNOUSERSITE=1
set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
call >nul pip install --upgrade sdkit || (
echo "Error updating sdkit"
)
) else (
echo "Installing sdkit: https://pypi.org/project/sdkit/"
@REM prevent from using packages from the user's home directory, to avoid conflicts
set PYTHONNOUSERSITE=1
set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
call pip install sdkit || (
echo "Error installing sdkit. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause
exit /b
)
@echo conda_sd_env_created >> ..\scripts\install_status.txt
)
@rem allow rolling back the sdkit-based changes
if exist "src-old" (
if not exist "src" (
rename "src-old" "src"
@rem install rich
call python ..\scripts\check_modules.py rich
if "%ERRORLEVEL%" EQU "0" (
echo "rich has already been installed."
) else (
echo "Installing rich.."
if exist "ldm-old" (
rd /s /q "ldm-old"
)
set PYTHONNOUSERSITE=1
set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
call pip uninstall -y sdkit stable-diffusion-sdkit
call pip install rich || (
echo "Error installing rich. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause
exit /b
)
)
set PATH=C:\Windows\System32;%PATH%
@>nul findstr /m "conda_sd_gfpgan_deps_installed" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
@echo "Packages necessary for GFPGAN (Face Correction) were already installed"
) else (
@echo. & echo "Downloading packages necessary for GFPGAN (Face Correction).." & echo.
@set PYTHONNOUSERSITE=1
set USERPROFILE=%cd%\profile
set PYTHONPATH=%cd%;%cd%\env\lib\site-packages
for /f "tokens=*" %%a in ('python -c "from gfpgan import GFPGANer; print(42)"') do if "%%a" NEQ "42" (
@echo. & echo "Dependency test failed! Error installing the packages necessary for GFPGAN (Face Correction). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
@echo conda_sd_gfpgan_deps_installed >> ..\scripts\install_status.txt
)
@>nul findstr /m "conda_sd_esrgan_deps_installed" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
@echo "Packages necessary for ESRGAN (Resolution Upscaling) were already installed"
) else (
@echo. & echo "Downloading packages necessary for ESRGAN (Resolution Upscaling).." & echo.
@set PYTHONNOUSERSITE=1
set USERPROFILE=%cd%\profile
set PYTHONPATH=%cd%;%cd%\env\lib\site-packages
for /f "tokens=*" %%a in ('python -c "from basicsr.archs.rrdbnet_arch import RRDBNet; from realesrgan import RealESRGANer; print(42)"') do if "%%a" NEQ "42" (
@echo. & echo "Dependency test failed! Error installing the packages necessary for ESRGAN (Resolution Upscaling). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo.
pause
exit /b
)
@echo conda_sd_esrgan_deps_installed >> ..\scripts\install_status.txt
)
@>nul findstr /m "conda_sd_ui_deps_installed" ..\scripts\install_status.txt
call python ..\scripts\check_modules.py uvicorn fastapi
@if "%ERRORLEVEL%" EQU "0" (
echo "Packages necessary for Stable Diffusion UI were already installed"
) else (
@echo. & echo "Downloading packages necessary for Stable Diffusion UI.." & echo.
@set PYTHONNOUSERSITE=1
set PYTHONNOUSERSITE=1
set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
set USERPROFILE=%cd%\profile
set PYTHONPATH=%cd%;%cd%\env\lib\site-packages
@call conda install -c conda-forge -y --prefix env uvicorn fastapi || (
@call conda install -c conda-forge -y uvicorn fastapi || (
echo "Error installing the packages necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause
exit /b
@ -185,26 +136,6 @@ call WHERE uvicorn > .tmp
exit /b
)
@>nul 2>nul call python -m picklescan --help
@if "%ERRORLEVEL%" NEQ "0" (
@echo. & echo Picklescan not found. Installing
@call pip install picklescan || (
echo "Error installing the picklescan package necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause
exit /b
)
)
@>nul 2>nul call python -c "import safetensors"
@if "%ERRORLEVEL%" NEQ "0" (
@echo. & echo SafeTensors not found. Installing
@call pip install safetensors || (
echo "Error installing the safetensors package necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!"
pause
exit /b
)
)
@>nul findstr /m "conda_sd_ui_deps_installed" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" NEQ "0" (
@echo conda_sd_ui_deps_installed >> ..\scripts\install_status.txt
@ -212,12 +143,7 @@ call WHERE uvicorn > .tmp
if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion"
if not exist "..\models\vae" mkdir "..\models\vae"
if not exist "..\models\hypernetwork" mkdir "..\models\hypernetwork"
echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
echo. > "..\models\vae\Put your VAE files here.txt"
echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt"
@if exist "sd-v1-4.ckpt" (
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (
@ -375,10 +301,6 @@ echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt"
)
)
if "%test_sd2%" == "Y" (
@call pip install open_clip_torch==2.0.2
)
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" NEQ "0" (
@echo sd_weights_downloaded >> ..\scripts\install_status.txt
@ -389,10 +311,8 @@ if "%test_sd2%" == "Y" (
@set SD_DIR=%cd%
@cd env\lib\site-packages
@set PYTHONPATH=%SD_DIR%;%cd%
@cd ..\..\..
@echo PYTHONPATH=%PYTHONPATH%
set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages
echo PYTHONPATH=%PYTHONPATH%
call where python
call python --version
@ -401,17 +321,9 @@ call python --version
@set SD_UI_PATH=%cd%\ui
@cd stable-diffusion
@rem
@rem Rewrite easy-install.pth. This fixes the installation if the user has relocated the SDUI installation
@rem
>env\Lib\site-packages\easy-install.pth echo %cd%\src\taming-transformers
>>env\Lib\site-packages\easy-install.pth echo %cd%\src\clip
>>env\Lib\site-packages\easy-install.pth echo %cd%\src\gfpgan
>>env\Lib\site-packages\easy-install.pth echo %cd%\src\realesrgan
@if NOT DEFINED SD_UI_BIND_PORT set SD_UI_BIND_PORT=9000
@if NOT DEFINED SD_UI_BIND_IP set SD_UI_BIND_IP=0.0.0.0
@uvicorn server:app --app-dir "%SD_UI_PATH%" --port %SD_UI_BIND_PORT% --host %SD_UI_BIND_IP%
@uvicorn main:server_api --app-dir "%SD_UI_PATH%" --port %SD_UI_BIND_PORT% --host %SD_UI_BIND_IP% --log-level error
@pause

View File

@ -4,6 +4,7 @@ source ./scripts/functions.sh
cp sd-ui-files/scripts/on_env_start.sh scripts/
cp sd-ui-files/scripts/bootstrap.sh scripts/
cp sd-ui-files/scripts/check_modules.py scripts/
# activate the installer env
CONDA_BASEPATH=$(conda info --base)
@ -21,125 +22,89 @@ python -c "import os; import shutil; frm = 'sd-ui-files/ui/hotfix/9c24e6cd9f499d
# Caution, this file will make your eyes and brain bleed. It's such an unholy mess.
# Note to self: Please rewrite this in Python. For the sake of your own sanity.
if [ "$test_sd2" == "" ]; then
export test_sd2="N"
fi
if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then
echo "Stable Diffusion's git repository was already installed. Updating.."
cd stable-diffusion
git remote set-url origin https://github.com/easydiffusion/diffusion-kit.git
git reset --hard
git pull
if [ "$test_sd2" == "N" ]; then
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
elif [ "$test_sd2" == "Y" ]; then
git -c advice.detachedHead=false checkout 733a1f6f9cae9b9a9b83294bf3281b123378cb1f
fi
cd ..
else
printf "\n\nDownloading Stable Diffusion..\n\n"
if git clone https://github.com/easydiffusion/diffusion-kit.git stable-diffusion ; then
echo sd_git_cloned >> scripts/install_status.txt
else
fail "git clone of basujindal/stable-diffusion.git failed"
fi
cd stable-diffusion
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
cd ..
# set the correct installer path (current vs legacy)
if [ -e "installer_files/env" ]; then
export INSTALL_ENV_DIR="$(pwd)/installer_files/env"
fi
if [ -e "stable-diffusion/env" ]; then
export INSTALL_ENV_DIR="$(pwd)/stable-diffusion/env"
fi
# create the stable-diffusion folder, to work with legacy installations
if [ ! -e "stable-diffusion" ]; then mkdir stable-diffusion; fi
cd stable-diffusion
if [ `grep -c conda_sd_env_created ../scripts/install_status.txt` -gt "0" ]; then
echo "Packages necessary for Stable Diffusion were already installed"
# activate the old stable-diffusion env, if it exists
if [ -e "env" ]; then
conda activate ./env || fail "conda activate failed"
fi
# disable the legacy src and ldm folder (otherwise this prevents installing gfpgan and realesrgan)
if [ -e "src" ]; then mv src src-old; fi
if [ -e "ldm" ]; then mv ldm ldm-old; fi
# install torch and torchvision
if python ../scripts/check_modules.py torch torchvision; then
echo "torch and torchvision have already been installed."
else
printf "\n\nDownloading packages necessary for Stable Diffusion..\n"
printf "\n\n***** This will take some time (depending on the speed of the Internet connection) and may appear to be stuck, but please be patient ***** ..\n\n"
echo "Installing torch and torchvision.."
# prevent conda from using packages from the user's home directory, to avoid conflicts
export PYTHONNOUSERSITE=1
export PYTHONPATH="$(pwd):$(pwd)/env/lib/site-packages"
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
if conda env create --prefix env --force -f environment.yaml ; then
echo "Installed. Testing.."
if pip install --upgrade torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116 ; then
echo "Installed."
else
fail "'conda env create' failed"
fail "torch install failed"
fi
conda activate ./env || fail "conda activate failed"
out_test=`python -c "import torch; import ldm; import transformers; import numpy; import antlr4; print(42)"`
if [ "$out_test" != "42" ]; then
fail "Dependency test failed"
fi
echo conda_sd_env_created >> ../scripts/install_status.txt
fi
# allow rolling back the sdkit-based changes
if [ -e "src-old" ] && [ ! -e "src" ]; then
mv src-old src
if [ -e "ldm-old" ]; then rm -r ldm-old; fi
pip uninstall -y sdkit stable-diffusion-sdkit
fi
if [ `grep -c conda_sd_gfpgan_deps_installed ../scripts/install_status.txt` -gt "0" ]; then
echo "Packages necessary for GFPGAN (Face Correction) were already installed"
else
printf "\n\nDownloading packages necessary for GFPGAN (Face Correction)..\n"
# install/upgrade sdkit
if python ../scripts/check_modules.py sdkit sdkit.models ldm transformers numpy antlr4 gfpgan realesrgan ; then
echo "sdkit is already installed."
export PYTHONNOUSERSITE=1
export PYTHONPATH="$(pwd):$(pwd)/env/lib/site-packages"
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
out_test=`python -c "from gfpgan import GFPGANer; print(42)"`
if [ "$out_test" != "42" ]; then
echo "EE The dependency check has failed. This usually means that some system libraries are missing."
echo "EE On Debian/Ubuntu systems, this are often these packages: libsm6 libxext6 libxrender-dev"
echo "EE Other Linux distributions might have different package names for these libraries."
fail "GFPGAN dependency test failed"
fi
echo conda_sd_gfpgan_deps_installed >> ../scripts/install_status.txt
fi
if [ `grep -c conda_sd_esrgan_deps_installed ../scripts/install_status.txt` -gt "0" ]; then
echo "Packages necessary for ESRGAN (Resolution Upscaling) were already installed"
pip install --upgrade sdkit > /dev/null
else
printf "\n\nDownloading packages necessary for ESRGAN (Resolution Upscaling)..\n"
echo "Installing sdkit: https://pypi.org/project/sdkit/"
export PYTHONNOUSERSITE=1
export PYTHONPATH="$(pwd):$(pwd)/env/lib/site-packages"
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
out_test=`python -c "from basicsr.archs.rrdbnet_arch import RRDBNet; from realesrgan import RealESRGANer; print(42)"`
if [ "$out_test" != "42" ]; then
fail "ESRGAN dependency test failed"
if pip install sdkit ; then
echo "Installed."
else
fail "sdkit install failed"
fi
echo conda_sd_esrgan_deps_installed >> ../scripts/install_status.txt
fi
if [ `grep -c conda_sd_ui_deps_installed ../scripts/install_status.txt` -gt "0" ]; then
# install rich
if python ../scripts/check_modules.py rich; then
echo "rich has already been installed."
else
echo "Installing rich.."
export PYTHONNOUSERSITE=1
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
if pip install rich ; then
echo "Installed."
else
fail "Install failed for rich"
fi
fi
if python ../scripts/check_modules.py uvicorn fastapi ; then
echo "Packages necessary for Stable Diffusion UI were already installed"
else
printf "\n\nDownloading packages necessary for Stable Diffusion UI..\n\n"
export PYTHONNOUSERSITE=1
export PYTHONPATH="$(pwd):$(pwd)/env/lib/site-packages"
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
if conda install -c conda-forge --prefix ./env -y uvicorn fastapi ; then
if conda install -c conda-forge -y uvicorn fastapi ; then
echo "Installed. Testing.."
else
fail "'conda install uvicorn' failed"
@ -148,32 +113,9 @@ else
if ! command -v uvicorn &> /dev/null; then
fail "UI packages not found!"
fi
echo conda_sd_ui_deps_installed >> ../scripts/install_status.txt
fi
if python -m picklescan --help >/dev/null 2>&1; then
echo "Picklescan is already installed."
else
echo "Picklescan not found, installing."
pip install picklescan || fail "Picklescan installation failed."
fi
if python -c "import safetensors" --help >/dev/null 2>&1; then
echo "SafeTensors is already installed."
else
echo "SafeTensors not found, installing."
pip install safetensors || fail "SafeTensors installation failed."
fi
mkdir -p "../models/stable-diffusion"
mkdir -p "../models/vae"
mkdir -p "../models/hypernetwork"
echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt"
echo "" > "../models/vae/Put your VAE files here.txt"
echo "" > "../models/hypernetwork/Put your hypernetwork files here.txt"
if [ -f "sd-v1-4.ckpt" ]; then
model_size=`find "sd-v1-4.ckpt" -printf "%s"`
@ -314,10 +256,6 @@ if [ ! -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then
fi
fi
if [ "$test_sd2" == "Y" ]; then
pip install open_clip_torch==2.0.2
fi
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
echo sd_weights_downloaded >> ../scripts/install_status.txt
echo sd_install_complete >> ../scripts/install_status.txt
@ -326,7 +264,8 @@ fi
printf "\n\nStable Diffusion is ready!\n\n"
SD_PATH=`pwd`
export PYTHONPATH="$SD_PATH:$SD_PATH/env/lib/python3.8/site-packages"
export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages"
echo "PYTHONPATH=$PYTHONPATH"
which python
@ -336,6 +275,6 @@ cd ..
export SD_UI_PATH=`pwd`/ui
cd stable-diffusion
uvicorn server:app --app-dir "$SD_UI_PATH" --port ${SD_UI_BIND_PORT:-9000} --host ${SD_UI_BIND_IP:-0.0.0.0}
uvicorn main:server_api --app-dir "$SD_UI_PATH" --port ${SD_UI_BIND_PORT:-9000} --host ${SD_UI_BIND_IP:-0.0.0.0} --log-level error
read -p "Press any key to continue"

View File

165
ui/easydiffusion/app.py Normal file
View File

@ -0,0 +1,165 @@
import os
import socket
import sys
import json
import traceback
import logging
from rich.logging import RichHandler
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
from easydiffusion import task_manager
from easydiffusion.utils import log
# Remove all handlers associated with the root logger object.
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s'
logging.basicConfig(
level=logging.INFO,
format=LOG_FORMAT,
datefmt="%X",
handlers=[RichHandler(markup=True, rich_tracebacks=True, show_time=False, show_level=False)]
)
SD_DIR = os.getcwd()
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
sys.path.append(os.path.dirname(SD_UI_DIR))
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'))
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = {
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
'update_branch': 'main',
'ui': {
'open_browser_on_start': True,
},
}
def init():
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
update_render_threads()
def getConfig(default_val=APP_CONFIG_DEFAULTS):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return default_val
with open(config_json_path, 'r', encoding='utf-8') as f:
config = json.load(f)
if 'net' not in config:
config['net'] = {}
if os.getenv('SD_UI_BIND_PORT') is not None:
config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT'))
if os.getenv('SD_UI_BIND_IP') is not None:
config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0')
return config
except Exception as e:
log.warn(traceback.format_exc())
return default_val
def setConfig(config):
try: # config.json
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w', encoding='utf-8') as f:
json.dump(config, f)
except:
log.error(traceback.format_exc())
try: # config.bat
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
config_bat = []
if 'update_branch' in config:
config_bat.append(f"@set update_branch={config['update_branch']}")
config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
if len(config_bat) > 0:
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat))
except:
log.error(traceback.format_exc())
try: # config.sh
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
config_sh = ['#!/bin/bash']
if 'update_branch' in config:
config_sh.append(f"export update_branch={config['update_branch']}")
config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
if len(config_sh) > 1:
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh))
except:
log.error(traceback.format_exc())
def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = ckpt_model_name
config['model']['vae'] = vae_model_name
config['model']['hypernetwork'] = hypernetwork_model_name
if vae_model_name is None or vae_model_name == "":
del config['model']['vae']
if hypernetwork_model_name is None or hypernetwork_model_name == "":
del config['model']['hypernetwork']
config['vram_usage_level'] = vram_usage_level
setConfig(config)
def update_render_threads():
config = getConfig()
render_devices = config.get('render_devices', 'auto')
active_devices = task_manager.get_devices()['active'].keys()
log.debug(f'requesting for render_devices: {render_devices}')
task_manager.update_render_threads(render_devices, active_devices)
def getUIPlugins():
plugins = []
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
for file in os.listdir(plugins_dir):
if file.endswith('.plugin.js'):
plugins.append(f'/plugins/{dir_prefix}/{file}')
return plugins
def getIPConfig():
try:
ips = socket.gethostbyname_ex(socket.gethostname())
ips[2].append(ips[0])
return ips[2]
except Exception as e:
log.exception(e)
return []
def open_browser():
config = getConfig()
ui = config.get('ui', {})
net = config.get('net', {'listen_port':9000})
port = net.get('listen_port', 9000)
if ui.get('open_browser_on_start', True):
import webbrowser; webbrowser.open(f"http://localhost:{port}")

View File

@ -3,6 +3,15 @@ import torch
import traceback
import re
from easydiffusion.utils import log
'''
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
Otherwise the models will load at half-precision (i.e. float16).
Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs).
'''
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
mem_free_threshold = 0
@ -34,7 +43,7 @@ def get_device_delta(render_devices, active_devices):
if 'auto' in render_devices:
render_devices = auto_pick_devices(active_devices)
if 'cpu' in render_devices:
print('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
active_devices = set(active_devices)
render_devices = set(render_devices)
@ -53,7 +62,7 @@ def auto_pick_devices(currently_active_devices):
if device_count == 1:
return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu']
print('Autoselecting GPU. Using most free memory.')
log.debug('Autoselecting GPU. Using most free memory.')
devices = []
for device in range(device_count):
device = f'cuda:{device}'
@ -64,7 +73,7 @@ def auto_pick_devices(currently_active_devices):
mem_free /= float(10**9)
mem_total /= float(10**9)
device_name = torch.cuda.get_device_name(device)
print(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb')
log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb')
devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free})
devices.sort(key=lambda x:x['mem_free'], reverse=True)
@ -82,7 +91,7 @@ def auto_pick_devices(currently_active_devices):
devices = list(map(lambda x: x['device'], devices))
return devices
def device_init(thread_data, device):
def device_init(context, device):
'''
This function assumes the 'device' has already been verified to be compatible.
`get_device_delta()` has already filtered out incompatible devices.
@ -91,27 +100,45 @@ def device_init(thread_data, device):
validate_device_id(device, log_prefix='device_init')
if device == 'cpu':
thread_data.device = 'cpu'
thread_data.device_name = get_processor_name()
print('Render device CPU available as', thread_data.device_name)
context.device = 'cpu'
context.device_name = get_processor_name()
context.half_precision = False
log.debug(f'Render device CPU available as {context.device_name}')
return
thread_data.device_name = torch.cuda.get_device_name(device)
thread_data.device = device
context.device_name = torch.cuda.get_device_name(device)
context.device = device
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
device_name = thread_data.device_name.lower()
thread_data.force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
if thread_data.force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', thread_data.device_name)
if needs_to_force_full_precision(context):
log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
# Apply force_full_precision now before models are loaded.
thread_data.precision = 'full'
context.half_precision = False
print(f'Setting {device} as active')
log.info(f'Setting {device} as active, with precision: {"half" if context.half_precision else "full"}')
torch.cuda.device(device)
return
def needs_to_force_full_precision(context):
if 'FORCE_FULL_PRECISION' in os.environ:
return True
device_name = context.device_name.lower()
return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
def get_max_vram_usage_level(device):
if device != 'cpu':
_, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9)
if mem_total < 4.5:
return 'low'
elif mem_total < 6.5:
return 'balanced'
return 'high'
def validate_device_id(device, log_prefix=''):
def is_valid():
if not isinstance(device, str):
@ -132,7 +159,7 @@ def is_device_compatible(device):
try:
validate_device_id(device, log_prefix='is_device_compatible')
except:
print(str(e))
log.error(str(e))
return False
if device == 'cpu': return True
@ -141,10 +168,10 @@ def is_device_compatible(device):
_, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9)
if mem_total < 3.0:
print(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion')
log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion')
return False
except RuntimeError as e:
print(str(e))
log.error(str(e))
return False
return True
@ -164,5 +191,5 @@ def get_processor_name():
if "model name" in line:
return re.sub(".*model name.*:", "", line, 1).strip()
except:
print(traceback.format_exc())
log.error(traceback.format_exc())
return "cpu"

View File

@ -0,0 +1,223 @@
import os
from easydiffusion import app, device_manager
from easydiffusion.types import TaskData
from easydiffusion.utils import log
from sdkit import Context
from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model
from sdkit.utils import hash_file_quick
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
MODEL_EXTENSIONS = {
'stable-diffusion': ['.ckpt', '.safetensors'],
'vae': ['.vae.pt', '.ckpt', '.safetensors'],
'hypernetwork': ['.pt', '.safetensors'],
'gfpgan': ['.pth'],
'realesrgan': ['.pth'],
}
DEFAULT_MODELS = {
'stable-diffusion': [ # needed to support the legacy installations
'custom-model', # only one custom model file was supported initially, creatively named 'custom-model'
'sd-v1-4', # Default fallback.
],
'gfpgan': ['GFPGANv1.3'],
'realesrgan': ['RealESRGAN_x4plus'],
}
VRAM_USAGE_LEVEL_TO_OPTIMIZATIONS = {
'balanced': {'KEEP_FS_AND_CS_IN_CPU', 'SET_ATTENTION_STEP_TO_4'},
'low': {'KEEP_ENTIRE_MODEL_IN_CPU'},
'high': {},
}
MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork']
known_models = {}
def init():
make_model_folders()
getModels() # run this once, to cache the picklescan results
def load_default_models(context: Context):
set_vram_optimizations(context)
# init default model paths
for model_type in MODELS_TO_LOAD_ON_START:
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
load_model(context, model_type)
def unload_all(context: Context):
for model_type in KNOWN_MODEL_TYPES:
unload_model(context, model_type)
def resolve_model_to_use(model_name:str=None, model_type:str=None):
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
default_models = DEFAULT_MODELS.get(model_type, [])
config = app.getConfig()
model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR]
if not model_name: # When None try user configured model.
# config = getConfig()
if 'model' in config and model_type in config['model']:
model_name = config['model'][model_type]
if model_name:
# Check models directory
models_dir_path = os.path.join(app.MODELS_DIR, model_type, model_name)
for model_extension in model_extensions:
if os.path.exists(models_dir_path + model_extension):
return models_dir_path + model_extension
if os.path.exists(model_name + model_extension):
return os.path.abspath(model_name + model_extension)
# Default locations
if model_name in default_models:
default_model_path = os.path.join(app.SD_DIR, model_name)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
return default_model_path + model_extension
# Can't find requested model, check the default paths.
for default_model in default_models:
for model_dir in model_dirs:
default_model_path = os.path.join(model_dir, default_model)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
if model_name is not None:
log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path + model_extension
return None
def reload_models_if_necessary(context: Context, task_data: TaskData):
model_paths_in_req = {
'stable-diffusion': task_data.use_stable_diffusion_model,
'vae': task_data.use_vae_model,
'hypernetwork': task_data.use_hypernetwork_model,
'gfpgan': task_data.use_face_correction,
'realesrgan': task_data.use_upscale,
}
models_to_reload = {model_type: path for model_type, path in model_paths_in_req.items() if context.model_paths.get(model_type) != path}
if set_vram_optimizations(context): # reload SD
models_to_reload['stable-diffusion'] = model_paths_in_req['stable-diffusion']
if 'stable-diffusion' in models_to_reload:
quick_hash = hash_file_quick(models_to_reload['stable-diffusion'])
known_model_info = get_model_info_from_db(quick_hash=quick_hash)
for model_type, model_path_in_req in models_to_reload.items():
context.model_paths[model_type] = model_path_in_req
action_fn = unload_model if context.model_paths[model_type] is None else load_model
action_fn(context, model_type, scan_model=False) # we've scanned them already
def resolve_model_paths(task_data: TaskData):
task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae')
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan')
def set_vram_optimizations(context: Context):
config = app.getConfig()
max_usage_level = device_manager.get_max_vram_usage_level(context.device)
vram_usage_level = config.get('vram_usage_level', 'balanced')
v = {'low': 0, 'balanced': 1, 'high': 2}
if v[vram_usage_level] > v[max_usage_level]:
log.error(f'Requested GPU Memory Usage level ({vram_usage_level}) is higher than what is ' + \
f'possible ({max_usage_level}) on this device ({context.device}). Using "{max_usage_level}" instead')
vram_usage_level = max_usage_level
vram_optimizations = VRAM_USAGE_LEVEL_TO_OPTIMIZATIONS[vram_usage_level]
if vram_optimizations != context.vram_optimizations:
context.vram_optimizations = vram_optimizations
return True
return False
def make_model_folders():
for model_type in KNOWN_MODEL_TYPES:
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
os.makedirs(model_dir_path, exist_ok=True)
help_file_name = f'Place your {model_type} model files here.txt'
help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}'
with open(os.path.join(model_dir_path, help_file_name), 'w', encoding='utf-8') as f:
f.write(help_file_contents)
def is_malicious_model(file_path):
try:
scan_result = scan_model(file_path)
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return True
else:
log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return False
except Exception as e:
log.error(f'error while scanning: {file_path}, error: {e}')
return False
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
'vae': '',
'hypernetwork': '',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
'vae': [],
'hypernetwork': [],
},
}
models_scanned = 0
def listModels(model_type):
nonlocal models_scanned
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
models_dir = os.path.join(app.MODELS_DIR, model_type)
if not os.path.exists(models_dir):
os.makedirs(models_dir)
for file in os.listdir(models_dir):
for model_extension in model_extensions:
if not file.endswith(model_extension):
continue
model_path = os.path.join(models_dir, file)
mtime = os.path.getmtime(model_path)
mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime:
models_scanned += 1
if is_malicious_model(model_path):
models['scan-error'] = file
return
known_models[model_path] = mtime
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates
models['options'][model_type].sort()
# custom models
listModels(model_type='stable-diffusion')
listModels(model_type='vae')
listModels(model_type='hypernetwork')
if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. Nothing infected[/]')
# legacy
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['options']['stable-diffusion'].append('custom-model')
return models

View File

@ -0,0 +1,124 @@
import queue
import time
import json
from easydiffusion import device_manager
from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop, GenerateImageRequest
from easydiffusion.utils import get_printable_request, save_images_to_disk, log
from sdkit import Context
from sdkit.generate import generate_images
from sdkit.filter import apply_filters
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc
context = Context() # thread-local
'''
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
'''
def init(device):
'''
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
'''
context.stop_processing = False
context.temp_images = {}
context.partial_x_samples = None
device_manager.device_init(context, device)
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
context.stop_processing = False
log.info(f'request: {get_printable_request(req)}')
log.info(f'task data: {task_data.dict()}')
images = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed))
res = res.json()
data_queue.put(json.dumps(res))
log.info('Task completed')
return res
def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
filtered_images = filter_images(task_data, images, user_stopped)
if task_data.save_to_disk_path is not None:
save_images_to_disk(images, filtered_images, req, task_data)
return filtered_images if task_data.show_only_filtered_image else images + filtered_images
def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
context.temp_images.clear()
callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
try:
images = generate_images(context, callback=callback, **req.dict())
user_stopped = False
except UserInitiatedStop:
images = []
user_stopped = True
if context.partial_x_samples is not None:
images = latent_samples_to_images(context, context.partial_x_samples)
context.partial_x_samples = None
finally:
gc(context)
return images, user_stopped
def filter_images(task_data: TaskData, images: list, user_stopped):
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
return images
filters_to_apply = []
if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan')
return apply_filters(context, filters_to_apply, images)
def construct_response(images: list, task_data: TaskData, base_seed: int):
return [
ResponseImage(
data=img_to_base64_str(img, task_data.output_format, task_data.output_quality),
seed=base_seed + i
) for i, img in enumerate(images)
]
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
last_callback_time = -1
def update_temp_img(x_samples, task_temp_images: list):
partial_images = []
images = latent_samples_to_images(context, x_samples)
for i, img in enumerate(images):
buf = img_to_buffer(img, output_format='JPEG')
context.temp_images[f"{task_data.request_id}/{i}"] = buf
task_temp_images[i] = buf
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
del images
return partial_images
def on_image_step(x_samples, i):
nonlocal last_callback_time
context.partial_x_samples = x_samples
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
if stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(x_samples, task_temp_images)
data_queue.put(json.dumps(progress))
step_callback()
if context.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return on_image_step

219
ui/easydiffusion/server.py Normal file
View File

@ -0,0 +1,219 @@
"""server.py: FastAPI SD-UI Web Host.
Notes:
async endpoints always run on the main thread. Without they run on the thread pool.
"""
import os
import traceback
import datetime
from typing import List, Union
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
from easydiffusion import app, model_manager, task_manager
from easydiffusion.types import TaskData, GenerateImageRequest
from easydiffusion.utils import log
log.info(f'started in {app.SD_DIR}')
log.info(f'started at {datetime.datetime.now():%x %X}')
server_api = FastAPI()
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
class NoCacheStaticFiles(StaticFiles):
def is_not_modified(self, response_headers, request_headers) -> bool:
if 'content-type' in response_headers and ('javascript' in response_headers['content-type'] or 'css' in response_headers['content-type']):
response_headers.update(NOCACHE_HEADERS)
return False
return super().is_not_modified(response_headers, request_headers)
class SetAppConfigRequest(BaseModel):
update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
model_vae: str = None
ui_open_browser_on_start: bool = None
listen_to_network: bool = None
listen_port: int = None
def init():
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media")
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
server_api.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}")
@server_api.post('/app_config')
async def set_app_config(req : SetAppConfigRequest):
return set_app_config_internal(req)
@server_api.get('/get/{key:path}')
def read_web_data(key:str=None):
return read_web_data_internal(key)
@server_api.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None):
return ping_internal(session_id)
@server_api.post('/render')
def render(req: dict):
return render_internal(req)
@server_api.get('/image/stream/{task_id:int}')
def stream(task_id:int):
return stream_internal(task_id)
@server_api.get('/image/stop')
def stop(task: int):
return stop_internal(task)
@server_api.get('/image/tmp/{task_id:int}/{img_id:int}')
def get_image(task_id: int, img_id: int):
return get_image_internal(task_id, img_id)
@server_api.get('/')
def read_root():
return FileResponse(os.path.join(app.SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
@server_api.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
# API implementations
def set_app_config_internal(req : SetAppConfigRequest):
config = app.getConfig()
if req.update_branch is not None:
config['update_branch'] = req.update_branch
if req.render_devices is not None:
update_render_devices_in_config(config, req.render_devices)
if req.ui_open_browser_on_start is not None:
if 'ui' not in config:
config['ui'] = {}
config['ui']['open_browser_on_start'] = req.ui_open_browser_on_start
if req.listen_to_network is not None:
if 'net' not in config:
config['net'] = {}
config['net']['listen_to_network'] = bool(req.listen_to_network)
if req.listen_port is not None:
if 'net' not in config:
config['net'] = {}
config['net']['listen_port'] = int(req.listen_port)
try:
app.setConfig(config)
if req.render_devices:
app.update_render_threads()
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def update_render_devices_in_config(config, render_devices):
if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'):
raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}')
if render_devices.startswith('cuda:'):
render_devices = render_devices.split(',')
config['render_devices'] = render_devices
def read_web_data_internal(key:str=None):
if not key: # /get without parameters, stable-diffusion easter egg.
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == 'app_config':
return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS)
elif key == 'system_info':
config = app.getConfig()
system_info = {
'devices': task_manager.get_devices(),
'hosts': app.getIPConfig(),
'default_output_dir': os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME),
}
system_info['devices']['config'] = config.get('render_devices', "auto")
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
elif key == 'models':
return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(app.SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
elif key == 'ui_plugins': return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS)
else:
raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
def ping_internal(session_id:str=None):
if task_manager.is_alive() <= 0: # Check that render threads are alive.
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
raise HTTPException(status_code=500, detail='Render thread is dead.')
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive
response = {'status': str(task_manager.current_state)}
if session_id:
session = task_manager.get_cached_session(session_id, update_ttl=True)
response['tasks'] = {id(t): t.status for t in session.tasks}
response['devices'] = task_manager.get_devices()
return JSONResponse(response, headers=NOCACHE_HEADERS)
def render_internal(req: dict):
try:
# separate out the request data into rendering and task-specific data
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
task_data: TaskData = TaskData.parse_obj(req)
render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision
app.save_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model, task_data.vram_usage_level)
# enqueue the task
new_task = task_manager.render(render_req, task_data)
response = {
'status': str(task_manager.current_state),
'queue': len(task_manager.tasks_queue),
'stream': f'/image/stream/{id(new_task)}',
'task': id(new_task)
}
return JSONResponse(response, headers=NOCACHE_HEADERS)
except ChildProcessError as e: # Render thread is dead
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
except Exception as e:
log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def stream_internal(task_id:int):
#TODO Move to WebSockets ??
task = task_manager.get_cached_task(task_id, update_ttl=True)
if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#log.info(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
def stop_internal(task: int):
if not task:
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
task_manager.current_state_error = StopAsyncIteration('')
return {'OK'}
task_id = task
task = task_manager.get_cached_task(task_id, update_ttl=False)
if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration(f'Task {task_id} stop requested.')
return {'OK'}
def get_image_internal(task_id: int, img_id: int):
task = task_manager.get_cached_task(task_id, update_ttl=True)
if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
try:
img_data = task.temp_images[img_id]
img_data.seek(0)
return StreamingResponse(img_data, media_type='image/jpeg')
except KeyError as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@ -11,12 +11,13 @@ TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
import torch
import queue, threading, time, weakref
from typing import Any, Generator, Hashable, Optional, Union
from typing import Any, Hashable
from pydantic import BaseModel
from sd_internal import Request, Response, runtime, device_manager
from easydiffusion import device_manager
from easydiffusion.types import TaskData, GenerateImageRequest
from easydiffusion.utils import log
THREAD_NAME_PREFIX = 'Runtime-Render/'
THREAD_NAME_PREFIX = ''
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
@ -36,12 +37,13 @@ class ServerStates:
class Unavailable(Symbol): pass
class RenderTask(): # Task with output queue and completion lock.
def __init__(self, req: Request):
req.request_id = id(self)
self.request: Request = req # Initial Request
def __init__(self, req: GenerateImageRequest, task_data: TaskData):
task_data.request_id = id(self)
self.render_request: GenerateImageRequest = req # Initial Request
self.task_data: TaskData = task_data
self.response: Any = None # Copy of the last reponse
self.render_device = None # Select the task affinity. (Not used to change active devices).
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
self.temp_images:list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
@ -69,54 +71,6 @@ class RenderTask(): # Task with output queue and completion lock.
def is_pending(self):
return bool(not self.response and not self.error)
# defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel):
session_id: str = "session"
prompt: str = ""
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = 42
prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False ##TODO Remove after UI and plugins transition.
render_device: str = None # Select the task affinity. (Not used to change active devices).
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
use_hypernetwork_model: str = None
hypernetwork_strength: float = None
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
stream_progress_updates: bool = False
stream_image_progress: bool = False
class FilterRequest(BaseModel):
session_id: str = "session"
model: str = None
name: str = ""
init_image: str = None # base64
width: int = 512
height: int = 512
save_to_disk_path: str = None
turbo: bool = True
render_device: str = None
use_full_precision: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
# Temporary cache to allow to query tasks results for a short time after they are completed.
class DataCache():
def __init__(self):
@ -139,11 +93,11 @@ class DataCache():
for key in to_delete:
(_, val) = self._base[key]
if isinstance(val, RenderTask):
print(f'RenderTask {key} expired. Data removed.')
log.debug(f'RenderTask {key} expired. Data removed.')
elif isinstance(val, SessionState):
print(f'Session {key} expired. Data removed.')
log.debug(f'Session {key} expired. Data removed.')
else:
print(f'Key {key} expired. Data removed.')
log.debug(f'Key {key} expired. Data removed.')
del self._base[key]
finally:
self._lock.release()
@ -177,8 +131,7 @@ class DataCache():
self._get_ttl_time(ttl), value
)
except Exception as e:
print(str(e))
print(traceback.format_exc())
log.error(traceback.format_exc())
return False
else:
return True
@ -189,7 +142,7 @@ class DataCache():
try:
ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl):
print(f'Session {key} expired. Discarding data.')
log.debug(f'Session {key} expired. Discarding data.')
del self._base[key]
return None
return value
@ -200,15 +153,9 @@ manager_lock = threading.RLock()
render_threads = []
current_state = ServerStates.Init
current_state_error:Exception = None
current_model_path = None
current_vae_path = None
current_hypernetwork_path = None
tasks_queue = []
session_cache = DataCache()
task_cache = DataCache()
default_model_to_load = None
default_vae_to_load = None
default_hypernetwork_to_load = None
weak_thread_data = weakref.WeakKeyDictionary()
idle_event: threading.Event = threading.Event()
@ -236,40 +183,10 @@ class SessionState():
self._tasks_ids.pop(0)
return True
def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None):
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
if ckpt_file_path == None:
ckpt_file_path = default_model_to_load
if vae_file_path == None:
vae_file_path = default_vae_to_load
if hypernetwork_file_path == None:
hypernetwork_file_path = default_hypernetwork_to_load
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
return
current_state = ServerStates.LoadingModel
try:
from . import runtime
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
runtime.thread_data.ckpt_file = ckpt_file_path
runtime.thread_data.vae_file = vae_file_path
runtime.load_model_ckpt()
runtime.load_hypernetwork()
current_model_path = ckpt_file_path
current_vae_path = vae_file_path
current_hypernetwork_path = hypernetwork_file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
current_model_path = None
current_vae_path = None
current_state_error = e
current_state = ServerStates.Unavailable
print(traceback.format_exc())
def thread_get_next_task():
from . import runtime
from easydiffusion import renderer
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.')
log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.')
return None
if len(tasks_queue) <= 0:
manager_lock.release()
@ -277,7 +194,7 @@ def thread_get_next_task():
task = None
try: # Select a render task.
for queued_task in tasks_queue:
if queued_task.render_device and runtime.thread_data.device != queued_task.render_device:
if queued_task.render_device and renderer.context.device != queued_task.render_device:
# Is asking for a specific render device.
if is_alive(queued_task.render_device) > 0:
continue # requested device alive, skip current one.
@ -286,7 +203,7 @@ def thread_get_next_task():
queued_task.error = Exception(queued_task.render_device + ' is not currently active.')
task = queued_task
break
if not queued_task.render_device and runtime.thread_data.device == 'cpu' and is_alive() > 1:
if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1:
# not asking for any specific devices, cpu want to grab task but other render devices are alive.
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
task = queued_task
@ -298,31 +215,36 @@ def thread_get_next_task():
manager_lock.release()
def thread_render(device):
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
from . import runtime
global current_state, current_state_error
from easydiffusion import renderer, model_manager
try:
runtime.thread_init(device)
except Exception as e:
print(traceback.format_exc())
renderer.init(device)
weak_thread_data[threading.current_thread()] = {
'error': e
}
return
weak_thread_data[threading.current_thread()] = {
'device': runtime.thread_data.device,
'device_name': runtime.thread_data.device_name,
'device': renderer.context.device,
'device_name': renderer.context.device_name,
'alive': True
}
if runtime.thread_data.device != 'cpu' or is_alive() == 1:
preload_model()
current_state = ServerStates.LoadingModel
model_manager.load_default_models(renderer.context)
current_state = ServerStates.Online
except Exception as e:
log.error(traceback.format_exc())
weak_thread_data[threading.current_thread()] = {
'error': e,
'alive': False
}
return
while True:
session_cache.clean()
task_cache.clean()
if not weak_thread_data[threading.current_thread()]['alive']:
print(f'Shutting down thread for device {runtime.thread_data.device}')
runtime.unload_models()
runtime.unload_filters()
log.info(f'Shutting down thread for device {renderer.context.device}')
model_manager.unload_all(renderer.context)
return
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
@ -333,7 +255,7 @@ def thread_render(device):
idle_event.wait(timeout=1)
continue
if task.error is not None:
print(task.error)
log.error(task.error)
task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response))
continue
@ -342,51 +264,45 @@ def thread_render(device):
task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response))
continue
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try:
if runtime.is_hypernetwork_reload_necessary(task.request):
runtime.reload_hypernetwork()
current_hypernetwork_path = task.request.use_hypernetwork_model
if runtime.is_model_reload_necessary(task.request):
current_state = ServerStates.LoadingModel
runtime.reload_model()
current_model_path = task.request.use_stable_diffusion_model
current_vae_path = task.request.use_vae_model
def step_callback():
global current_state_error
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.thread_data.stop_processing = True
renderer.context.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}')
current_state = ServerStates.LoadingModel
model_manager.resolve_model_paths(task.task_data)
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
current_state = ServerStates.Rendering
task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback)
task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL)
except Exception as e:
task.error = e
task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response))
print(traceback.format_exc())
log.error(traceback.format_exc())
continue
finally:
# Task completed
task.lock.release()
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!')
elif task.error is not None:
print(f'Session {task.request.session_id} task {id(task)} failed!')
log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
else:
print(f'Session {task.request.session_id} task {id(task)} completed by {runtime.thread_data.device_name}.')
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.')
current_state = ServerStates.Online
def get_cached_task(task_id:str, update_ttl:bool=False):
@ -423,6 +339,7 @@ def get_devices():
'name': torch.cuda.get_device_name(device),
'mem_free': mem_free,
'mem_total': mem_total,
'max_vram_usage_level': device_manager.get_max_vram_usage_level(device),
}
# list the compatible devices
@ -472,7 +389,7 @@ def is_alive(device=None):
def start_render_thread(device):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED)
print('Start new Rendering Thread on device', device)
log.info(f'Start new Rendering Thread on device: {device}')
try:
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
rthread.daemon = True
@ -484,7 +401,7 @@ def start_render_thread(device):
timeout = DEVICE_START_TIMEOUT
while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]:
if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]:
print(rthread, device, 'error:', weak_thread_data[rthread]['error'])
log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
return False
if timeout <= 0:
return False
@ -496,11 +413,11 @@ def stop_render_thread(device):
try:
device_manager.validate_device_id(device, log_prefix='stop_render_thread')
except:
print(traceback.format_exc())
log.error(traceback.format_exc())
return False
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED)
print('Stopping Rendering Thread on device', device)
log.info(f'Stopping Rendering Thread on device: {device}')
try:
thread_to_remove = None
@ -523,79 +440,44 @@ def stop_render_thread(device):
def update_render_threads(render_devices, active_devices):
devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices)
print('devices_to_start', devices_to_start)
print('devices_to_stop', devices_to_stop)
log.debug(f'devices_to_start: {devices_to_start}')
log.debug(f'devices_to_stop: {devices_to_stop}')
for device in devices_to_stop:
if is_alive(device) <= 0:
print(device, 'is not alive')
log.debug(f'{device} is not alive')
continue
if not stop_render_thread(device):
print(device, 'could not stop render thread')
log.warn(f'{device} could not stop render thread')
for device in devices_to_start:
if is_alive(device) >= 1:
print(device, 'already registered.')
log.debug(f'{device} already registered.')
continue
if not start_render_thread(device):
print(device, 'failed to start.')
log.warn(f'{device} failed to start.')
if is_alive() <= 0: # No running devices, probably invalid user config.
raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json')
print('active devices', get_devices()['active'])
log.debug(f"active devices: {get_devices()['active']}")
def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error
current_state_error = SystemExit('Application shutting down.')
def render(req : ImageRequest):
def render(render_req: GenerateImageRequest, task_data: TaskData):
current_thread_count = is_alive()
if current_thread_count <= 0: # Render thread is dead
raise ChildProcessError('Rendering thread has died.')
# Alive, check if task in cache
session = get_cached_session(req.session_id, update_ttl=True)
session = get_cached_session(task_data.session_id, update_ttl=True)
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
if current_thread_count < len(pending_tasks):
raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
from . import runtime
r = Request()
r.session_id = req.session_id
r.prompt = req.prompt
r.negative_prompt = req.negative_prompt
r.init_image = req.init_image
r.mask = req.mask
r.num_outputs = req.num_outputs
r.num_inference_steps = req.num_inference_steps
r.guidance_scale = req.guidance_scale
r.width = req.width
r.height = req.height
r.seed = req.seed
r.prompt_strength = req.prompt_strength
r.sampler = req.sampler
# r.allow_nsfw = req.allow_nsfw
r.turbo = req.turbo
r.use_full_precision = req.use_full_precision
r.save_to_disk_path = req.save_to_disk_path
r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction
r.use_stable_diffusion_model = req.use_stable_diffusion_model
r.use_vae_model = req.use_vae_model
r.use_hypernetwork_model = req.use_hypernetwork_model
r.hypernetwork_strength = req.hypernetwork_strength
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.output_quality = req.output_quality
r.stream_progress_updates = True # the underlying implementation only supports streaming
r.stream_image_progress = req.stream_image_progress
if not req.stream_progress_updates:
r.stream_image_progress = False
new_task = RenderTask(r)
new_task = RenderTask(render_req, task_data)
if session.put(new_task, TASK_TTL):
# Use twice the normal timeout for adding user requests.
# Tries to force session.put to fail before tasks_queue.put would.

87
ui/easydiffusion/types.py Normal file
View File

@ -0,0 +1,87 @@
from pydantic import BaseModel
from typing import Any
class GenerateImageRequest(BaseModel):
prompt: str = ""
negative_prompt: str = ""
seed: int = 42
width: int = 512
height: int = 512
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
init_image: Any = None
init_image_mask: Any = None
prompt_strength: float = 0.8
preserve_init_image_color_profile = False
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
hypernetwork_strength: float = 0
class TaskData(BaseModel):
request_id: str = None
session_id: str = "session"
save_to_disk_path: str = None
vram_usage_level: str = "balanced" # or "low" or "medium"
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_stable_diffusion_config: str = "v1-inference"
use_vae_model: str = None
use_hypernetwork_model: str = None
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
metadata_output_format: str = "txt" # or "json"
stream_image_progress: bool = False
class Image:
data: str # base64
seed: int
is_nsfw: bool
path_abs: str = None
def __init__(self, data, seed):
self.data = data
self.seed = seed
def json(self):
return {
"data": self.data,
"seed": self.seed,
"path_abs": self.path_abs,
}
class Response:
render_request: GenerateImageRequest
task_data: TaskData
images: list
def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list):
self.render_request = render_request
self.task_data = task_data
self.images = images
def json(self):
del self.render_request.init_image
del self.render_request.init_image_mask
res = {
"status": 'succeeded',
"render_request": self.render_request.dict(),
"task_data": self.task_data.dict(),
"output": [],
}
for image in self.images:
res["output"].append(image.json())
return res
class UserInitiatedStop(Exception):
pass

View File

@ -0,0 +1,8 @@
import logging
log = logging.getLogger('easydiffusion')
from .save_utils import (
save_images_to_disk,
get_printable_request,
)

View File

@ -0,0 +1,79 @@
import os
import time
import base64
import re
from easydiffusion.types import TaskData, GenerateImageRequest
from sdkit.utils import save_images, save_dicts
filename_regex = re.compile('[^a-zA-Z0-9]')
# keep in sync with `ui/media/js/dnd.js`
TASK_TEXT_MAPPING = {
'prompt': 'Prompt',
'width': 'Width',
'height': 'Height',
'seed': 'Seed',
'num_inference_steps': 'Steps',
'guidance_scale': 'Guidance Scale',
'prompt_strength': 'Prompt Strength',
'use_face_correction': 'Use Face Correction',
'use_upscale': 'Use Upscaling',
'sampler_name': 'Sampler',
'negative_prompt': 'Negative Prompt',
'use_stable_diffusion_model': 'Stable Diffusion model',
'use_hypernetwork_model': 'Hypernetwork model',
'hypernetwork_strength': 'Hypernetwork Strength'
}
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
metadata_entries = get_metadata_entries_for_request(req, task_data)
if task_data.show_only_filtered_image or filtered_images == images:
save_images(filtered_images, save_dir_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
save_dicts(metadata_entries, save_dir_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format)
else:
save_images(images, save_dir_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
save_images(filtered_images, save_dir_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality)
save_dicts(metadata_entries, save_dir_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format)
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData):
metadata = get_printable_request(req)
metadata.update({
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
'use_vae_model': task_data.use_vae_model,
'use_hypernetwork_model': task_data.use_hypernetwork_model,
'use_face_correction': task_data.use_face_correction,
'use_upscale': task_data.use_upscale,
})
# if text, format it in the text format expected by the UI
is_txt_format = (task_data.metadata_output_format.lower() == 'txt')
if is_txt_format:
metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING}
entries = [metadata.copy() for _ in range(req.num_outputs)]
for i, entry in enumerate(entries):
entry['Seed' if is_txt_format else 'seed'] = req.seed + i
return entries
def get_printable_request(req: GenerateImageRequest):
metadata = req.dict()
del metadata['init_image']
del metadata['init_image_mask']
return metadata
def make_filename_callback(req: GenerateImageRequest, suffix=None):
def make_filename(i):
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
name = f"{prompt_flattened}_{img_id}"
name = name if suffix is None else f'{name}_{suffix}'
return name
return make_filename

View File

@ -24,8 +24,8 @@
<div id="top-nav">
<div id="logo">
<h1>
Stable Diffusion UI
<small>v2.4.22 <span id="updateBranchLabel"></span></small>
Easy Diffusion
<small>v2.5.0 <span id="updateBranchLabel"></span></small>
</h1>
</div>
<div id="server-status">
@ -129,22 +129,32 @@
</select>
<a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a>
</td></tr>
<!-- <tr id="modelConfigSelection" class="pl-5"><td><label for="model_config">Model Config:</i></label></td><td>
<select id="model_config" name="model_config">
</select>
</td></tr> -->
<tr class="pl-5"><td><label for="vae_model">Custom VAE:</i></label></td><td>
<select id="vae_model" name="vae_model">
<!-- <option value="" selected>None</option> -->
</select>
<a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/VAE-Variational-Auto-Encoder" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about VAEs</span></i></a>
</td></tr>
<tr id="samplerSelection" class="pl-5"><td><label for="sampler">Sampler:</label></td><td>
<select id="sampler" name="sampler">
<option value="plms">plms</option>
<option value="ddim">ddim</option>
<option value="heun">heun</option>
<option value="euler">euler</option>
<option value="euler_a" selected>euler_a</option>
<option value="dpm2">dpm2</option>
<option value="dpm2_a">dpm2_a</option>
<option value="lms">lms</option>
<tr id="samplerSelection" class="pl-5"><td><label for="sampler_name">Sampler:</label></td><td>
<select id="sampler_name" name="sampler_name">
<option value="plms">PLMS</option>
<option value="ddim">DDIM</option>
<option value="heun">Heun</option>
<option value="euler">Euler</option>
<option value="euler_a" selected>Euler Ancestral</option>
<option value="dpm2">DPM2</option>
<option value="dpm2_a">DPM2 Ancestral</option>
<option value="lms">LMS</option>
<option value="dpm_solver_stability">DPM Solver (Stability AI)</option>
<option value="dpmpp_2s_a" selected>DPM++ 2s Ancestral</option>
<option value="dpmpp_2m">DPM++ 2m</option>
<option value="dpmpp_sde">DPM++ SDE</option>
<option value="dpm_fast">DPM Fast</option>
<option value="dpm_adaptive">DPM Adaptive</option>
</select>
<a href="https://github.com/cmdr2/stable-diffusion-ui/wiki/How-to-Use#samplers" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about samplers</span></i></a>
</td></tr>
@ -220,6 +230,7 @@
<div><ul>
<li><b class="settings-subheader">Render Settings</b></li>
<li class="pl-5"><input id="stream_image_progress" name="stream_image_progress" type="checkbox"> <label for="stream_image_progress">Show a live preview <small>(uses more VRAM, slower images)</small></label></li>
<li id="apply_color_correction_setting" class="pl-5"><input id="apply_color_correction" name="apply_color_correction" type="checkbox"> <label for="apply_color_correction">Preserve color profile <small>(helps during inpainting)</small></label></li>
<li class="pl-5"><input id="use_face_correction" name="use_face_correction" type="checkbox"> <label for="use_face_correction">Fix incorrect faces and eyes <small>(uses GFPGAN)</small></label></li>
<li class="pl-5">
<input id="use_upscale" name="use_upscale" type="checkbox"> <label for="use_upscale">Upscale image by 4x with </label>
@ -416,7 +427,6 @@
async function init() {
await initSettings()
await getModels()
await getDiskPath()
await getAppConfig()
await loadUIPlugins()
await loadModifiers()

10
ui/main.py Normal file
View File

@ -0,0 +1,10 @@
from easydiffusion import model_manager, app, server
from easydiffusion.server import server_api # required for uvicorn
# Init the app
model_manager.init()
app.init()
server.init()
# start the browser ui
app.open_browser()

View File

@ -15,7 +15,7 @@ const SETTINGS_IDS_LIST = [
"stable_diffusion_model",
"vae_model",
"hypernetwork_model",
"sampler",
"sampler_name",
"width",
"height",
"num_inference_steps",
@ -36,10 +36,11 @@ const SETTINGS_IDS_LIST = [
"save_to_disk",
"diskPath",
"sound_toggle",
"turbo",
"use_full_precision",
"vram_usage_level",
"confirm_dangerous_actions",
"auto_save_settings"
"metadata_output_format",
"auto_save_settings",
"apply_color_correction"
]
const IGNORE_BY_DEFAULT = [
@ -277,7 +278,6 @@ function tryLoadOldSettings() {
"soundEnabled": "sound_toggle",
"saveToDisk": "save_to_disk",
"useCPU": "use_cpu",
"useFullPrecision": "use_full_precision",
"useTurboMode": "turbo",
"diskPath": "diskPath",
"useFaceCorrection": "use_face_correction",

View File

@ -25,6 +25,7 @@ function parseBoolean(stringValue) {
case "no":
case "off":
case "0":
case "none":
case null:
case undefined:
return false;
@ -160,9 +161,9 @@ const TASK_MAPPING = {
readUI: () => (useUpscalingField.checked ? upscaleModelField.value : undefined),
parse: (val) => val
},
sampler: { name: 'Sampler',
setUI: (sampler) => {
samplerField.value = sampler
sampler_name: { name: 'Sampler',
setUI: (sampler_name) => {
samplerField.value = sampler_name
},
readUI: () => samplerField.value,
parse: (val) => val
@ -171,7 +172,7 @@ const TASK_MAPPING = {
setUI: (use_stable_diffusion_model) => {
const oldVal = stableDiffusionModelField.value
use_stable_diffusion_model = getModelPath(use_stable_diffusion_model, ['.ckpt'])
use_stable_diffusion_model = getModelPath(use_stable_diffusion_model, ['.ckpt', '.safetensors'])
stableDiffusionModelField.value = use_stable_diffusion_model
if (!stableDiffusionModelField.value) {
@ -184,6 +185,7 @@ const TASK_MAPPING = {
use_vae_model: { name: 'VAE model',
setUI: (use_vae_model) => {
const oldVal = vaeModelField.value
use_vae_model = (use_vae_model === undefined || use_vae_model === null || use_vae_model === 'None' ? '' : use_vae_model)
if (use_vae_model !== '') {
use_vae_model = getModelPath(use_vae_model, ['.vae.pt', '.ckpt'])
@ -197,6 +199,7 @@ const TASK_MAPPING = {
use_hypernetwork_model: { name: 'Hypernetwork model',
setUI: (use_hypernetwork_model) => {
const oldVal = hypernetworkModelField.value
use_hypernetwork_model = (use_hypernetwork_model === undefined || use_hypernetwork_model === null || use_hypernetwork_model === 'None' ? '' : use_hypernetwork_model)
if (use_hypernetwork_model !== '') {
use_hypernetwork_model = getModelPath(use_hypernetwork_model, ['.pt'])
@ -239,13 +242,6 @@ const TASK_MAPPING = {
readUI: () => turboField.checked,
parse: (val) => Boolean(val)
},
use_full_precision: { name: 'Use Full Precision',
setUI: (use_full_precision) => {
useFullPrecisionField.checked = use_full_precision
},
readUI: () => useFullPrecisionField.checked,
parse: (val) => Boolean(val)
},
stream_image_progress: { name: 'Stream Image Progress',
setUI: (stream_image_progress) => {
@ -350,6 +346,7 @@ function getModelPath(filename, extensions)
}
const TASK_TEXT_MAPPING = {
prompt: 'Prompt',
width: 'Width',
height: 'Height',
seed: 'Seed',
@ -358,7 +355,7 @@ const TASK_TEXT_MAPPING = {
prompt_strength: 'Prompt Strength',
use_face_correction: 'Use Face Correction',
use_upscale: 'Use Upscaling',
sampler: 'Sampler',
sampler_name: 'Sampler',
negative_prompt: 'Negative Prompt',
use_stable_diffusion_model: 'Stable Diffusion model',
use_hypernetwork_model: 'Hypernetwork model',
@ -410,6 +407,9 @@ async function parseContent(text) {
if (text.startsWith('{') && text.endsWith('}')) {
try {
const task = JSON.parse(text)
if (!('reqBody' in task)) { // support the format saved to the disk, by the UI
task.reqBody = Object.assign({}, task)
}
restoreTaskToUI(task)
return true
} catch (e) {
@ -477,7 +477,6 @@ document.addEventListener("dragover", dragOverHandler)
const TASK_REQ_NO_EXPORT = [
"use_cpu",
"turbo",
"use_full_precision",
"save_to_disk_path"
]
const resetSettings = document.getElementById('reset-image-settings')

View File

@ -728,7 +728,6 @@
"stream_image_progress": 'boolean',
"show_only_filtered_image": 'boolean',
"turbo": 'boolean',
"use_full_precision": 'boolean',
"output_format": 'string',
"output_quality": 'number',
}
@ -744,7 +743,6 @@
"stream_image_progress": true,
"show_only_filtered_image": true,
"turbo": false,
"use_full_precision": false,
"output_format": "png",
"output_quality": 75,
}

View File

@ -26,9 +26,11 @@ let initImagePreview = document.querySelector("#init_image_preview")
let initImageSizeBox = document.querySelector("#init_image_size_box")
let maskImageSelector = document.querySelector("#mask")
let maskImagePreview = document.querySelector("#mask_preview")
let applyColorCorrectionField = document.querySelector('#apply_color_correction')
let colorCorrectionSetting = document.querySelector('#apply_color_correction_setting')
let promptStrengthSlider = document.querySelector('#prompt_strength_slider')
let promptStrengthField = document.querySelector('#prompt_strength')
let samplerField = document.querySelector('#sampler')
let samplerField = document.querySelector('#sampler_name')
let samplerSelectionContainer = document.querySelector("#samplerSelection")
let useFaceCorrectionField = document.querySelector("#use_face_correction")
let useUpscalingField = document.querySelector("#use_upscale")
@ -610,7 +612,7 @@ function onTaskCompleted(task, reqBody, instance, outputContainer, stepUpdate) {
<b>Suggestions</b>:
<br/>
1. If you have set an initial image, please try reducing its dimension to ${MAX_INIT_IMAGE_DIMENSION}x${MAX_INIT_IMAGE_DIMENSION} or smaller.<br/>
2. Try disabling the '<em>Turbo mode</em>' under '<em>Advanced Settings</em>'.<br/>
2. Try picking a lower level in the '<em>GPU Memory Usage</em>' setting (in the '<em>Settings</em>' tab).<br/>
3. Try generating a smaller image.<br/>`
}
} else {
@ -789,7 +791,8 @@ function createTask(task) {
let w = task.reqBody.width * h / task.reqBody.height >>0
taskConfig += `<div class="task-initimg" style="float:left;"><img style="width:${w}px;height:${h}px;" src="${task.reqBody.init_image}"><div class="task-fs-initimage"></div></div>`
}
taskConfig += `<b>Seed:</b> ${task.seed}, <b>Sampler:</b> ${task.reqBody.sampler}, <b>Inference Steps:</b> ${task.reqBody.num_inference_steps}, <b>Guidance Scale:</b> ${task.reqBody.guidance_scale}, <b>Model:</b> ${task.reqBody.use_stable_diffusion_model}`
taskConfig += `<b>Seed:</b> ${task.seed}, <b>Sampler:</b> ${task.reqBody.sampler_name}, <b>Inference Steps:</b> ${task.reqBody.num_inference_steps}, <b>Guidance Scale:</b> ${task.reqBody.guidance_scale}, <b>Model:</b> ${task.reqBody.use_stable_diffusion_model}`
if (task.reqBody.use_vae_model.trim() !== '') {
taskConfig += `, <b>VAE:</b> ${task.reqBody.use_vae_model}`
}
@ -809,6 +812,9 @@ function createTask(task) {
taskConfig += `, <b>Hypernetwork:</b> ${task.reqBody.use_hypernetwork_model}`
taskConfig += `, <b>Hypernetwork Strength:</b> ${task.reqBody.hypernetwork_strength}`
}
if (task.reqBody.preserve_init_image_color_profile) {
taskConfig += `, <b>Preserve Color Profile:</b> true`
}
let taskEntry = document.createElement('div')
taskEntry.id = `imageTaskContainer-${Date.now()}`
@ -914,9 +920,8 @@ function getCurrentUserRequest() {
width: parseInt(widthField.value),
height: parseInt(heightField.value),
// allow_nsfw: allowNSFWField.checked,
turbo: turboField.checked,
vram_usage_level: vramUsageLevelField.value,
//render_device: undefined, // Set device affinity. Prefer this device, but wont activate.
use_full_precision: useFullPrecisionField.checked,
use_stable_diffusion_model: stableDiffusionModelField.value,
use_vae_model: vaeModelField.value,
stream_progress_updates: true,
@ -924,6 +929,7 @@ function getCurrentUserRequest() {
show_only_filtered_image: showOnlyFilteredImageField.checked,
output_format: outputFormatField.value,
output_quality: parseInt(outputQualityField.value),
metadata_output_format: document.querySelector('#metadata_output_format').value,
original_prompt: promptField.value,
active_tags: (activeTags.map(x => x.name))
}
@ -938,9 +944,10 @@ function getCurrentUserRequest() {
if (maskSetting.checked) {
newTask.reqBody.mask = imageInpainter.getImg()
}
newTask.reqBody.sampler = 'ddim'
newTask.reqBody.preserve_init_image_color_profile = applyColorCorrectionField.checked
newTask.reqBody.sampler_name = 'ddim'
} else {
newTask.reqBody.sampler = samplerField.value
newTask.reqBody.sampler_name = samplerField.value
}
if (saveToDiskField.checked && diskPathField.value.trim() !== '') {
newTask.reqBody.save_to_disk_path = diskPathField.value.trim()
@ -1349,6 +1356,7 @@ function img2imgLoad() {
promptStrengthContainer.style.display = 'table-row'
samplerSelectionContainer.style.display = "none"
initImagePreviewContainer.classList.add("has-image")
colorCorrectionSetting.style.display = ''
initImageSizeBox.textContent = initImagePreview.naturalWidth + " x " + initImagePreview.naturalHeight
imageEditor.setImage(this.src, initImagePreview.naturalWidth, initImagePreview.naturalHeight)
@ -1363,6 +1371,7 @@ function img2imgUnload() {
promptStrengthContainer.style.display = "none"
samplerSelectionContainer.style.display = ""
initImagePreviewContainer.classList.remove("has-image")
colorCorrectionSetting.style.display = 'none'
imageEditor.setImage(null, parseInt(widthField.value), parseInt(heightField.value))
}

View File

@ -53,6 +53,23 @@ var PARAMETERS = [
return `<input id="${parameter.id}" name="${parameter.id}" size="30" disabled>`
}
},
{
id: "metadata_output_format",
type: ParameterType.select,
label: "Metadata format",
note: "will be saved to disk in this format",
default: "txt",
options: [
{
value: "txt",
label: "txt"
},
{
value: "json",
label: "json"
}
],
},
{
id: "sound_toggle",
type: ParameterType.checkbox,
@ -77,12 +94,20 @@ var PARAMETERS = [
default: true,
},
{
id: "turbo",
type: ParameterType.checkbox,
label: "Turbo Mode",
note: "generates images faster, but uses an additional 1 GB of GPU memory",
id: "vram_usage_level",
type: ParameterType.select,
label: "GPU Memory Usage",
note: "Faster performance requires more GPU memory (VRAM)<br/><br/>" +
"<b>Balanced:</b> nearly as fast as High, much lower VRAM usage<br/>" +
"<b>High:</b> fastest, maximum GPU memory usage</br>" +
"<b>Low:</b> slowest, force-used for GPUs with 4 GB (or less) memory",
icon: "fa-forward",
default: true,
default: "balanced",
options: [
{value: "balanced", label: "Balanced"},
{value: "high", label: "High"},
{value: "low", label: "Low"}
],
},
{
id: "use_cpu",
@ -105,14 +130,6 @@ var PARAMETERS = [
note: "to process in parallel",
default: false,
},
{
id: "use_full_precision",
type: ParameterType.checkbox,
label: "Use Full Precision",
note: "for GPU-only. warning: this will consume more VRAM",
icon: "fa-crosshairs",
default: false,
},
{
id: "auto_save_settings",
type: ParameterType.checkbox,
@ -147,14 +164,6 @@ var PARAMETERS = [
return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">`
}
},
{
id: "test_sd2",
type: ParameterType.checkbox,
label: "Test SD 2.0",
note: "Experimental! High memory usage! GPU-only! Not the final version! Please restart the program after changing this.",
icon: "fa-fire",
default: false,
},
{
id: "use_beta_channel",
type: ParameterType.checkbox,
@ -210,16 +219,14 @@ function initParameters() {
initParameters()
let turboField = document.querySelector('#turbo')
let vramUsageLevelField = document.querySelector('#vram_usage_level')
let useCPUField = document.querySelector('#use_cpu')
let autoPickGPUsField = document.querySelector('#auto_pick_gpus')
let useGPUsField = document.querySelector('#use_gpus')
let useFullPrecisionField = document.querySelector('#use_full_precision')
let saveToDiskField = document.querySelector('#save_to_disk')
let diskPathField = document.querySelector('#diskPath')
let listenToNetworkField = document.querySelector("#listen_to_network")
let listenPortField = document.querySelector("#listen_port")
let testSD2Field = document.querySelector("#test_sd2")
let useBetaChannelField = document.querySelector("#use_beta_channel")
let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start")
let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions")
@ -256,12 +263,6 @@ async function getAppConfig() {
if (config.ui && config.ui.open_browser_on_start === false) {
uiOpenBrowserOnStartField.checked = false
}
if ('test_sd2' in config) {
testSD2Field.checked = config['test_sd2']
}
let testSD2SettingEntry = getParameterSettingsEntry('test_sd2')
testSD2SettingEntry.style.display = (config.update_branch === 'beta' ? '' : 'none')
if (config.net && config.net.listen_to_network === false) {
listenToNetworkField.checked = false
}
@ -327,20 +328,10 @@ autoPickGPUsField.addEventListener('click', function() {
gpuSettingEntry.style.display = (this.checked ? 'none' : '')
})
async function getDiskPath() {
try {
async function setDiskPath(defaultDiskPath) {
var diskPath = getSetting("diskPath")
if (diskPath == '' || diskPath == undefined || diskPath == "undefined") {
let res = await fetch('/get/output_dir')
if (res.status === 200) {
res = await res.json()
res = res.output_dir
setSetting("diskPath", res)
}
}
} catch (e) {
console.log('error fetching output dir path', e)
setSetting("diskPath", defaultDiskPath)
}
}
@ -415,6 +406,7 @@ async function getSystemInfo() {
setDeviceInfo(devices)
setHostInfo(res['hosts'])
setDiskPath(res['default_output_dir'])
} catch (e) {
console.log('error fetching devices', e)
}
@ -435,8 +427,7 @@ saveSettingsBtn.addEventListener('click', function() {
'update_branch': updateBranch,
'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked,
'listen_to_network': listenToNetworkField.checked,
'listen_port': listenPortField.value,
'test_sd2': testSD2Field.checked
'listen_port': listenPortField.value
})
saveSettingsBtn.classList.add('active')
asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active'))

View File

@ -1,119 +0,0 @@
import json
class Request:
request_id: str = None
session_id: str = "session"
prompt: str = ""
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = 42
prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False
precision: str = "autocast" # or "full"
save_to_disk_path: str = None
turbo: bool = True
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
use_hypernetwork_model: str = None
hypernetwork_strength: float = 1
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
stream_progress_updates: bool = False
stream_image_progress: bool = False
def json(self):
return {
"session_id": self.session_id,
"prompt": self.prompt,
"negative_prompt": self.negative_prompt,
"num_outputs": self.num_outputs,
"num_inference_steps": self.num_inference_steps,
"guidance_scale": self.guidance_scale,
"hypernetwork_strengtgh": self.guidance_scale,
"width": self.width,
"height": self.height,
"seed": self.seed,
"prompt_strength": self.prompt_strength,
"sampler": self.sampler,
"use_face_correction": self.use_face_correction,
"use_upscale": self.use_upscale,
"use_stable_diffusion_model": self.use_stable_diffusion_model,
"use_vae_model": self.use_vae_model,
"use_hypernetwork_model": self.use_hypernetwork_model,
"hypernetwork_strength": self.hypernetwork_strength,
"output_format": self.output_format,
"output_quality": self.output_quality,
}
def __str__(self):
return f'''
session_id: {self.session_id}
prompt: {self.prompt}
negative_prompt: {self.negative_prompt}
seed: {self.seed}
num_inference_steps: {self.num_inference_steps}
sampler: {self.sampler}
guidance_scale: {self.guidance_scale}
w: {self.width}
h: {self.height}
precision: {self.precision}
save_to_disk_path: {self.save_to_disk_path}
turbo: {self.turbo}
use_full_precision: {self.use_full_precision}
use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale}
use_stable_diffusion_model: {self.use_stable_diffusion_model}
use_vae_model: {self.use_vae_model}
use_hypernetwork_model: {self.use_hypernetwork_model}
hypernetwork_strength: {self.hypernetwork_strength}
show_only_filtered_image: {self.show_only_filtered_image}
output_format: {self.output_format}
output_quality: {self.output_quality}
stream_progress_updates: {self.stream_progress_updates}
stream_image_progress: {self.stream_image_progress}'''
class Image:
data: str # base64
seed: int
is_nsfw: bool
path_abs: str = None
def __init__(self, data, seed):
self.data = data
self.seed = seed
def json(self):
return {
"data": self.data,
"seed": self.seed,
"path_abs": self.path_abs,
}
class Response:
request: Request
images: list
def json(self):
res = {
"status": 'succeeded',
"request": self.request.json(),
"output": [],
}
for image in self.images:
res["output"].append(image.json())
return res

View File

@ -1,162 +0,0 @@
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
index 79058bc..a473411 100644
--- a/optimizedSD/ddpm.py
+++ b/optimizedSD/ddpm.py
@@ -564,12 +564,12 @@ class UNet(DDPM):
unconditional_guidance_scale=unconditional_guidance_scale,
callback=callback, img_callback=img_callback)
+ yield from samples
+
if(self.turbo):
self.model1.to("cpu")
self.model2.to("cpu")
- return samples
-
@torch.no_grad()
def plms_sampling(self, cond,b, img,
ddim_use_original_steps=False,
@@ -608,10 +608,10 @@ class UNet(DDPM):
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
- return img
+ yield from img_callback(img, len(iterator)-1)
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
@@ -740,13 +740,13 @@ class UNet(DDPM):
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
- if callback: callback(i)
- if img_callback: img_callback(x_dec, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x_dec, i)
if mask is not None:
- return x0 * mask + (1. - mask) * x_dec
+ x_dec = x0 * mask + (1. - mask) * x_dec
- return x_dec
+ yield from img_callback(x_dec, len(iterator)-1)
@torch.no_grad()
@@ -820,12 +820,12 @@ class UNet(DDPM):
d = to_d(x, sigma_hat, denoised)
- if callback: callback(i)
- if img_callback: img_callback(x, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x, i)
dt = sigmas[i + 1] - sigma_hat
# Euler method
x = x + d * dt
- return x
+ yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, img_callback=None):
@@ -852,14 +852,14 @@ class UNet(DDPM):
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
- if callback: callback(i)
- if img_callback: img_callback(x, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x, i)
d = to_d(x, sigmas[i], denoised)
# Euler method
dt = sigma_down - sigmas[i]
x = x + d * dt
x = x + torch.randn_like(x) * sigma_up
- return x
+ yield from img_callback(x, len(sigmas)-1)
@@ -892,8 +892,8 @@ class UNet(DDPM):
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
d = to_d(x, sigma_hat, denoised)
- if callback: callback(i)
- if img_callback: img_callback(x, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x, i)
dt = sigmas[i + 1] - sigma_hat
if sigmas[i + 1] == 0:
# Euler method
@@ -913,7 +913,7 @@ class UNet(DDPM):
d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
- return x
+ yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
@@ -944,8 +944,8 @@ class UNet(DDPM):
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
- if callback: callback(i)
- if img_callback: img_callback(x, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x, i)
d = to_d(x, sigma_hat, denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
@@ -966,7 +966,7 @@ class UNet(DDPM):
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
- return x
+ yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
@@ -994,8 +994,8 @@ class UNet(DDPM):
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
- if callback: callback(i)
- if img_callback: img_callback(x, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x, i)
d = to_d(x, sigmas[i], denoised)
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule
sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3
@@ -1016,7 +1016,7 @@ class UNet(DDPM):
d_2 = to_d(x_2, sigma_mid, denoised_2)
x = x + d_2 * dt_2
x = x + torch.randn_like(x) * sigma_up
- return x
+ yield from img_callback(x, len(sigmas)-1)
@torch.no_grad()
@@ -1042,8 +1042,8 @@ class UNet(DDPM):
e_t_uncond, e_t = (x_in + eps * c_out).chunk(2)
denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
- if callback: callback(i)
- if img_callback: img_callback(x, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(x, i)
d = to_d(x, sigmas[i], denoised)
ds.append(d)
@@ -1054,4 +1054,4 @@ class UNet(DDPM):
cur_order = min(i + 1, order)
coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)]
x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
- return x
+ yield from img_callback(x, len(sigmas)-1)

View File

@ -1,84 +0,0 @@
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
index 27ead0e..6215939 100644
--- a/ldm/models/diffusion/ddim.py
+++ b/ldm/models/diffusion/ddim.py
@@ -100,7 +100,7 @@ class DDIMSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
- samples, intermediates = self.ddim_sampling(conditioning, size,
+ samples = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
@@ -117,7 +117,8 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
- return samples, intermediates
+ # return samples, intermediates
+ yield from samples
@torch.no_grad()
def ddim_sampling(self, cond, shape,
@@ -168,14 +169,15 @@ class DDIMSampler(object):
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
- return img, intermediates
+ # return img, intermediates
+ yield from img_callback(pred_x0, len(iterator)-1)
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
index 7002a36..0951f39 100644
--- a/ldm/models/diffusion/plms.py
+++ b/ldm/models/diffusion/plms.py
@@ -96,7 +96,7 @@ class PLMSSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
- samples, intermediates = self.plms_sampling(conditioning, size,
+ samples = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
@@ -112,7 +112,8 @@ class PLMSSampler(object):
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
- return samples, intermediates
+ #return samples, intermediates
+ yield from samples
@torch.no_grad()
def plms_sampling(self, cond, shape,
@@ -165,14 +166,15 @@ class PLMSSampler(object):
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
- return img, intermediates
+ # return img, intermediates
+ yield from img_callback(pred_x0, len(iterator)-1)
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,

View File

@ -1,198 +0,0 @@
# this is basically a cut down version of https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/c9a2cfdf2a53d37c2de1908423e4f548088667ef/modules/hypernetworks/hypernetwork.py, mostly for feature parity
# I, c0bra5, don't really understand how deep learning works. I just know how to port stuff.
import inspect
import torch
import optimizedSD.splitAttention
from . import runtime
from einops import rearrange
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
loaded_hypernetwork = None
class HypernetworkModule(torch.nn.Module):
multiplier = 0.5
activation_dict = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
"tanh": torch.nn.Tanh,
"sigmoid": torch.nn.Sigmoid,
}
activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
super().__init__()
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
# Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# Add an activation func except last layer
if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
# Add layer normalization
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout except last layer
if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
self.fix_old_state_dict(state_dict)
self.load_state_dict(state_dict)
self.to(runtime.thread_data.device)
def fix_old_state_dict(self, state_dict):
changes = {
'linear1.bias': 'linear.0.bias',
'linear1.weight': 'linear.0.weight',
'linear2.bias': 'linear.1.bias',
'linear2.weight': 'linear.1.weight',
}
for fr, to in changes.items():
x = state_dict.get(fr, None)
if x is None:
continue
del state_dict[fr]
state_dict[to] = x
def forward(self, x: torch.Tensor):
return x + self.linear(x) * runtime.thread_data.hypernetwork_strength
def apply_hypernetwork(hypernetwork, context, layer=None):
hypernetwork_layers = hypernetwork.get(context.shape[2], None)
if hypernetwork_layers is None:
return context, context
if layer is not None:
layer.hyper_k = hypernetwork_layers[0]
layer.hyper_v = hypernetwork_layers[1]
context_k = hypernetwork_layers[0](context)
context_v = hypernetwork_layers[1](context)
return context_k, context_v
def get_kv(context, hypernetwork):
if hypernetwork is None:
return context, context
else:
return apply_hypernetwork(runtime.thread_data.hypernetwork, context)
# This might need updating as the optimisedSD code changes
# I think yall have a system for this (patch files in sd_internal) but idk how it works and no amount of searching gave me any clue
# just in case for attribution https://github.com/easydiffusion/diffusion-kit/blob/e8ea0cadd543056059cd951e76d4744de76327d2/optimizedSD/splitAttention.py#L171
def new_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
# default context
context = context if context is not None else x() if inspect.isfunction(x) else x
# hypernetwork!
context_k, context_v = get_kv(context, runtime.thread_data.hypernetwork)
k = self.to_k(context_k)
v = self.to_v(context_v)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
limit = k.shape[0]
att_step = self.att_step
q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0))
k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0))
v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0))
q_chunks.reverse()
k_chunks.reverse()
v_chunks.reverse()
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
del k, q, v
for i in range (0, limit, att_step):
q_buffer = q_chunks.pop()
k_buffer = k_chunks.pop()
v_buffer = v_chunks.pop()
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
del k_buffer, q_buffer
# attention, what we cannot get enough of, by chunks
sim_buffer = sim_buffer.softmax(dim=-1)
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
del v_buffer
sim[i:i+att_step,:,:] = sim_buffer
del sim_buffer
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)
def load_hypernetwork(path: str):
state_dict = torch.load(path, map_location='cpu')
layer_structure = state_dict.get('layer_structure', [1, 2, 1])
activation_func = state_dict.get('activation_func', None)
weight_init = state_dict.get('weight_initialization', 'Normal')
add_layer_norm = state_dict.get('is_layer_norm', False)
use_dropout = state_dict.get('use_dropout', False)
activate_output = state_dict.get('activate_output', True)
last_layer_dropout = state_dict.get('last_layer_dropout', False)
# this is a bit verbose so leaving it commented out for the poor soul who ever has to debug this
# print(f"layer_structure: {layer_structure}")
# print(f"activation_func: {activation_func}")
# print(f"weight_init: {weight_init}")
# print(f"add_layer_norm: {add_layer_norm}")
# print(f"use_dropout: {use_dropout}")
# print(f"activate_output: {activate_output}")
# print(f"last_layer_dropout: {last_layer_dropout}")
layers = {}
for size, sd in state_dict.items():
if type(size) == int:
layers[size] = (
HypernetworkModule(size, sd[0], layer_structure, activation_func, weight_init, add_layer_norm,
use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
HypernetworkModule(size, sd[1], layer_structure, activation_func, weight_init, add_layer_norm,
use_dropout, activate_output, last_layer_dropout=last_layer_dropout),
)
print(f"hypernetwork loaded")
return layers
# overriding of original function
old_cross_attention_forward = optimizedSD.splitAttention.CrossAttention.forward
# hijacks the cross attention forward function to add hyper network support
def hijack_cross_attention():
print("hypernetwork functionality added to cross attention")
optimizedSD.splitAttention.CrossAttention.forward = new_cross_attention_forward
# there was a cop on board
def unhijack_cross_attention_forward():
print("hypernetwork functionality removed from cross attention")
optimizedSD.splitAttention.CrossAttention.forward = old_cross_attention_forward
hijack_cross_attention()

File diff suppressed because it is too large Load Diff

View File

@ -1,500 +0,0 @@
"""server.py: FastAPI SD-UI Web Host.
Notes:
async endpoints always run on the main thread. Without they run on the thread pool.
"""
import json
import traceback
import sys
import os
import socket
import picklescan.scanner
import rich
SD_DIR = os.getcwd()
print('started in ', SD_DIR)
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
sys.path.append(os.path.dirname(SD_UI_DIR))
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'))
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = {
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
'update_branch': 'main',
'ui': {
'open_browser_on_start': True,
},
}
APP_CONFIG_DEFAULT_MODELS = [
# needed to support the legacy installations
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
import logging
from typing import Any, Generator, Hashable, List, Optional, Union
from sd_internal import Request, Response, task_manager
app = FastAPI()
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
# don't show access log entries for URLs that start with the given prefix
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
class NoCacheStaticFiles(StaticFiles):
def is_not_modified(self, response_headers, request_headers) -> bool:
if 'content-type' in response_headers and ('javascript' in response_headers['content-type'] or 'css' in response_headers['content-type']):
response_headers.update(NOCACHE_HEADERS)
return False
return super().is_not_modified(response_headers, request_headers)
app.mount('/media', NoCacheStaticFiles(directory=os.path.join(SD_UI_DIR, 'media')), name="media")
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
app.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}")
def getConfig(default_val=APP_CONFIG_DEFAULTS):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return default_val
with open(config_json_path, 'r', encoding='utf-8') as f:
config = json.load(f)
if 'net' not in config:
config['net'] = {}
if os.getenv('SD_UI_BIND_PORT') is not None:
config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT'))
if os.getenv('SD_UI_BIND_IP') is not None:
config['net']['listen_to_network'] = ( os.getenv('SD_UI_BIND_IP') == '0.0.0.0' )
return config
except Exception as e:
print(str(e))
print(traceback.format_exc())
return default_val
def setConfig(config):
print( json.dumps(config) )
try: # config.json
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w', encoding='utf-8') as f:
json.dump(config, f)
except:
print(traceback.format_exc())
try: # config.bat
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
config_bat = []
if 'update_branch' in config:
config_bat.append(f"@set update_branch={config['update_branch']}")
config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}")
if len(config_bat) > 0:
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat))
except:
print(traceback.format_exc())
try: # config.sh
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
config_sh = ['#!/bin/bash']
if 'update_branch' in config:
config_sh.append(f"export update_branch={config['update_branch']}")
config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"")
if len(config_sh) > 1:
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh))
except:
print(traceback.format_exc())
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
config = getConfig()
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
if not model_name: # When None try user configured model.
# config = getConfig()
if 'model' in config and model_type in config['model']:
model_name = config['model'][model_type]
if model_name:
is_sd2 = config.get('test_sd2', False)
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
model_name = 'sd-v1-4'
# Check models directory
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
for model_extension in model_extensions:
if os.path.exists(models_dir_path + model_extension):
return models_dir_path
if os.path.exists(model_name + model_extension):
# Direct Path to file
model_name = os.path.abspath(model_name)
return model_name
# Default locations
if model_name in default_models:
default_model_path = os.path.join(SD_DIR, model_name)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
return default_model_path
# Can't find requested model, check the default paths.
for default_model in default_models:
for model_dir in model_dirs:
default_model_path = os.path.join(model_dir, default_model)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
if model_name is not None:
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path
raise Exception('No valid models found.')
def resolve_ckpt_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=APP_CONFIG_DEFAULT_MODELS)
def resolve_vae_to_use(model_name:str=None):
try:
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[])
except:
return None
def resolve_hypernetwork_to_use(model_name:str=None):
try:
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
except:
return None
class SetAppConfigRequest(BaseModel):
update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
model_vae: str = None
ui_open_browser_on_start: bool = None
listen_to_network: bool = None
listen_port: int = None
test_sd2: bool = None
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
config = getConfig()
if req.update_branch is not None:
config['update_branch'] = req.update_branch
if req.render_devices is not None:
update_render_devices_in_config(config, req.render_devices)
if req.ui_open_browser_on_start is not None:
if 'ui' not in config:
config['ui'] = {}
config['ui']['open_browser_on_start'] = req.ui_open_browser_on_start
if req.listen_to_network is not None:
if 'net' not in config:
config['net'] = {}
config['net']['listen_to_network'] = bool(req.listen_to_network)
if req.listen_port is not None:
if 'net' not in config:
config['net'] = {}
config['net']['listen_port'] = int(req.listen_port)
if req.test_sd2 is not None:
config['test_sd2'] = req.test_sd2
try:
setConfig(config)
if req.render_devices:
update_render_threads()
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def is_malicious_model(file_path):
try:
scan_result = picklescan.scanner.scan_file_path(file_path)
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return True
else:
rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return False
except Exception as e:
print('error while scanning', file_path, 'error:', e)
return False
known_models = {}
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
'vae': '',
'hypernetwork': '',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
'vae': [],
'hypernetwork': [],
},
}
def listModels(models_dirname, model_type, model_extensions):
models_dir = os.path.join(MODELS_DIR, models_dirname)
if not os.path.exists(models_dir):
os.makedirs(models_dir)
for file in os.listdir(models_dir):
for model_extension in model_extensions:
if not file.endswith(model_extension):
continue
model_path = os.path.join(models_dir, file)
mtime = os.path.getmtime(model_path)
mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime:
if is_malicious_model(model_path):
models['scan-error'] = file
return
known_models[model_path] = mtime
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates
models['options'][model_type].sort()
# custom models
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['options']['stable-diffusion'].append('custom-model')
return models
def getUIPlugins():
plugins = []
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
for file in os.listdir(plugins_dir):
if file.endswith('.plugin.js'):
plugins.append(f'/plugins/{dir_prefix}/{file}')
return plugins
def getIPConfig():
try:
ips = socket.gethostbyname_ex(socket.gethostname())
ips[2].append(ips[0])
return ips[2]
except Exception as e:
print(e)
print(traceback.format_exc())
return []
@app.get('/get/{key:path}')
def read_web_data(key:str=None):
if not key: # /get without parameters, stable-diffusion easter egg.
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == 'app_config':
config = getConfig(default_val=None)
if config is None:
config = APP_CONFIG_DEFAULTS
return JSONResponse(config, headers=NOCACHE_HEADERS)
elif key == 'system_info':
config = getConfig()
system_info = {
'devices': task_manager.get_devices(),
'hosts': getIPConfig(),
}
system_info['devices']['config'] = config.get('render_devices', "auto")
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
elif key == 'models':
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
elif key == 'ui_plugins': return JSONResponse(getUIPlugins(), headers=NOCACHE_HEADERS)
else:
raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
@app.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None):
if task_manager.is_alive() <= 0: # Check that render threads are alive.
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
raise HTTPException(status_code=500, detail='Render thread is dead.')
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive
response = {'status': str(task_manager.current_state)}
if session_id:
session = task_manager.get_cached_session(session_id, update_ttl=True)
response['tasks'] = {id(t): t.status for t in session.tasks}
response['devices'] = task_manager.get_devices()
return JSONResponse(response, headers=NOCACHE_HEADERS)
def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = ckpt_model_name
config['model']['vae'] = vae_model_name
config['model']['hypernetwork'] = hypernetwork_model_name
if vae_model_name is None or vae_model_name == "":
del config['model']['vae']
if hypernetwork_model_name is None or hypernetwork_model_name == "":
del config['model']['hypernetwork']
setConfig(config)
def update_render_devices_in_config(config, render_devices):
if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'):
raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}')
if render_devices.startswith('cuda:'):
render_devices = render_devices.split(',')
config['render_devices'] = render_devices
@app.post('/render')
def render(req : task_manager.ImageRequest):
try:
save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model)
req.use_vae_model = resolve_vae_to_use(req.use_vae_model)
req.use_hypernetwork_model = resolve_hypernetwork_to_use(req.use_hypernetwork_model)
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),
'queue': len(task_manager.tasks_queue),
'stream': f'/image/stream/{id(new_task)}',
'task': id(new_task)
}
return JSONResponse(response, headers=NOCACHE_HEADERS)
except ChildProcessError as e: # Render thread is dead
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
except Exception as e:
print(e)
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
@app.get('/image/stream/{task_id:int}')
def stream(task_id:int):
#TODO Move to WebSockets ??
task = task_manager.get_cached_task(task_id, update_ttl=True)
if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#print(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
@app.get('/image/stop')
def stop(task: int):
if not task:
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
task_manager.current_state_error = StopAsyncIteration('')
return {'OK'}
task_id = task
task = task_manager.get_cached_task(task_id, update_ttl=False)
if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration(f'Task {task_id} stop requested.')
return {'OK'}
@app.get('/image/tmp/{task_id:int}/{img_id:int}')
def get_image(task_id: int, img_id: int):
task = task_manager.get_cached_task(task_id, update_ttl=True)
if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
try:
img_data = task.temp_images[img_id]
img_data.seek(0)
return StreamingResponse(img_data, media_type='image/jpeg')
except KeyError as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get('/')
def read_root():
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
# don't log certain requests
class LogSuppressFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
path = record.getMessage()
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
if path.find(prefix) != -1:
return False
return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
# Check models and prepare cache for UI open
getModels()
# Start the task_manager
task_manager.default_model_to_load = resolve_ckpt_to_use()
task_manager.default_vae_to_load = resolve_vae_to_use()
task_manager.default_hypernetwork_to_load = resolve_hypernetwork_to_use()
def update_render_threads():
config = getConfig()
render_devices = config.get('render_devices', 'auto')
active_devices = task_manager.get_devices()['active'].keys()
print('requesting for render_devices', render_devices)
task_manager.update_render_threads(render_devices, active_devices)
update_render_threads()
# start the browser ui
def open_browser():
config = getConfig()
ui = config.get('ui', {})
net = config.get('net', {'listen_port':9000})
port = net.get('listen_port', 9000)
if ui.get('open_browser_on_start', True):
import webbrowser; webbrowser.open(f"http://localhost:{port}")
open_browser()