Merge branch 'beta' into Custom-modifiers-as-a-plugin

This commit is contained in:
cmdr2 2022-12-01 14:57:39 +05:30 committed by GitHub
commit 1ead764a02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1060 additions and 495 deletions

27
3rd-PARTY-LICENSES Normal file
View File

@ -0,0 +1,27 @@
jquery-confirm
==============
https://craftpip.github.io/jquery-confirm/
jquery-confirm is licensed under the MIT license:
The MIT License (MIT)
Copyright (c) 2019 Boniface Pereira
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@ -19,8 +19,16 @@
- Configuration to prevent the browser from opening on startup - Configuration to prevent the browser from opening on startup
- Lots of minor bug fixes - Lots of minor bug fixes
- A `What's New?` tab in the UI - A `What's New?` tab in the UI
- Ask for a confimation before clearing the results pane or stopping a render task. The dialog can be skipped by holding down the shift key while clicking on the button.
- Show the network addresses of the server in the systems setting dialog
### Detailed changelog ### Detailed changelog
* 2.4.17 - 30 Nov 2022 - Scroll to generated image. Thanks @patriceac
* 2.4.17 - 30 Nov 2022 - Show the network addresses of the server in the systems setting dialog. Thanks @JeLuf
* 2.4.17 - 30 Nov 2022 - Fix a bug where GFPGAN wouldn't work properly when multiple GPUs tried to run it at the same time. Thanks @madrang
* 2.4.17 - 30 Nov 2022 - Confirm before stopping or clearing all the tasks. Thanks @JeLuf
* 2.4.16 - 29 Nov 2022 - Bug fixes for SD 2.0 - remove the need for patching, default to SD 1.4 model if trying to load an SD2 model in SD1.4.
* 2.4.15 - 25 Nov 2022 - Experimental support for SD 2.0. Uses lots of memory, not optimized, probably GPU-only.
* 2.4.14 - 22 Nov 2022 - Change the backend to a custom fork of Stable Diffusion * 2.4.14 - 22 Nov 2022 - Change the backend to a custom fork of Stable Diffusion
* 2.4.13 - 21 Nov 2022 - Change the modifier weight via mouse wheel, drag to reorder selected modifiers, and some more modifier-related fixes. Thanks @patriceac * 2.4.13 - 21 Nov 2022 - Change the modifier weight via mouse wheel, drag to reorder selected modifiers, and some more modifier-related fixes. Thanks @patriceac
* 2.4.12 - 21 Nov 2022 - Another fix for improving how long images take to generate. Reduces the time taken for an enqueued task to start processing. * 2.4.12 - 21 Nov 2022 - Another fix for improving how long images take to generate. Reduces the time taken for an enqueued task to start processing.

View File

@ -29,6 +29,18 @@ call conda activate .\stable-diffusion\env
call where python call where python
call python --version call python --version
@rem set the PYTHONPATH
cd stable-diffusion
set SD_DIR=%cd%
cd env\lib\site-packages
set PYTHONPATH=%SD_DIR%;%cd%
cd ..\..\..
echo PYTHONPATH=%PYTHONPATH%
cd ..
@rem done
echo. echo.
cmd /k cmd /k

View File

@ -35,6 +35,15 @@ if [ "$0" == "bash" ]; then
which python which python
python --version python --version
# set the PYTHONPATH
cd stable-diffusion
SD_PATH=`pwd`
export PYTHONPATH="$SD_PATH:$SD_PATH/env/lib/python3.8/site-packages"
echo "PYTHONPATH=$PYTHONPATH"
cd ..
# done
echo "" echo ""
else else
file_name=$(basename "${BASH_SOURCE[0]}") file_name=$(basename "${BASH_SOURCE[0]}")

View File

@ -27,6 +27,8 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');" @call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');"
if NOT DEFINED test_sd2 set test_sd2=N
@>nul findstr /m "sd_git_cloned" scripts\install_status.txt @>nul findstr /m "sd_git_cloned" scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" ( @if "%ERRORLEVEL%" EQU "0" (
@echo "Stable Diffusion's git repository was already installed. Updating.." @echo "Stable Diffusion's git repository was already installed. Updating.."
@ -37,9 +39,13 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@call git reset --hard @call git reset --hard
@call git pull @call git pull
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch if "%test_sd2%" == "N" (
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
)
if "%test_sd2%" == "Y" (
@call git -c advice.detachedHead=false checkout 5d647c5459f4cd790672512222bc41903c01bb71
)
@cd .. @cd ..
) else ( ) else (
@ -56,8 +62,6 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@cd stable-diffusion @cd stable-diffusion
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
@cd .. @cd ..
) )
@ -346,7 +350,9 @@ echo. > "..\models\vae\Put your VAE files here.txt"
) )
) )
if "%test_sd2%" == "Y" (
@call pip install open_clip_torch==2.0.2
)
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt @>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" NEQ "0" ( @if "%ERRORLEVEL%" NEQ "0" (

View File

@ -21,6 +21,10 @@ python -c "import os; import shutil; frm = 'sd-ui-files/ui/hotfix/9c24e6cd9f499d
# Caution, this file will make your eyes and brain bleed. It's such an unholy mess. # Caution, this file will make your eyes and brain bleed. It's such an unholy mess.
# Note to self: Please rewrite this in Python. For the sake of your own sanity. # Note to self: Please rewrite this in Python. For the sake of your own sanity.
if [ "$test_sd2" == "" ]; then
export test_sd2="N"
fi
if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then
echo "Stable Diffusion's git repository was already installed. Updating.." echo "Stable Diffusion's git repository was already installed. Updating.."
@ -30,9 +34,12 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
git reset --hard git reset --hard
git pull git pull
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed" if [ "$test_sd2" == "N" ]; then
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
elif [ "$test_sd2" == "Y" ]; then
git -c advice.detachedHead=false checkout 5d647c5459f4cd790672512222bc41903c01bb71
fi
cd .. cd ..
else else
@ -47,8 +54,6 @@ else
cd stable-diffusion cd stable-diffusion
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
cd .. cd ..
fi fi
@ -291,6 +296,9 @@ if [ ! -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then
fi fi
fi fi
if [ "$test_sd2" == "Y" ]; then
pip install open_clip_torch==2.0.2
fi
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
echo sd_weights_downloaded >> ../scripts/install_status.txt echo sd_weights_downloaded >> ../scripts/install_status.txt

View File

@ -1,6 +0,0 @@
@call conda --version
@call git --version
cd %CONDA_PREFIX%\..\scripts
on_env_start.bat

View File

@ -1,12 +0,0 @@
#!/bin/bash
conda-unpack
source $CONDA_PREFIX/etc/profile.d/conda.sh
conda --version
git --version
cd $CONDA_PREFIX/../scripts
./on_env_start.sh

View File

@ -3,6 +3,7 @@
<head> <head>
<title>Stable Diffusion UI</title> <title>Stable Diffusion UI</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0"> <meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta name="theme-color" content="#673AB6">
<link rel="icon" type="image/png" href="/media/images/favicon-16x16.png" sizes="16x16"> <link rel="icon" type="image/png" href="/media/images/favicon-16x16.png" sizes="16x16">
<link rel="icon" type="image/png" href="/media/images/favicon-32x32.png" sizes="32x32"> <link rel="icon" type="image/png" href="/media/images/favicon-32x32.png" sizes="32x32">
<link rel="stylesheet" href="/media/css/fonts.css"> <link rel="stylesheet" href="/media/css/fonts.css">
@ -12,7 +13,10 @@
<link rel="stylesheet" href="/media/css/modifier-thumbnails.css"> <link rel="stylesheet" href="/media/css/modifier-thumbnails.css">
<link rel="stylesheet" href="/media/css/fontawesome-all.min.css"> <link rel="stylesheet" href="/media/css/fontawesome-all.min.css">
<link rel="stylesheet" href="/media/css/drawingboard.min.css"> <link rel="stylesheet" href="/media/css/drawingboard.min.css">
<link rel="stylesheet" href="/media/css/jquery-confirm.min.css">
<link rel="manifest" href="/media/manifest.webmanifest">
<script src="/media/js/jquery-3.6.1.min.js"></script> <script src="/media/js/jquery-3.6.1.min.js"></script>
<script src="/media/js/jquery-confirm.min.js"></script>
<script src="/media/js/drawingboard.min.js"></script> <script src="/media/js/drawingboard.min.js"></script>
<script src="/media/js/marked.min.js"></script> <script src="/media/js/marked.min.js"></script>
</head> </head>
@ -22,7 +26,7 @@
<div id="logo"> <div id="logo">
<h1> <h1>
Stable Diffusion UI Stable Diffusion UI
<small>v2.4.14 <span id="updateBranchLabel"></span></small> <small>v2.4.17 <span id="updateBranchLabel"></span></small>
</h1> </h1>
</div> </div>
<div id="server-status"> <div id="server-status">
@ -67,7 +71,7 @@
<div id="init_image_wrapper"> <div id="init_image_wrapper">
<img id="init_image_preview" src="" /> <img id="init_image_preview" src="" />
<span id="init_image_size_box"></span> <span id="init_image_size_box"></span>
<button class="init_image_clear image_clear_btn">X</button> <button class="init_image_clear image_clear_btn"><i class="fa-solid fa-xmark"></i></button>
</div> </div>
<br/> <br/>
@ -82,7 +86,7 @@
</div> </div>
<div id="editor-inputs-tags-container" class="row"> <div id="editor-inputs-tags-container" class="row">
<label>Image Modifiers: <small>(click an Image Modifier to remove it)</small></label> <label>Image Modifiers <i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip right">click an Image Modifier to remove it, use Ctrl+Mouse Wheel to adjust its weight</span></i>:</label>
<div id="editor-inputs-tags-list"></div> <div id="editor-inputs-tags-list"></div>
</div> </div>
@ -250,8 +254,17 @@
<br/><br/> <br/><br/>
<div> <div>
<h3><i class="fa fa-microchip icon"></i> System Info</h3> <h3><i class="fa fa-microchip icon"></i> System Info</h3>
<div id="system-info"></div> <div id="system-info">
<table>
<tr><td><label>Processor:</label></td><td id="system-info-cpu" class="value"></td></tr>
<tr><td><label>Compatible Graphics Cards (all):</label></td><td id="system-info-gpus-all" class="value"></td></tr>
<tr><td></td><td>&nbsp;</td></tr>
<tr><td><label>Used for rendering 🔥:</label></td><td id="system-info-rendering-devices" class="value"></td></tr>
<tr><td><label>Server Addresses <i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip right">You can access Stable Diffusion UI from other devices using these addresses</span></i> :</label></td><td id="system-info-server-hosts" class="value"></td></tr>
</table>
</div>
</div> </div>
</div> </div>
</div> </div>
<div id="tab-content-about" class="tab-content"> <div id="tab-content-about" class="tab-content">
@ -348,7 +361,7 @@ async function init() {
await getAppConfig() await getAppConfig()
await loadUIPlugins() await loadUIPlugins()
await loadModifiers() await loadModifiers()
await getDevices() await getSystemInfo()
setInterval(healthCheck, HEALTH_PING_INTERVAL * 1000) setInterval(healthCheck, HEALTH_PING_INTERVAL * 1000)
healthCheck() healthCheck()

9
ui/media/css/jquery-confirm.min.css vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -64,6 +64,11 @@ code {
top: 0px; top: 0px;
right: 0px; right: 0px;
} }
.image_clear_btn:active {
position: absolute;
top: 0px;
left: auto;
}
.settings-box ul { .settings-box ul {
font-size: 9pt; font-size: 9pt;
margin-bottom: 5px; margin-bottom: 5px;
@ -210,7 +215,7 @@ code {
} }
.collapsible-content { .collapsible-content {
display: block; display: block;
padding-left: 15px; padding-left: 10px;
} }
.collapsible-content h5 { .collapsible-content h5 {
padding: 5pt 0pt; padding: 5pt 0pt;
@ -658,11 +663,15 @@ input::file-selector-button {
opacity: 1; opacity: 1;
} }
/* MOBILE SUPPORT */ /* Small screens */
@media screen and (max-width: 700px) { @media screen and (max-width: 1265px) {
#top-nav { #top-nav {
flex-direction: column; flex-direction: column;
} }
}
/* MOBILE SUPPORT */
@media screen and (max-width: 700px) {
body { body {
margin: 0px; margin: 0px;
} }
@ -712,7 +721,7 @@ input::file-selector-button {
padding-right: 0px; padding-right: 0px;
} }
#server-status { #server-status {
display: none; top: 75%;
} }
.popup > div { .popup > div {
padding-left: 5px !important; padding-left: 5px !important;
@ -730,6 +739,15 @@ input::file-selector-button {
} }
} }
@media screen and (max-width: 500px) {
#server-status #server-status-msg {
display: none;
}
#server-status:hover #server-status-msg {
display: inline;
}
}
@media (min-width: 700px) { @media (min-width: 700px) {
/* #editor { /* #editor {
max-width: 480px; max-width: 480px;
@ -997,8 +1015,17 @@ button:hover {
button:active { button:active {
transition-duration: 0.1s; transition-duration: 0.1s;
background-color: hsl(var(--accent-hue), 100%, calc(var(--accent-lightness) + 24%)); background-color: hsl(var(--accent-hue), 100%, calc(var(--accent-lightness) + 24%));
position: relative;
top: 1px;
left: 1px;
} }
button#save-system-settings-btn { button#save-system-settings-btn {
padding: 4pt 8pt; padding: 4pt 8pt;
} }
#ip-info a {
color:var(--text-color)
}
#ip-info div {
line-height: 200%;
}

View File

@ -30,6 +30,9 @@
--primary-button-border: none; --primary-button-border: none;
--input-switch-padding: 1px; --input-switch-padding: 1px;
--input-height: 18px; --input-height: 18px;
/* Main theme color, hex color fallback. */
--theme-color-fallback: #673AB6;
} }
.theme-light { .theme-light {
@ -44,6 +47,8 @@
--input-text-color: black; --input-text-color: black;
--input-background-color: #f8f9fa; --input-background-color: #f8f9fa;
--input-border-color: grey; --input-border-color: grey;
--theme-color-fallback: #aaaaaa;
} }
.theme-discord { .theme-discord {
@ -58,6 +63,8 @@
--input-border-size: 2px; --input-border-size: 2px;
--input-background-color: #202225; --input-background-color: #202225;
--input-border-color: var(--input-background-color); --input-border-color: var(--input-background-color);
--theme-color-fallback: #202225;
} }
.theme-cool-blue { .theme-cool-blue {
@ -71,8 +78,10 @@
--background-color4: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) - (3 * var(--value-step)))); --background-color4: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) - (3 * var(--value-step))));
--input-background-color: var(--background-color3); --input-background-color: var(--background-color3);
--accent-hue: 212; --accent-hue: 212;
--theme-color-fallback: #0056b8;
} }
@ -87,6 +96,8 @@
--background-color4: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) - (3 * var(--value-step)))); --background-color4: hsl(var(--main-hue), var(--main-saturation), calc(var(--value-base) - (3 * var(--value-step))));
--input-background-color: var(--background-color3); --input-background-color: var(--background-color3);
--theme-color-fallback: #5300b8;
} }
.theme-super-dark { .theme-super-dark {
@ -101,6 +112,8 @@
--input-background-color: var(--background-color3); --input-background-color: var(--background-color3);
--input-border-size: 0px; --input-border-size: 0px;
--theme-color-fallback: #000000;
} }
.theme-wild { .theme-wild {
@ -117,8 +130,8 @@
--input-border-size: 1px; --input-border-size: 1px;
--input-background-color: hsl(222, var(--main-saturation), calc(var(--value-base) - (2 * var(--value-step)))); --input-background-color: hsl(222, var(--main-saturation), calc(var(--value-base) - (2 * var(--value-step))));
--input-text-color: red; --input-text-color: #FF0000;
--input-border-color: green; --input-border-color: #005E05;
} }
.theme-gnomie { .theme-gnomie {
@ -136,6 +149,8 @@
--input-background-color: #2a2a2a; --input-background-color: #2a2a2a;
--input-border-size: 0px; --input-border-size: 0px;
--input-border-color: var(--input-background-color); --input-border-color: var(--input-background-color);
--theme-color-fallback: #2168bf;
} }
.theme-gnomie .panel-box { .theme-gnomie .panel-box {

View File

@ -35,6 +35,7 @@ const SETTINGS_IDS_LIST = [
"sound_toggle", "sound_toggle",
"turbo", "turbo",
"use_full_precision", "use_full_precision",
"confirm_dangerous_actions",
"auto_save_settings" "auto_save_settings"
] ]
@ -55,6 +56,9 @@ async function initSettings() {
if (!element) { if (!element) {
console.error(`Missing settings element ${id}`) console.error(`Missing settings element ${id}`)
} }
if (id in SETTINGS) { // don't create it again
return
}
SETTINGS[id] = { SETTINGS[id] = {
key: id, key: id,
element: element, element: element,

View File

@ -192,9 +192,9 @@ const TASK_MAPPING = {
parse: (val) => val parse: (val) => val
}, },
numOutputsParallel: { name: 'Parallel Images', num_outputs: { name: 'Parallel Images',
setUI: (numOutputsParallel) => { setUI: (num_outputs) => {
numOutputsParallelField.value = numOutputsParallel numOutputsParallelField.value = num_outputs
}, },
readUI: () => parseInt(numOutputsParallelField.value), readUI: () => parseInt(numOutputsParallelField.value),
parse: (val) => val parse: (val) => val
@ -328,6 +328,7 @@ function getModelPath(filename, extensions)
filename = filename.slice(0, filename.length - ext.length) filename = filename.slice(0, filename.length - ext.length)
} }
}) })
return filename
} }
const TASK_TEXT_MAPPING = { const TASK_TEXT_MAPPING = {

View File

@ -85,14 +85,13 @@ function createModifierGroup(modifierGroup, initiallyExpanded) {
if(typeof modifierCard == 'object') { if(typeof modifierCard == 'object') {
modifiersEl.appendChild(modifierCard) modifiersEl.appendChild(modifierCard)
const trimmedName = trimModifiers(modifierName)
modifierCard.addEventListener('click', () => { modifierCard.addEventListener('click', () => {
if (activeTags.map(x => x.name).includes(modifierName)) { if (activeTags.map(x => trimModifiers(x.name)).includes(trimmedName)) {
// remove modifier from active array // remove modifier from active array
activeTags = activeTags.filter(x => x.name != modifierName) activeTags = activeTags.filter(x => trimModifiers(x.name) != trimmedName)
modifierCard.classList.remove(activeCardClass) toggleCardState(trimmedName, false)
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '+'
} else { } else {
// add modifier to active array // add modifier to active array
activeTags.push({ activeTags.push({
@ -101,10 +100,7 @@ function createModifierGroup(modifierGroup, initiallyExpanded) {
'originElement': modifierCard, 'originElement': modifierCard,
'previews': modifierPreviews 'previews': modifierPreviews
}) })
toggleCardState(trimmedName, true)
modifierCard.classList.add(activeCardClass)
modifierCard.querySelector('.modifier-card-image-overlay').innerText = '-'
} }
refreshTagsList() refreshTagsList()
@ -125,6 +121,10 @@ function createModifierGroup(modifierGroup, initiallyExpanded) {
return e return e
} }
function trimModifiers(tag) {
return tag.replace(/^\(+|\)+$/g, '').replace(/^\[+|\]+$/g, '')
}
async function loadModifiers() { async function loadModifiers() {
try { try {
let res = await fetch('/get/modifiers') let res = await fetch('/get/modifiers')
@ -221,9 +221,8 @@ function refreshTagsList() {
tag.element.addEventListener('click', () => { tag.element.addEventListener('click', () => {
let idx = activeTags.indexOf(tag) let idx = activeTags.indexOf(tag)
if (idx !== -1 && activeTags[idx].originElement !== undefined) { if (idx !== -1) {
activeTags[idx].originElement.classList.remove(activeCardClass) toggleCardState(activeTags[idx].name, false)
activeTags[idx].originElement.querySelector('.modifier-card-image-overlay').innerText = '+'
activeTags.splice(idx, 1) activeTags.splice(idx, 1)
refreshTagsList() refreshTagsList()
@ -236,6 +235,22 @@ function refreshTagsList() {
editorModifierTagsList.appendChild(brk) editorModifierTagsList.appendChild(brk)
} }
function toggleCardState(modifierName, makeActive) {
document.querySelector('#editor-modifiers').querySelectorAll('.modifier-card').forEach(card => {
const name = card.querySelector('.modifier-card-label').innerText
if (trimModifiers(modifierName) == trimModifiers(name)) {
if(makeActive) {
card.classList.add(activeCardClass)
card.querySelector('.modifier-card-image-overlay').innerText = '-'
}
else{
card.classList.remove(activeCardClass)
card.querySelector('.modifier-card-image-overlay').innerText = '+'
}
}
})
}
function changePreviewImages(val) { function changePreviewImages(val) {
const previewImages = document.querySelectorAll('.modifier-card-image-container img') const previewImages = document.querySelectorAll('.modifier-card-image-container img')

10
ui/media/js/jquery-confirm.min.js vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -138,6 +138,35 @@ function isServerAvailable() {
} }
} }
// shiftOrConfirm(e, prompt, fn)
// e : MouseEvent
// prompt : Text to be shown as prompt. Should be a question to which "yes" is a good answer.
// fn : function to be called if the user confirms the dialog or has the shift key pressed
//
// If the user had the shift key pressed while clicking, the function fn will be executed.
// If the setting "confirm_dangerous_actions" in the system settings is disabled, the function
// fn will be executed.
// Otherwise, a confirmation dialog is shown. If the user confirms, the function fn will also
// be executed.
function shiftOrConfirm(e, prompt, fn) {
e.stopPropagation()
if (e.shiftKey || !confirmDangerousActionsField.checked) {
fn(e)
} else {
$.confirm({
theme: 'modern',
title: prompt,
useBootstrap: false,
animateFromElement: false,
content: '<small>Tip: To skip this dialog, use shift-click or disable the "Confirm dangerous actions" setting in the Settings tab.</small>',
buttons: {
yes: () => { fn(e) },
cancel: () => {}
}
});
}
}
function logMsg(msg, level, outputMsg) { function logMsg(msg, level, outputMsg) {
if (outputMsg.hasChildNodes()) { if (outputMsg.hasChildNodes()) {
outputMsg.appendChild(document.createElement('br')) outputMsg.appendChild(document.createElement('br'))
@ -169,34 +198,6 @@ function playSound() {
}) })
} }
} }
function setSystemInfo(devices) {
let cpu = devices.all.cpu.name
let allGPUs = Object.keys(devices.all).filter(d => d != 'cpu')
let activeGPUs = Object.keys(devices.active)
function ID_TO_TEXT(d) {
let info = devices.all[d]
if ("mem_free" in info && "mem_total" in info) {
return `${info.name} <small>(${d}) (${info.mem_free.toFixed(1)}Gb free / ${info.mem_total.toFixed(1)} Gb total)</small>`
} else {
return `${info.name} <small>(${d}) (no memory info)</small>`
}
}
allGPUs = allGPUs.map(ID_TO_TEXT)
activeGPUs = activeGPUs.map(ID_TO_TEXT)
let systemInfo = `
<table>
<tr><td><label>Processor:</label></td><td class="value">${cpu}</td></tr>
<tr><td><label>Compatible Graphics Cards (all):</label></td><td class="value">${allGPUs.join('</br>')}</td></tr>
<tr><td></td><td>&nbsp;</td></tr>
<tr><td><label>Used for rendering 🔥:</label></td><td class="value">${activeGPUs.join('</br>')}</td></tr>
</table>`
let systemInfoEl = document.querySelector('#system-info')
systemInfoEl.innerHTML = systemInfo
}
async function healthCheck() { async function healthCheck() {
try { try {
@ -231,7 +232,7 @@ async function healthCheck() {
break break
} }
if (serverState.devices) { if (serverState.devices) {
setSystemInfo(serverState.devices) setDeviceInfo(serverState.devices)
} }
serverState.time = Date.now() serverState.time = Date.now()
} catch (e) { } catch (e) {
@ -887,24 +888,26 @@ function createTask(task) {
task['progressBar'] = taskEntry.querySelector('.progress-bar') task['progressBar'] = taskEntry.querySelector('.progress-bar')
task['stopTask'] = taskEntry.querySelector('.stopTask') task['stopTask'] = taskEntry.querySelector('.stopTask')
task['stopTask'].addEventListener('click', async function(e) { task['stopTask'].addEventListener('click', (e) => {
e.stopPropagation() let question = (task['isProcessing'] ? "Stop this task?" : "Remove this task?")
if (task['isProcessing']) { shiftOrConfirm(e, question, async function(e) {
task.isProcessing = false if (task['isProcessing']) {
task.progressBar.classList.remove("active") task.isProcessing = false
try { task.progressBar.classList.remove("active")
let res = await fetch('/image/stop?session_id=' + sessionId) try {
} catch (e) { let res = await fetch('/image/stop?session_id=' + sessionId)
console.log(e) } catch (e) {
} console.log(e)
} else { }
let idx = taskQueue.indexOf(task) } else {
if (idx >= 0) { let idx = taskQueue.indexOf(task)
taskQueue.splice(idx, 1) if (idx >= 0) {
} taskQueue.splice(idx, 1)
}
taskEntry.remove() removeTask(taskEntry)
} }
})
}) })
task['useSettings'] = taskEntry.querySelector('.useSettings') task['useSettings'] = taskEntry.querySelector('.useSettings')
@ -934,10 +937,10 @@ function getPrompts() {
prompts = prompts.filter(prompt => prompt !== '') prompts = prompts.filter(prompt => prompt !== '')
if (activeTags.length > 0) { if (activeTags.length > 0) {
const promptTags = activeTags.map(x => x.name).join(", ") const promptTags = activeTags.map(x => x.name).join(", ")
prompts = prompts.map((prompt) => `${prompt}, ${promptTags}`) prompts = prompts.map((prompt) => `${prompt}, ${promptTags}`)
} }
let promptsToMake = applySetOperator(prompts) let promptsToMake = applySetOperator(prompts)
promptsToMake = applyPermuteOperator(promptsToMake) promptsToMake = applyPermuteOperator(promptsToMake)
@ -1047,21 +1050,25 @@ async function stopAllTasks() {
} }
} }
clearAllPreviewsBtn.addEventListener('click', async function() { function removeTask(taskToRemove) {
taskToRemove.remove()
if (document.querySelector('.imageTaskContainer') === null) {
previewTools.style.display = 'none'
initialText.style.display = 'block'
}
}
clearAllPreviewsBtn.addEventListener('click', (e) => { shiftOrConfirm(e, "Clear all the results and tasks in this window?", async function() {
await stopAllTasks() await stopAllTasks()
let taskEntries = document.querySelectorAll('.imageTaskContainer') let taskEntries = document.querySelectorAll('.imageTaskContainer')
taskEntries.forEach(task => { taskEntries.forEach(removeTask)
task.remove() })})
})
previewTools.style.display = 'none' stopImageBtn.addEventListener('click', (e) => { shiftOrConfirm(e, "Stop all the tasks?", async function(e) {
initialText.style.display = 'block'
})
stopImageBtn.addEventListener('click', async function() {
await stopAllTasks() await stopAllTasks()
}) })})
widthField.addEventListener('change', onDimensionChange) widthField.addEventListener('change', onDimensionChange)
heightField.addEventListener('change', onDimensionChange) heightField.addEventListener('change', onDimensionChange)

View File

@ -5,9 +5,9 @@
*/ */
var ParameterType = { var ParameterType = {
checkbox: "checkbox", checkbox: "checkbox",
select: "select", select: "select",
select_multiple: "select_multiple", select_multiple: "select_multiple",
custom: "custom", custom: "custom",
}; };
/** /**
@ -23,166 +23,182 @@
/** @type {Array.<Parameter>} */ /** @type {Array.<Parameter>} */
var PARAMETERS = [ var PARAMETERS = [
{ {
id: "theme", id: "theme",
type: ParameterType.select, type: ParameterType.select,
label: "Theme", label: "Theme",
default: "theme-default", default: "theme-default",
note: "customize the look and feel of the ui", note: "customize the look and feel of the ui",
options: [ // Note: options expanded dynamically options: [ // Note: options expanded dynamically
{ {
value: "theme-default", value: "theme-default",
label: "Default" label: "Default"
} }
], ],
icon: "fa-palette" icon: "fa-palette"
}, },
{ {
id: "save_to_disk", id: "save_to_disk",
type: ParameterType.checkbox, type: ParameterType.checkbox,
label: "Auto-Save Images", label: "Auto-Save Images",
note: "automatically saves images to the specified location", note: "automatically saves images to the specified location",
icon: "fa-download", icon: "fa-download",
default: false, default: false,
}, },
{ {
id: "diskPath", id: "diskPath",
type: ParameterType.custom, type: ParameterType.custom,
label: "Save Location", label: "Save Location",
render: (parameter) => { render: (parameter) => {
return `<input id="${parameter.id}" name="${parameter.id}" size="30" disabled>` return `<input id="${parameter.id}" name="${parameter.id}" size="30" disabled>`
} }
}, },
{ {
id: "sound_toggle", id: "sound_toggle",
type: ParameterType.checkbox, type: ParameterType.checkbox,
label: "Enable Sound", label: "Enable Sound",
note: "plays a sound on task completion", note: "plays a sound on task completion",
icon: "fa-volume-low", icon: "fa-volume-low",
default: true, default: true,
}, },
{ {
id: "ui_open_browser_on_start", id: "ui_open_browser_on_start",
type: ParameterType.checkbox, type: ParameterType.checkbox,
label: "Open browser on startup", label: "Open browser on startup",
note: "starts the default browser on startup", note: "starts the default browser on startup",
icon: "fa-window-restore", icon: "fa-window-restore",
default: true, default: true,
}, },
{ {
id: "turbo", id: "turbo",
type: ParameterType.checkbox, type: ParameterType.checkbox,
label: "Turbo Mode", label: "Turbo Mode",
note: "generates images faster, but uses an additional 1 GB of GPU memory", note: "generates images faster, but uses an additional 1 GB of GPU memory",
icon: "fa-forward", icon: "fa-forward",
default: true, default: true,
}, },
{ {
id: "use_cpu", id: "use_cpu",
type: ParameterType.checkbox, type: ParameterType.checkbox,
label: "Use CPU (not GPU)", label: "Use CPU (not GPU)",
note: "warning: this will be *very* slow", note: "warning: this will be *very* slow",
icon: "fa-microchip", icon: "fa-microchip",
default: false, default: false,
}, },
{ {
id: "auto_pick_gpus", id: "auto_pick_gpus",
type: ParameterType.checkbox, type: ParameterType.checkbox,
label: "Automatically pick the GPUs (experimental)", label: "Automatically pick the GPUs (experimental)",
default: false, default: false,
}, },
{ {
id: "use_gpus", id: "use_gpus",
type: ParameterType.select_multiple, type: ParameterType.select_multiple,
label: "GPUs to use (experimental)", label: "GPUs to use (experimental)",
note: "to process in parallel", note: "to process in parallel",
default: false, default: false,
}, },
{ {
id: "use_full_precision", id: "use_full_precision",
type: ParameterType.checkbox, type: ParameterType.checkbox,
label: "Use Full Precision", label: "Use Full Precision",
note: "for GPU-only. warning: this will consume more VRAM", note: "for GPU-only. warning: this will consume more VRAM",
icon: "fa-crosshairs", icon: "fa-crosshairs",
default: false, default: false,
}, },
{ {
id: "auto_save_settings", id: "auto_save_settings",
type: ParameterType.checkbox, type: ParameterType.checkbox,
label: "Auto-Save Settings", label: "Auto-Save Settings",
note: "restores settings on browser load", note: "restores settings on browser load",
icon: "fa-gear", icon: "fa-gear",
default: true, default: true,
}, },
{ {
id: "listen_to_network", id: "confirm_dangerous_actions",
type: ParameterType.checkbox, type: ParameterType.checkbox,
label: "Make Stable Diffusion available on your network", label: "Confirm dangerous actions",
note: "Other devices on your network can access this web page", note: "Actions that might lead to data loss must either be clicked with the shift key pressed, or confirmed in an 'Are you sure?' dialog",
icon: "fa-network-wired", icon: "fa-check-double",
default: true, default: true,
}, },
{ {
id: "listen_port", id: "listen_to_network",
type: ParameterType.custom, type: ParameterType.checkbox,
label: "Network port", label: "Make Stable Diffusion available on your network",
note: "Port that this server listens to. The '9000' part in 'http://localhost:9000'", note: "Other devices on your network can access this web page",
icon: "fa-anchor", icon: "fa-network-wired",
render: (parameter) => { default: true,
return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">` },
} {
}, id: "listen_port",
{ type: ParameterType.custom,
id: "use_beta_channel", label: "Network port",
type: ParameterType.checkbox, note: "Port that this server listens to. The '9000' part in 'http://localhost:9000'",
label: "Beta channel", icon: "fa-anchor",
note: "Get the latest features immediately (but could be less stable). Please restart the program after changing this.", render: (parameter) => {
icon: "fa-fire", return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">`
default: false, }
}, },
{
id: "test_sd2",
type: ParameterType.checkbox,
label: "Test SD 2.0",
note: "Experimental! High memory usage! GPU-only! Not the final version! Please restart the program after changing this.",
icon: "fa-fire",
default: false,
},
{
id: "use_beta_channel",
type: ParameterType.checkbox,
label: "Beta channel",
note: "Get the latest features immediately (but could be less stable). Please restart the program after changing this.",
icon: "fa-fire",
default: false,
},
]; ];
function getParameterSettingsEntry(id) { function getParameterSettingsEntry(id) {
let parameter = PARAMETERS.filter(p => p.id === id) let parameter = PARAMETERS.filter(p => p.id === id)
if (parameter.length === 0) { if (parameter.length === 0) {
return return
} }
return parameter[0].settingsEntry return parameter[0].settingsEntry
} }
function getParameterElement(parameter) { function getParameterElement(parameter) {
switch (parameter.type) { switch (parameter.type) {
case ParameterType.checkbox: case ParameterType.checkbox:
var is_checked = parameter.default ? " checked" : ""; var is_checked = parameter.default ? " checked" : "";
return `<input id="${parameter.id}" name="${parameter.id}"${is_checked} type="checkbox">` return `<input id="${parameter.id}" name="${parameter.id}"${is_checked} type="checkbox">`
case ParameterType.select: case ParameterType.select:
case ParameterType.select_multiple: case ParameterType.select_multiple:
var options = (parameter.options || []).map(option => `<option value="${option.value}">${option.label}</option>`).join("") var options = (parameter.options || []).map(option => `<option value="${option.value}">${option.label}</option>`).join("")
var multiple = (parameter.type == ParameterType.select_multiple ? 'multiple' : '') var multiple = (parameter.type == ParameterType.select_multiple ? 'multiple' : '')
return `<select id="${parameter.id}" name="${parameter.id}" ${multiple}>${options}</select>` return `<select id="${parameter.id}" name="${parameter.id}" ${multiple}>${options}</select>`
case ParameterType.custom: case ParameterType.custom:
return parameter.render(parameter) return parameter.render(parameter)
default: default:
console.error(`Invalid type for parameter ${parameter.id}`); console.error(`Invalid type for parameter ${parameter.id}`);
return "ERROR: Invalid Type" return "ERROR: Invalid Type"
} }
} }
let parametersTable = document.querySelector("#system-settings .parameters-table") let parametersTable = document.querySelector("#system-settings .parameters-table")
/* fill in the system settings popup table */ /* fill in the system settings popup table */
function initParameters() { function initParameters() {
PARAMETERS.forEach(parameter => { PARAMETERS.forEach(parameter => {
var element = getParameterElement(parameter) var element = getParameterElement(parameter)
var note = parameter.note ? `<small>${parameter.note}</small>` : ""; var note = parameter.note ? `<small>${parameter.note}</small>` : "";
var icon = parameter.icon ? `<i class="fa ${parameter.icon}"></i>` : ""; var icon = parameter.icon ? `<i class="fa ${parameter.icon}"></i>` : "";
var newrow = document.createElement('div') var newrow = document.createElement('div')
newrow.innerHTML = ` newrow.innerHTML = `
<div>${icon}</div> <div>${icon}</div>
<div><label for="${parameter.id}">${parameter.label}</label>${note}</div> <div><label for="${parameter.id}">${parameter.label}</label>${note}</div>
<div>${element}</div>` <div>${element}</div>`
parametersTable.appendChild(newrow) parametersTable.appendChild(newrow)
parameter.settingsEntry = newrow parameter.settingsEntry = newrow
}) })
} }
initParameters() initParameters()
@ -196,11 +212,14 @@ let saveToDiskField = document.querySelector('#save_to_disk')
let diskPathField = document.querySelector('#diskPath') let diskPathField = document.querySelector('#diskPath')
let listenToNetworkField = document.querySelector("#listen_to_network") let listenToNetworkField = document.querySelector("#listen_to_network")
let listenPortField = document.querySelector("#listen_port") let listenPortField = document.querySelector("#listen_port")
let testSD2Field = document.querySelector("#test_sd2")
let useBetaChannelField = document.querySelector("#use_beta_channel") let useBetaChannelField = document.querySelector("#use_beta_channel")
let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start") let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start")
let confirmDangerousActionsField = document.querySelector("#confirm_dangerous_actions")
let saveSettingsBtn = document.querySelector('#save-system-settings-btn') let saveSettingsBtn = document.querySelector('#save-system-settings-btn')
async function changeAppConfig(configDelta) { async function changeAppConfig(configDelta) {
try { try {
let res = await fetch('/app_config', { let res = await fetch('/app_config', {
@ -230,12 +249,18 @@ async function getAppConfig() {
if (config.ui && config.ui.open_browser_on_start === false) { if (config.ui && config.ui.open_browser_on_start === false) {
uiOpenBrowserOnStartField.checked = false uiOpenBrowserOnStartField.checked = false
} }
if (config.net && config.net.listen_to_network === false) { if ('test_sd2' in config) {
listenToNetworkField.checked = false testSD2Field.checked = config['test_sd2']
} }
if (config.net && config.net.listen_port !== undefined) {
listenPortField.value = config.net.listen_port let testSD2SettingEntry = getParameterSettingsEntry('test_sd2')
} testSD2SettingEntry.style.display = (config.update_branch === 'beta' ? '' : 'none')
if (config.net && config.net.listen_to_network === false) {
listenToNetworkField.checked = false
}
if (config.net && config.net.listen_port !== undefined) {
listenPortField.value = config.net.listen_port
}
console.log('get config status response', config) console.log('get config status response', config)
} catch (e) { } catch (e) {
@ -263,7 +288,6 @@ function getCurrentRenderDeviceSelection() {
useCPUField.addEventListener('click', function() { useCPUField.addEventListener('click', function() {
let gpuSettingEntry = getParameterSettingsEntry('use_gpus') let gpuSettingEntry = getParameterSettingsEntry('use_gpus')
let autoPickGPUSettingEntry = getParameterSettingsEntry('auto_pick_gpus') let autoPickGPUSettingEntry = getParameterSettingsEntry('auto_pick_gpus')
console.log("hello", this.checked);
if (this.checked) { if (this.checked) {
gpuSettingEntry.style.display = 'none' gpuSettingEntry.style.display = 'none'
autoPickGPUSettingEntry.style.display = 'none' autoPickGPUSettingEntry.style.display = 'none'
@ -313,14 +337,45 @@ async function getDiskPath() {
} }
} }
async function getDevices() { function setDeviceInfo(devices) {
let cpu = devices.all.cpu.name
let allGPUs = Object.keys(devices.all).filter(d => d != 'cpu')
let activeGPUs = Object.keys(devices.active)
function ID_TO_TEXT(d) {
let info = devices.all[d]
if ("mem_free" in info && "mem_total" in info) {
return `${info.name} <small>(${d}) (${info.mem_free.toFixed(1)}Gb free / ${info.mem_total.toFixed(1)} Gb total)</small>`
} else {
return `${info.name} <small>(${d}) (no memory info)</small>`
}
}
allGPUs = allGPUs.map(ID_TO_TEXT)
activeGPUs = activeGPUs.map(ID_TO_TEXT)
let systemInfoEl = document.querySelector('#system-info')
systemInfoEl.querySelector('#system-info-cpu').innerText = cpu
systemInfoEl.querySelector('#system-info-gpus-all').innerHTML = allGPUs.join('</br>')
systemInfoEl.querySelector('#system-info-rendering-devices').innerHTML = activeGPUs.join('</br>')
}
function setHostInfo(hosts) {
let port = listenPortField.value
hosts = hosts.map(addr => `http://${addr}:${port}/`).map(url => `<div><a href="${url}">${url}</a></div>`)
document.querySelector('#system-info-server-hosts').innerHTML = hosts.join('')
}
async function getSystemInfo() {
try { try {
let res = await fetch('/get/devices') let res = await fetch('/get/system_info')
if (res.status === 200) { if (res.status === 200) {
res = await res.json() res = await res.json()
let devices = res['devices']
let hosts = res['hosts']
let allDeviceIds = Object.keys(res['all']).filter(d => d !== 'cpu') let allDeviceIds = Object.keys(devices['all']).filter(d => d !== 'cpu')
let activeDeviceIds = Object.keys(res['active']).filter(d => d !== 'cpu') let activeDeviceIds = Object.keys(devices['active']).filter(d => d !== 'cpu')
if (activeDeviceIds.length === 0) { if (activeDeviceIds.length === 0) {
useCPUField.checked = true useCPUField.checked = true
@ -338,11 +393,11 @@ async function getDevices() {
useCPUField.disabled = true // no compatible GPUs, so make the CPU mandatory useCPUField.disabled = true // no compatible GPUs, so make the CPU mandatory
} }
autoPickGPUsField.checked = (res['config'] === 'auto') autoPickGPUsField.checked = (devices['config'] === 'auto')
useGPUsField.innerHTML = '' useGPUsField.innerHTML = ''
allDeviceIds.forEach(device => { allDeviceIds.forEach(device => {
let deviceName = res['all'][device]['name'] let deviceName = devices['all'][device]['name']
let deviceOption = `<option value="${device}">${deviceName} (${device})</option>` let deviceOption = `<option value="${device}">${deviceName} (${device})</option>`
useGPUsField.insertAdjacentHTML('beforeend', deviceOption) useGPUsField.insertAdjacentHTML('beforeend', deviceOption)
}) })
@ -353,6 +408,9 @@ async function getDevices() {
} else { } else {
$('#use_gpus').val(activeDeviceIds) $('#use_gpus').val(activeDeviceIds)
} }
setDeviceInfo(devices)
setHostInfo(hosts)
} }
} catch (e) { } catch (e) {
console.log('error fetching devices', e) console.log('error fetching devices', e)
@ -360,22 +418,23 @@ async function getDevices() {
} }
saveSettingsBtn.addEventListener('click', function() { saveSettingsBtn.addEventListener('click', function() {
let updateBranch = (useBetaChannelField.checked ? 'beta' : 'main') let updateBranch = (useBetaChannelField.checked ? 'beta' : 'main')
if (listenPortField.value == '') { if (listenPortField.value == '') {
alert('The network port field must not be empty.') alert('The network port field must not be empty.')
} else if (listenPortField.value<1 || listenPortField.value>65535) { } else if (listenPortField.value<1 || listenPortField.value>65535) {
alert('The network port must be a number from 1 to 65535') alert('The network port must be a number from 1 to 65535')
} else { } else {
changeAppConfig({ changeAppConfig({
'render_devices': getCurrentRenderDeviceSelection(), 'render_devices': getCurrentRenderDeviceSelection(),
'update_branch': updateBranch, 'update_branch': updateBranch,
'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked, 'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked,
'listen_to_network': listenToNetworkField.checked, 'listen_to_network': listenToNetworkField.checked,
'listen_port': listenPortField.value 'listen_port': listenPortField.value,
}) 'test_sd2': testSD2Field.checked
} })
}
saveSettingsBtn.classList.add('active') saveSettingsBtn.classList.add('active')
asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active')) asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active'))
}) })

View File

@ -60,6 +60,7 @@ function themeFieldChanged() {
body.style = ""; body.style = "";
var theme = THEMES.find(t => t.key == theme_key); var theme = THEMES.find(t => t.key == theme_key);
let borderColor = undefined
if (theme) { if (theme) {
// refresh variables incase they are back referencing // refresh variables incase they are back referencing
Array.from(DEFAULT_THEME.rule.style) Array.from(DEFAULT_THEME.rule.style)
@ -67,7 +68,14 @@ function themeFieldChanged() {
.forEach(cssVariable => { .forEach(cssVariable => {
body.style.setProperty(cssVariable, DEFAULT_THEME.rule.style.getPropertyValue(cssVariable)); body.style.setProperty(cssVariable, DEFAULT_THEME.rule.style.getPropertyValue(cssVariable));
}); });
borderColor = theme.rule.style.getPropertyValue('--input-border-color').trim()
if (!borderColor.startsWith('#')) {
borderColor = theme.rule.style.getPropertyValue('--theme-color-fallback')
}
} else {
borderColor = DEFAULT_THEME.rule.style.getPropertyValue('--theme-color-fallback')
} }
document.querySelector('meta[name="theme-color"]').setAttribute("content", borderColor)
} }
themeField.addEventListener('change', themeFieldChanged); themeField.addEventListener('change', themeFieldChanged);

View File

@ -1,17 +1,17 @@
// https://gomakethings.com/finding-the-next-and-previous-sibling-elements-that-match-a-selector-with-vanilla-js/ // https://gomakethings.com/finding-the-next-and-previous-sibling-elements-that-match-a-selector-with-vanilla-js/
function getNextSibling(elem, selector) { function getNextSibling(elem, selector) {
// Get the next sibling element // Get the next sibling element
var sibling = elem.nextElementSibling var sibling = elem.nextElementSibling
// If there's no selector, return the first sibling // If there's no selector, return the first sibling
if (!selector) return sibling if (!selector) return sibling
// If the sibling matches our selector, use it // If the sibling matches our selector, use it
// If not, jump to the next sibling and continue the loop // If not, jump to the next sibling and continue the loop
while (sibling) { while (sibling) {
if (sibling.matches(selector)) return sibling if (sibling.matches(selector)) return sibling
sibling = sibling.nextElementSibling sibling = sibling.nextElementSibling
} }
} }

View File

@ -0,0 +1,8 @@
{
"name": "Stable Diffusion UI",
"display": "standalone",
"display_override": [
"window-controls-overlay"
],
"theme_color": "#000000"
}

View File

@ -0,0 +1,42 @@
(function () {
"use strict"
var styleSheet = document.createElement("style");
styleSheet.textContent = `
.auto-scroll {
float: right;
}
`;
document.head.appendChild(styleSheet);
const autoScrollControl = document.createElement('div');
autoScrollControl.innerHTML = `<input id="auto_scroll" name="auto_scroll" type="checkbox">
<label for="auto_scroll">Scroll to generated image</label>`
autoScrollControl.className = "auto-scroll"
clearAllPreviewsBtn.parentNode.insertBefore(autoScrollControl, clearAllPreviewsBtn.nextSibling)
prettifyInputs(document);
let autoScroll = document.querySelector("#auto_scroll")
SETTINGS_IDS_LIST.push("auto_scroll")
initSettings()
// observe for changes in the preview pane
var observer = new MutationObserver(function (mutations) {
mutations.forEach(function (mutation) {
if (mutation.target.className == 'img-batch') {
Autoscroll(mutation.target)
}
})
})
observer.observe(document.getElementById('preview'), {
childList: true,
subtree: true
})
function Autoscroll(target) {
if (autoScroll.checked && target !== null) {
target.parentElement.parentElement.parentElement.scrollIntoView();
}
}
})()

View File

@ -18,40 +18,42 @@
let overlays = document.querySelector('#editor-inputs-tags-list').querySelectorAll('.modifier-card-overlay') let overlays = document.querySelector('#editor-inputs-tags-list').querySelectorAll('.modifier-card-overlay')
overlays.forEach (i => { overlays.forEach (i => {
i.onwheel = (e) => { i.onwheel = (e) => {
e.preventDefault() if (e.ctrlKey == true) {
e.preventDefault()
const delta = Math.sign(event.deltaY)
let s = i.parentElement.getElementsByClassName('modifier-card-label')[0].getElementsByTagName("p")[0].innerText const delta = Math.sign(event.deltaY)
if (delta < 0) { let s = i.parentElement.getElementsByClassName('modifier-card-label')[0].getElementsByTagName("p")[0].innerText
// wheel scrolling up if (delta < 0) {
if (s.substring(0, 1) == '[' && s.substring(s.length-1) == ']') { // wheel scrolling up
s = s.substring(1, s.length - 1) if (s.substring(0, 1) == '[' && s.substring(s.length-1) == ']') {
} s = s.substring(1, s.length - 1)
else }
{ else
if (s.substring(0, 10) !== '('.repeat(10) && s.substring(s.length-10) !== ')'.repeat(10)) { {
s = '(' + s + ')' if (s.substring(0, 10) !== '('.repeat(10) && s.substring(s.length-10) !== ')'.repeat(10)) {
s = '(' + s + ')'
}
} }
} }
} else{
else{ // wheel scrolling down
// wheel scrolling down if (s.substring(0, 1) == '(' && s.substring(s.length-1) == ')') {
if (s.substring(0, 1) == '(' && s.substring(s.length-1) == ')') { s = s.substring(1, s.length - 1)
s = s.substring(1, s.length - 1) }
} else
else {
{ if (s.substring(0, 10) !== '['.repeat(10) && s.substring(s.length-10) !== ']'.repeat(10)) {
if (s.substring(0, 10) !== '['.repeat(10) && s.substring(s.length-10) !== ']'.repeat(10)) { s = '[' + s + ']'
s = '[' + s + ']' }
} }
} }
} i.parentElement.getElementsByClassName('modifier-card-label')[0].getElementsByTagName("p")[0].innerText = s
i.parentElement.getElementsByClassName('modifier-card-label')[0].getElementsByTagName("p")[0].innerText = s // update activeTags
// update activeTags for (let it = 0; it < overlays.length; it++) {
for (let it = 0; it < overlays.length; it++) { if (i == overlays[it]) {
if (i == overlays[it]) { activeTags[it].name = s
activeTags[it].name = s break
break }
} }
} }
} }

View File

@ -0,0 +1,84 @@
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
index 27ead0e..6215939 100644
--- a/ldm/models/diffusion/ddim.py
+++ b/ldm/models/diffusion/ddim.py
@@ -100,7 +100,7 @@ class DDIMSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
- samples, intermediates = self.ddim_sampling(conditioning, size,
+ samples = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
@@ -117,7 +117,8 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
- return samples, intermediates
+ # return samples, intermediates
+ yield from samples
@torch.no_grad()
def ddim_sampling(self, cond, shape,
@@ -168,14 +169,15 @@ class DDIMSampler(object):
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
- return img, intermediates
+ # return img, intermediates
+ yield from img_callback(pred_x0, len(iterator)-1)
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
index 7002a36..0951f39 100644
--- a/ldm/models/diffusion/plms.py
+++ b/ldm/models/diffusion/plms.py
@@ -96,7 +96,7 @@ class PLMSSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
- samples, intermediates = self.plms_sampling(conditioning, size,
+ samples = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
@@ -112,7 +112,8 @@ class PLMSSampler(object):
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
- return samples, intermediates
+ #return samples, intermediates
+ yield from samples
@torch.no_grad()
def plms_sampling(self, cond, shape,
@@ -165,14 +166,15 @@ class PLMSSampler(object):
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
- return img, intermediates
+ # return img, intermediates
+ yield from img_callback(pred_x0, len(iterator)-1)
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,

View File

@ -101,7 +101,7 @@ def device_init(thread_data, device):
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
device_name = thread_data.device_name.lower() device_name = thread_data.device_name.lower()
thread_data.force_full_precision = ('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name) thread_data.force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
if thread_data.force_full_precision: if thread_data.force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', thread_data.device_name) print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', thread_data.device_name)
# Apply force_full_precision now before models are loaded. # Apply force_full_precision now before models are loaded.

View File

@ -7,6 +7,7 @@ Notes:
import json import json
import os, re import os, re
import traceback import traceback
import queue
import torch import torch
import numpy as np import numpy as np
from gc import collect as gc_collect from gc import collect as gc_collect
@ -21,13 +22,14 @@ from torch import autocast
from contextlib import nullcontext from contextlib import nullcontext
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from optimizedSD.optimUtils import split_weighted_subprompts
from transformers import logging from transformers import logging
from gfpgan import GFPGANer from gfpgan import GFPGANer
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from threading import Lock
import uuid import uuid
logging.set_verbosity_error() logging.set_verbosity_error()
@ -35,7 +37,7 @@ logging.set_verbosity_error()
# consts # consts
config_yaml = "optimizedSD/v1-inference.yaml" config_yaml = "optimizedSD/v1-inference.yaml"
filename_regex = re.compile('[^a-zA-Z0-9]') filename_regex = re.compile('[^a-zA-Z0-9]')
force_gfpgan_to_cuda0 = True # workaround: gfpgan currently works only on cuda:0 gfpgan_temp_device_lock = Lock() # workaround: gfpgan currently can only start on one device at a time.
# api stuff # api stuff
from sd_internal import device_manager from sd_internal import device_manager
@ -76,8 +78,24 @@ def thread_init(device):
thread_data.force_full_precision = False thread_data.force_full_precision = False
thread_data.reduced_memory = True thread_data.reduced_memory = True
thread_data.test_sd2 = isSD2()
device_manager.device_init(thread_data, device) device_manager.device_init(thread_data, device)
# temp hack, will remove soon
def isSD2():
try:
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return False
with open(config_json_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config.get('test_sd2', False)
except Exception as e:
return False
def load_model_ckpt(): def load_model_ckpt():
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.') if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt') if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
@ -92,6 +110,13 @@ def load_model_ckpt():
thread_data.precision = 'full' thread_data.precision = 'full'
print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision) print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision)
if thread_data.test_sd2:
load_model_ckpt_sd2()
else:
load_model_ckpt_sd1()
def load_model_ckpt_sd1():
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt') sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
li, lo = [], [] li, lo = [], []
for key, value in sd.items(): for key, value in sd.items():
@ -185,6 +210,38 @@ def load_model_ckpt():
modelFS.device: {thread_data.modelFS.device} modelFS.device: {thread_data.modelFS.device}
using precision: {thread_data.precision}''') using precision: {thread_data.precision}''')
def load_model_ckpt_sd2():
config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if 'sd2_' in thread_data.ckpt_file else "configs/stable-diffusion/v1-inference.yaml"
config = OmegaConf.load(config_file)
verbose = False
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
thread_data.model = instantiate_from_config(config.model)
m, u = thread_data.model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
thread_data.model.to(thread_data.device)
thread_data.model.eval()
del sd
if thread_data.device != "cpu" and thread_data.precision == "autocast":
thread_data.model.half()
thread_data.model_is_half = True
thread_data.model_fs_is_half = True
else:
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
print(f'''loaded model
model file: {thread_data.ckpt_file}.ckpt
using precision: {thread_data.precision}''')
def unload_filters(): def unload_filters():
if thread_data.model_gfpgan is not None: if thread_data.model_gfpgan is not None:
if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu') if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu')
@ -204,10 +261,11 @@ def unload_models():
if thread_data.model is not None: if thread_data.model is not None:
print('Unloading models...') print('Unloading models...')
if thread_data.device != 'cpu': if thread_data.device != 'cpu':
thread_data.modelFS.to('cpu') if not thread_data.test_sd2:
thread_data.modelCS.to('cpu') thread_data.modelFS.to('cpu')
thread_data.model.model1.to("cpu") thread_data.modelCS.to('cpu')
thread_data.model.model2.to("cpu") thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu")
del thread_data.model del thread_data.model
del thread_data.modelCS del thread_data.modelCS
@ -253,12 +311,6 @@ def move_to_cpu(model):
def load_model_gfpgan(): def load_model_gfpgan():
if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.') if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.')
# hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files
from facexlib.detection import retinaface
retinaface.device = torch.device(thread_data.device)
print('forced retinaface.device to', thread_data.device)
model_path = thread_data.gfpgan_file + ".pth" model_path = thread_data.gfpgan_file + ".pth"
thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision) print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
@ -314,15 +366,23 @@ def apply_filters(filter_name, image_data, model_path=None):
image_data.to(thread_data.device) image_data.to(thread_data.device)
if filter_name == 'gfpgan': if filter_name == 'gfpgan':
if model_path is not None and model_path != thread_data.gfpgan_file: # This lock is only ever used here. No need to use timeout for the request. Should never deadlock.
thread_data.gfpgan_file = model_path with gfpgan_temp_device_lock: # Wait for any other devices to complete before starting.
load_model_gfpgan() # hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files
elif not thread_data.model_gfpgan: from facexlib.detection import retinaface
load_model_gfpgan() retinaface.device = torch.device(thread_data.device)
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.') print('forced retinaface.device to', thread_data.device)
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) if model_path is not None and model_path != thread_data.gfpgan_file:
image_data = output[:,:,::-1] thread_data.gfpgan_file = model_path
load_model_gfpgan()
elif not thread_data.model_gfpgan:
load_model_gfpgan()
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
image_data = output[:,:,::-1]
if filter_name == 'real_esrgan': if filter_name == 'real_esrgan':
if model_path is not None and model_path != thread_data.real_esrgan_file: if model_path is not None and model_path != thread_data.real_esrgan_file:
@ -337,45 +397,73 @@ def apply_filters(filter_name, image_data, model_path=None):
return image_data return image_data
def mk_img(req: Request): def is_model_reload_necessary(req: Request):
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
return needs_model_reload
def reload_model():
unload_models()
unload_filters()
load_model_ckpt()
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
try: try:
yield from do_mk_img(req) return do_mk_img(req, data_queue, task_temp_images, step_callback)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
if thread_data.device != 'cpu': if thread_data.device != 'cpu' and not thread_data.test_sd2:
thread_data.modelFS.to('cpu') thread_data.modelFS.to('cpu')
thread_data.modelCS.to('cpu') thread_data.modelCS.to('cpu')
thread_data.model.model1.to("cpu") thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu") thread_data.model.model2.to("cpu")
gc() # Release from memory. gc() # Release from memory.
yield json.dumps({ data_queue.put(json.dumps({
"status": 'failed', "status": 'failed',
"detail": str(e) "detail": str(e)
}) }))
raise e
def update_temp_img(req, x_samples): def update_temp_img(req, x_samples, task_temp_images: list):
partial_images = [] partial_images = []
for i in range(req.num_outputs): for i in range(req.num_outputs):
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) if thread_data.test_sd2:
x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
else:
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample) img = Image.fromarray(x_sample)
buf = BytesIO() buf = img_to_buffer(img, output_format='JPEG')
img.save(buf, format='JPEG')
buf.seek(0)
del img, x_sample, x_sample_ddim del img, x_sample, x_sample_ddim
# don't delete x_samples, it is used in the code that called this callback # don't delete x_samples, it is used in the code that called this callback
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
task_temp_images[i] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'}) partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
return partial_images return partial_images
# Build and return the apropriate generator for do_mk_img # Build and return the apropriate generator for do_mk_img
def get_image_progress_generator(req, extra_props=None): def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
if not req.stream_progress_updates: if not req.stream_progress_updates:
def empty_callback(x_samples, i): return x_samples def empty_callback(x_samples, i): return x_samples
return empty_callback return empty_callback
@ -394,15 +482,17 @@ def get_image_progress_generator(req, extra_props=None):
progress.update(extra_props) progress.update(extra_props)
if req.stream_image_progress and i % 5 == 0: if req.stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(req, x_samples) progress['output'] = update_temp_img(req, x_samples, task_temp_images)
yield json.dumps(progress) data_queue.put(json.dumps(progress))
step_callback()
if thread_data.stop_processing: if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing") raise UserInitiatedStop("User requested that we stop processing")
return img_callback return img_callback
def do_mk_img(req: Request): def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
thread_data.stop_processing = False thread_data.stop_processing = False
res = Response() res = Response()
@ -411,29 +501,7 @@ def do_mk_img(req: Request):
thread_data.temp_images.clear() thread_data.temp_images.clear()
# custom model support: if thread_data.turbo != req.turbo and not thread_data.test_sd2:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
if needs_model_reload:
unload_models()
unload_filters()
load_model_ckpt()
if thread_data.turbo != req.turbo:
thread_data.turbo = req.turbo thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo thread_data.model.turbo = req.turbo
@ -478,10 +546,14 @@ def do_mk_img(req: Request):
if thread_data.device != "cpu" and thread_data.precision == "autocast": if thread_data.device != "cpu" and thread_data.precision == "autocast":
init_image = init_image.half() init_image = init_image.half()
thread_data.modelFS.to(thread_data.device) if not thread_data.test_sd2:
thread_data.modelFS.to(thread_data.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space if thread_data.test_sd2:
init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space
else:
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None: if req.mask is not None:
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device) mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
@ -493,7 +565,8 @@ def do_mk_img(req: Request):
# Send to CPU and wait until complete. # Send to CPU and wait until complete.
# wait_model_move_to(thread_data.modelFS, 'cpu') # wait_model_move_to(thread_data.modelFS, 'cpu')
move_to_cpu(thread_data.modelFS) if not thread_data.test_sd2:
move_to_cpu(thread_data.modelFS)
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(req.prompt_strength * req.num_inference_steps) t_enc = int(req.prompt_strength * req.num_inference_steps)
@ -509,11 +582,14 @@ def do_mk_img(req: Request):
for prompts in tqdm(data, desc="data"): for prompts in tqdm(data, desc="data"):
with precision_scope("cuda"): with precision_scope("cuda"):
if thread_data.reduced_memory: if thread_data.reduced_memory and not thread_data.test_sd2:
thread_data.modelCS.to(thread_data.device) thread_data.modelCS.to(thread_data.device)
uc = None uc = None
if req.guidance_scale != 1.0: if req.guidance_scale != 1.0:
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) if thread_data.test_sd2:
uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt])
else:
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
@ -526,15 +602,21 @@ def do_mk_img(req: Request):
weight = weights[i] weight = weights[i]
# if not skip_normalize: # if not skip_normalize:
weight = weight / totalWeight weight = weight / totalWeight
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) if thread_data.test_sd2:
c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else: else:
c = thread_data.modelCS.get_learned_conditioning(prompts) if thread_data.test_sd2:
c = thread_data.model.get_learned_conditioning(prompts)
else:
c = thread_data.modelCS.get_learned_conditioning(prompts)
if thread_data.reduced_memory: if thread_data.reduced_memory and not thread_data.test_sd2:
thread_data.modelFS.to(thread_data.device) thread_data.modelFS.to(thread_data.device)
n_steps = req.num_inference_steps if req.init_image is None else t_enc n_steps = req.num_inference_steps if req.init_image is None else t_enc
img_callback = get_image_progress_generator(req, {"total_steps": n_steps}) img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps})
# run the handler # run the handler
try: try:
@ -542,14 +624,7 @@ def do_mk_img(req: Request):
if handler == _txt2img: if handler == _txt2img:
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler) x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
else: else:
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask) x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f)
if req.stream_progress_updates:
yield from x_samples
if hasattr(thread_data, 'partial_x_samples'):
if thread_data.partial_x_samples is not None:
x_samples = thread_data.partial_x_samples
del thread_data.partial_x_samples
except UserInitiatedStop: except UserInitiatedStop:
if not hasattr(thread_data, 'partial_x_samples'): if not hasattr(thread_data, 'partial_x_samples'):
continue continue
@ -562,7 +637,10 @@ def do_mk_img(req: Request):
print("decoding images") print("decoding images")
img_data = [None] * batch_size img_data = [None] * batch_size
for i in range(batch_size): for i in range(batch_size):
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) if thread_data.test_sd2:
x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
else:
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
@ -591,9 +669,11 @@ def do_mk_img(req: Request):
save_metadata(meta_out_path, req, prompts[0], opt_seed) save_metadata(meta_out_path, req, prompts[0], opt_seed)
if return_orig_img: if return_orig_img:
img_str = img_to_base64_str(img, req.output_format) img_buffer = img_to_buffer(img, req.output_format)
img_str = buffer_to_base64_str(img_buffer, req.output_format)
res_image_orig = ResponseImage(data=img_str, seed=opt_seed) res_image_orig = ResponseImage(data=img_str, seed=opt_seed)
res.images.append(res_image_orig) res.images.append(res_image_orig)
task_temp_images[i] = img_buffer
if req.save_to_disk_path is not None: if req.save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path res_image_orig.path_abs = img_out_path
@ -609,9 +689,11 @@ def do_mk_img(req: Request):
filters_applied.append(req.use_upscale) filters_applied.append(req.use_upscale)
if (len(filters_applied) > 0): if (len(filters_applied) > 0):
filtered_image = Image.fromarray(img_data[i]) filtered_image = Image.fromarray(img_data[i])
filtered_img_data = img_to_base64_str(filtered_image, req.output_format) filtered_buffer = img_to_buffer(filtered_image, req.output_format)
filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format)
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed) response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(response_image) res.images.append(response_image)
task_temp_images[i] = filtered_buffer
if req.save_to_disk_path is not None: if req.save_to_disk_path is not None:
filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied)) filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied))
save_image(filtered_image, filtered_img_out_path) save_image(filtered_image, filtered_img_out_path)
@ -622,14 +704,18 @@ def do_mk_img(req: Request):
# if thread_data.reduced_memory: # if thread_data.reduced_memory:
# unload_filters() # unload_filters()
move_to_cpu(thread_data.modelFS) if not thread_data.test_sd2:
move_to_cpu(thread_data.modelFS)
del img_data del img_data
gc() gc()
if thread_data.device != 'cpu': if thread_data.device != 'cpu':
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb') print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb')
print('Task completed') print('Task completed')
yield json.dumps(res.json()) res = res.json()
data_queue.put(json.dumps(res))
return res
def save_image(img, img_out_path): def save_image(img, img_out_path):
try: try:
@ -664,51 +750,109 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
# Send to CPU and wait until complete. # Send to CPU and wait until complete.
# wait_model_move_to(thread_data.modelCS, 'cpu') # wait_model_move_to(thread_data.modelCS, 'cpu')
move_to_cpu(thread_data.modelCS) if not thread_data.test_sd2:
move_to_cpu(thread_data.modelCS)
if sampler_name == 'ddim': if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'):
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0')
samples_ddim = thread_data.model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
yield from samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask): # samples, _ = sampler.sample(S=opt.steps,
# conditioning=c,
# batch_size=opt.n_samples,
# shape=shape,
# verbose=False,
# unconditional_guidance_scale=opt.scale,
# unconditional_conditioning=uc,
# eta=opt.ddim_eta,
# x_T=start_code)
if thread_data.test_sd2:
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
if sampler_name == 'plms':
sampler = PLMSSampler(thread_data.model)
elif sampler_name == 'ddim':
sampler = DDIMSampler(thread_data.model)
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim, intermediates = sampler.sample(
S=opt_ddim_steps,
conditioning=c,
batch_size=opt_n_samples,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
else:
if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = thread_data.model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1):
# encode (scaled latent) # encode (scaled latent)
z_enc = thread_data.model.stochastic_encode(
init_latent,
torch.tensor([t_enc] * batch_size).to(thread_data.device),
opt_seed,
opt_ddim_eta,
opt_ddim_steps,
)
x_T = None if mask is None else init_latent x_T = None if mask is None else init_latent
# decode it if thread_data.test_sd2:
samples_ddim = thread_data.model.sample( from ldm.models.diffusion.ddim import DDIMSampler
t_enc,
c, sampler = DDIMSampler(thread_data.model)
z_enc,
unconditional_guidance_scale=opt_scale, sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
unconditional_conditioning=uc,
img_callback=img_callback, z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(thread_data.device))
mask=mask,
x_T=x_T, samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt_scale,unconditional_conditioning=uc, img_callback=img_callback)
sampler = 'ddim'
) else:
yield from samples_ddim z_enc = thread_data.model.stochastic_encode(
init_latent,
torch.tensor([t_enc] * batch_size).to(thread_data.device),
opt_seed,
opt_ddim_eta,
opt_ddim_steps,
)
# decode it
samples_ddim = thread_data.model.sample(
t_enc,
c,
z_enc,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
img_callback=img_callback,
mask=mask,
x_T=x_T,
sampler = 'ddim'
)
return samples_ddim
def gc(): def gc():
gc_collect() gc_collect()
@ -776,8 +920,16 @@ def load_mask(mask_str, h0, w0, newH, newW, invert=False):
# https://stackoverflow.com/a/61114178 # https://stackoverflow.com/a/61114178
def img_to_base64_str(img, output_format="PNG"): def img_to_base64_str(img, output_format="PNG"):
buffered = img_to_buffer(img, output_format)
return buffer_to_base64_str(buffered, output_format)
def img_to_buffer(img, output_format="PNG"):
buffered = BytesIO() buffered = BytesIO()
img.save(buffered, format=output_format) img.save(buffered, format=output_format)
buffered.seek(0)
return buffered
def buffer_to_base64_str(buffered, output_format="PNG"):
buffered.seek(0) buffered.seek(0)
img_byte = buffered.getvalue() img_byte = buffered.getvalue()
mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg" mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg"
@ -795,3 +947,48 @@ def base64_str_to_img(img_str):
buffered = base64_str_to_buffer(img_str) buffered = base64_str_to_buffer(img_str)
img = Image.open(buffered) img = Image.open(buffered)
return img return img
def split_weighted_subprompts(text):
"""
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
remaining = len(text)
prompts = []
weights = []
while remaining > 0:
if ":" in text:
idx = text.index(":") # first occurrence from start
# grab up to index as sub-prompt
prompt = text[:idx]
remaining -= idx
# remove from main text
text = text[idx+1:]
# find value for weight
if " " in text:
idx = text.index(" ") # first occurence
else: # no space, read to end
idx = len(text)
if idx != 0:
try:
weight = float(text[:idx])
except: # couldn't treat as float
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
weight = 1.0
else: # no value found
weight = 1.0
# remove from main text
remaining -= idx
text = text[idx+1:]
# append the sub-prompt and its weight
prompts.append(prompt)
weights.append(weight)
else: # no : found
if len(text) > 0: # there is still text though
# take remainder as weight 1
prompts.append(text)
weights.append(1.0)
remaining = 0
return prompts, weights

View File

@ -283,45 +283,26 @@ def thread_render(device):
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}') print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try: try:
if runtime.thread_data.device == 'cpu' and is_alive() > 1: if runtime.is_model_reload_necessary(task.request):
# CPU is not the only device. Keep track of active time to unload resources later.
runtime.thread_data.lastActive = time.time()
# Open data generator.
res = runtime.mk_img(task.request)
if current_model_path == task.request.use_stable_diffusion_model:
current_state = ServerStates.Rendering
else:
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
# Start reading from generator. runtime.reload_model()
dataQueue = None current_model_path = task.request.use_stable_diffusion_model
if task.request.stream_progress_updates: current_vae_path = task.request.use_vae_model
dataQueue = task.buffer_queue
for result in res: def step_callback():
if current_state == ServerStates.LoadingModel: global current_state_error
current_state = ServerStates.Rendering
current_model_path = task.request.use_stable_diffusion_model
current_vae_path = task.request.use_vae_model
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.thread_data.stop_processing = True runtime.thread_data.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration): if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error task.error = current_state_error
current_state_error = None current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result) task_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(result, str):
result = json.loads(result) current_state = ServerStates.Rendering
task.response = result task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback)
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
buf = runtime.base64_str_to_buffer(out_obj['data'])
task.temp_images[result['output'].index(out_obj)] = buf
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(task.request.session_id, TASK_TTL)
except Exception as e: except Exception as e:
task.error = e task.error = e
print(traceback.format_exc()) print(traceback.format_exc())

View File

@ -7,6 +7,7 @@ import traceback
import sys import sys
import os import os
import socket
import picklescan.scanner import picklescan.scanner
import rich import rich
@ -116,6 +117,8 @@ def setConfig(config):
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}")
if len(config_bat) > 0: if len(config_bat) > 0:
with open(config_bat_path, 'w', encoding='utf-8') as f: with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat)) f.write('\r\n'.join(config_bat))
@ -133,6 +136,8 @@ def setConfig(config):
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"")
if len(config_sh) > 1: if len(config_sh) > 1:
with open(config_sh_path, 'w', encoding='utf-8') as f: with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh)) f.write('\n'.join(config_sh))
@ -140,12 +145,19 @@ def setConfig(config):
print(traceback.format_exc()) print(traceback.format_exc())
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]): def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
config = getConfig()
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR] model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
if not model_name: # When None try user configured model. if not model_name: # When None try user configured model.
config = getConfig() # config = getConfig()
if 'model' in config and model_type in config['model']: if 'model' in config and model_type in config['model']:
model_name = config['model'][model_type] model_name = config['model'][model_type]
if model_name: if model_name:
is_sd2 = config.get('test_sd2', False)
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
model_name = 'sd-v1-4'
# Check models directory # Check models directory
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name) models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
for model_extension in model_extensions: for model_extension in model_extensions:
@ -188,6 +200,7 @@ class SetAppConfigRequest(BaseModel):
ui_open_browser_on_start: bool = None ui_open_browser_on_start: bool = None
listen_to_network: bool = None listen_to_network: bool = None
listen_port: int = None listen_port: int = None
test_sd2: bool = None
@app.post('/app_config') @app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest): async def setAppConfig(req : SetAppConfigRequest):
@ -208,6 +221,8 @@ async def setAppConfig(req : SetAppConfigRequest):
if 'net' not in config: if 'net' not in config:
config['net'] = {} config['net'] = {}
config['net']['listen_port'] = int(req.listen_port) config['net']['listen_port'] = int(req.listen_port)
if req.test_sd2 is not None:
config['test_sd2'] = req.test_sd2
try: try:
setConfig(config) setConfig(config)
@ -230,9 +245,9 @@ def is_malicious_model(file_path):
return False return False
except Exception as e: except Exception as e:
print('error while scanning', file_path, 'error:', e) print('error while scanning', file_path, 'error:', e)
return False return False
known_models = {}
def getModels(): def getModels():
models = { models = {
'active': { 'active': {
@ -255,9 +270,14 @@ def getModels():
if not file.endswith(model_extension): if not file.endswith(model_extension):
continue continue
if is_malicious_model(os.path.join(models_dir, file)): model_path = os.path.join(models_dir, file)
models['scan-error'] = file mtime = os.path.getmtime(model_path)
return mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime:
if is_malicious_model(model_path):
models['scan-error'] = file
return
known_models[model_path] = mtime
model_name = file[:-len(model_extension)] model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name) models['options'][model_type].append(model_name)
@ -286,6 +306,11 @@ def getUIPlugins():
return plugins return plugins
def getIPConfig():
ips = socket.gethostbyname_ex(socket.gethostname())
ips[2].append(ips[0])
return ips[2]
@app.get('/get/{key:path}') @app.get('/get/{key:path}')
def read_web_data(key:str=None): def read_web_data(key:str=None):
if not key: # /get without parameters, stable-diffusion easter egg. if not key: # /get without parameters, stable-diffusion easter egg.
@ -295,11 +320,14 @@ def read_web_data(key:str=None):
if config is None: if config is None:
config = APP_CONFIG_DEFAULTS config = APP_CONFIG_DEFAULTS
return JSONResponse(config, headers=NOCACHE_HEADERS) return JSONResponse(config, headers=NOCACHE_HEADERS)
elif key == 'devices': elif key == 'system_info':
config = getConfig() config = getConfig()
devices = task_manager.get_devices() system_info = {
devices['config'] = config.get('render_devices', "auto") 'devices': task_manager.get_devices(),
return JSONResponse(devices, headers=NOCACHE_HEADERS) 'hosts': getIPConfig(),
}
system_info['devices']['config'] = config.get('render_devices', "auto")
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
elif key == 'models': elif key == 'models':
return JSONResponse(getModels(), headers=NOCACHE_HEADERS) return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
@ -435,6 +463,9 @@ class LogSuppressFilter(logging.Filter):
return True return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
# Check models and prepare cache for UI open
getModels()
# Start the task_manager # Start the task_manager
task_manager.default_model_to_load = resolve_ckpt_to_use() task_manager.default_model_to_load = resolve_ckpt_to_use()
task_manager.default_vae_to_load = resolve_vae_to_use() task_manager.default_vae_to_load = resolve_vae_to_use()