diff --git a/CHANGES.md b/CHANGES.md index 324bf261..9681dbe8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,22 @@ # What's new? +## v2.5 +### Major Changes +- **Nearly twice as fast** - significantly faster speed of image generation. We're now pretty close to automatic1111's speed. Code contributions are welcome to make our project even faster: https://github.com/easydiffusion/sdkit/#is-it-fast +- **Full support for Stable Diffusion 2.1** - supports loading v1.4 or v2.0 or v2.1 models seamlessly. No need to enable "Test SD2", and no need to add `sd2_` to your SD 2.0 model file names. +- **Memory optimized Stable Diffusion 2.1** - you can now use 768x768 models for SD 2.1, with the same low VRAM optimizations that we've always had for SD 1.4. +- **6 new samplers!** - explore the new samplers, some of which can generate great images in less than 10 inference steps! +- **Model Merging** - You can now merge two models (`.ckpt` or `.safetensors`) and output `.ckpt` or `.safetensors` models, optionally in `fp16` precision. Details: https://github.com/cmdr2/stable-diffusion-ui/wiki/Model-Merging +- **Fast loading/unloading of VAEs** - No longer needs to reload the entire Stable Diffusion model, each time you change the VAE +- **Database of known models** - automatically picks the right configuration for known models. E.g. we automatically detect and apply "v" parameterization (required for some SD 2.0 models), and "fp32" attention precision (required for some SD 2.1 models). +- **Color correction for img2img** - an option to preserve the color profile (histogram) of the initial image. This is especially useful if you're getting red-tinted images after inpainting/masking. +- **Three GPU Memory Usage Settings** - `High` (fastest, maximum VRAM usage), `Balanced` (default - almost as fast, significantly lower VRAM usage), `Low` (slowest, very low VRAM usage). The `Low` setting is applied automatically for GPUs with less than 4 GB of VRAM. +- **Save metadata as JSON** - You can now save the metadata files as either text or json files (choose in the Settings tab). +- **Major rewrite of the code** - Most of the codebase has been reorganized and rewritten, to make it more manageable and easier for new developers to contribute features. We've separated our core engine into a new project called `sdkit`, which allows anyone to easily integrate Stable Diffusion (and related modules like GFPGAN etc) into their programming projects (via a simple `pip install sdkit`): https://github.com/easydiffusion/sdkit/ +- **Name change** - Last, and probably the least, the UI is now called "Easy Diffusion". It indicates the focus of this project - an easy way for people to play with Stable Diffusion. + +Our focus continues to remain on an easy installation experience, and an easy user-interface. While still remaining pretty powerful, in terms of features and speed. + ## v2.4 ### Major Changes - **Allow reordering the task queue** (by dragging and dropping tasks). Thanks @madrang diff --git a/scripts/Developer Console.cmd b/scripts/Developer Console.cmd index 750e4311..921a9dca 100644 --- a/scripts/Developer Console.cmd +++ b/scripts/Developer Console.cmd @@ -23,23 +23,20 @@ call conda --version echo. -@rem activate the environment -call conda activate .\stable-diffusion\env +@rem activate the legacy environment (if present) and set PYTHONPATH +if exist "installer_files\env" ( + set PYTHONPATH=%cd%\installer_files\env\lib\site-packages +) +if exist "stable-diffusion\env" ( + call conda activate .\stable-diffusion\env + set PYTHONPATH=%cd%\stable-diffusion\env\lib\site-packages +) call where python call python --version -@rem set the PYTHONPATH -cd stable-diffusion -set SD_DIR=%cd% - -cd env\lib\site-packages -set PYTHONPATH=%SD_DIR%;%cd% -cd ..\..\.. echo PYTHONPATH=%PYTHONPATH% -cd .. - @rem done echo. diff --git a/scripts/bootstrap.bat b/scripts/bootstrap.bat index cb0909d0..c58ddcce 100644 --- a/scripts/bootstrap.bat +++ b/scripts/bootstrap.bat @@ -24,7 +24,7 @@ if exist "%INSTALL_ENV_DIR%" set PATH=%INSTALL_ENV_DIR%;%INSTALL_ENV_DIR%\Librar set PACKAGES_TO_INSTALL= if not exist "%LEGACY_INSTALL_ENV_DIR%\etc\profile.d\conda.sh" ( - if not exist "%INSTALL_ENV_DIR%\etc\profile.d\conda.sh" set PACKAGES_TO_INSTALL=%PACKAGES_TO_INSTALL% conda + if not exist "%INSTALL_ENV_DIR%\etc\profile.d\conda.sh" set PACKAGES_TO_INSTALL=%PACKAGES_TO_INSTALL% conda python=3.8.5 ) call git --version >.tmp1 2>.tmp2 diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index 6a929133..6a8df9b3 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -39,7 +39,7 @@ if [ -e "$INSTALL_ENV_DIR" ]; then export PATH="$INSTALL_ENV_DIR/bin:$PATH"; fi PACKAGES_TO_INSTALL="" -if [ ! -e "$LEGACY_INSTALL_ENV_DIR/etc/profile.d/conda.sh" ] && [ ! -e "$INSTALL_ENV_DIR/etc/profile.d/conda.sh" ]; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL conda"; fi +if [ ! -e "$LEGACY_INSTALL_ENV_DIR/etc/profile.d/conda.sh" ] && [ ! -e "$INSTALL_ENV_DIR/etc/profile.d/conda.sh" ]; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL conda python=3.8.5"; fi if ! hash "git" &>/dev/null; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL git"; fi if "$MAMBA_ROOT_PREFIX/micromamba" --version &>/dev/null; then umamba_exists="T"; fi diff --git a/scripts/check_modules.py b/scripts/check_modules.py new file mode 100644 index 00000000..416ad851 --- /dev/null +++ b/scripts/check_modules.py @@ -0,0 +1,13 @@ +''' +This script checks if the given modules exist +''' + +import sys +import pkgutil + +modules = sys.argv[1:] +missing_modules = [] +for m in modules: + if pkgutil.find_loader(m) is None: + print('module', m, 'not found') + exit(1) diff --git a/scripts/developer_console.sh b/scripts/developer_console.sh index 49e71b34..73972568 100755 --- a/scripts/developer_console.sh +++ b/scripts/developer_console.sh @@ -26,21 +26,23 @@ if [ "$0" == "bash" ]; then echo "" - # activate the environment - CONDA_BASEPATH=$(conda info --base) - source "$CONDA_BASEPATH/etc/profile.d/conda.sh" # otherwise conda complains about 'shell not initialized' (needed when running in a script) + # activate the legacy environment (if present) and set PYTHONPATH + if [ -e "installer_files/env" ]; then + export PYTHONPATH="$(pwd)/installer_files/env/lib/python3.8/site-packages" + fi + if [ -e "stable-diffusion/env" ]; then + CONDA_BASEPATH=$(conda info --base) + source "$CONDA_BASEPATH/etc/profile.d/conda.sh" # otherwise conda complains about 'shell not initialized' (needed when running in a script) - conda activate ./stable-diffusion/env + conda activate ./stable-diffusion/env + + export PYTHONPATH="$(pwd)/stable-diffusion/env/lib/python3.8/site-packages" + fi which python python --version - # set the PYTHONPATH - cd stable-diffusion - SD_PATH=`pwd` - export PYTHONPATH="$SD_PATH:$SD_PATH/env/lib/python3.8/site-packages" echo "PYTHONPATH=$PYTHONPATH" - cd .. # done diff --git a/scripts/on_env_start.bat b/scripts/on_env_start.bat index b18b4f4e..13d54161 100644 --- a/scripts/on_env_start.bat +++ b/scripts/on_env_start.bat @@ -53,6 +53,7 @@ if "%update_branch%"=="" ( @xcopy sd-ui-files\ui ui /s /i /Y /q @copy sd-ui-files\scripts\on_sd_start.bat scripts\ /Y @copy sd-ui-files\scripts\bootstrap.bat scripts\ /Y +@copy sd-ui-files\scripts\check_modules.py scripts\ /Y @copy "sd-ui-files\scripts\Start Stable Diffusion UI.cmd" . /Y @copy "sd-ui-files\scripts\Developer Console.cmd" . /Y diff --git a/scripts/on_env_start.sh b/scripts/on_env_start.sh index 185e4fd4..cac307c2 100755 --- a/scripts/on_env_start.sh +++ b/scripts/on_env_start.sh @@ -37,6 +37,7 @@ rm -rf ui cp -Rf sd-ui-files/ui . cp sd-ui-files/scripts/on_sd_start.sh scripts/ cp sd-ui-files/scripts/bootstrap.sh scripts/ +cp sd-ui-files/scripts/check_modules.py scripts/ cp sd-ui-files/scripts/start.sh . cp sd-ui-files/scripts/developer_console.sh . diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 22e5d714..60118056 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -5,11 +5,20 @@ @copy sd-ui-files\scripts\on_env_start.bat scripts\ /Y @copy sd-ui-files\scripts\bootstrap.bat scripts\ /Y +@copy sd-ui-files\scripts\check_modules.py scripts\ /Y if exist "%cd%\profile" ( set USERPROFILE=%cd%\profile ) +@rem set the correct installer path (current vs legacy) +if exist "%cd%\installer_files\env" ( + set INSTALL_ENV_DIR=%cd%\installer_files\env +) +if exist "%cd%\stable-diffusion\env" ( + set INSTALL_ENV_DIR=%cd%\stable-diffusion\env +) + @mkdir tmp @set TMP=%cd%\tmp @set TEMP=%cd%\tmp @@ -27,150 +36,92 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd" @call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');" -if NOT DEFINED test_sd2 set test_sd2=N +@rem create the stable-diffusion folder, to work with legacy installations +if not exist "stable-diffusion" mkdir stable-diffusion +cd stable-diffusion -@>nul findstr /m "sd_git_cloned" scripts\install_status.txt -@if "%ERRORLEVEL%" EQU "0" ( - @echo "Stable Diffusion's git repository was already installed. Updating.." - - @cd stable-diffusion - - @call git remote set-url origin https://github.com/easydiffusion/diffusion-kit.git - - @call git reset --hard - @call git pull - - if "%test_sd2%" == "N" ( - @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - ) - if "%test_sd2%" == "Y" ( - @call git -c advice.detachedHead=false checkout 733a1f6f9cae9b9a9b83294bf3281b123378cb1f - ) - - @cd .. -) else ( - @echo. & echo "Downloading Stable Diffusion.." & echo. - - @call git clone https://github.com/easydiffusion/diffusion-kit.git stable-diffusion && ( - @echo sd_git_cloned >> scripts\install_status.txt - ) || ( - @echo "Error downloading Stable Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" - pause - @exit /b - ) - - @cd stable-diffusion - @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - - @cd .. +@rem activate the old stable-diffusion env, if it exists +if exist "env" ( + call conda activate .\env ) -@cd stable-diffusion +@rem disable the legacy src and ldm folder (otherwise this prevents installing gfpgan and realesrgan) +if exist src rename src src-old +if exist ldm rename ldm ldm-old -@>nul findstr /m "conda_sd_env_created" ..\scripts\install_status.txt -@if "%ERRORLEVEL%" EQU "0" ( - @echo "Packages necessary for Stable Diffusion were already installed" - - @call conda activate .\env +@rem install torch and torchvision +call python ..\scripts\check_modules.py torch torchvision +if "%ERRORLEVEL%" EQU "0" ( + echo "torch and torchvision have already been installed." ) else ( - @echo. & echo "Downloading packages necessary for Stable Diffusion.." & echo. & echo "***** This will take some time (depending on the speed of the Internet connection) and may appear to be stuck, but please be patient ***** .." & echo. + echo "Installing torch and torchvision.." - @rmdir /s /q .\env + @REM prevent from using packages from the user's home directory, to avoid conflicts + set PYTHONNOUSERSITE=1 + set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages - @REM prevent conda from using packages from the user's home directory, to avoid conflicts - @set PYTHONNOUSERSITE=1 - - set USERPROFILE=%cd%\profile - - set PYTHONPATH=%cd%;%cd%\env\lib\site-packages - - @call conda env create --prefix env -f environment.yaml || ( - @echo. & echo "Error installing the packages necessary for Stable Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo. + call pip install --upgrade torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116 || ( + echo "Error installing torch. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" pause exit /b ) +) - @call conda activate .\env +@rem install/upgrade sdkit +call python ..\scripts\check_modules.py sdkit sdkit.models ldm transformers numpy antlr4 gfpgan realesrgan +if "%ERRORLEVEL%" EQU "0" ( + echo "sdkit is already installed." - for /f "tokens=*" %%a in ('python -c "import torch; import ldm; import transformers; import numpy; import antlr4; print(42)"') do if "%%a" NEQ "42" ( - @echo. & echo "Dependency test failed! Error installing the packages necessary for Stable Diffusion. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo. + @REM prevent from using packages from the user's home directory, to avoid conflicts + set PYTHONNOUSERSITE=1 + set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages + + call >nul pip install --upgrade sdkit || ( + echo "Error updating sdkit" + ) +) else ( + echo "Installing sdkit: https://pypi.org/project/sdkit/" + + @REM prevent from using packages from the user's home directory, to avoid conflicts + set PYTHONNOUSERSITE=1 + set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages + + call pip install sdkit || ( + echo "Error installing sdkit. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" pause exit /b ) - - @echo conda_sd_env_created >> ..\scripts\install_status.txt ) -@rem allow rolling back the sdkit-based changes -if exist "src-old" ( - if not exist "src" ( - rename "src-old" "src" +@rem install rich +call python ..\scripts\check_modules.py rich +if "%ERRORLEVEL%" EQU "0" ( + echo "rich has already been installed." +) else ( + echo "Installing rich.." - if exist "ldm-old" ( - rd /s /q "ldm-old" - ) + set PYTHONNOUSERSITE=1 + set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages - call pip uninstall -y sdkit stable-diffusion-sdkit + call pip install rich || ( + echo "Error installing rich. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" + pause + exit /b ) ) set PATH=C:\Windows\System32;%PATH% -@>nul findstr /m "conda_sd_gfpgan_deps_installed" ..\scripts\install_status.txt -@if "%ERRORLEVEL%" EQU "0" ( - @echo "Packages necessary for GFPGAN (Face Correction) were already installed" -) else ( - @echo. & echo "Downloading packages necessary for GFPGAN (Face Correction).." & echo. - - @set PYTHONNOUSERSITE=1 - - set USERPROFILE=%cd%\profile - - set PYTHONPATH=%cd%;%cd%\env\lib\site-packages - - for /f "tokens=*" %%a in ('python -c "from gfpgan import GFPGANer; print(42)"') do if "%%a" NEQ "42" ( - @echo. & echo "Dependency test failed! Error installing the packages necessary for GFPGAN (Face Correction). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo. - pause - exit /b - ) - - @echo conda_sd_gfpgan_deps_installed >> ..\scripts\install_status.txt -) - -@>nul findstr /m "conda_sd_esrgan_deps_installed" ..\scripts\install_status.txt -@if "%ERRORLEVEL%" EQU "0" ( - @echo "Packages necessary for ESRGAN (Resolution Upscaling) were already installed" -) else ( - @echo. & echo "Downloading packages necessary for ESRGAN (Resolution Upscaling).." & echo. - - @set PYTHONNOUSERSITE=1 - - set USERPROFILE=%cd%\profile - - set PYTHONPATH=%cd%;%cd%\env\lib\site-packages - - for /f "tokens=*" %%a in ('python -c "from basicsr.archs.rrdbnet_arch import RRDBNet; from realesrgan import RealESRGANer; print(42)"') do if "%%a" NEQ "42" ( - @echo. & echo "Dependency test failed! Error installing the packages necessary for ESRGAN (Resolution Upscaling). Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" & echo. - pause - exit /b - ) - - @echo conda_sd_esrgan_deps_installed >> ..\scripts\install_status.txt -) - -@>nul findstr /m "conda_sd_ui_deps_installed" ..\scripts\install_status.txt +call python ..\scripts\check_modules.py uvicorn fastapi @if "%ERRORLEVEL%" EQU "0" ( echo "Packages necessary for Stable Diffusion UI were already installed" ) else ( @echo. & echo "Downloading packages necessary for Stable Diffusion UI.." & echo. - @set PYTHONNOUSERSITE=1 + set PYTHONNOUSERSITE=1 + set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages - set USERPROFILE=%cd%\profile - - set PYTHONPATH=%cd%;%cd%\env\lib\site-packages - - @call conda install -c conda-forge -y --prefix env uvicorn fastapi || ( + @call conda install -c conda-forge -y uvicorn fastapi || ( echo "Error installing the packages necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" pause exit /b @@ -185,26 +136,6 @@ call WHERE uvicorn > .tmp exit /b ) -@>nul 2>nul call python -m picklescan --help -@if "%ERRORLEVEL%" NEQ "0" ( - @echo. & echo Picklescan not found. Installing - @call pip install picklescan || ( - echo "Error installing the picklescan package necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" - pause - exit /b - ) -) - -@>nul 2>nul call python -c "import safetensors" -@if "%ERRORLEVEL%" NEQ "0" ( - @echo. & echo SafeTensors not found. Installing - @call pip install safetensors || ( - echo "Error installing the safetensors package necessary for Stable Diffusion UI. Sorry about that, please try to:" & echo " 1. Run this installer again." & echo " 2. If that doesn't fix it, please try the common troubleshooting steps at https://github.com/cmdr2/stable-diffusion-ui/wiki/Troubleshooting" & echo " 3. If those steps don't help, please copy *all* the error messages in this window, and ask the community at https://discord.com/invite/u9yhsFmEkB" & echo " 4. If that doesn't solve the problem, please file an issue at https://github.com/cmdr2/stable-diffusion-ui/issues" & echo "Thanks!" - pause - exit /b - ) -) - @>nul findstr /m "conda_sd_ui_deps_installed" ..\scripts\install_status.txt @if "%ERRORLEVEL%" NEQ "0" ( @echo conda_sd_ui_deps_installed >> ..\scripts\install_status.txt @@ -212,12 +143,7 @@ call WHERE uvicorn > .tmp -if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion" if not exist "..\models\vae" mkdir "..\models\vae" -if not exist "..\models\hypernetwork" mkdir "..\models\hypernetwork" -echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt" -echo. > "..\models\vae\Put your VAE files here.txt" -echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt" @if exist "sd-v1-4.ckpt" ( for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" ( @@ -375,10 +301,6 @@ echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt" ) ) -if "%test_sd2%" == "Y" ( - @call pip install open_clip_torch==2.0.2 -) - @>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt @if "%ERRORLEVEL%" NEQ "0" ( @echo sd_weights_downloaded >> ..\scripts\install_status.txt @@ -389,10 +311,8 @@ if "%test_sd2%" == "Y" ( @set SD_DIR=%cd% -@cd env\lib\site-packages -@set PYTHONPATH=%SD_DIR%;%cd% -@cd ..\..\.. -@echo PYTHONPATH=%PYTHONPATH% +set PYTHONPATH=%INSTALL_ENV_DIR%\lib\site-packages +echo PYTHONPATH=%PYTHONPATH% call where python call python --version @@ -401,17 +321,9 @@ call python --version @set SD_UI_PATH=%cd%\ui @cd stable-diffusion -@rem -@rem Rewrite easy-install.pth. This fixes the installation if the user has relocated the SDUI installation -@rem ->env\Lib\site-packages\easy-install.pth echo %cd%\src\taming-transformers ->>env\Lib\site-packages\easy-install.pth echo %cd%\src\clip ->>env\Lib\site-packages\easy-install.pth echo %cd%\src\gfpgan ->>env\Lib\site-packages\easy-install.pth echo %cd%\src\realesrgan - @if NOT DEFINED SD_UI_BIND_PORT set SD_UI_BIND_PORT=9000 @if NOT DEFINED SD_UI_BIND_IP set SD_UI_BIND_IP=0.0.0.0 -@uvicorn server:app --app-dir "%SD_UI_PATH%" --port %SD_UI_BIND_PORT% --host %SD_UI_BIND_IP% +@uvicorn main:server_api --app-dir "%SD_UI_PATH%" --port %SD_UI_BIND_PORT% --host %SD_UI_BIND_IP% --log-level error @pause diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 353a315e..ce898be3 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -4,6 +4,7 @@ source ./scripts/functions.sh cp sd-ui-files/scripts/on_env_start.sh scripts/ cp sd-ui-files/scripts/bootstrap.sh scripts/ +cp sd-ui-files/scripts/check_modules.py scripts/ # activate the installer env CONDA_BASEPATH=$(conda info --base) @@ -21,125 +22,89 @@ python -c "import os; import shutil; frm = 'sd-ui-files/ui/hotfix/9c24e6cd9f499d # Caution, this file will make your eyes and brain bleed. It's such an unholy mess. # Note to self: Please rewrite this in Python. For the sake of your own sanity. -if [ "$test_sd2" == "" ]; then - export test_sd2="N" -fi - -if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then - echo "Stable Diffusion's git repository was already installed. Updating.." - - cd stable-diffusion - - git remote set-url origin https://github.com/easydiffusion/diffusion-kit.git - - git reset --hard - git pull - - if [ "$test_sd2" == "N" ]; then - git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - elif [ "$test_sd2" == "Y" ]; then - git -c advice.detachedHead=false checkout 733a1f6f9cae9b9a9b83294bf3281b123378cb1f - fi - - cd .. -else - printf "\n\nDownloading Stable Diffusion..\n\n" - - if git clone https://github.com/easydiffusion/diffusion-kit.git stable-diffusion ; then - echo sd_git_cloned >> scripts/install_status.txt - else - fail "git clone of basujindal/stable-diffusion.git failed" - fi - - cd stable-diffusion - git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - - cd .. +# set the correct installer path (current vs legacy) +if [ -e "installer_files/env" ]; then + export INSTALL_ENV_DIR="$(pwd)/installer_files/env" +fi +if [ -e "stable-diffusion/env" ]; then + export INSTALL_ENV_DIR="$(pwd)/stable-diffusion/env" fi +# create the stable-diffusion folder, to work with legacy installations +if [ ! -e "stable-diffusion" ]; then mkdir stable-diffusion; fi cd stable-diffusion -if [ `grep -c conda_sd_env_created ../scripts/install_status.txt` -gt "0" ]; then - echo "Packages necessary for Stable Diffusion were already installed" - +# activate the old stable-diffusion env, if it exists +if [ -e "env" ]; then conda activate ./env || fail "conda activate failed" +fi + +# disable the legacy src and ldm folder (otherwise this prevents installing gfpgan and realesrgan) +if [ -e "src" ]; then mv src src-old; fi +if [ -e "ldm" ]; then mv ldm ldm-old; fi + +# install torch and torchvision +if python ../scripts/check_modules.py torch torchvision; then + echo "torch and torchvision have already been installed." else - printf "\n\nDownloading packages necessary for Stable Diffusion..\n" - printf "\n\n***** This will take some time (depending on the speed of the Internet connection) and may appear to be stuck, but please be patient ***** ..\n\n" + echo "Installing torch and torchvision.." - # prevent conda from using packages from the user's home directory, to avoid conflicts export PYTHONNOUSERSITE=1 - export PYTHONPATH="$(pwd):$(pwd)/env/lib/site-packages" + export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" - if conda env create --prefix env --force -f environment.yaml ; then - echo "Installed. Testing.." + if pip install --upgrade torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116 ; then + echo "Installed." else - fail "'conda env create' failed" + fail "torch install failed" fi - - conda activate ./env || fail "conda activate failed" - - out_test=`python -c "import torch; import ldm; import transformers; import numpy; import antlr4; print(42)"` - if [ "$out_test" != "42" ]; then - fail "Dependency test failed" - fi - - echo conda_sd_env_created >> ../scripts/install_status.txt fi -# allow rolling back the sdkit-based changes -if [ -e "src-old" ] && [ ! -e "src" ]; then - mv src-old src - - if [ -e "ldm-old" ]; then rm -r ldm-old; fi - - pip uninstall -y sdkit stable-diffusion-sdkit -fi - -if [ `grep -c conda_sd_gfpgan_deps_installed ../scripts/install_status.txt` -gt "0" ]; then - echo "Packages necessary for GFPGAN (Face Correction) were already installed" -else - printf "\n\nDownloading packages necessary for GFPGAN (Face Correction)..\n" +# install/upgrade sdkit +if python ../scripts/check_modules.py sdkit sdkit.models ldm transformers numpy antlr4 gfpgan realesrgan ; then + echo "sdkit is already installed." export PYTHONNOUSERSITE=1 - export PYTHONPATH="$(pwd):$(pwd)/env/lib/site-packages" + export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" - out_test=`python -c "from gfpgan import GFPGANer; print(42)"` - if [ "$out_test" != "42" ]; then - echo "EE The dependency check has failed. This usually means that some system libraries are missing." - echo "EE On Debian/Ubuntu systems, this are often these packages: libsm6 libxext6 libxrender-dev" - echo "EE Other Linux distributions might have different package names for these libraries." - fail "GFPGAN dependency test failed" - fi - - echo conda_sd_gfpgan_deps_installed >> ../scripts/install_status.txt -fi - -if [ `grep -c conda_sd_esrgan_deps_installed ../scripts/install_status.txt` -gt "0" ]; then - echo "Packages necessary for ESRGAN (Resolution Upscaling) were already installed" + pip install --upgrade sdkit > /dev/null else - printf "\n\nDownloading packages necessary for ESRGAN (Resolution Upscaling)..\n" + echo "Installing sdkit: https://pypi.org/project/sdkit/" export PYTHONNOUSERSITE=1 - export PYTHONPATH="$(pwd):$(pwd)/env/lib/site-packages" + export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" - out_test=`python -c "from basicsr.archs.rrdbnet_arch import RRDBNet; from realesrgan import RealESRGANer; print(42)"` - if [ "$out_test" != "42" ]; then - fail "ESRGAN dependency test failed" + if pip install sdkit ; then + echo "Installed." + else + fail "sdkit install failed" fi - - echo conda_sd_esrgan_deps_installed >> ../scripts/install_status.txt fi -if [ `grep -c conda_sd_ui_deps_installed ../scripts/install_status.txt` -gt "0" ]; then +# install rich +if python ../scripts/check_modules.py rich; then + echo "rich has already been installed." +else + echo "Installing rich.." + + export PYTHONNOUSERSITE=1 + export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" + + if pip install rich ; then + echo "Installed." + else + fail "Install failed for rich" + fi +fi + +if python ../scripts/check_modules.py uvicorn fastapi ; then echo "Packages necessary for Stable Diffusion UI were already installed" else printf "\n\nDownloading packages necessary for Stable Diffusion UI..\n\n" export PYTHONNOUSERSITE=1 - export PYTHONPATH="$(pwd):$(pwd)/env/lib/site-packages" + export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" - if conda install -c conda-forge --prefix ./env -y uvicorn fastapi ; then + if conda install -c conda-forge -y uvicorn fastapi ; then echo "Installed. Testing.." else fail "'conda install uvicorn' failed" @@ -148,32 +113,9 @@ else if ! command -v uvicorn &> /dev/null; then fail "UI packages not found!" fi - - echo conda_sd_ui_deps_installed >> ../scripts/install_status.txt fi -if python -m picklescan --help >/dev/null 2>&1; then - echo "Picklescan is already installed." -else - echo "Picklescan not found, installing." - pip install picklescan || fail "Picklescan installation failed." -fi - -if python -c "import safetensors" --help >/dev/null 2>&1; then - echo "SafeTensors is already installed." -else - echo "SafeTensors not found, installing." - pip install safetensors || fail "SafeTensors installation failed." -fi - - - -mkdir -p "../models/stable-diffusion" mkdir -p "../models/vae" -mkdir -p "../models/hypernetwork" -echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt" -echo "" > "../models/vae/Put your VAE files here.txt" -echo "" > "../models/hypernetwork/Put your hypernetwork files here.txt" if [ -f "sd-v1-4.ckpt" ]; then model_size=`find "sd-v1-4.ckpt" -printf "%s"` @@ -314,10 +256,6 @@ if [ ! -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then fi fi -if [ "$test_sd2" == "Y" ]; then - pip install open_clip_torch==2.0.2 -fi - if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then echo sd_weights_downloaded >> ../scripts/install_status.txt echo sd_install_complete >> ../scripts/install_status.txt @@ -326,7 +264,8 @@ fi printf "\n\nStable Diffusion is ready!\n\n" SD_PATH=`pwd` -export PYTHONPATH="$SD_PATH:$SD_PATH/env/lib/python3.8/site-packages" + +export PYTHONPATH="$INSTALL_ENV_DIR/lib/python3.8/site-packages" echo "PYTHONPATH=$PYTHONPATH" which python @@ -336,6 +275,6 @@ cd .. export SD_UI_PATH=`pwd`/ui cd stable-diffusion -uvicorn server:app --app-dir "$SD_UI_PATH" --port ${SD_UI_BIND_PORT:-9000} --host ${SD_UI_BIND_IP:-0.0.0.0} +uvicorn main:server_api --app-dir "$SD_UI_PATH" --port ${SD_UI_BIND_PORT:-9000} --host ${SD_UI_BIND_IP:-0.0.0.0} --log-level error read -p "Press any key to continue" diff --git a/ui/easydiffusion/__init__.py b/ui/easydiffusion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py new file mode 100644 index 00000000..2a8d0804 --- /dev/null +++ b/ui/easydiffusion/app.py @@ -0,0 +1,165 @@ +import os +import socket +import sys +import json +import traceback +import logging +from rich.logging import RichHandler + +from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config + +from easydiffusion import task_manager +from easydiffusion.utils import log + +# Remove all handlers associated with the root logger object. +for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + +LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s' +logging.basicConfig( + level=logging.INFO, + format=LOG_FORMAT, + datefmt="%X", + handlers=[RichHandler(markup=True, rich_tracebacks=True, show_time=False, show_level=False)] +) + +SD_DIR = os.getcwd() + +SD_UI_DIR = os.getenv('SD_UI_PATH', None) +sys.path.append(os.path.dirname(SD_UI_DIR)) + +CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) +MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) + +USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui')) +CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui')) +UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user')) + +OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder +TASK_TTL = 15 * 60 # Discard last session's task timeout +APP_CONFIG_DEFAULTS = { + # auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device. + 'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index) + 'update_branch': 'main', + 'ui': { + 'open_browser_on_start': True, + }, +} + +def init(): + os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) + + update_render_threads() + +def getConfig(default_val=APP_CONFIG_DEFAULTS): + try: + config_json_path = os.path.join(CONFIG_DIR, 'config.json') + if not os.path.exists(config_json_path): + return default_val + with open(config_json_path, 'r', encoding='utf-8') as f: + config = json.load(f) + if 'net' not in config: + config['net'] = {} + if os.getenv('SD_UI_BIND_PORT') is not None: + config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT')) + if os.getenv('SD_UI_BIND_IP') is not None: + config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0') + return config + except Exception as e: + log.warn(traceback.format_exc()) + return default_val + +def setConfig(config): + try: # config.json + config_json_path = os.path.join(CONFIG_DIR, 'config.json') + with open(config_json_path, 'w', encoding='utf-8') as f: + json.dump(config, f) + except: + log.error(traceback.format_exc()) + + try: # config.bat + config_bat_path = os.path.join(CONFIG_DIR, 'config.bat') + config_bat = [] + + if 'update_branch' in config: + config_bat.append(f"@set update_branch={config['update_branch']}") + + config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}") + bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' + config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") + + if len(config_bat) > 0: + with open(config_bat_path, 'w', encoding='utf-8') as f: + f.write('\r\n'.join(config_bat)) + except: + log.error(traceback.format_exc()) + + try: # config.sh + config_sh_path = os.path.join(CONFIG_DIR, 'config.sh') + config_sh = ['#!/bin/bash'] + + if 'update_branch' in config: + config_sh.append(f"export update_branch={config['update_branch']}") + + config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}") + bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' + config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") + + if len(config_sh) > 1: + with open(config_sh_path, 'w', encoding='utf-8') as f: + f.write('\n'.join(config_sh)) + except: + log.error(traceback.format_exc()) + +def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level): + config = getConfig() + if 'model' not in config: + config['model'] = {} + + config['model']['stable-diffusion'] = ckpt_model_name + config['model']['vae'] = vae_model_name + config['model']['hypernetwork'] = hypernetwork_model_name + + if vae_model_name is None or vae_model_name == "": + del config['model']['vae'] + if hypernetwork_model_name is None or hypernetwork_model_name == "": + del config['model']['hypernetwork'] + + config['vram_usage_level'] = vram_usage_level + + setConfig(config) + +def update_render_threads(): + config = getConfig() + render_devices = config.get('render_devices', 'auto') + active_devices = task_manager.get_devices()['active'].keys() + + log.debug(f'requesting for render_devices: {render_devices}') + task_manager.update_render_threads(render_devices, active_devices) + +def getUIPlugins(): + plugins = [] + + for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES: + for file in os.listdir(plugins_dir): + if file.endswith('.plugin.js'): + plugins.append(f'/plugins/{dir_prefix}/{file}') + + return plugins + +def getIPConfig(): + try: + ips = socket.gethostbyname_ex(socket.gethostname()) + ips[2].append(ips[0]) + return ips[2] + except Exception as e: + log.exception(e) + return [] + +def open_browser(): + config = getConfig() + ui = config.get('ui', {}) + net = config.get('net', {'listen_port':9000}) + port = net.get('listen_port', 9000) + if ui.get('open_browser_on_start', True): + import webbrowser; webbrowser.open(f"http://localhost:{port}") diff --git a/ui/sd_internal/device_manager.py b/ui/easydiffusion/device_manager.py similarity index 73% rename from ui/sd_internal/device_manager.py rename to ui/easydiffusion/device_manager.py index d2c6430b..252aeee4 100644 --- a/ui/sd_internal/device_manager.py +++ b/ui/easydiffusion/device_manager.py @@ -3,6 +3,15 @@ import torch import traceback import re +from easydiffusion.utils import log + +''' +Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32). +Otherwise the models will load at half-precision (i.e. float16). + +Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs). +''' + COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked mem_free_threshold = 0 @@ -34,7 +43,7 @@ def get_device_delta(render_devices, active_devices): if 'auto' in render_devices: render_devices = auto_pick_devices(active_devices) if 'cpu' in render_devices: - print('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!') + log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!') active_devices = set(active_devices) render_devices = set(render_devices) @@ -53,7 +62,7 @@ def auto_pick_devices(currently_active_devices): if device_count == 1: return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu'] - print('Autoselecting GPU. Using most free memory.') + log.debug('Autoselecting GPU. Using most free memory.') devices = [] for device in range(device_count): device = f'cuda:{device}' @@ -64,7 +73,7 @@ def auto_pick_devices(currently_active_devices): mem_free /= float(10**9) mem_total /= float(10**9) device_name = torch.cuda.get_device_name(device) - print(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb') + log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb') devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free}) devices.sort(key=lambda x:x['mem_free'], reverse=True) @@ -82,7 +91,7 @@ def auto_pick_devices(currently_active_devices): devices = list(map(lambda x: x['device'], devices)) return devices -def device_init(thread_data, device): +def device_init(context, device): ''' This function assumes the 'device' has already been verified to be compatible. `get_device_delta()` has already filtered out incompatible devices. @@ -91,27 +100,45 @@ def device_init(thread_data, device): validate_device_id(device, log_prefix='device_init') if device == 'cpu': - thread_data.device = 'cpu' - thread_data.device_name = get_processor_name() - print('Render device CPU available as', thread_data.device_name) + context.device = 'cpu' + context.device_name = get_processor_name() + context.half_precision = False + log.debug(f'Render device CPU available as {context.device_name}') return - thread_data.device_name = torch.cuda.get_device_name(device) - thread_data.device = device + context.device_name = torch.cuda.get_device_name(device) + context.device = device # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images - device_name = thread_data.device_name.lower() - thread_data.force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) - if thread_data.force_full_precision: - print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', thread_data.device_name) + if needs_to_force_full_precision(context): + log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') # Apply force_full_precision now before models are loaded. - thread_data.precision = 'full' + context.half_precision = False - print(f'Setting {device} as active') + log.info(f'Setting {device} as active, with precision: {"half" if context.half_precision else "full"}') torch.cuda.device(device) return +def needs_to_force_full_precision(context): + if 'FORCE_FULL_PRECISION' in os.environ: + return True + + device_name = context.device_name.lower() + return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) + +def get_max_vram_usage_level(device): + if device != 'cpu': + _, mem_total = torch.cuda.mem_get_info(device) + mem_total /= float(10**9) + + if mem_total < 4.5: + return 'low' + elif mem_total < 6.5: + return 'balanced' + + return 'high' + def validate_device_id(device, log_prefix=''): def is_valid(): if not isinstance(device, str): @@ -132,7 +159,7 @@ def is_device_compatible(device): try: validate_device_id(device, log_prefix='is_device_compatible') except: - print(str(e)) + log.error(str(e)) return False if device == 'cpu': return True @@ -141,10 +168,10 @@ def is_device_compatible(device): _, mem_total = torch.cuda.mem_get_info(device) mem_total /= float(10**9) if mem_total < 3.0: - print(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion') + log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion') return False except RuntimeError as e: - print(str(e)) + log.error(str(e)) return False return True @@ -164,5 +191,5 @@ def get_processor_name(): if "model name" in line: return re.sub(".*model name.*:", "", line, 1).strip() except: - print(traceback.format_exc()) + log.error(traceback.format_exc()) return "cpu" diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py new file mode 100644 index 00000000..e08328fe --- /dev/null +++ b/ui/easydiffusion/model_manager.py @@ -0,0 +1,223 @@ +import os + +from easydiffusion import app, device_manager +from easydiffusion.types import TaskData +from easydiffusion.utils import log + +from sdkit import Context +from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model +from sdkit.utils import hash_file_quick + +KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] +MODEL_EXTENSIONS = { + 'stable-diffusion': ['.ckpt', '.safetensors'], + 'vae': ['.vae.pt', '.ckpt', '.safetensors'], + 'hypernetwork': ['.pt', '.safetensors'], + 'gfpgan': ['.pth'], + 'realesrgan': ['.pth'], +} +DEFAULT_MODELS = { + 'stable-diffusion': [ # needed to support the legacy installations + 'custom-model', # only one custom model file was supported initially, creatively named 'custom-model' + 'sd-v1-4', # Default fallback. + ], + 'gfpgan': ['GFPGANv1.3'], + 'realesrgan': ['RealESRGAN_x4plus'], +} +VRAM_USAGE_LEVEL_TO_OPTIMIZATIONS = { + 'balanced': {'KEEP_FS_AND_CS_IN_CPU', 'SET_ATTENTION_STEP_TO_4'}, + 'low': {'KEEP_ENTIRE_MODEL_IN_CPU'}, + 'high': {}, +} +MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork'] + +known_models = {} + +def init(): + make_model_folders() + getModels() # run this once, to cache the picklescan results + +def load_default_models(context: Context): + set_vram_optimizations(context) + + # init default model paths + for model_type in MODELS_TO_LOAD_ON_START: + context.model_paths[model_type] = resolve_model_to_use(model_type=model_type) + load_model(context, model_type) + +def unload_all(context: Context): + for model_type in KNOWN_MODEL_TYPES: + unload_model(context, model_type) + +def resolve_model_to_use(model_name:str=None, model_type:str=None): + model_extensions = MODEL_EXTENSIONS.get(model_type, []) + default_models = DEFAULT_MODELS.get(model_type, []) + config = app.getConfig() + + model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR] + if not model_name: # When None try user configured model. + # config = getConfig() + if 'model' in config and model_type in config['model']: + model_name = config['model'][model_type] + + if model_name: + # Check models directory + models_dir_path = os.path.join(app.MODELS_DIR, model_type, model_name) + for model_extension in model_extensions: + if os.path.exists(models_dir_path + model_extension): + return models_dir_path + model_extension + if os.path.exists(model_name + model_extension): + return os.path.abspath(model_name + model_extension) + + # Default locations + if model_name in default_models: + default_model_path = os.path.join(app.SD_DIR, model_name) + for model_extension in model_extensions: + if os.path.exists(default_model_path + model_extension): + return default_model_path + model_extension + + # Can't find requested model, check the default paths. + for default_model in default_models: + for model_dir in model_dirs: + default_model_path = os.path.join(model_dir, default_model) + for model_extension in model_extensions: + if os.path.exists(default_model_path + model_extension): + if model_name is not None: + log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') + return default_model_path + model_extension + + return None + +def reload_models_if_necessary(context: Context, task_data: TaskData): + model_paths_in_req = { + 'stable-diffusion': task_data.use_stable_diffusion_model, + 'vae': task_data.use_vae_model, + 'hypernetwork': task_data.use_hypernetwork_model, + 'gfpgan': task_data.use_face_correction, + 'realesrgan': task_data.use_upscale, + } + models_to_reload = {model_type: path for model_type, path in model_paths_in_req.items() if context.model_paths.get(model_type) != path} + + if set_vram_optimizations(context): # reload SD + models_to_reload['stable-diffusion'] = model_paths_in_req['stable-diffusion'] + + if 'stable-diffusion' in models_to_reload: + quick_hash = hash_file_quick(models_to_reload['stable-diffusion']) + known_model_info = get_model_info_from_db(quick_hash=quick_hash) + + for model_type, model_path_in_req in models_to_reload.items(): + context.model_paths[model_type] = model_path_in_req + + action_fn = unload_model if context.model_paths[model_type] is None else load_model + action_fn(context, model_type, scan_model=False) # we've scanned them already + +def resolve_model_paths(task_data: TaskData): + task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion') + task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae') + task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork') + + if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan') + if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan') + +def set_vram_optimizations(context: Context): + config = app.getConfig() + + max_usage_level = device_manager.get_max_vram_usage_level(context.device) + vram_usage_level = config.get('vram_usage_level', 'balanced') + + v = {'low': 0, 'balanced': 1, 'high': 2} + if v[vram_usage_level] > v[max_usage_level]: + log.error(f'Requested GPU Memory Usage level ({vram_usage_level}) is higher than what is ' + \ + f'possible ({max_usage_level}) on this device ({context.device}). Using "{max_usage_level}" instead') + vram_usage_level = max_usage_level + + vram_optimizations = VRAM_USAGE_LEVEL_TO_OPTIMIZATIONS[vram_usage_level] + + if vram_optimizations != context.vram_optimizations: + context.vram_optimizations = vram_optimizations + return True + + return False + +def make_model_folders(): + for model_type in KNOWN_MODEL_TYPES: + model_dir_path = os.path.join(app.MODELS_DIR, model_type) + + os.makedirs(model_dir_path, exist_ok=True) + + help_file_name = f'Place your {model_type} model files here.txt' + help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}' + + with open(os.path.join(model_dir_path, help_file_name), 'w', encoding='utf-8') as f: + f.write(help_file_contents) + +def is_malicious_model(file_path): + try: + scan_result = scan_model(file_path) + if scan_result.issues_count > 0 or scan_result.infected_files > 0: + log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + return True + else: + log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + return False + except Exception as e: + log.error(f'error while scanning: {file_path}, error: {e}') + return False + +def getModels(): + models = { + 'active': { + 'stable-diffusion': 'sd-v1-4', + 'vae': '', + 'hypernetwork': '', + }, + 'options': { + 'stable-diffusion': ['sd-v1-4'], + 'vae': [], + 'hypernetwork': [], + }, + } + + models_scanned = 0 + def listModels(model_type): + nonlocal models_scanned + + model_extensions = MODEL_EXTENSIONS.get(model_type, []) + models_dir = os.path.join(app.MODELS_DIR, model_type) + if not os.path.exists(models_dir): + os.makedirs(models_dir) + + for file in os.listdir(models_dir): + for model_extension in model_extensions: + if not file.endswith(model_extension): + continue + + model_path = os.path.join(models_dir, file) + mtime = os.path.getmtime(model_path) + mod_time = known_models[model_path] if model_path in known_models else -1 + if mod_time != mtime: + models_scanned += 1 + if is_malicious_model(model_path): + models['scan-error'] = file + return + known_models[model_path] = mtime + + model_name = file[:-len(model_extension)] + models['options'][model_type].append(model_name) + + models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates + models['options'][model_type].sort() + + # custom models + listModels(model_type='stable-diffusion') + listModels(model_type='vae') + listModels(model_type='hypernetwork') + + if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. Nothing infected[/]') + + # legacy + custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') + if os.path.exists(custom_weight_path): + models['options']['stable-diffusion'].append('custom-model') + + return models diff --git a/ui/easydiffusion/renderer.py b/ui/easydiffusion/renderer.py new file mode 100644 index 00000000..b55f3667 --- /dev/null +++ b/ui/easydiffusion/renderer.py @@ -0,0 +1,124 @@ +import queue +import time +import json + +from easydiffusion import device_manager +from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop, GenerateImageRequest +from easydiffusion.utils import get_printable_request, save_images_to_disk, log + +from sdkit import Context +from sdkit.generate import generate_images +from sdkit.filter import apply_filters +from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc + +context = Context() # thread-local +''' +runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc +''' + +def init(device): + ''' + Initializes the fields that will be bound to this runtime's context, and sets the current torch device + ''' + context.stop_processing = False + context.temp_images = {} + context.partial_x_samples = None + + device_manager.device_init(context, device) + +def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): + context.stop_processing = False + log.info(f'request: {get_printable_request(req)}') + log.info(f'task data: {task_data.dict()}') + + images = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) + + res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed)) + res = res.json() + data_queue.put(json.dumps(res)) + log.info('Task completed') + + return res + +def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): + images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) + filtered_images = filter_images(task_data, images, user_stopped) + + if task_data.save_to_disk_path is not None: + save_images_to_disk(images, filtered_images, req, task_data) + + return filtered_images if task_data.show_only_filtered_image else images + filtered_images + +def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): + context.temp_images.clear() + + callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress) + + try: + images = generate_images(context, callback=callback, **req.dict()) + user_stopped = False + except UserInitiatedStop: + images = [] + user_stopped = True + if context.partial_x_samples is not None: + images = latent_samples_to_images(context, context.partial_x_samples) + context.partial_x_samples = None + finally: + gc(context) + + return images, user_stopped + +def filter_images(task_data: TaskData, images: list, user_stopped): + if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None): + return images + + filters_to_apply = [] + if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan') + if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan') + + return apply_filters(context, filters_to_apply, images) + +def construct_response(images: list, task_data: TaskData, base_seed: int): + return [ + ResponseImage( + data=img_to_base64_str(img, task_data.output_format, task_data.output_quality), + seed=base_seed + i + ) for i, img in enumerate(images) + ] + +def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): + n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) + last_callback_time = -1 + + def update_temp_img(x_samples, task_temp_images: list): + partial_images = [] + images = latent_samples_to_images(context, x_samples) + for i, img in enumerate(images): + buf = img_to_buffer(img, output_format='JPEG') + + context.temp_images[f"{task_data.request_id}/{i}"] = buf + task_temp_images[i] = buf + partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"}) + del images + return partial_images + + def on_image_step(x_samples, i): + nonlocal last_callback_time + + context.partial_x_samples = x_samples + step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 + last_callback_time = time.time() + + progress = {"step": i, "step_time": step_time, "total_steps": n_steps} + + if stream_image_progress and i % 5 == 0: + progress['output'] = update_temp_img(x_samples, task_temp_images) + + data_queue.put(json.dumps(progress)) + + step_callback() + + if context.stop_processing: + raise UserInitiatedStop("User requested that we stop processing") + + return on_image_step diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py new file mode 100644 index 00000000..56535f1f --- /dev/null +++ b/ui/easydiffusion/server.py @@ -0,0 +1,219 @@ +"""server.py: FastAPI SD-UI Web Host. +Notes: + async endpoints always run on the main thread. Without they run on the thread pool. +""" +import os +import traceback +import datetime +from typing import List, Union + +from fastapi import FastAPI, HTTPException +from fastapi.staticfiles import StaticFiles +from starlette.responses import FileResponse, JSONResponse, StreamingResponse +from pydantic import BaseModel + +from easydiffusion import app, model_manager, task_manager +from easydiffusion.types import TaskData, GenerateImageRequest +from easydiffusion.utils import log + +log.info(f'started in {app.SD_DIR}') +log.info(f'started at {datetime.datetime.now():%x %X}') + +server_api = FastAPI() + +NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} + +class NoCacheStaticFiles(StaticFiles): + def is_not_modified(self, response_headers, request_headers) -> bool: + if 'content-type' in response_headers and ('javascript' in response_headers['content-type'] or 'css' in response_headers['content-type']): + response_headers.update(NOCACHE_HEADERS) + return False + + return super().is_not_modified(response_headers, request_headers) + +class SetAppConfigRequest(BaseModel): + update_branch: str = None + render_devices: Union[List[str], List[int], str, int] = None + model_vae: str = None + ui_open_browser_on_start: bool = None + listen_to_network: bool = None + listen_port: int = None + +def init(): + server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media") + + for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES: + server_api.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}") + + @server_api.post('/app_config') + async def set_app_config(req : SetAppConfigRequest): + return set_app_config_internal(req) + + @server_api.get('/get/{key:path}') + def read_web_data(key:str=None): + return read_web_data_internal(key) + + @server_api.get('/ping') # Get server and optionally session status. + def ping(session_id:str=None): + return ping_internal(session_id) + + @server_api.post('/render') + def render(req: dict): + return render_internal(req) + + @server_api.get('/image/stream/{task_id:int}') + def stream(task_id:int): + return stream_internal(task_id) + + @server_api.get('/image/stop') + def stop(task: int): + return stop_internal(task) + + @server_api.get('/image/tmp/{task_id:int}/{img_id:int}') + def get_image(task_id: int, img_id: int): + return get_image_internal(task_id, img_id) + + @server_api.get('/') + def read_root(): + return FileResponse(os.path.join(app.SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) + + @server_api.on_event("shutdown") + def shutdown_event(): # Signal render thread to close on shutdown + task_manager.current_state_error = SystemExit('Application shutting down.') + +# API implementations +def set_app_config_internal(req : SetAppConfigRequest): + config = app.getConfig() + if req.update_branch is not None: + config['update_branch'] = req.update_branch + if req.render_devices is not None: + update_render_devices_in_config(config, req.render_devices) + if req.ui_open_browser_on_start is not None: + if 'ui' not in config: + config['ui'] = {} + config['ui']['open_browser_on_start'] = req.ui_open_browser_on_start + if req.listen_to_network is not None: + if 'net' not in config: + config['net'] = {} + config['net']['listen_to_network'] = bool(req.listen_to_network) + if req.listen_port is not None: + if 'net' not in config: + config['net'] = {} + config['net']['listen_port'] = int(req.listen_port) + try: + app.setConfig(config) + + if req.render_devices: + app.update_render_threads() + + return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS) + except Exception as e: + log.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + +def update_render_devices_in_config(config, render_devices): + if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'): + raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}') + + if render_devices.startswith('cuda:'): + render_devices = render_devices.split(',') + + config['render_devices'] = render_devices + +def read_web_data_internal(key:str=None): + if not key: # /get without parameters, stable-diffusion easter egg. + raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot + elif key == 'app_config': + return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS) + elif key == 'system_info': + config = app.getConfig() + system_info = { + 'devices': task_manager.get_devices(), + 'hosts': app.getIPConfig(), + 'default_output_dir': os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME), + } + system_info['devices']['config'] = config.get('render_devices', "auto") + return JSONResponse(system_info, headers=NOCACHE_HEADERS) + elif key == 'models': + return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS) + elif key == 'modifiers': return FileResponse(os.path.join(app.SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) + elif key == 'ui_plugins': return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS) + else: + raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found + +def ping_internal(session_id:str=None): + if task_manager.is_alive() <= 0: # Check that render threads are alive. + if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) + raise HTTPException(status_code=500, detail='Render thread is dead.') + if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) + # Alive + response = {'status': str(task_manager.current_state)} + if session_id: + session = task_manager.get_cached_session(session_id, update_ttl=True) + response['tasks'] = {id(t): t.status for t in session.tasks} + response['devices'] = task_manager.get_devices() + return JSONResponse(response, headers=NOCACHE_HEADERS) + +def render_internal(req: dict): + try: + # separate out the request data into rendering and task-specific data + render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req) + task_data: TaskData = TaskData.parse_obj(req) + + render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision + + app.save_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model, task_data.vram_usage_level) + + # enqueue the task + new_task = task_manager.render(render_req, task_data) + response = { + 'status': str(task_manager.current_state), + 'queue': len(task_manager.tasks_queue), + 'stream': f'/image/stream/{id(new_task)}', + 'task': id(new_task) + } + return JSONResponse(response, headers=NOCACHE_HEADERS) + except ChildProcessError as e: # Render thread is dead + raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error + except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. + raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable + except Exception as e: + log.error(traceback.format_exc()) + raise HTTPException(status_code=500, detail=str(e)) + +def stream_internal(task_id:int): + #TODO Move to WebSockets ?? + task = task_manager.get_cached_task(task_id, update_ttl=True) + if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound + #if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict + if task.buffer_queue.empty() and not task.lock.locked(): + if task.response: + #log.info(f'Session {session_id} sending cached response') + return JSONResponse(task.response, headers=NOCACHE_HEADERS) + raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early + #log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') + return StreamingResponse(task.read_buffer_generator(), media_type='application/json') + +def stop_internal(task: int): + if not task: + if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable: + raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict + task_manager.current_state_error = StopAsyncIteration('') + return {'OK'} + task_id = task + task = task_manager.get_cached_task(task_id, update_ttl=False) + if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found + if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict + task.error = StopAsyncIteration(f'Task {task_id} stop requested.') + return {'OK'} + +def get_image_internal(task_id: int, img_id: int): + task = task_manager.get_cached_task(task_id, update_ttl=True) + if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound + if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early + try: + img_data = task.temp_images[img_id] + img_data.seek(0) + return StreamingResponse(img_data, media_type='image/jpeg') + except KeyError as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/ui/sd_internal/task_manager.py b/ui/easydiffusion/task_manager.py similarity index 65% rename from ui/sd_internal/task_manager.py rename to ui/easydiffusion/task_manager.py index a8c5a4b2..3a764137 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/easydiffusion/task_manager.py @@ -11,12 +11,13 @@ TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout import torch import queue, threading, time, weakref -from typing import Any, Generator, Hashable, Optional, Union +from typing import Any, Hashable -from pydantic import BaseModel -from sd_internal import Request, Response, runtime, device_manager +from easydiffusion import device_manager +from easydiffusion.types import TaskData, GenerateImageRequest +from easydiffusion.utils import log -THREAD_NAME_PREFIX = 'Runtime-Render/' +THREAD_NAME_PREFIX = '' ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. @@ -36,12 +37,13 @@ class ServerStates: class Unavailable(Symbol): pass class RenderTask(): # Task with output queue and completion lock. - def __init__(self, req: Request): - req.request_id = id(self) - self.request: Request = req # Initial Request + def __init__(self, req: GenerateImageRequest, task_data: TaskData): + task_data.request_id = id(self) + self.render_request: GenerateImageRequest = req # Initial Request + self.task_data: TaskData = task_data self.response: Any = None # Copy of the last reponse self.render_device = None # Select the task affinity. (Not used to change active devices). - self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) + self.temp_images:list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2) self.error: Exception = None self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments @@ -69,54 +71,6 @@ class RenderTask(): # Task with output queue and completion lock. def is_pending(self): return bool(not self.response and not self.error) -# defaults from https://huggingface.co/blog/stable_diffusion -class ImageRequest(BaseModel): - session_id: str = "session" - prompt: str = "" - negative_prompt: str = "" - init_image: str = None # base64 - mask: str = None # base64 - num_outputs: int = 1 - num_inference_steps: int = 50 - guidance_scale: float = 7.5 - width: int = 512 - height: int = 512 - seed: int = 42 - prompt_strength: float = 0.8 - sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" - # allow_nsfw: bool = False - save_to_disk_path: str = None - turbo: bool = True - use_cpu: bool = False ##TODO Remove after UI and plugins transition. - render_device: str = None # Select the task affinity. (Not used to change active devices). - use_full_precision: bool = False - use_face_correction: str = None # or "GFPGANv1.3" - use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" - use_stable_diffusion_model: str = "sd-v1-4" - use_vae_model: str = None - use_hypernetwork_model: str = None - hypernetwork_strength: float = None - show_only_filtered_image: bool = False - output_format: str = "jpeg" # or "png" - output_quality: int = 75 - - stream_progress_updates: bool = False - stream_image_progress: bool = False - -class FilterRequest(BaseModel): - session_id: str = "session" - model: str = None - name: str = "" - init_image: str = None # base64 - width: int = 512 - height: int = 512 - save_to_disk_path: str = None - turbo: bool = True - render_device: str = None - use_full_precision: bool = False - output_format: str = "jpeg" # or "png" - output_quality: int = 75 - # Temporary cache to allow to query tasks results for a short time after they are completed. class DataCache(): def __init__(self): @@ -139,11 +93,11 @@ class DataCache(): for key in to_delete: (_, val) = self._base[key] if isinstance(val, RenderTask): - print(f'RenderTask {key} expired. Data removed.') + log.debug(f'RenderTask {key} expired. Data removed.') elif isinstance(val, SessionState): - print(f'Session {key} expired. Data removed.') + log.debug(f'Session {key} expired. Data removed.') else: - print(f'Key {key} expired. Data removed.') + log.debug(f'Key {key} expired. Data removed.') del self._base[key] finally: self._lock.release() @@ -177,8 +131,7 @@ class DataCache(): self._get_ttl_time(ttl), value ) except Exception as e: - print(str(e)) - print(traceback.format_exc()) + log.error(traceback.format_exc()) return False else: return True @@ -189,7 +142,7 @@ class DataCache(): try: ttl, value = self._base.get(key, (None, None)) if ttl is not None and self._is_expired(ttl): - print(f'Session {key} expired. Discarding data.') + log.debug(f'Session {key} expired. Discarding data.') del self._base[key] return None return value @@ -200,15 +153,9 @@ manager_lock = threading.RLock() render_threads = [] current_state = ServerStates.Init current_state_error:Exception = None -current_model_path = None -current_vae_path = None -current_hypernetwork_path = None tasks_queue = [] session_cache = DataCache() task_cache = DataCache() -default_model_to_load = None -default_vae_to_load = None -default_hypernetwork_to_load = None weak_thread_data = weakref.WeakKeyDictionary() idle_event: threading.Event = threading.Event() @@ -236,40 +183,10 @@ class SessionState(): self._tasks_ids.pop(0) return True -def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None): - global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path - if ckpt_file_path == None: - ckpt_file_path = default_model_to_load - if vae_file_path == None: - vae_file_path = default_vae_to_load - if hypernetwork_file_path == None: - hypernetwork_file_path = default_hypernetwork_to_load - if ckpt_file_path == current_model_path and vae_file_path == current_vae_path: - return - current_state = ServerStates.LoadingModel - try: - from . import runtime - runtime.thread_data.hypernetwork_file = hypernetwork_file_path - runtime.thread_data.ckpt_file = ckpt_file_path - runtime.thread_data.vae_file = vae_file_path - runtime.load_model_ckpt() - runtime.load_hypernetwork() - current_model_path = ckpt_file_path - current_vae_path = vae_file_path - current_hypernetwork_path = hypernetwork_file_path - current_state_error = None - current_state = ServerStates.Online - except Exception as e: - current_model_path = None - current_vae_path = None - current_state_error = e - current_state = ServerStates.Unavailable - print(traceback.format_exc()) - def thread_get_next_task(): - from . import runtime + from easydiffusion import renderer if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): - print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.') + log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.') return None if len(tasks_queue) <= 0: manager_lock.release() @@ -277,7 +194,7 @@ def thread_get_next_task(): task = None try: # Select a render task. for queued_task in tasks_queue: - if queued_task.render_device and runtime.thread_data.device != queued_task.render_device: + if queued_task.render_device and renderer.context.device != queued_task.render_device: # Is asking for a specific render device. if is_alive(queued_task.render_device) > 0: continue # requested device alive, skip current one. @@ -286,7 +203,7 @@ def thread_get_next_task(): queued_task.error = Exception(queued_task.render_device + ' is not currently active.') task = queued_task break - if not queued_task.render_device and runtime.thread_data.device == 'cpu' and is_alive() > 1: + if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1: # not asking for any specific devices, cpu want to grab task but other render devices are alive. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. task = queued_task @@ -298,31 +215,36 @@ def thread_get_next_task(): manager_lock.release() def thread_render(device): - global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path - from . import runtime + global current_state, current_state_error + + from easydiffusion import renderer, model_manager try: - runtime.thread_init(device) - except Exception as e: - print(traceback.format_exc()) + renderer.init(device) + weak_thread_data[threading.current_thread()] = { - 'error': e + 'device': renderer.context.device, + 'device_name': renderer.context.device_name, + 'alive': True + } + + current_state = ServerStates.LoadingModel + model_manager.load_default_models(renderer.context) + + current_state = ServerStates.Online + except Exception as e: + log.error(traceback.format_exc()) + weak_thread_data[threading.current_thread()] = { + 'error': e, + 'alive': False } return - weak_thread_data[threading.current_thread()] = { - 'device': runtime.thread_data.device, - 'device_name': runtime.thread_data.device_name, - 'alive': True - } - if runtime.thread_data.device != 'cpu' or is_alive() == 1: - preload_model() - current_state = ServerStates.Online + while True: session_cache.clean() task_cache.clean() if not weak_thread_data[threading.current_thread()]['alive']: - print(f'Shutting down thread for device {runtime.thread_data.device}') - runtime.unload_models() - runtime.unload_filters() + log.info(f'Shutting down thread for device {renderer.context.device}') + model_manager.unload_all(renderer.context) return if isinstance(current_state_error, SystemExit): current_state = ServerStates.Unavailable @@ -333,7 +255,7 @@ def thread_render(device): idle_event.wait(timeout=1) continue if task.error is not None: - print(task.error) + log.error(task.error) task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue @@ -342,51 +264,45 @@ def thread_render(device): task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) continue - print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}') + log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: - if runtime.is_hypernetwork_reload_necessary(task.request): - runtime.reload_hypernetwork() - current_hypernetwork_path = task.request.use_hypernetwork_model - - if runtime.is_model_reload_necessary(task.request): - current_state = ServerStates.LoadingModel - runtime.reload_model() - current_model_path = task.request.use_stable_diffusion_model - current_vae_path = task.request.use_vae_model - def step_callback(): global current_state_error if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): - runtime.thread_data.stop_processing = True + renderer.context.stop_processing = True if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None - print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') + log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}') + + current_state = ServerStates.LoadingModel + model_manager.resolve_model_paths(task.task_data) + model_manager.reload_models_if_necessary(renderer.context, task.task_data) current_state = ServerStates.Rendering - task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback) + task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) # Before looping back to the generator, mark cache as still alive. task_cache.keep(id(task), TASK_TTL) - session_cache.keep(task.request.session_id, TASK_TTL) + session_cache.keep(task.task_data.session_id, TASK_TTL) except Exception as e: task.error = e task.response = {"status": 'failed', "detail": str(task.error)} task.buffer_queue.put(json.dumps(task.response)) - print(traceback.format_exc()) + log.error(traceback.format_exc()) continue finally: # Task completed task.lock.release() task_cache.keep(id(task), TASK_TTL) - session_cache.keep(task.request.session_id, TASK_TTL) + session_cache.keep(task.task_data.session_id, TASK_TTL) if isinstance(task.error, StopAsyncIteration): - print(f'Session {task.request.session_id} task {id(task)} cancelled!') + log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!') elif task.error is not None: - print(f'Session {task.request.session_id} task {id(task)} failed!') + log.info(f'Session {task.task_data.session_id} task {id(task)} failed!') else: - print(f'Session {task.request.session_id} task {id(task)} completed by {runtime.thread_data.device_name}.') + log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.') current_state = ServerStates.Online def get_cached_task(task_id:str, update_ttl:bool=False): @@ -423,6 +339,7 @@ def get_devices(): 'name': torch.cuda.get_device_name(device), 'mem_free': mem_free, 'mem_total': mem_total, + 'max_vram_usage_level': device_manager.get_max_vram_usage_level(device), } # list the compatible devices @@ -472,7 +389,7 @@ def is_alive(device=None): def start_render_thread(device): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED) - print('Start new Rendering Thread on device', device) + log.info(f'Start new Rendering Thread on device: {device}') try: rthread = threading.Thread(target=thread_render, kwargs={'device': device}) rthread.daemon = True @@ -484,7 +401,7 @@ def start_render_thread(device): timeout = DEVICE_START_TIMEOUT while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]: if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]: - print(rthread, device, 'error:', weak_thread_data[rthread]['error']) + log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}") return False if timeout <= 0: return False @@ -496,11 +413,11 @@ def stop_render_thread(device): try: device_manager.validate_device_id(device, log_prefix='stop_render_thread') except: - print(traceback.format_exc()) + log.error(traceback.format_exc()) return False if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED) - print('Stopping Rendering Thread on device', device) + log.info(f'Stopping Rendering Thread on device: {device}') try: thread_to_remove = None @@ -523,79 +440,44 @@ def stop_render_thread(device): def update_render_threads(render_devices, active_devices): devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices) - print('devices_to_start', devices_to_start) - print('devices_to_stop', devices_to_stop) + log.debug(f'devices_to_start: {devices_to_start}') + log.debug(f'devices_to_stop: {devices_to_stop}') for device in devices_to_stop: if is_alive(device) <= 0: - print(device, 'is not alive') + log.debug(f'{device} is not alive') continue if not stop_render_thread(device): - print(device, 'could not stop render thread') + log.warn(f'{device} could not stop render thread') for device in devices_to_start: if is_alive(device) >= 1: - print(device, 'already registered.') + log.debug(f'{device} already registered.') continue if not start_render_thread(device): - print(device, 'failed to start.') + log.warn(f'{device} failed to start.') if is_alive() <= 0: # No running devices, probably invalid user config. raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json') - print('active devices', get_devices()['active']) + log.debug(f"active devices: {get_devices()['active']}") def shutdown_event(): # Signal render thread to close on shutdown global current_state_error current_state_error = SystemExit('Application shutting down.') -def render(req : ImageRequest): +def render(render_req: GenerateImageRequest, task_data: TaskData): current_thread_count = is_alive() if current_thread_count <= 0: # Render thread is dead raise ChildProcessError('Rendering thread has died.') # Alive, check if task in cache - session = get_cached_session(req.session_id, update_ttl=True) + session = get_cached_session(task_data.session_id, update_ttl=True) pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) if current_thread_count < len(pending_tasks): - raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') + raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') - from . import runtime - r = Request() - r.session_id = req.session_id - r.prompt = req.prompt - r.negative_prompt = req.negative_prompt - r.init_image = req.init_image - r.mask = req.mask - r.num_outputs = req.num_outputs - r.num_inference_steps = req.num_inference_steps - r.guidance_scale = req.guidance_scale - r.width = req.width - r.height = req.height - r.seed = req.seed - r.prompt_strength = req.prompt_strength - r.sampler = req.sampler - # r.allow_nsfw = req.allow_nsfw - r.turbo = req.turbo - r.use_full_precision = req.use_full_precision - r.save_to_disk_path = req.save_to_disk_path - r.use_upscale: str = req.use_upscale - r.use_face_correction = req.use_face_correction - r.use_stable_diffusion_model = req.use_stable_diffusion_model - r.use_vae_model = req.use_vae_model - r.use_hypernetwork_model = req.use_hypernetwork_model - r.hypernetwork_strength = req.hypernetwork_strength - r.show_only_filtered_image = req.show_only_filtered_image - r.output_format = req.output_format - r.output_quality = req.output_quality - - r.stream_progress_updates = True # the underlying implementation only supports streaming - r.stream_image_progress = req.stream_image_progress - - if not req.stream_progress_updates: - r.stream_image_progress = False - - new_task = RenderTask(r) + new_task = RenderTask(render_req, task_data) if session.put(new_task, TASK_TTL): # Use twice the normal timeout for adding user requests. # Tries to force session.put to fail before tasks_queue.put would. diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py new file mode 100644 index 00000000..805c8683 --- /dev/null +++ b/ui/easydiffusion/types.py @@ -0,0 +1,87 @@ +from pydantic import BaseModel +from typing import Any + +class GenerateImageRequest(BaseModel): + prompt: str = "" + negative_prompt: str = "" + + seed: int = 42 + width: int = 512 + height: int = 512 + + num_outputs: int = 1 + num_inference_steps: int = 50 + guidance_scale: float = 7.5 + + init_image: Any = None + init_image_mask: Any = None + prompt_strength: float = 0.8 + preserve_init_image_color_profile = False + + sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" + hypernetwork_strength: float = 0 + +class TaskData(BaseModel): + request_id: str = None + session_id: str = "session" + save_to_disk_path: str = None + vram_usage_level: str = "balanced" # or "low" or "medium" + + use_face_correction: str = None # or "GFPGANv1.3" + use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" + use_stable_diffusion_model: str = "sd-v1-4" + use_stable_diffusion_config: str = "v1-inference" + use_vae_model: str = None + use_hypernetwork_model: str = None + + show_only_filtered_image: bool = False + output_format: str = "jpeg" # or "png" + output_quality: int = 75 + metadata_output_format: str = "txt" # or "json" + stream_image_progress: bool = False + +class Image: + data: str # base64 + seed: int + is_nsfw: bool + path_abs: str = None + + def __init__(self, data, seed): + self.data = data + self.seed = seed + + def json(self): + return { + "data": self.data, + "seed": self.seed, + "path_abs": self.path_abs, + } + +class Response: + render_request: GenerateImageRequest + task_data: TaskData + images: list + + def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list): + self.render_request = render_request + self.task_data = task_data + self.images = images + + def json(self): + del self.render_request.init_image + del self.render_request.init_image_mask + + res = { + "status": 'succeeded', + "render_request": self.render_request.dict(), + "task_data": self.task_data.dict(), + "output": [], + } + + for image in self.images: + res["output"].append(image.json()) + + return res + +class UserInitiatedStop(Exception): + pass diff --git a/ui/easydiffusion/utils/__init__.py b/ui/easydiffusion/utils/__init__.py new file mode 100644 index 00000000..8be070b4 --- /dev/null +++ b/ui/easydiffusion/utils/__init__.py @@ -0,0 +1,8 @@ +import logging + +log = logging.getLogger('easydiffusion') + +from .save_utils import ( + save_images_to_disk, + get_printable_request, +) \ No newline at end of file diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py new file mode 100644 index 00000000..b9ce8aba --- /dev/null +++ b/ui/easydiffusion/utils/save_utils.py @@ -0,0 +1,79 @@ +import os +import time +import base64 +import re + +from easydiffusion.types import TaskData, GenerateImageRequest + +from sdkit.utils import save_images, save_dicts + +filename_regex = re.compile('[^a-zA-Z0-9]') + +# keep in sync with `ui/media/js/dnd.js` +TASK_TEXT_MAPPING = { + 'prompt': 'Prompt', + 'width': 'Width', + 'height': 'Height', + 'seed': 'Seed', + 'num_inference_steps': 'Steps', + 'guidance_scale': 'Guidance Scale', + 'prompt_strength': 'Prompt Strength', + 'use_face_correction': 'Use Face Correction', + 'use_upscale': 'Use Upscaling', + 'sampler_name': 'Sampler', + 'negative_prompt': 'Negative Prompt', + 'use_stable_diffusion_model': 'Stable Diffusion model', + 'use_hypernetwork_model': 'Hypernetwork model', + 'hypernetwork_strength': 'Hypernetwork Strength' +} + +def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): + save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) + metadata_entries = get_metadata_entries_for_request(req, task_data) + + if task_data.show_only_filtered_image or filtered_images == images: + save_images(filtered_images, save_dir_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + save_dicts(metadata_entries, save_dir_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format) + else: + save_images(images, save_dir_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + save_images(filtered_images, save_dir_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) + save_dicts(metadata_entries, save_dir_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) + +def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData): + metadata = get_printable_request(req) + metadata.update({ + 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, + 'use_vae_model': task_data.use_vae_model, + 'use_hypernetwork_model': task_data.use_hypernetwork_model, + 'use_face_correction': task_data.use_face_correction, + 'use_upscale': task_data.use_upscale, + }) + + # if text, format it in the text format expected by the UI + is_txt_format = (task_data.metadata_output_format.lower() == 'txt') + if is_txt_format: + metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING} + + entries = [metadata.copy() for _ in range(req.num_outputs)] + for i, entry in enumerate(entries): + entry['Seed' if is_txt_format else 'seed'] = req.seed + i + + return entries + +def get_printable_request(req: GenerateImageRequest): + metadata = req.dict() + del metadata['init_image'] + del metadata['init_image_mask'] + return metadata + +def make_filename_callback(req: GenerateImageRequest, suffix=None): + def make_filename(i): + img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. + img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. + + prompt_flattened = filename_regex.sub('_', req.prompt)[:50] + name = f"{prompt_flattened}_{img_id}" + name = name if suffix is None else f'{name}_{suffix}' + return name + + return make_filename \ No newline at end of file diff --git a/ui/index.html b/ui/index.html index 46f367d5..9cd66eb5 100644 --- a/ui/index.html +++ b/ui/index.html @@ -24,8 +24,8 @@
@@ -129,22 +129,32 @@ Click to learn more about custom models + Click to learn more about VAEs - - + + + + + + + + + + + + + + Click to learn more about samplers @@ -220,6 +230,7 @@
  • Render Settings
  • +
  • @@ -416,7 +427,6 @@ async function init() { await initSettings() await getModels() - await getDiskPath() await getAppConfig() await loadUIPlugins() await loadModifiers() diff --git a/ui/main.py b/ui/main.py new file mode 100644 index 00000000..77def4a5 --- /dev/null +++ b/ui/main.py @@ -0,0 +1,10 @@ +from easydiffusion import model_manager, app, server +from easydiffusion.server import server_api # required for uvicorn + +# Init the app +model_manager.init() +app.init() +server.init() + +# start the browser ui +app.open_browser() diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js index c1a9ce49..934b3f32 100644 --- a/ui/media/js/auto-save.js +++ b/ui/media/js/auto-save.js @@ -15,7 +15,7 @@ const SETTINGS_IDS_LIST = [ "stable_diffusion_model", "vae_model", "hypernetwork_model", - "sampler", + "sampler_name", "width", "height", "num_inference_steps", @@ -36,10 +36,11 @@ const SETTINGS_IDS_LIST = [ "save_to_disk", "diskPath", "sound_toggle", - "turbo", - "use_full_precision", + "vram_usage_level", "confirm_dangerous_actions", - "auto_save_settings" + "metadata_output_format", + "auto_save_settings", + "apply_color_correction" ] const IGNORE_BY_DEFAULT = [ @@ -277,7 +278,6 @@ function tryLoadOldSettings() { "soundEnabled": "sound_toggle", "saveToDisk": "save_to_disk", "useCPU": "use_cpu", - "useFullPrecision": "use_full_precision", "useTurboMode": "turbo", "diskPath": "diskPath", "useFaceCorrection": "use_face_correction", diff --git a/ui/media/js/dnd.js b/ui/media/js/dnd.js index bec618c6..b5d91880 100644 --- a/ui/media/js/dnd.js +++ b/ui/media/js/dnd.js @@ -25,6 +25,7 @@ function parseBoolean(stringValue) { case "no": case "off": case "0": + case "none": case null: case undefined: return false; @@ -160,9 +161,9 @@ const TASK_MAPPING = { readUI: () => (useUpscalingField.checked ? upscaleModelField.value : undefined), parse: (val) => val }, - sampler: { name: 'Sampler', - setUI: (sampler) => { - samplerField.value = sampler + sampler_name: { name: 'Sampler', + setUI: (sampler_name) => { + samplerField.value = sampler_name }, readUI: () => samplerField.value, parse: (val) => val @@ -171,7 +172,7 @@ const TASK_MAPPING = { setUI: (use_stable_diffusion_model) => { const oldVal = stableDiffusionModelField.value - use_stable_diffusion_model = getModelPath(use_stable_diffusion_model, ['.ckpt']) + use_stable_diffusion_model = getModelPath(use_stable_diffusion_model, ['.ckpt', '.safetensors']) stableDiffusionModelField.value = use_stable_diffusion_model if (!stableDiffusionModelField.value) { @@ -184,6 +185,7 @@ const TASK_MAPPING = { use_vae_model: { name: 'VAE model', setUI: (use_vae_model) => { const oldVal = vaeModelField.value + use_vae_model = (use_vae_model === undefined || use_vae_model === null || use_vae_model === 'None' ? '' : use_vae_model) if (use_vae_model !== '') { use_vae_model = getModelPath(use_vae_model, ['.vae.pt', '.ckpt']) @@ -197,6 +199,7 @@ const TASK_MAPPING = { use_hypernetwork_model: { name: 'Hypernetwork model', setUI: (use_hypernetwork_model) => { const oldVal = hypernetworkModelField.value + use_hypernetwork_model = (use_hypernetwork_model === undefined || use_hypernetwork_model === null || use_hypernetwork_model === 'None' ? '' : use_hypernetwork_model) if (use_hypernetwork_model !== '') { use_hypernetwork_model = getModelPath(use_hypernetwork_model, ['.pt']) @@ -239,13 +242,6 @@ const TASK_MAPPING = { readUI: () => turboField.checked, parse: (val) => Boolean(val) }, - use_full_precision: { name: 'Use Full Precision', - setUI: (use_full_precision) => { - useFullPrecisionField.checked = use_full_precision - }, - readUI: () => useFullPrecisionField.checked, - parse: (val) => Boolean(val) - }, stream_image_progress: { name: 'Stream Image Progress', setUI: (stream_image_progress) => { @@ -350,6 +346,7 @@ function getModelPath(filename, extensions) } const TASK_TEXT_MAPPING = { + prompt: 'Prompt', width: 'Width', height: 'Height', seed: 'Seed', @@ -358,7 +355,7 @@ const TASK_TEXT_MAPPING = { prompt_strength: 'Prompt Strength', use_face_correction: 'Use Face Correction', use_upscale: 'Use Upscaling', - sampler: 'Sampler', + sampler_name: 'Sampler', negative_prompt: 'Negative Prompt', use_stable_diffusion_model: 'Stable Diffusion model', use_hypernetwork_model: 'Hypernetwork model', @@ -410,6 +407,9 @@ async function parseContent(text) { if (text.startsWith('{') && text.endsWith('}')) { try { const task = JSON.parse(text) + if (!('reqBody' in task)) { // support the format saved to the disk, by the UI + task.reqBody = Object.assign({}, task) + } restoreTaskToUI(task) return true } catch (e) { @@ -477,7 +477,6 @@ document.addEventListener("dragover", dragOverHandler) const TASK_REQ_NO_EXPORT = [ "use_cpu", "turbo", - "use_full_precision", "save_to_disk_path" ] const resetSettings = document.getElementById('reset-image-settings') diff --git a/ui/media/js/engine.js b/ui/media/js/engine.js index 467a45c5..54f63b3c 100644 --- a/ui/media/js/engine.js +++ b/ui/media/js/engine.js @@ -728,7 +728,6 @@ "stream_image_progress": 'boolean', "show_only_filtered_image": 'boolean', "turbo": 'boolean', - "use_full_precision": 'boolean', "output_format": 'string', "output_quality": 'number', } @@ -744,7 +743,6 @@ "stream_image_progress": true, "show_only_filtered_image": true, "turbo": false, - "use_full_precision": false, "output_format": "png", "output_quality": 75, } diff --git a/ui/media/js/main.js b/ui/media/js/main.js index b518c8c1..5e79ec56 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -26,9 +26,11 @@ let initImagePreview = document.querySelector("#init_image_preview") let initImageSizeBox = document.querySelector("#init_image_size_box") let maskImageSelector = document.querySelector("#mask") let maskImagePreview = document.querySelector("#mask_preview") +let applyColorCorrectionField = document.querySelector('#apply_color_correction') +let colorCorrectionSetting = document.querySelector('#apply_color_correction_setting') let promptStrengthSlider = document.querySelector('#prompt_strength_slider') let promptStrengthField = document.querySelector('#prompt_strength') -let samplerField = document.querySelector('#sampler') +let samplerField = document.querySelector('#sampler_name') let samplerSelectionContainer = document.querySelector("#samplerSelection") let useFaceCorrectionField = document.querySelector("#use_face_correction") let useUpscalingField = document.querySelector("#use_upscale") @@ -610,7 +612,7 @@ function onTaskCompleted(task, reqBody, instance, outputContainer, stepUpdate) { Suggestions:
    1. If you have set an initial image, please try reducing its dimension to ${MAX_INIT_IMAGE_DIMENSION}x${MAX_INIT_IMAGE_DIMENSION} or smaller.
    - 2. Try disabling the 'Turbo mode' under 'Advanced Settings'.
    + 2. Try picking a lower level in the 'GPU Memory Usage' setting (in the 'Settings' tab).
    3. Try generating a smaller image.
    ` } } else { @@ -786,10 +788,11 @@ function createTask(task) { if (task.reqBody.init_image !== undefined) { let h = 80 - let w = task.reqBody.width * h / task.reqBody.height >>0 + let w = task.reqBody.width * h / task.reqBody.height >>0 taskConfig += `
    ` } - taskConfig += `Seed: ${task.seed}, Sampler: ${task.reqBody.sampler}, Inference Steps: ${task.reqBody.num_inference_steps}, Guidance Scale: ${task.reqBody.guidance_scale}, Model: ${task.reqBody.use_stable_diffusion_model}` + taskConfig += `Seed: ${task.seed}, Sampler: ${task.reqBody.sampler_name}, Inference Steps: ${task.reqBody.num_inference_steps}, Guidance Scale: ${task.reqBody.guidance_scale}, Model: ${task.reqBody.use_stable_diffusion_model}` + if (task.reqBody.use_vae_model.trim() !== '') { taskConfig += `, VAE: ${task.reqBody.use_vae_model}` } @@ -809,6 +812,9 @@ function createTask(task) { taskConfig += `, Hypernetwork: ${task.reqBody.use_hypernetwork_model}` taskConfig += `, Hypernetwork Strength: ${task.reqBody.hypernetwork_strength}` } + if (task.reqBody.preserve_init_image_color_profile) { + taskConfig += `, Preserve Color Profile: true` + } let taskEntry = document.createElement('div') taskEntry.id = `imageTaskContainer-${Date.now()}` @@ -914,9 +920,8 @@ function getCurrentUserRequest() { width: parseInt(widthField.value), height: parseInt(heightField.value), // allow_nsfw: allowNSFWField.checked, - turbo: turboField.checked, + vram_usage_level: vramUsageLevelField.value, //render_device: undefined, // Set device affinity. Prefer this device, but wont activate. - use_full_precision: useFullPrecisionField.checked, use_stable_diffusion_model: stableDiffusionModelField.value, use_vae_model: vaeModelField.value, stream_progress_updates: true, @@ -924,6 +929,7 @@ function getCurrentUserRequest() { show_only_filtered_image: showOnlyFilteredImageField.checked, output_format: outputFormatField.value, output_quality: parseInt(outputQualityField.value), + metadata_output_format: document.querySelector('#metadata_output_format').value, original_prompt: promptField.value, active_tags: (activeTags.map(x => x.name)) } @@ -938,9 +944,10 @@ function getCurrentUserRequest() { if (maskSetting.checked) { newTask.reqBody.mask = imageInpainter.getImg() } - newTask.reqBody.sampler = 'ddim' + newTask.reqBody.preserve_init_image_color_profile = applyColorCorrectionField.checked + newTask.reqBody.sampler_name = 'ddim' } else { - newTask.reqBody.sampler = samplerField.value + newTask.reqBody.sampler_name = samplerField.value } if (saveToDiskField.checked && diskPathField.value.trim() !== '') { newTask.reqBody.save_to_disk_path = diskPathField.value.trim() @@ -1349,6 +1356,7 @@ function img2imgLoad() { promptStrengthContainer.style.display = 'table-row' samplerSelectionContainer.style.display = "none" initImagePreviewContainer.classList.add("has-image") + colorCorrectionSetting.style.display = '' initImageSizeBox.textContent = initImagePreview.naturalWidth + " x " + initImagePreview.naturalHeight imageEditor.setImage(this.src, initImagePreview.naturalWidth, initImagePreview.naturalHeight) @@ -1363,6 +1371,7 @@ function img2imgUnload() { promptStrengthContainer.style.display = "none" samplerSelectionContainer.style.display = "" initImagePreviewContainer.classList.remove("has-image") + colorCorrectionSetting.style.display = 'none' imageEditor.setImage(null, parseInt(widthField.value), parseInt(heightField.value)) } diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index 5c522d7e..4e7db4f5 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -53,6 +53,23 @@ var PARAMETERS = [ return `` } }, + { + id: "metadata_output_format", + type: ParameterType.select, + label: "Metadata format", + note: "will be saved to disk in this format", + default: "txt", + options: [ + { + value: "txt", + label: "txt" + }, + { + value: "json", + label: "json" + } + ], + }, { id: "sound_toggle", type: ParameterType.checkbox, @@ -77,12 +94,20 @@ var PARAMETERS = [ default: true, }, { - id: "turbo", - type: ParameterType.checkbox, - label: "Turbo Mode", - note: "generates images faster, but uses an additional 1 GB of GPU memory", + id: "vram_usage_level", + type: ParameterType.select, + label: "GPU Memory Usage", + note: "Faster performance requires more GPU memory (VRAM)

    " + + "Balanced: nearly as fast as High, much lower VRAM usage
    " + + "High: fastest, maximum GPU memory usage
    " + + "Low: slowest, force-used for GPUs with 4 GB (or less) memory", icon: "fa-forward", - default: true, + default: "balanced", + options: [ + {value: "balanced", label: "Balanced"}, + {value: "high", label: "High"}, + {value: "low", label: "Low"} + ], }, { id: "use_cpu", @@ -105,14 +130,6 @@ var PARAMETERS = [ note: "to process in parallel", default: false, }, - { - id: "use_full_precision", - type: ParameterType.checkbox, - label: "Use Full Precision", - note: "for GPU-only. warning: this will consume more VRAM", - icon: "fa-crosshairs", - default: false, - }, { id: "auto_save_settings", type: ParameterType.checkbox, @@ -147,14 +164,6 @@ var PARAMETERS = [ return `` } }, - { - id: "test_sd2", - type: ParameterType.checkbox, - label: "Test SD 2.0", - note: "Experimental! High memory usage! GPU-only! Not the final version! Please restart the program after changing this.", - icon: "fa-fire", - default: false, - }, { id: "use_beta_channel", type: ParameterType.checkbox, @@ -210,16 +219,14 @@ function initParameters() { initParameters() -let turboField = document.querySelector('#turbo') +let vramUsageLevelField = document.querySelector('#vram_usage_level') let useCPUField = document.querySelector('#use_cpu') let autoPickGPUsField = document.querySelector('#auto_pick_gpus') let useGPUsField = document.querySelector('#use_gpus') -let useFullPrecisionField = document.querySelector('#use_full_precision') let saveToDiskField = document.querySelector('#save_to_disk') let diskPathField = document.querySelector('#diskPath') let listenToNetworkField = document.querySelector("#listen_to_network") let listenPortField = document.querySelector("#listen_port") -let testSD2Field = document.querySelector("#test_sd2") let useBetaChannelField = document.querySelector("#use_beta_channel") let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start") let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions") @@ -256,12 +263,6 @@ async function getAppConfig() { if (config.ui && config.ui.open_browser_on_start === false) { uiOpenBrowserOnStartField.checked = false } - if ('test_sd2' in config) { - testSD2Field.checked = config['test_sd2'] - } - - let testSD2SettingEntry = getParameterSettingsEntry('test_sd2') - testSD2SettingEntry.style.display = (config.update_branch === 'beta' ? '' : 'none') if (config.net && config.net.listen_to_network === false) { listenToNetworkField.checked = false } @@ -327,20 +328,10 @@ autoPickGPUsField.addEventListener('click', function() { gpuSettingEntry.style.display = (this.checked ? 'none' : '') }) -async function getDiskPath() { - try { - var diskPath = getSetting("diskPath") - if (diskPath == '' || diskPath == undefined || diskPath == "undefined") { - let res = await fetch('/get/output_dir') - if (res.status === 200) { - res = await res.json() - res = res.output_dir - - setSetting("diskPath", res) - } - } - } catch (e) { - console.log('error fetching output dir path', e) +async function setDiskPath(defaultDiskPath) { + var diskPath = getSetting("diskPath") + if (diskPath == '' || diskPath == undefined || diskPath == "undefined") { + setSetting("diskPath", defaultDiskPath) } } @@ -415,6 +406,7 @@ async function getSystemInfo() { setDeviceInfo(devices) setHostInfo(res['hosts']) + setDiskPath(res['default_output_dir']) } catch (e) { console.log('error fetching devices', e) } @@ -435,8 +427,7 @@ saveSettingsBtn.addEventListener('click', function() { 'update_branch': updateBranch, 'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked, 'listen_to_network': listenToNetworkField.checked, - 'listen_port': listenPortField.value, - 'test_sd2': testSD2Field.checked + 'listen_port': listenPortField.value }) saveSettingsBtn.classList.add('active') asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active')) diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py deleted file mode 100644 index 0a1590f0..00000000 --- a/ui/sd_internal/__init__.py +++ /dev/null @@ -1,119 +0,0 @@ -import json - -class Request: - request_id: str = None - session_id: str = "session" - prompt: str = "" - negative_prompt: str = "" - init_image: str = None # base64 - mask: str = None # base64 - num_outputs: int = 1 - num_inference_steps: int = 50 - guidance_scale: float = 7.5 - width: int = 512 - height: int = 512 - seed: int = 42 - prompt_strength: float = 0.8 - sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" - # allow_nsfw: bool = False - precision: str = "autocast" # or "full" - save_to_disk_path: str = None - turbo: bool = True - use_full_precision: bool = False - use_face_correction: str = None # or "GFPGANv1.3" - use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" - use_stable_diffusion_model: str = "sd-v1-4" - use_vae_model: str = None - use_hypernetwork_model: str = None - hypernetwork_strength: float = 1 - show_only_filtered_image: bool = False - output_format: str = "jpeg" # or "png" - output_quality: int = 75 - - stream_progress_updates: bool = False - stream_image_progress: bool = False - - def json(self): - return { - "session_id": self.session_id, - "prompt": self.prompt, - "negative_prompt": self.negative_prompt, - "num_outputs": self.num_outputs, - "num_inference_steps": self.num_inference_steps, - "guidance_scale": self.guidance_scale, - "hypernetwork_strengtgh": self.guidance_scale, - "width": self.width, - "height": self.height, - "seed": self.seed, - "prompt_strength": self.prompt_strength, - "sampler": self.sampler, - "use_face_correction": self.use_face_correction, - "use_upscale": self.use_upscale, - "use_stable_diffusion_model": self.use_stable_diffusion_model, - "use_vae_model": self.use_vae_model, - "use_hypernetwork_model": self.use_hypernetwork_model, - "hypernetwork_strength": self.hypernetwork_strength, - "output_format": self.output_format, - "output_quality": self.output_quality, - } - - def __str__(self): - return f''' - session_id: {self.session_id} - prompt: {self.prompt} - negative_prompt: {self.negative_prompt} - seed: {self.seed} - num_inference_steps: {self.num_inference_steps} - sampler: {self.sampler} - guidance_scale: {self.guidance_scale} - w: {self.width} - h: {self.height} - precision: {self.precision} - save_to_disk_path: {self.save_to_disk_path} - turbo: {self.turbo} - use_full_precision: {self.use_full_precision} - use_face_correction: {self.use_face_correction} - use_upscale: {self.use_upscale} - use_stable_diffusion_model: {self.use_stable_diffusion_model} - use_vae_model: {self.use_vae_model} - use_hypernetwork_model: {self.use_hypernetwork_model} - hypernetwork_strength: {self.hypernetwork_strength} - show_only_filtered_image: {self.show_only_filtered_image} - output_format: {self.output_format} - output_quality: {self.output_quality} - - stream_progress_updates: {self.stream_progress_updates} - stream_image_progress: {self.stream_image_progress}''' - -class Image: - data: str # base64 - seed: int - is_nsfw: bool - path_abs: str = None - - def __init__(self, data, seed): - self.data = data - self.seed = seed - - def json(self): - return { - "data": self.data, - "seed": self.seed, - "path_abs": self.path_abs, - } - -class Response: - request: Request - images: list - - def json(self): - res = { - "status": 'succeeded', - "request": self.request.json(), - "output": [], - } - - for image in self.images: - res["output"].append(image.json()) - - return res diff --git a/ui/sd_internal/ddim_callback.patch b/ui/sd_internal/ddim_callback.patch deleted file mode 100644 index e4dd69e0..00000000 --- a/ui/sd_internal/ddim_callback.patch +++ /dev/null @@ -1,162 +0,0 @@ -diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py -index 79058bc..a473411 100644 ---- a/optimizedSD/ddpm.py -+++ b/optimizedSD/ddpm.py -@@ -564,12 +564,12 @@ class UNet(DDPM): - unconditional_guidance_scale=unconditional_guidance_scale, - callback=callback, img_callback=img_callback) - -+ yield from samples -+ - if(self.turbo): - self.model1.to("cpu") - self.model2.to("cpu") - -- return samples -- - @torch.no_grad() - def plms_sampling(self, cond,b, img, - ddim_use_original_steps=False, -@@ -608,10 +608,10 @@ class UNet(DDPM): - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) -- if callback: callback(i) -- if img_callback: img_callback(pred_x0, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(pred_x0, i) - -- return img -+ yield from img_callback(img, len(iterator)-1) - - @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, -@@ -740,13 +740,13 @@ class UNet(DDPM): - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=unconditional_conditioning) - -- if callback: callback(i) -- if img_callback: img_callback(x_dec, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x_dec, i) - - if mask is not None: -- return x0 * mask + (1. - mask) * x_dec -+ x_dec = x0 * mask + (1. - mask) * x_dec - -- return x_dec -+ yield from img_callback(x_dec, len(iterator)-1) - - - @torch.no_grad() -@@ -820,12 +820,12 @@ class UNet(DDPM): - - - d = to_d(x, sigma_hat, denoised) -- if callback: callback(i) -- if img_callback: img_callback(x, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x, i) - dt = sigmas[i + 1] - sigma_hat - # Euler method - x = x + d * dt -- return x -+ yield from img_callback(x, len(sigmas)-1) - - @torch.no_grad() - def euler_ancestral_sampling(self,ac,x, S, cond, unconditional_conditioning = None, unconditional_guidance_scale = 1,extra_args=None, callback=None, disable=None, img_callback=None): -@@ -852,14 +852,14 @@ class UNet(DDPM): - denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) -- if callback: callback(i) -- if img_callback: img_callback(x, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x, i) - d = to_d(x, sigmas[i], denoised) - # Euler method - dt = sigma_down - sigmas[i] - x = x + d * dt - x = x + torch.randn_like(x) * sigma_up -- return x -+ yield from img_callback(x, len(sigmas)-1) - - - -@@ -892,8 +892,8 @@ class UNet(DDPM): - denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - d = to_d(x, sigma_hat, denoised) -- if callback: callback(i) -- if img_callback: img_callback(x, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x, i) - dt = sigmas[i + 1] - sigma_hat - if sigmas[i + 1] == 0: - # Euler method -@@ -913,7 +913,7 @@ class UNet(DDPM): - d_2 = to_d(x_2, sigmas[i + 1], denoised_2) - d_prime = (d + d_2) / 2 - x = x + d_prime * dt -- return x -+ yield from img_callback(x, len(sigmas)-1) - - - @torch.no_grad() -@@ -944,8 +944,8 @@ class UNet(DDPM): - e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) - denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - -- if callback: callback(i) -- if img_callback: img_callback(x, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x, i) - - d = to_d(x, sigma_hat, denoised) - # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule -@@ -966,7 +966,7 @@ class UNet(DDPM): - - d_2 = to_d(x_2, sigma_mid, denoised_2) - x = x + d_2 * dt_2 -- return x -+ yield from img_callback(x, len(sigmas)-1) - - - @torch.no_grad() -@@ -994,8 +994,8 @@ class UNet(DDPM): - - - sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) -- if callback: callback(i) -- if img_callback: img_callback(x, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x, i) - d = to_d(x, sigmas[i], denoised) - # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule - sigma_mid = ((sigmas[i] ** (1 / 3) + sigma_down ** (1 / 3)) / 2) ** 3 -@@ -1016,7 +1016,7 @@ class UNet(DDPM): - d_2 = to_d(x_2, sigma_mid, denoised_2) - x = x + d_2 * dt_2 - x = x + torch.randn_like(x) * sigma_up -- return x -+ yield from img_callback(x, len(sigmas)-1) - - - @torch.no_grad() -@@ -1042,8 +1042,8 @@ class UNet(DDPM): - e_t_uncond, e_t = (x_in + eps * c_out).chunk(2) - denoised = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - -- if callback: callback(i) -- if img_callback: img_callback(x, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(x, i) - - d = to_d(x, sigmas[i], denoised) - ds.append(d) -@@ -1054,4 +1054,4 @@ class UNet(DDPM): - cur_order = min(i + 1, order) - coeffs = [linear_multistep_coeff(cur_order, sigmas.cpu(), i, j) for j in range(cur_order)] - x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds))) -- return x -+ yield from img_callback(x, len(sigmas)-1) diff --git a/ui/sd_internal/ddim_callback_sd2.patch b/ui/sd_internal/ddim_callback_sd2.patch deleted file mode 100644 index cadf81ca..00000000 --- a/ui/sd_internal/ddim_callback_sd2.patch +++ /dev/null @@ -1,84 +0,0 @@ -diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py -index 27ead0e..6215939 100644 ---- a/ldm/models/diffusion/ddim.py -+++ b/ldm/models/diffusion/ddim.py -@@ -100,7 +100,7 @@ class DDIMSampler(object): - size = (batch_size, C, H, W) - print(f'Data shape for DDIM sampling is {size}, eta {eta}') - -- samples, intermediates = self.ddim_sampling(conditioning, size, -+ samples = self.ddim_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, -@@ -117,7 +117,8 @@ class DDIMSampler(object): - dynamic_threshold=dynamic_threshold, - ucg_schedule=ucg_schedule - ) -- return samples, intermediates -+ # return samples, intermediates -+ yield from samples - - @torch.no_grad() - def ddim_sampling(self, cond, shape, -@@ -168,14 +169,15 @@ class DDIMSampler(object): - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold) - img, pred_x0 = outs -- if callback: callback(i) -- if img_callback: img_callback(pred_x0, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - -- return img, intermediates -+ # return img, intermediates -+ yield from img_callback(pred_x0, len(iterator)-1) - - @torch.no_grad() - def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, -diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py -index 7002a36..0951f39 100644 ---- a/ldm/models/diffusion/plms.py -+++ b/ldm/models/diffusion/plms.py -@@ -96,7 +96,7 @@ class PLMSSampler(object): - size = (batch_size, C, H, W) - print(f'Data shape for PLMS sampling is {size}') - -- samples, intermediates = self.plms_sampling(conditioning, size, -+ samples = self.plms_sampling(conditioning, size, - callback=callback, - img_callback=img_callback, - quantize_denoised=quantize_x0, -@@ -112,7 +112,8 @@ class PLMSSampler(object): - unconditional_conditioning=unconditional_conditioning, - dynamic_threshold=dynamic_threshold, - ) -- return samples, intermediates -+ #return samples, intermediates -+ yield from samples - - @torch.no_grad() - def plms_sampling(self, cond, shape, -@@ -165,14 +166,15 @@ class PLMSSampler(object): - old_eps.append(e_t) - if len(old_eps) >= 4: - old_eps.pop(0) -- if callback: callback(i) -- if img_callback: img_callback(pred_x0, i) -+ if callback: yield from callback(i) -+ if img_callback: yield from img_callback(pred_x0, i) - - if index % log_every_t == 0 or index == total_steps - 1: - intermediates['x_inter'].append(img) - intermediates['pred_x0'].append(pred_x0) - -- return img, intermediates -+ # return img, intermediates -+ yield from img_callback(pred_x0, len(iterator)-1) - - @torch.no_grad() - def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, diff --git a/ui/sd_internal/hypernetwork.py b/ui/sd_internal/hypernetwork.py deleted file mode 100644 index 979a74f3..00000000 --- a/ui/sd_internal/hypernetwork.py +++ /dev/null @@ -1,198 +0,0 @@ -# this is basically a cut down version of https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/c9a2cfdf2a53d37c2de1908423e4f548088667ef/modules/hypernetworks/hypernetwork.py, mostly for feature parity -# I, c0bra5, don't really understand how deep learning works. I just know how to port stuff. - -import inspect -import torch -import optimizedSD.splitAttention -from . import runtime -from einops import rearrange - -optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} - -loaded_hypernetwork = None - -class HypernetworkModule(torch.nn.Module): - multiplier = 0.5 - activation_dict = { - "linear": torch.nn.Identity, - "relu": torch.nn.ReLU, - "leakyrelu": torch.nn.LeakyReLU, - "elu": torch.nn.ELU, - "swish": torch.nn.Hardswish, - "tanh": torch.nn.Tanh, - "sigmoid": torch.nn.Sigmoid, - } - activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'}) - - def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', - add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False): - super().__init__() - - assert layer_structure is not None, "layer_structure must not be None" - assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" - assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" - - linears = [] - for i in range(len(layer_structure) - 1): - - # Add a fully-connected layer - linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) - - # Add an activation func except last layer - if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output): - pass - elif activation_func in self.activation_dict: - linears.append(self.activation_dict[activation_func]()) - else: - raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') - - # Add layer normalization - if add_layer_norm: - linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) - - # Add dropout except last layer - if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2): - linears.append(torch.nn.Dropout(p=0.3)) - - self.linear = torch.nn.Sequential(*linears) - - self.fix_old_state_dict(state_dict) - self.load_state_dict(state_dict) - - self.to(runtime.thread_data.device) - - def fix_old_state_dict(self, state_dict): - changes = { - 'linear1.bias': 'linear.0.bias', - 'linear1.weight': 'linear.0.weight', - 'linear2.bias': 'linear.1.bias', - 'linear2.weight': 'linear.1.weight', - } - - for fr, to in changes.items(): - x = state_dict.get(fr, None) - if x is None: - continue - - del state_dict[fr] - state_dict[to] = x - - def forward(self, x: torch.Tensor): - return x + self.linear(x) * runtime.thread_data.hypernetwork_strength - -def apply_hypernetwork(hypernetwork, context, layer=None): - hypernetwork_layers = hypernetwork.get(context.shape[2], None) - - if hypernetwork_layers is None: - return context, context - - if layer is not None: - layer.hyper_k = hypernetwork_layers[0] - layer.hyper_v = hypernetwork_layers[1] - - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) - return context_k, context_v - -def get_kv(context, hypernetwork): - if hypernetwork is None: - return context, context - else: - return apply_hypernetwork(runtime.thread_data.hypernetwork, context) - -# This might need updating as the optimisedSD code changes -# I think yall have a system for this (patch files in sd_internal) but idk how it works and no amount of searching gave me any clue -# just in case for attribution https://github.com/easydiffusion/diffusion-kit/blob/e8ea0cadd543056059cd951e76d4744de76327d2/optimizedSD/splitAttention.py#L171 -def new_cross_attention_forward(self, x, context=None, mask=None): - h = self.heads - - q = self.to_q(x) - # default context - context = context if context is not None else x() if inspect.isfunction(x) else x - # hypernetwork! - context_k, context_v = get_kv(context, runtime.thread_data.hypernetwork) - k = self.to_k(context_k) - v = self.to_v(context_v) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - - limit = k.shape[0] - att_step = self.att_step - q_chunks = list(torch.tensor_split(q, limit//att_step, dim=0)) - k_chunks = list(torch.tensor_split(k, limit//att_step, dim=0)) - v_chunks = list(torch.tensor_split(v, limit//att_step, dim=0)) - - q_chunks.reverse() - k_chunks.reverse() - v_chunks.reverse() - sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - del k, q, v - for i in range (0, limit, att_step): - - q_buffer = q_chunks.pop() - k_buffer = k_chunks.pop() - v_buffer = v_chunks.pop() - sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale - - del k_buffer, q_buffer - # attention, what we cannot get enough of, by chunks - - sim_buffer = sim_buffer.softmax(dim=-1) - - sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer) - del v_buffer - sim[i:i+att_step,:,:] = sim_buffer - - del sim_buffer - sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h) - return self.to_out(sim) - - -def load_hypernetwork(path: str): - - state_dict = torch.load(path, map_location='cpu') - - layer_structure = state_dict.get('layer_structure', [1, 2, 1]) - activation_func = state_dict.get('activation_func', None) - weight_init = state_dict.get('weight_initialization', 'Normal') - add_layer_norm = state_dict.get('is_layer_norm', False) - use_dropout = state_dict.get('use_dropout', False) - activate_output = state_dict.get('activate_output', True) - last_layer_dropout = state_dict.get('last_layer_dropout', False) - # this is a bit verbose so leaving it commented out for the poor soul who ever has to debug this - # print(f"layer_structure: {layer_structure}") - # print(f"activation_func: {activation_func}") - # print(f"weight_init: {weight_init}") - # print(f"add_layer_norm: {add_layer_norm}") - # print(f"use_dropout: {use_dropout}") - # print(f"activate_output: {activate_output}") - # print(f"last_layer_dropout: {last_layer_dropout}") - - layers = {} - for size, sd in state_dict.items(): - if type(size) == int: - layers[size] = ( - HypernetworkModule(size, sd[0], layer_structure, activation_func, weight_init, add_layer_norm, - use_dropout, activate_output, last_layer_dropout=last_layer_dropout), - HypernetworkModule(size, sd[1], layer_structure, activation_func, weight_init, add_layer_norm, - use_dropout, activate_output, last_layer_dropout=last_layer_dropout), - ) - print(f"hypernetwork loaded") - return layers - - - -# overriding of original function -old_cross_attention_forward = optimizedSD.splitAttention.CrossAttention.forward -# hijacks the cross attention forward function to add hyper network support -def hijack_cross_attention(): - print("hypernetwork functionality added to cross attention") - optimizedSD.splitAttention.CrossAttention.forward = new_cross_attention_forward -# there was a cop on board -def unhijack_cross_attention_forward(): - print("hypernetwork functionality removed from cross attention") - optimizedSD.splitAttention.CrossAttention.forward = old_cross_attention_forward - -hijack_cross_attention() \ No newline at end of file diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py deleted file mode 100644 index 2bf53d0c..00000000 --- a/ui/sd_internal/runtime.py +++ /dev/null @@ -1,1078 +0,0 @@ -"""runtime.py: torch device owned by a thread. -Notes: - Avoid device switching, transfering all models will get too complex. - To use a diffrent device signal the current render device to exit - And then start a new clean thread for the new device. -""" -import json -import os, re -import traceback -import queue -import torch -import numpy as np -from gc import collect as gc_collect -from omegaconf import OmegaConf -from PIL import Image, ImageOps -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange -import time -from pytorch_lightning import seed_everything -from torch import autocast -from contextlib import nullcontext -from einops import rearrange, repeat -from ldm.util import instantiate_from_config -from transformers import logging - -from gfpgan import GFPGANer -from basicsr.archs.rrdbnet_arch import RRDBNet -from realesrgan import RealESRGANer - -from server import HYPERNETWORK_MODEL_EXTENSIONS# , STABLE_DIFFUSION_MODEL_EXTENSIONS, VAE_MODEL_EXTENSIONS - -from threading import Lock -from safetensors.torch import load_file - -import uuid - -logging.set_verbosity_error() - -# consts -config_yaml = "optimizedSD/v1-inference.yaml" -filename_regex = re.compile('[^a-zA-Z0-9]') -gfpgan_temp_device_lock = Lock() # workaround: gfpgan currently can only start on one device at a time. - -# api stuff -from sd_internal import device_manager -from . import Request, Response, Image as ResponseImage -import base64 -from io import BytesIO -#from colorama import Fore - -from threading import local as LocalThreadVars -thread_data = LocalThreadVars() - -def thread_init(device): - # Thread bound properties - thread_data.stop_processing = False - thread_data.temp_images = {} - - thread_data.ckpt_file = None - thread_data.vae_file = None - thread_data.hypernetwork_file = None - thread_data.gfpgan_file = None - thread_data.real_esrgan_file = None - - thread_data.model = None - thread_data.modelCS = None - thread_data.modelFS = None - thread_data.hypernetwork = None - thread_data.hypernetwork_strength = 1 - thread_data.model_gfpgan = None - thread_data.model_real_esrgan = None - - thread_data.model_is_half = False - thread_data.model_fs_is_half = False - thread_data.device = None - thread_data.device_name = None - thread_data.unet_bs = 1 - thread_data.precision = 'autocast' - thread_data.sampler_plms = None - thread_data.sampler_ddim = None - - thread_data.turbo = False - thread_data.force_full_precision = False - thread_data.reduced_memory = True - - thread_data.test_sd2 = isSD2() - - device_manager.device_init(thread_data, device) - -# temp hack, will remove soon -def isSD2(): - try: - SD_UI_DIR = os.getenv('SD_UI_PATH', None) - CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) - config_json_path = os.path.join(CONFIG_DIR, 'config.json') - if not os.path.exists(config_json_path): - return False - with open(config_json_path, 'r', encoding='utf-8') as f: - config = json.load(f) - return config.get('test_sd2', False) - except Exception as e: - return False - -def load_model_ckpt(): - if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.') - if os.path.exists(thread_data.ckpt_file + '.ckpt'): - thread_data.ckpt_file += '.ckpt' - elif os.path.exists(thread_data.ckpt_file + '.safetensors'): - thread_data.ckpt_file += '.safetensors' - elif not os.path.exists(thread_data.ckpt_file): - raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt or .safetensors') - - if not thread_data.precision: - thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast' - - if not thread_data.unet_bs: - thread_data.unet_bs = 1 - - if thread_data.device == 'cpu': - thread_data.precision = 'full' - - print('loading', thread_data.ckpt_file, 'to device', thread_data.device, 'using precision', thread_data.precision) - - if thread_data.test_sd2: - load_model_ckpt_sd2() - else: - load_model_ckpt_sd1() - -def load_model_ckpt_sd1(): - sd, model_ver = load_model_from_config(thread_data.ckpt_file) - li, lo = [], [] - for key, value in sd.items(): - sp = key.split(".") - if (sp[0]) == "model": - if "input_blocks" in sp: - li.append(key) - elif "middle_block" in sp: - li.append(key) - elif "time_embed" in sp: - li.append(key) - else: - lo.append(key) - for key in li: - sd["model1." + key[6:]] = sd.pop(key) - for key in lo: - sd["model2." + key[6:]] = sd.pop(key) - - config = OmegaConf.load(f"{config_yaml}") - - model = instantiate_from_config(config.modelUNet) - _, _ = model.load_state_dict(sd, strict=False) - model.eval() - model.cdevice = torch.device(thread_data.device) - model.unet_bs = thread_data.unet_bs - model.turbo = thread_data.turbo - # if thread_data.device != 'cpu': - # model.to(thread_data.device) - #if thread_data.reduced_memory: - #model.model1.to("cpu") - #model.model2.to("cpu") - thread_data.model = model - - modelCS = instantiate_from_config(config.modelCondStage) - _, _ = modelCS.load_state_dict(sd, strict=False) - modelCS.eval() - modelCS.cond_stage_model.device = torch.device(thread_data.device) - # if thread_data.device != 'cpu': - # if thread_data.reduced_memory: - # modelCS.to('cpu') - # else: - # modelCS.to(thread_data.device) # Preload on device if not already there. - thread_data.modelCS = modelCS - - modelFS = instantiate_from_config(config.modelFirstStage) - _, _ = modelFS.load_state_dict(sd, strict=False) - - if thread_data.vae_file is not None: - try: - loaded = False - for model_extension in ['.ckpt', '.vae.pt']: - if os.path.exists(thread_data.vae_file + model_extension): - print(f"Loading VAE weights from: {thread_data.vae_file}{model_extension}") - vae_ckpt = torch.load(thread_data.vae_file + model_extension, map_location="cpu") - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} - modelFS.first_stage_model.load_state_dict(vae_dict, strict=False) - loaded = True - break - - if not loaded: - print(f'Cannot find VAE: {thread_data.vae_file}') - thread_data.vae_file = None - except: - print(traceback.format_exc()) - print(f'Could not load VAE: {thread_data.vae_file}') - thread_data.vae_file = None - - modelFS.eval() - # if thread_data.device != 'cpu': - # if thread_data.reduced_memory: - # modelFS.to('cpu') - # else: - # modelFS.to(thread_data.device) # Preload on device if not already there. - thread_data.modelFS = modelFS - del sd - - if thread_data.device != "cpu" and thread_data.precision == "autocast": - thread_data.model.half() - thread_data.modelCS.half() - thread_data.modelFS.half() - thread_data.model_is_half = True - thread_data.model_fs_is_half = True - else: - thread_data.model_is_half = False - thread_data.model_fs_is_half = False - - print(f'''loaded model - model file: {thread_data.ckpt_file} - model.device: {model.device} - modelCS.device: {modelCS.cond_stage_model.device} - modelFS.device: {thread_data.modelFS.device} - using precision: {thread_data.precision}''') - -def load_model_ckpt_sd2(): - sd, model_ver = load_model_from_config(thread_data.ckpt_file) - - config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if model_ver == 'sd2' else "configs/stable-diffusion/v1-inference.yaml" - config = OmegaConf.load(config_file) - verbose = False - - thread_data.model = instantiate_from_config(config.model) - m, u = thread_data.model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - thread_data.model.to(thread_data.device) - thread_data.model.eval() - del sd - - thread_data.model.cond_stage_model.device = torch.device(thread_data.device) - - if thread_data.device != "cpu" and thread_data.precision == "autocast": - thread_data.model.half() - thread_data.model_is_half = True - thread_data.model_fs_is_half = True - else: - thread_data.model_is_half = False - thread_data.model_fs_is_half = False - - print(f'''loaded model - model file: {thread_data.ckpt_file} - using precision: {thread_data.precision}''') - -def unload_filters(): - if thread_data.model_gfpgan is not None: - if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu') - - del thread_data.model_gfpgan - thread_data.model_gfpgan = None - - if thread_data.model_real_esrgan is not None: - if thread_data.device != 'cpu': thread_data.model_real_esrgan.model.to('cpu') - - del thread_data.model_real_esrgan - thread_data.model_real_esrgan = None - - gc() - -def unload_models(): - if thread_data.model is not None: - print('Unloading models...') - if thread_data.device != 'cpu': - if not thread_data.test_sd2: - thread_data.modelFS.to('cpu') - thread_data.modelCS.to('cpu') - thread_data.model.model1.to("cpu") - thread_data.model.model2.to("cpu") - - del thread_data.model - del thread_data.modelCS - del thread_data.modelFS - - thread_data.model = None - thread_data.modelCS = None - thread_data.modelFS = None - - gc() - -# def wait_model_move_to(model, target_device): # Send to target_device and wait until complete. -# if thread_data.device == target_device: return -# start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 -# if start_mem <= 0: return -# model_name = model.__class__.__name__ -# print(f'Device {thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mb') -# start_time = time.time() -# model.to(target_device) -# time_step = start_time -# WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout. -# last_mem = start_mem -# is_transfering = True -# while is_transfering: -# time.sleep(0.5) # 500ms -# mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 -# is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time. -# last_mem = mem -# if not is_transfering: -# break; -# if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity. -# print(f'Device {thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mb, Transfered: {round(start_mem - mem)}Mb') -# time_step = time.time() -# print(f'Device {thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mb in {round(time.time() - start_time, 3)} seconds to {target_device}') - -def move_to_cpu(model): - if thread_data.device != "cpu": - d = torch.device(thread_data.device) - mem = torch.cuda.memory_allocated(d) / 1e6 - model.to("cpu") - while torch.cuda.memory_allocated(d) / 1e6 >= mem: - time.sleep(1) - -def load_model_gfpgan(): - if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.') - model_path = thread_data.gfpgan_file + ".pth" - thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) - print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision) - -def load_model_real_esrgan(): - if thread_data.real_esrgan_file is None: raise ValueError(f'Thread real_esrgan_file is undefined.') - model_path = thread_data.real_esrgan_file + ".pth" - - RealESRGAN_models = { - 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), - 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) - } - - model_to_use = RealESRGAN_models[thread_data.real_esrgan_file] - - if thread_data.device == 'cpu': - thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half - #thread_data.model_real_esrgan.device = torch.device(thread_data.device) - thread_data.model_real_esrgan.model.to('cpu') - else: - thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half) - - thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file - print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision) - - -def get_session_out_path(disk_path, session_id): - if disk_path is None: return None - if session_id is None: return None - - session_out_path = os.path.join(disk_path, filename_regex.sub('_',session_id)) - os.makedirs(session_out_path, exist_ok=True) - return session_out_path - -def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None): - if disk_path is None: return None - if session_id is None: return None - if ext is None: raise Exception('Missing ext') - - session_out_path = get_session_out_path(disk_path, session_id) - - prompt_flattened = filename_regex.sub('_', prompt)[:50] - - if suffix is not None: - return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}") - return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}") - -def apply_filters(filter_name, image_data, model_path=None): - print(f'Applying filter {filter_name}...') - gc() # Free space before loading new data. - - if isinstance(image_data, torch.Tensor): - image_data.to(thread_data.device) - - if filter_name == 'gfpgan': - # This lock is only ever used here. No need to use timeout for the request. Should never deadlock. - with gfpgan_temp_device_lock: # Wait for any other devices to complete before starting. - # hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files - from facexlib.detection import retinaface - retinaface.device = torch.device(thread_data.device) - print('forced retinaface.device to', thread_data.device) - - if model_path is not None and model_path != thread_data.gfpgan_file: - thread_data.gfpgan_file = model_path - load_model_gfpgan() - elif not thread_data.model_gfpgan: - load_model_gfpgan() - if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.') - - print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision) - _, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) - image_data = output[:,:,::-1] - - if filter_name == 'real_esrgan': - if model_path is not None and model_path != thread_data.real_esrgan_file: - thread_data.real_esrgan_file = model_path - load_model_real_esrgan() - elif not thread_data.model_real_esrgan: - load_model_real_esrgan() - if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.') - print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision) - output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1]) - image_data = output[:,:,::-1] - - return image_data - -def is_model_reload_necessary(req: Request): - # custom model support: - # the req.use_stable_diffusion_model needs to be a valid path - # to the ckpt file (without the extension). - if os.path.exists(req.use_stable_diffusion_model + '.ckpt'): - req.use_stable_diffusion_model += '.ckpt' - elif os.path.exists(req.use_stable_diffusion_model + '.safetensors'): - req.use_stable_diffusion_model += '.safetensors' - elif not os.path.exists(req.use_stable_diffusion_model): - raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt or .safetensors') - - needs_model_reload = False - if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: - thread_data.ckpt_file = req.use_stable_diffusion_model - thread_data.vae_file = req.use_vae_model - needs_model_reload = True - - if thread_data.device != 'cpu': - if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ - (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): - thread_data.precision = 'full' if req.use_full_precision else 'autocast' - needs_model_reload = True - - return needs_model_reload - -def reload_model(): - unload_models() - unload_filters() - load_model_ckpt() - -def is_hypernetwork_reload_necessary(req: Request): - needs_model_reload = False - if thread_data.hypernetwork_file != req.use_hypernetwork_model: - thread_data.hypernetwork_file = req.use_hypernetwork_model - needs_model_reload = True - - return needs_model_reload - -def load_hypernetwork(): - if thread_data.test_sd2: - # Not yet supported in SD2 - return - - from . import hypernetwork - if thread_data.hypernetwork_file is not None: - try: - loaded = False - for model_extension in HYPERNETWORK_MODEL_EXTENSIONS: - if os.path.exists(thread_data.hypernetwork_file + model_extension): - print(f"Loading hypernetwork weights from: {thread_data.hypernetwork_file}{model_extension}") - thread_data.hypernetwork = hypernetwork.load_hypernetwork(thread_data.hypernetwork_file + model_extension) - loaded = True - break - - if not loaded: - print(f'Cannot find hypernetwork: {thread_data.hypernetwork_file}') - thread_data.hypernetwork_file = None - except: - print(traceback.format_exc()) - print(f'Could not load hypernetwork: {thread_data.hypernetwork_file}') - thread_data.hypernetwork_file = None - -def unload_hypernetwork(): - if thread_data.hypernetwork is not None: - print('Unloading hypernetwork...') - if thread_data.device != 'cpu': - for i in thread_data.hypernetwork: - thread_data.hypernetwork[i][0].to('cpu') - thread_data.hypernetwork[i][1].to('cpu') - del thread_data.hypernetwork - thread_data.hypernetwork = None - - gc() - -def reload_hypernetwork(): - unload_hypernetwork() - load_hypernetwork() - -def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): - try: - return do_mk_img(req, data_queue, task_temp_images, step_callback) - except Exception as e: - print(traceback.format_exc()) - - if thread_data.device != 'cpu' and not thread_data.test_sd2: - thread_data.modelFS.to('cpu') - thread_data.modelCS.to('cpu') - thread_data.model.model1.to("cpu") - thread_data.model.model2.to("cpu") - - gc() # Release from memory. - data_queue.put(json.dumps({ - "status": 'failed', - "detail": str(e) - })) - raise e - -def update_temp_img(req, x_samples, task_temp_images: list): - partial_images = [] - for i in range(req.num_outputs): - if thread_data.test_sd2: - x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0)) - else: - x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) - x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") - x_sample = x_sample.astype(np.uint8) - img = Image.fromarray(x_sample) - buf = img_to_buffer(img, output_format='JPEG') - - del img, x_sample, x_sample_ddim - # don't delete x_samples, it is used in the code that called this callback - - thread_data.temp_images[f'{req.request_id}/{i}'] = buf - task_temp_images[i] = buf - partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'}) - return partial_images - -# Build and return the apropriate generator for do_mk_img -def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None): - if not req.stream_progress_updates: - def empty_callback(x_samples, i): - step_callback() - return empty_callback - - thread_data.partial_x_samples = None - last_callback_time = -1 - def img_callback(x_samples, i): - nonlocal last_callback_time - - thread_data.partial_x_samples = x_samples - step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 - last_callback_time = time.time() - - progress = {"step": i, "step_time": step_time} - if extra_props is not None: - progress.update(extra_props) - - if req.stream_image_progress and i % 5 == 0: - progress['output'] = update_temp_img(req, x_samples, task_temp_images) - - data_queue.put(json.dumps(progress)) - - step_callback() - - if thread_data.stop_processing: - raise UserInitiatedStop("User requested that we stop processing") - return img_callback - -def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): - thread_data.stop_processing = False - - res = Response() - res.request = req - res.images = [] - thread_data.hypernetwork_strength = req.hypernetwork_strength - - thread_data.temp_images.clear() - - if thread_data.turbo != req.turbo and not thread_data.test_sd2: - thread_data.turbo = req.turbo - thread_data.model.turbo = req.turbo - - # Start by cleaning memory, loading and unloading things can leave memory allocated. - gc() - - opt_prompt = req.prompt - opt_seed = req.seed - opt_n_iter = 1 - opt_C = 4 - opt_f = 8 - opt_ddim_eta = 0.0 - - print(req, '\n device', torch.device(thread_data.device), "as", thread_data.device_name) - print('\n\n Using precision:', thread_data.precision) - - seed_everything(opt_seed) - - batch_size = req.num_outputs - prompt = opt_prompt - assert prompt is not None - data = [batch_size * [prompt]] - - if thread_data.precision == "autocast" and thread_data.device != "cpu": - precision_scope = autocast - else: - precision_scope = nullcontext - - mask = None - - if req.init_image is None: - handler = _txt2img - - init_latent = None - t_enc = None - else: - handler = _img2img - - init_image = load_img(req.init_image, req.width, req.height) - init_image = init_image.to(thread_data.device) - - if thread_data.device != "cpu" and thread_data.precision == "autocast": - init_image = init_image.half() - - if not thread_data.test_sd2: - thread_data.modelFS.to(thread_data.device) - - init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) - if thread_data.test_sd2: - init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space - else: - init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space - - if req.mask is not None: - mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device) - mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0) - mask = repeat(mask, '1 ... -> b ...', b=batch_size) - - if thread_data.device != "cpu" and thread_data.precision == "autocast": - mask = mask.half() - - # Send to CPU and wait until complete. - # wait_model_move_to(thread_data.modelFS, 'cpu') - if not thread_data.test_sd2: - move_to_cpu(thread_data.modelFS) - - assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]' - t_enc = int(req.prompt_strength * req.num_inference_steps) - print(f"target t_enc is {t_enc} steps") - - with torch.no_grad(): - for n in trange(opt_n_iter, desc="Sampling"): - for prompts in tqdm(data, desc="data"): - - with precision_scope("cuda"): - if thread_data.reduced_memory and not thread_data.test_sd2: - thread_data.modelCS.to(thread_data.device) - uc = None - if req.guidance_scale != 1.0: - if thread_data.test_sd2: - uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt]) - else: - uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) - if isinstance(prompts, tuple): - prompts = list(prompts) - - subprompts, weights = split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - c = torch.zeros_like(uc) - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(len(subprompts)): - weight = weights[i] - # if not skip_normalize: - weight = weight / totalWeight - if thread_data.test_sd2: - c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight) - else: - c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) - else: - if thread_data.test_sd2: - c = thread_data.model.get_learned_conditioning(prompts) - else: - c = thread_data.modelCS.get_learned_conditioning(prompts) - - if thread_data.reduced_memory and not thread_data.test_sd2: - thread_data.modelFS.to(thread_data.device) - - n_steps = req.num_inference_steps if req.init_image is None else t_enc - img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps}) - - # run the handler - try: - print('Running handler...') - if handler == _txt2img: - x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler) - else: - x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f) - except UserInitiatedStop: - if not hasattr(thread_data, 'partial_x_samples'): - continue - if thread_data.partial_x_samples is None: - del thread_data.partial_x_samples - continue - x_samples = thread_data.partial_x_samples - del thread_data.partial_x_samples - - print("decoding images") - img_data = [None] * batch_size - for i in range(batch_size): - if thread_data.test_sd2: - x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0)) - else: - x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) - x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") - x_sample = x_sample.astype(np.uint8) - img_data[i] = x_sample - del x_samples, x_samples_ddim, x_sample - - print("saving images") - for i in range(batch_size): - img = Image.fromarray(img_data[i]) - img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. - img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. - - has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \ - (req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN')) - - return_orig_img = not has_filters or not req.show_only_filtered_image - - if thread_data.stop_processing: - return_orig_img = True - - if req.save_to_disk_path is not None: - if return_orig_img: - img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format) - save_image(img, img_out_path, req.output_format, req.output_quality) - meta_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, 'txt') - save_metadata(meta_out_path, req, prompts[0], opt_seed) - - if return_orig_img: - img_buffer = img_to_buffer(img, req.output_format, req.output_quality) - img_str = buffer_to_base64_str(img_buffer, req.output_format) - res_image_orig = ResponseImage(data=img_str, seed=opt_seed) - res.images.append(res_image_orig) - task_temp_images[i] = img_buffer - - if req.save_to_disk_path is not None: - res_image_orig.path_abs = img_out_path - del img - - if has_filters and not thread_data.stop_processing: - filters_applied = [] - if req.use_face_correction: - img_data[i] = apply_filters('gfpgan', img_data[i], req.use_face_correction) - filters_applied.append(req.use_face_correction) - if req.use_upscale: - img_data[i] = apply_filters('real_esrgan', img_data[i], req.use_upscale) - filters_applied.append(req.use_upscale) - if (len(filters_applied) > 0): - filtered_image = Image.fromarray(img_data[i]) - filtered_buffer = img_to_buffer(filtered_image, req.output_format, req.output_quality) - filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format) - response_image = ResponseImage(data=filtered_img_data, seed=opt_seed) - res.images.append(response_image) - task_temp_images[i] = filtered_buffer - if req.save_to_disk_path is not None: - filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied)) - save_image(filtered_image, filtered_img_out_path, req.output_format, req.output_quality) - response_image.path_abs = filtered_img_out_path - del filtered_image - # Filter Applied, move to next seed - opt_seed += 1 - - # if thread_data.reduced_memory: - # unload_filters() - if not thread_data.test_sd2: - move_to_cpu(thread_data.modelFS) - del img_data - gc() - if thread_data.device != 'cpu': - print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb') - - print('Task completed') - res = res.json() - data_queue.put(json.dumps(res)) - - return res - -def save_image(img, img_out_path, output_format="", output_quality=75): - try: - if output_format.upper() == "JPEG": - img.save(img_out_path, quality=output_quality) - else: - img.save(img_out_path) - except: - print('could not save the file', traceback.format_exc()) - -def save_metadata(meta_out_path, req, prompt, opt_seed): - metadata = f'''{prompt} -Width: {req.width} -Height: {req.height} -Seed: {opt_seed} -Steps: {req.num_inference_steps} -Guidance Scale: {req.guidance_scale} -Prompt Strength: {req.prompt_strength} -Use Face Correction: {req.use_face_correction} -Use Upscaling: {req.use_upscale} -Sampler: {req.sampler} -Negative Prompt: {req.negative_prompt} -Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'} -VAE model: {req.use_vae_model} -Hypernetwork Model: {req.use_hypernetwork_model} -Hypernetwork Strength: {req.hypernetwork_strength} -''' - try: - with open(meta_out_path, 'w', encoding='utf-8') as f: - f.write(metadata) - except: - print('could not save the file', traceback.format_exc()) - -def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name): - shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] - - # Send to CPU and wait until complete. - # wait_model_move_to(thread_data.modelCS, 'cpu') - - if not thread_data.test_sd2: - move_to_cpu(thread_data.modelCS) - - if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim', 'dpm2'): - raise Exception('Only plms, ddim and dpm2 samplers are supported right now, in SD 2.0') - - - # samples, _ = sampler.sample(S=opt.steps, - # conditioning=c, - # batch_size=opt.n_samples, - # shape=shape, - # verbose=False, - # unconditional_guidance_scale=opt.scale, - # unconditional_conditioning=uc, - # eta=opt.ddim_eta, - # x_T=start_code) - - if thread_data.test_sd2: - if sampler_name == 'plms': - from ldm.models.diffusion.plms import PLMSSampler - sampler = PLMSSampler(thread_data.model) - elif sampler_name == 'ddim': - from ldm.models.diffusion.ddim import DDIMSampler - sampler = DDIMSampler(thread_data.model) - sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) - elif sampler_name == 'dpm2': - from ldm.models.diffusion.dpm_solver import DPMSolverSampler - sampler = DPMSolverSampler(thread_data.model) - - shape = [opt_C, opt_H // opt_f, opt_W // opt_f] - - samples_ddim, intermediates = sampler.sample( - S=opt_ddim_steps, - conditioning=c, - batch_size=opt_n_samples, - seed=opt_seed, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt_scale, - unconditional_conditioning=uc, - eta=opt_ddim_eta, - x_T=start_code, - img_callback=img_callback, - mask=mask, - sampler = sampler_name, - ) - else: - if sampler_name == 'ddim': - thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) - - samples_ddim = thread_data.model.sample( - S=opt_ddim_steps, - conditioning=c, - seed=opt_seed, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt_scale, - unconditional_conditioning=uc, - eta=opt_ddim_eta, - x_T=start_code, - img_callback=img_callback, - mask=mask, - sampler = sampler_name, - ) - return samples_ddim - -def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1): - # encode (scaled latent) - x_T = None if mask is None else init_latent - - if thread_data.test_sd2: - from ldm.models.diffusion.ddim import DDIMSampler - - sampler = DDIMSampler(thread_data.model) - - sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) - - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(thread_data.device)) - - samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt_scale,unconditional_conditioning=uc, img_callback=img_callback) - - else: - z_enc = thread_data.model.stochastic_encode( - init_latent, - torch.tensor([t_enc] * batch_size).to(thread_data.device), - opt_seed, - opt_ddim_eta, - opt_ddim_steps, - ) - - # decode it - samples_ddim = thread_data.model.sample( - t_enc, - c, - z_enc, - unconditional_guidance_scale=opt_scale, - unconditional_conditioning=uc, - img_callback=img_callback, - mask=mask, - x_T=x_T, - sampler = 'ddim' - ) - return samples_ddim - -def gc(): - gc_collect() - if thread_data.device == 'cpu': - return - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - -# internal - -def chunk(it, size): - it = iter(it) - return iter(lambda: tuple(islice(it, size)), ()) - -def load_model_from_config(ckpt, verbose=False): - print(f"Loading model from {ckpt}") - model_ver = 'sd1' - - if ckpt.endswith(".safetensors"): - print("Loading from safetensors") - pl_sd = load_file(ckpt, device="cpu") - else: - pl_sd = torch.load(ckpt, map_location="cpu") - - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - - if "state_dict" in pl_sd: - # check for a key that only seems to be present in SD2 models - if 'cond_stage_model.model.ln_final.bias' in pl_sd['state_dict'].keys(): - model_ver = 'sd2' - - return pl_sd["state_dict"], model_ver - else: - return pl_sd, model_ver - -class UserInitiatedStop(Exception): - pass - -def load_img(img_str, w0, h0): - image = base64_str_to_img(img_str).convert("RGB") - w, h = image.size - print(f"loaded input image of size ({w}, {h}) from base64") - if h0 is not None and w0 is not None: - h, w = h0, w0 - - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 - image = image.resize((w, h), resample=Image.Resampling.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.*image - 1. - -def load_mask(mask_str, h0, w0, newH, newW, invert=False): - image = base64_str_to_img(mask_str).convert("RGB") - w, h = image.size - print(f"loaded input mask of size ({w}, {h})") - - if invert: - print("inverted") - image = ImageOps.invert(image) - # where_0, where_1 = np.where(image == 0), np.where(image == 255) - # image[where_0], image[where_1] = 255, 0 - - if h0 is not None and w0 is not None: - h, w = h0, w0 - - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 - - print(f"New mask size ({w}, {h})") - image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS) - image = np.array(image) - - image = image.astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return image - -# https://stackoverflow.com/a/61114178 -def img_to_base64_str(img, output_format="PNG", output_quality=75): - buffered = img_to_buffer(img, output_format, quality=output_quality) - return buffer_to_base64_str(buffered, output_format) - -def img_to_buffer(img, output_format="PNG", output_quality=75): - buffered = BytesIO() - if ( output_format.upper() == "JPEG" ): - img.save(buffered, format=output_format, quality=output_quality) - else: - img.save(buffered, format=output_format) - buffered.seek(0) - return buffered - -def buffer_to_base64_str(buffered, output_format="PNG"): - buffered.seek(0) - img_byte = buffered.getvalue() - mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg" - img_str = f"data:{mime_type};base64," + base64.b64encode(img_byte).decode() - return img_str - -def base64_str_to_buffer(img_str): - mime_type = "image/png" if img_str.startswith("data:image/png;") else "image/jpeg" - img_str = img_str[len(f"data:{mime_type};base64,"):] - data = base64.b64decode(img_str) - buffered = BytesIO(data) - return buffered - -def base64_str_to_img(img_str): - buffered = base64_str_to_buffer(img_str) - img = Image.open(buffered) - return img - -def split_weighted_subprompts(text): - """ - grabs all text up to the first occurrence of ':' - uses the grabbed text as a sub-prompt, and takes the value following ':' as weight - if ':' has no value defined, defaults to 1.0 - repeats until no text remaining - """ - remaining = len(text) - prompts = [] - weights = [] - while remaining > 0: - if ":" in text: - idx = text.index(":") # first occurrence from start - # grab up to index as sub-prompt - prompt = text[:idx] - remaining -= idx - # remove from main text - text = text[idx+1:] - # find value for weight - if " " in text: - idx = text.index(" ") # first occurence - else: # no space, read to end - idx = len(text) - if idx != 0: - try: - weight = float(text[:idx]) - except: # couldn't treat as float - print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?") - weight = 1.0 - else: # no value found - weight = 1.0 - # remove from main text - remaining -= idx - text = text[idx+1:] - # append the sub-prompt and its weight - prompts.append(prompt) - weights.append(weight) - else: # no : found - if len(text) > 0: # there is still text though - # take remainder as weight 1 - prompts.append(text) - weights.append(1.0) - remaining = 0 - return prompts, weights diff --git a/ui/server.py b/ui/server.py deleted file mode 100644 index d69b03fb..00000000 --- a/ui/server.py +++ /dev/null @@ -1,500 +0,0 @@ -"""server.py: FastAPI SD-UI Web Host. -Notes: - async endpoints always run on the main thread. Without they run on the thread pool. -""" -import json -import traceback - -import sys -import os -import socket -import picklescan.scanner -import rich - -SD_DIR = os.getcwd() -print('started in ', SD_DIR) - -SD_UI_DIR = os.getenv('SD_UI_PATH', None) -sys.path.append(os.path.dirname(SD_UI_DIR)) - -CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) -MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) - -USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui')) -CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui')) -UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user')) - -STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors'] -VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] -HYPERNETWORK_MODEL_EXTENSIONS = ['.pt'] - -OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder -TASK_TTL = 15 * 60 # Discard last session's task timeout -APP_CONFIG_DEFAULTS = { - # auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device. - 'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index) - 'update_branch': 'main', - 'ui': { - 'open_browser_on_start': True, - }, -} -APP_CONFIG_DEFAULT_MODELS = [ - # needed to support the legacy installations - 'custom-model', # Check if user has a custom model, use it first. - 'sd-v1-4', # Default fallback. -] - -from fastapi import FastAPI, HTTPException -from fastapi.staticfiles import StaticFiles -from starlette.responses import FileResponse, JSONResponse, StreamingResponse -from pydantic import BaseModel -import logging -from typing import Any, Generator, Hashable, List, Optional, Union - -from sd_internal import Request, Response, task_manager - -app = FastAPI() - -outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME) - -os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) - -# don't show access log entries for URLs that start with the given prefix -ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails'] - -NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} - -class NoCacheStaticFiles(StaticFiles): - def is_not_modified(self, response_headers, request_headers) -> bool: - if 'content-type' in response_headers and ('javascript' in response_headers['content-type'] or 'css' in response_headers['content-type']): - response_headers.update(NOCACHE_HEADERS) - return False - - return super().is_not_modified(response_headers, request_headers) - -app.mount('/media', NoCacheStaticFiles(directory=os.path.join(SD_UI_DIR, 'media')), name="media") - -for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES: - app.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}") - -def getConfig(default_val=APP_CONFIG_DEFAULTS): - try: - config_json_path = os.path.join(CONFIG_DIR, 'config.json') - if not os.path.exists(config_json_path): - return default_val - with open(config_json_path, 'r', encoding='utf-8') as f: - config = json.load(f) - if 'net' not in config: - config['net'] = {} - if os.getenv('SD_UI_BIND_PORT') is not None: - config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT')) - if os.getenv('SD_UI_BIND_IP') is not None: - config['net']['listen_to_network'] = ( os.getenv('SD_UI_BIND_IP') == '0.0.0.0' ) - return config - except Exception as e: - print(str(e)) - print(traceback.format_exc()) - return default_val - -def setConfig(config): - print( json.dumps(config) ) - try: # config.json - config_json_path = os.path.join(CONFIG_DIR, 'config.json') - with open(config_json_path, 'w', encoding='utf-8') as f: - json.dump(config, f) - except: - print(traceback.format_exc()) - - try: # config.bat - config_bat_path = os.path.join(CONFIG_DIR, 'config.bat') - config_bat = [] - - if 'update_branch' in config: - config_bat.append(f"@set update_branch={config['update_branch']}") - - config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}") - bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' - config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") - - config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}") - - if len(config_bat) > 0: - with open(config_bat_path, 'w', encoding='utf-8') as f: - f.write('\r\n'.join(config_bat)) - except: - print(traceback.format_exc()) - - try: # config.sh - config_sh_path = os.path.join(CONFIG_DIR, 'config.sh') - config_sh = ['#!/bin/bash'] - - if 'update_branch' in config: - config_sh.append(f"export update_branch={config['update_branch']}") - - config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}") - bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' - config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") - - config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"") - - if len(config_sh) > 1: - with open(config_sh_path, 'w', encoding='utf-8') as f: - f.write('\n'.join(config_sh)) - except: - print(traceback.format_exc()) - -def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]): - config = getConfig() - - model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR] - if not model_name: # When None try user configured model. - # config = getConfig() - if 'model' in config and model_type in config['model']: - model_name = config['model'][model_type] - if model_name: - is_sd2 = config.get('test_sd2', False) - if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4 - print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!') - model_name = 'sd-v1-4' - - # Check models directory - models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name) - for model_extension in model_extensions: - if os.path.exists(models_dir_path + model_extension): - return models_dir_path - if os.path.exists(model_name + model_extension): - # Direct Path to file - model_name = os.path.abspath(model_name) - return model_name - # Default locations - if model_name in default_models: - default_model_path = os.path.join(SD_DIR, model_name) - for model_extension in model_extensions: - if os.path.exists(default_model_path + model_extension): - return default_model_path - # Can't find requested model, check the default paths. - for default_model in default_models: - for model_dir in model_dirs: - default_model_path = os.path.join(model_dir, default_model) - for model_extension in model_extensions: - if os.path.exists(default_model_path + model_extension): - if model_name is not None: - print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') - return default_model_path - raise Exception('No valid models found.') - -def resolve_ckpt_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=APP_CONFIG_DEFAULT_MODELS) - -def resolve_vae_to_use(model_name:str=None): - try: - return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[]) - except: - return None - -def resolve_hypernetwork_to_use(model_name:str=None): - try: - return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[]) - except: - return None - -class SetAppConfigRequest(BaseModel): - update_branch: str = None - render_devices: Union[List[str], List[int], str, int] = None - model_vae: str = None - ui_open_browser_on_start: bool = None - listen_to_network: bool = None - listen_port: int = None - test_sd2: bool = None - -@app.post('/app_config') -async def setAppConfig(req : SetAppConfigRequest): - config = getConfig() - if req.update_branch is not None: - config['update_branch'] = req.update_branch - if req.render_devices is not None: - update_render_devices_in_config(config, req.render_devices) - if req.ui_open_browser_on_start is not None: - if 'ui' not in config: - config['ui'] = {} - config['ui']['open_browser_on_start'] = req.ui_open_browser_on_start - if req.listen_to_network is not None: - if 'net' not in config: - config['net'] = {} - config['net']['listen_to_network'] = bool(req.listen_to_network) - if req.listen_port is not None: - if 'net' not in config: - config['net'] = {} - config['net']['listen_port'] = int(req.listen_port) - if req.test_sd2 is not None: - config['test_sd2'] = req.test_sd2 - try: - setConfig(config) - - if req.render_devices: - update_render_threads() - - return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS) - except Exception as e: - print(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - -def is_malicious_model(file_path): - try: - scan_result = picklescan.scanner.scan_file_path(file_path) - if scan_result.issues_count > 0 or scan_result.infected_files > 0: - rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) - return True - else: - rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) - return False - except Exception as e: - print('error while scanning', file_path, 'error:', e) - return False - -known_models = {} -def getModels(): - models = { - 'active': { - 'stable-diffusion': 'sd-v1-4', - 'vae': '', - 'hypernetwork': '', - }, - 'options': { - 'stable-diffusion': ['sd-v1-4'], - 'vae': [], - 'hypernetwork': [], - }, - } - - def listModels(models_dirname, model_type, model_extensions): - models_dir = os.path.join(MODELS_DIR, models_dirname) - if not os.path.exists(models_dir): - os.makedirs(models_dir) - - for file in os.listdir(models_dir): - for model_extension in model_extensions: - if not file.endswith(model_extension): - continue - - model_path = os.path.join(models_dir, file) - mtime = os.path.getmtime(model_path) - mod_time = known_models[model_path] if model_path in known_models else -1 - if mod_time != mtime: - if is_malicious_model(model_path): - models['scan-error'] = file - return - known_models[model_path] = mtime - - model_name = file[:-len(model_extension)] - models['options'][model_type].append(model_name) - - models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates - models['options'][model_type].sort() - - # custom models - listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS) - listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS) - listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS) - # legacy - custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt') - if os.path.exists(custom_weight_path): - models['options']['stable-diffusion'].append('custom-model') - - return models - -def getUIPlugins(): - plugins = [] - - for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES: - for file in os.listdir(plugins_dir): - if file.endswith('.plugin.js'): - plugins.append(f'/plugins/{dir_prefix}/{file}') - - return plugins - -def getIPConfig(): - try: - ips = socket.gethostbyname_ex(socket.gethostname()) - ips[2].append(ips[0]) - return ips[2] - except Exception as e: - print(e) - print(traceback.format_exc()) - return [] - - -@app.get('/get/{key:path}') -def read_web_data(key:str=None): - if not key: # /get without parameters, stable-diffusion easter egg. - raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot - elif key == 'app_config': - config = getConfig(default_val=None) - if config is None: - config = APP_CONFIG_DEFAULTS - return JSONResponse(config, headers=NOCACHE_HEADERS) - elif key == 'system_info': - config = getConfig() - system_info = { - 'devices': task_manager.get_devices(), - 'hosts': getIPConfig(), - } - system_info['devices']['config'] = config.get('render_devices', "auto") - return JSONResponse(system_info, headers=NOCACHE_HEADERS) - elif key == 'models': - return JSONResponse(getModels(), headers=NOCACHE_HEADERS) - elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) - elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS) - elif key == 'ui_plugins': return JSONResponse(getUIPlugins(), headers=NOCACHE_HEADERS) - else: - raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found - -@app.get('/ping') # Get server and optionally session status. -def ping(session_id:str=None): - if task_manager.is_alive() <= 0: # Check that render threads are alive. - if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) - raise HTTPException(status_code=500, detail='Render thread is dead.') - if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) - # Alive - response = {'status': str(task_manager.current_state)} - if session_id: - session = task_manager.get_cached_session(session_id, update_ttl=True) - response['tasks'] = {id(t): t.status for t in session.tasks} - response['devices'] = task_manager.get_devices() - return JSONResponse(response, headers=NOCACHE_HEADERS) - -def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name): - config = getConfig() - if 'model' not in config: - config['model'] = {} - - config['model']['stable-diffusion'] = ckpt_model_name - config['model']['vae'] = vae_model_name - config['model']['hypernetwork'] = hypernetwork_model_name - - if vae_model_name is None or vae_model_name == "": - del config['model']['vae'] - if hypernetwork_model_name is None or hypernetwork_model_name == "": - del config['model']['hypernetwork'] - - setConfig(config) - -def update_render_devices_in_config(config, render_devices): - if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'): - raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}') - - if render_devices.startswith('cuda:'): - render_devices = render_devices.split(',') - - config['render_devices'] = render_devices - -@app.post('/render') -def render(req : task_manager.ImageRequest): - try: - save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) - req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model) - req.use_vae_model = resolve_vae_to_use(req.use_vae_model) - req.use_hypernetwork_model = resolve_hypernetwork_to_use(req.use_hypernetwork_model) - new_task = task_manager.render(req) - response = { - 'status': str(task_manager.current_state), - 'queue': len(task_manager.tasks_queue), - 'stream': f'/image/stream/{id(new_task)}', - 'task': id(new_task) - } - return JSONResponse(response, headers=NOCACHE_HEADERS) - except ChildProcessError as e: # Render thread is dead - raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error - except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. - raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable - except Exception as e: - print(e) - print(traceback.format_exc()) - raise HTTPException(status_code=500, detail=str(e)) - -@app.get('/image/stream/{task_id:int}') -def stream(task_id:int): - #TODO Move to WebSockets ?? - task = task_manager.get_cached_task(task_id, update_ttl=True) - if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound - #if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict - if task.buffer_queue.empty() and not task.lock.locked(): - if task.response: - #print(f'Session {session_id} sending cached response') - return JSONResponse(task.response, headers=NOCACHE_HEADERS) - raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early - #print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') - return StreamingResponse(task.read_buffer_generator(), media_type='application/json') - -@app.get('/image/stop') -def stop(task: int): - if not task: - if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable: - raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict - task_manager.current_state_error = StopAsyncIteration('') - return {'OK'} - task_id = task - task = task_manager.get_cached_task(task_id, update_ttl=False) - if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found - if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict - task.error = StopAsyncIteration(f'Task {task_id} stop requested.') - return {'OK'} - -@app.get('/image/tmp/{task_id:int}/{img_id:int}') -def get_image(task_id: int, img_id: int): - task = task_manager.get_cached_task(task_id, update_ttl=True) - if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound - if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early - try: - img_data = task.temp_images[img_id] - img_data.seek(0) - return StreamingResponse(img_data, media_type='image/jpeg') - except KeyError as e: - raise HTTPException(status_code=500, detail=str(e)) - -@app.get('/') -def read_root(): - return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) - -@app.on_event("shutdown") -def shutdown_event(): # Signal render thread to close on shutdown - task_manager.current_state_error = SystemExit('Application shutting down.') - -# don't log certain requests -class LogSuppressFilter(logging.Filter): - def filter(self, record: logging.LogRecord) -> bool: - path = record.getMessage() - for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES: - if path.find(prefix) != -1: - return False - return True -logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) - -# Check models and prepare cache for UI open -getModels() - -# Start the task_manager -task_manager.default_model_to_load = resolve_ckpt_to_use() -task_manager.default_vae_to_load = resolve_vae_to_use() -task_manager.default_hypernetwork_to_load = resolve_hypernetwork_to_use() - -def update_render_threads(): - config = getConfig() - render_devices = config.get('render_devices', 'auto') - active_devices = task_manager.get_devices()['active'].keys() - - print('requesting for render_devices', render_devices) - task_manager.update_render_threads(render_devices, active_devices) - -update_render_threads() - -# start the browser ui -def open_browser(): - config = getConfig() - ui = config.get('ui', {}) - net = config.get('net', {'listen_port':9000}) - port = net.get('listen_port', 9000) - if ui.get('open_browser_on_start', True): - import webbrowser; webbrowser.open(f"http://localhost:{port}") - -open_browser()