forked from extern/easydiffusion
Merge pull request #576 from cmdr2/beta
v2.4.16 - Remove the use of git-apply
This commit is contained in:
commit
2e69ffcb5e
@ -21,6 +21,8 @@
|
|||||||
- A `What's New?` tab in the UI
|
- A `What's New?` tab in the UI
|
||||||
|
|
||||||
### Detailed changelog
|
### Detailed changelog
|
||||||
|
* 2.4.16 - 29 Nov 2022 - Bug fixes for SD 2.0 - remove the need for patching, default to SD 1.4 model if trying to load an SD2 model in SD1.4.
|
||||||
|
* 2.4.15 - 25 Nov 2022 - Experimental support for SD 2.0. Uses lots of memory, not optimized, probably GPU-only.
|
||||||
* 2.4.14 - 22 Nov 2022 - Change the backend to a custom fork of Stable Diffusion
|
* 2.4.14 - 22 Nov 2022 - Change the backend to a custom fork of Stable Diffusion
|
||||||
* 2.4.13 - 21 Nov 2022 - Change the modifier weight via mouse wheel, drag to reorder selected modifiers, and some more modifier-related fixes. Thanks @patriceac
|
* 2.4.13 - 21 Nov 2022 - Change the modifier weight via mouse wheel, drag to reorder selected modifiers, and some more modifier-related fixes. Thanks @patriceac
|
||||||
* 2.4.12 - 21 Nov 2022 - Another fix for improving how long images take to generate. Reduces the time taken for an enqueued task to start processing.
|
* 2.4.12 - 21 Nov 2022 - Another fix for improving how long images take to generate. Reduces the time taken for an enqueued task to start processing.
|
||||||
|
@ -29,6 +29,18 @@ call conda activate .\stable-diffusion\env
|
|||||||
call where python
|
call where python
|
||||||
call python --version
|
call python --version
|
||||||
|
|
||||||
|
@rem set the PYTHONPATH
|
||||||
|
cd stable-diffusion
|
||||||
|
set SD_DIR=%cd%
|
||||||
|
|
||||||
|
cd env\lib\site-packages
|
||||||
|
set PYTHONPATH=%SD_DIR%;%cd%
|
||||||
|
cd ..\..\..
|
||||||
|
echo PYTHONPATH=%PYTHONPATH%
|
||||||
|
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
@rem done
|
||||||
echo.
|
echo.
|
||||||
|
|
||||||
cmd /k
|
cmd /k
|
||||||
|
@ -35,6 +35,15 @@ if [ "$0" == "bash" ]; then
|
|||||||
which python
|
which python
|
||||||
python --version
|
python --version
|
||||||
|
|
||||||
|
# set the PYTHONPATH
|
||||||
|
cd stable-diffusion
|
||||||
|
SD_PATH=`pwd`
|
||||||
|
export PYTHONPATH="$SD_PATH:$SD_PATH/env/lib/python3.8/site-packages"
|
||||||
|
echo "PYTHONPATH=$PYTHONPATH"
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
# done
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
else
|
else
|
||||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||||
|
@ -27,6 +27,8 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
|
|||||||
|
|
||||||
@call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');"
|
@call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');"
|
||||||
|
|
||||||
|
if NOT DEFINED test_sd2 set test_sd2=N
|
||||||
|
|
||||||
@>nul findstr /m "sd_git_cloned" scripts\install_status.txt
|
@>nul findstr /m "sd_git_cloned" scripts\install_status.txt
|
||||||
@if "%ERRORLEVEL%" EQU "0" (
|
@if "%ERRORLEVEL%" EQU "0" (
|
||||||
@echo "Stable Diffusion's git repository was already installed. Updating.."
|
@echo "Stable Diffusion's git repository was already installed. Updating.."
|
||||||
@ -37,9 +39,13 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
|
|||||||
|
|
||||||
@call git reset --hard
|
@call git reset --hard
|
||||||
@call git pull
|
@call git pull
|
||||||
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
|
||||||
|
|
||||||
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
|
if "%test_sd2%" == "N" (
|
||||||
|
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
||||||
|
)
|
||||||
|
if "%test_sd2%" == "Y" (
|
||||||
|
@call git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9
|
||||||
|
)
|
||||||
|
|
||||||
@cd ..
|
@cd ..
|
||||||
) else (
|
) else (
|
||||||
@ -56,8 +62,6 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
|
|||||||
@cd stable-diffusion
|
@cd stable-diffusion
|
||||||
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
||||||
|
|
||||||
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
|
|
||||||
|
|
||||||
@cd ..
|
@cd ..
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -346,7 +350,9 @@ echo. > "..\models\vae\Put your VAE files here.txt"
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "%test_sd2%" == "Y" (
|
||||||
|
@call pip install open_clip_torch==2.0.2
|
||||||
|
)
|
||||||
|
|
||||||
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
|
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
|
||||||
@if "%ERRORLEVEL%" NEQ "0" (
|
@if "%ERRORLEVEL%" NEQ "0" (
|
||||||
|
@ -21,6 +21,10 @@ python -c "import os; import shutil; frm = 'sd-ui-files/ui/hotfix/9c24e6cd9f499d
|
|||||||
# Caution, this file will make your eyes and brain bleed. It's such an unholy mess.
|
# Caution, this file will make your eyes and brain bleed. It's such an unholy mess.
|
||||||
# Note to self: Please rewrite this in Python. For the sake of your own sanity.
|
# Note to self: Please rewrite this in Python. For the sake of your own sanity.
|
||||||
|
|
||||||
|
if [ "$test_sd2" == "" ]; then
|
||||||
|
export test_sd2="N"
|
||||||
|
fi
|
||||||
|
|
||||||
if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then
|
if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then
|
||||||
echo "Stable Diffusion's git repository was already installed. Updating.."
|
echo "Stable Diffusion's git repository was already installed. Updating.."
|
||||||
|
|
||||||
@ -30,9 +34,12 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
|
|||||||
|
|
||||||
git reset --hard
|
git reset --hard
|
||||||
git pull
|
git pull
|
||||||
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
|
||||||
|
|
||||||
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
|
if [ "$test_sd2" == "N" ]; then
|
||||||
|
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
||||||
|
elif [ "$test_sd2" == "Y" ]; then
|
||||||
|
git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9
|
||||||
|
fi
|
||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
else
|
else
|
||||||
@ -47,8 +54,6 @@ else
|
|||||||
cd stable-diffusion
|
cd stable-diffusion
|
||||||
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
||||||
|
|
||||||
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
|
|
||||||
|
|
||||||
cd ..
|
cd ..
|
||||||
fi
|
fi
|
||||||
|
|
||||||
@ -291,6 +296,9 @@ if [ ! -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then
|
|||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
if [ "$test_sd2" == "Y" ]; then
|
||||||
|
pip install open_clip_torch==2.0.2
|
||||||
|
fi
|
||||||
|
|
||||||
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
|
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
|
||||||
echo sd_weights_downloaded >> ../scripts/install_status.txt
|
echo sd_weights_downloaded >> ../scripts/install_status.txt
|
||||||
|
@ -22,7 +22,7 @@
|
|||||||
<div id="logo">
|
<div id="logo">
|
||||||
<h1>
|
<h1>
|
||||||
Stable Diffusion UI
|
Stable Diffusion UI
|
||||||
<small>v2.4.14 <span id="updateBranchLabel"></span></small>
|
<small>v2.4.16 <span id="updateBranchLabel"></span></small>
|
||||||
</h1>
|
</h1>
|
||||||
</div>
|
</div>
|
||||||
<div id="server-status">
|
<div id="server-status">
|
||||||
|
@ -210,7 +210,7 @@ code {
|
|||||||
}
|
}
|
||||||
.collapsible-content {
|
.collapsible-content {
|
||||||
display: block;
|
display: block;
|
||||||
padding-left: 15px;
|
padding-left: 10px;
|
||||||
}
|
}
|
||||||
.collapsible-content h5 {
|
.collapsible-content h5 {
|
||||||
padding: 5pt 0pt;
|
padding: 5pt 0pt;
|
||||||
@ -658,11 +658,15 @@ input::file-selector-button {
|
|||||||
opacity: 1;
|
opacity: 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* MOBILE SUPPORT */
|
/* Small screens */
|
||||||
@media screen and (max-width: 700px) {
|
@media screen and (max-width: 1265px) {
|
||||||
#top-nav {
|
#top-nav {
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* MOBILE SUPPORT */
|
||||||
|
@media screen and (max-width: 700px) {
|
||||||
body {
|
body {
|
||||||
margin: 0px;
|
margin: 0px;
|
||||||
}
|
}
|
||||||
@ -712,7 +716,7 @@ input::file-selector-button {
|
|||||||
padding-right: 0px;
|
padding-right: 0px;
|
||||||
}
|
}
|
||||||
#server-status {
|
#server-status {
|
||||||
display: none;
|
top: 75%;
|
||||||
}
|
}
|
||||||
.popup > div {
|
.popup > div {
|
||||||
padding-left: 5px !important;
|
padding-left: 5px !important;
|
||||||
|
@ -132,6 +132,14 @@ var PARAMETERS = [
|
|||||||
return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">`
|
return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">`
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
id: "test_sd2",
|
||||||
|
type: ParameterType.checkbox,
|
||||||
|
label: "Test SD 2.0",
|
||||||
|
note: "Experimental! High memory usage! GPU-only! Not the final version! Please restart the program after changing this.",
|
||||||
|
icon: "fa-fire",
|
||||||
|
default: false,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
id: "use_beta_channel",
|
id: "use_beta_channel",
|
||||||
type: ParameterType.checkbox,
|
type: ParameterType.checkbox,
|
||||||
@ -196,6 +204,7 @@ let saveToDiskField = document.querySelector('#save_to_disk')
|
|||||||
let diskPathField = document.querySelector('#diskPath')
|
let diskPathField = document.querySelector('#diskPath')
|
||||||
let listenToNetworkField = document.querySelector("#listen_to_network")
|
let listenToNetworkField = document.querySelector("#listen_to_network")
|
||||||
let listenPortField = document.querySelector("#listen_port")
|
let listenPortField = document.querySelector("#listen_port")
|
||||||
|
let testSD2Field = document.querySelector("#test_sd2")
|
||||||
let useBetaChannelField = document.querySelector("#use_beta_channel")
|
let useBetaChannelField = document.querySelector("#use_beta_channel")
|
||||||
let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start")
|
let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start")
|
||||||
|
|
||||||
@ -230,6 +239,12 @@ async function getAppConfig() {
|
|||||||
if (config.ui && config.ui.open_browser_on_start === false) {
|
if (config.ui && config.ui.open_browser_on_start === false) {
|
||||||
uiOpenBrowserOnStartField.checked = false
|
uiOpenBrowserOnStartField.checked = false
|
||||||
}
|
}
|
||||||
|
if ('test_sd2' in config) {
|
||||||
|
testSD2Field.checked = config['test_sd2']
|
||||||
|
}
|
||||||
|
|
||||||
|
let testSD2SettingEntry = getParameterSettingsEntry('test_sd2')
|
||||||
|
testSD2SettingEntry.style.display = (config.update_branch === 'beta' ? '' : 'none')
|
||||||
if (config.net && config.net.listen_to_network === false) {
|
if (config.net && config.net.listen_to_network === false) {
|
||||||
listenToNetworkField.checked = false
|
listenToNetworkField.checked = false
|
||||||
}
|
}
|
||||||
@ -372,7 +387,8 @@ saveSettingsBtn.addEventListener('click', function() {
|
|||||||
'update_branch': updateBranch,
|
'update_branch': updateBranch,
|
||||||
'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked,
|
'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked,
|
||||||
'listen_to_network': listenToNetworkField.checked,
|
'listen_to_network': listenToNetworkField.checked,
|
||||||
'listen_port': listenPortField.value
|
'listen_port': listenPortField.value,
|
||||||
|
'test_sd2': testSD2Field.checked
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
84
ui/sd_internal/ddim_callback_sd2.patch
Normal file
84
ui/sd_internal/ddim_callback_sd2.patch
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
|
||||||
|
index 27ead0e..6215939 100644
|
||||||
|
--- a/ldm/models/diffusion/ddim.py
|
||||||
|
+++ b/ldm/models/diffusion/ddim.py
|
||||||
|
@@ -100,7 +100,7 @@ class DDIMSampler(object):
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||||
|
|
||||||
|
- samples, intermediates = self.ddim_sampling(conditioning, size,
|
||||||
|
+ samples = self.ddim_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
@@ -117,7 +117,8 @@ class DDIMSampler(object):
|
||||||
|
dynamic_threshold=dynamic_threshold,
|
||||||
|
ucg_schedule=ucg_schedule
|
||||||
|
)
|
||||||
|
- return samples, intermediates
|
||||||
|
+ # return samples, intermediates
|
||||||
|
+ yield from samples
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def ddim_sampling(self, cond, shape,
|
||||||
|
@@ -168,14 +169,15 @@ class DDIMSampler(object):
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold)
|
||||||
|
img, pred_x0 = outs
|
||||||
|
- if callback: callback(i)
|
||||||
|
- if img_callback: img_callback(pred_x0, i)
|
||||||
|
+ if callback: yield from callback(i)
|
||||||
|
+ if img_callback: yield from img_callback(pred_x0, i)
|
||||||
|
|
||||||
|
if index % log_every_t == 0 or index == total_steps - 1:
|
||||||
|
intermediates['x_inter'].append(img)
|
||||||
|
intermediates['pred_x0'].append(pred_x0)
|
||||||
|
|
||||||
|
- return img, intermediates
|
||||||
|
+ # return img, intermediates
|
||||||
|
+ yield from img_callback(pred_x0, len(iterator)-1)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
||||||
|
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
|
||||||
|
index 7002a36..0951f39 100644
|
||||||
|
--- a/ldm/models/diffusion/plms.py
|
||||||
|
+++ b/ldm/models/diffusion/plms.py
|
||||||
|
@@ -96,7 +96,7 @@ class PLMSSampler(object):
|
||||||
|
size = (batch_size, C, H, W)
|
||||||
|
print(f'Data shape for PLMS sampling is {size}')
|
||||||
|
|
||||||
|
- samples, intermediates = self.plms_sampling(conditioning, size,
|
||||||
|
+ samples = self.plms_sampling(conditioning, size,
|
||||||
|
callback=callback,
|
||||||
|
img_callback=img_callback,
|
||||||
|
quantize_denoised=quantize_x0,
|
||||||
|
@@ -112,7 +112,8 @@ class PLMSSampler(object):
|
||||||
|
unconditional_conditioning=unconditional_conditioning,
|
||||||
|
dynamic_threshold=dynamic_threshold,
|
||||||
|
)
|
||||||
|
- return samples, intermediates
|
||||||
|
+ #return samples, intermediates
|
||||||
|
+ yield from samples
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def plms_sampling(self, cond, shape,
|
||||||
|
@@ -165,14 +166,15 @@ class PLMSSampler(object):
|
||||||
|
old_eps.append(e_t)
|
||||||
|
if len(old_eps) >= 4:
|
||||||
|
old_eps.pop(0)
|
||||||
|
- if callback: callback(i)
|
||||||
|
- if img_callback: img_callback(pred_x0, i)
|
||||||
|
+ if callback: yield from callback(i)
|
||||||
|
+ if img_callback: yield from img_callback(pred_x0, i)
|
||||||
|
|
||||||
|
if index % log_every_t == 0 or index == total_steps - 1:
|
||||||
|
intermediates['x_inter'].append(img)
|
||||||
|
intermediates['pred_x0'].append(pred_x0)
|
||||||
|
|
||||||
|
- return img, intermediates
|
||||||
|
+ # return img, intermediates
|
||||||
|
+ yield from img_callback(pred_x0, len(iterator)-1)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
|
@ -7,6 +7,7 @@ Notes:
|
|||||||
import json
|
import json
|
||||||
import os, re
|
import os, re
|
||||||
import traceback
|
import traceback
|
||||||
|
import queue
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gc import collect as gc_collect
|
from gc import collect as gc_collect
|
||||||
@ -21,7 +22,6 @@ from torch import autocast
|
|||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
from optimizedSD.optimUtils import split_weighted_subprompts
|
|
||||||
from transformers import logging
|
from transformers import logging
|
||||||
|
|
||||||
from gfpgan import GFPGANer
|
from gfpgan import GFPGANer
|
||||||
@ -76,8 +76,24 @@ def thread_init(device):
|
|||||||
thread_data.force_full_precision = False
|
thread_data.force_full_precision = False
|
||||||
thread_data.reduced_memory = True
|
thread_data.reduced_memory = True
|
||||||
|
|
||||||
|
thread_data.test_sd2 = isSD2()
|
||||||
|
|
||||||
device_manager.device_init(thread_data, device)
|
device_manager.device_init(thread_data, device)
|
||||||
|
|
||||||
|
# temp hack, will remove soon
|
||||||
|
def isSD2():
|
||||||
|
try:
|
||||||
|
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
|
||||||
|
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
|
||||||
|
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||||
|
if not os.path.exists(config_json_path):
|
||||||
|
return False
|
||||||
|
with open(config_json_path, 'r', encoding='utf-8') as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return config.get('test_sd2', False)
|
||||||
|
except Exception as e:
|
||||||
|
return False
|
||||||
|
|
||||||
def load_model_ckpt():
|
def load_model_ckpt():
|
||||||
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
|
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
|
||||||
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
|
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
|
||||||
@ -92,6 +108,13 @@ def load_model_ckpt():
|
|||||||
thread_data.precision = 'full'
|
thread_data.precision = 'full'
|
||||||
|
|
||||||
print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision)
|
print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision)
|
||||||
|
|
||||||
|
if thread_data.test_sd2:
|
||||||
|
load_model_ckpt_sd2()
|
||||||
|
else:
|
||||||
|
load_model_ckpt_sd1()
|
||||||
|
|
||||||
|
def load_model_ckpt_sd1():
|
||||||
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
||||||
li, lo = [], []
|
li, lo = [], []
|
||||||
for key, value in sd.items():
|
for key, value in sd.items():
|
||||||
@ -185,6 +208,38 @@ def load_model_ckpt():
|
|||||||
modelFS.device: {thread_data.modelFS.device}
|
modelFS.device: {thread_data.modelFS.device}
|
||||||
using precision: {thread_data.precision}''')
|
using precision: {thread_data.precision}''')
|
||||||
|
|
||||||
|
def load_model_ckpt_sd2():
|
||||||
|
config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if 'sd2_' in thread_data.ckpt_file else "configs/stable-diffusion/v1-inference.yaml"
|
||||||
|
config = OmegaConf.load(config_file)
|
||||||
|
verbose = False
|
||||||
|
|
||||||
|
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
||||||
|
|
||||||
|
thread_data.model = instantiate_from_config(config.model)
|
||||||
|
m, u = thread_data.model.load_state_dict(sd, strict=False)
|
||||||
|
if len(m) > 0 and verbose:
|
||||||
|
print("missing keys:")
|
||||||
|
print(m)
|
||||||
|
if len(u) > 0 and verbose:
|
||||||
|
print("unexpected keys:")
|
||||||
|
print(u)
|
||||||
|
|
||||||
|
thread_data.model.to(thread_data.device)
|
||||||
|
thread_data.model.eval()
|
||||||
|
del sd
|
||||||
|
|
||||||
|
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||||
|
thread_data.model.half()
|
||||||
|
thread_data.model_is_half = True
|
||||||
|
thread_data.model_fs_is_half = True
|
||||||
|
else:
|
||||||
|
thread_data.model_is_half = False
|
||||||
|
thread_data.model_fs_is_half = False
|
||||||
|
|
||||||
|
print(f'''loaded model
|
||||||
|
model file: {thread_data.ckpt_file}.ckpt
|
||||||
|
using precision: {thread_data.precision}''')
|
||||||
|
|
||||||
def unload_filters():
|
def unload_filters():
|
||||||
if thread_data.model_gfpgan is not None:
|
if thread_data.model_gfpgan is not None:
|
||||||
if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu')
|
if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu')
|
||||||
@ -204,10 +259,11 @@ def unload_models():
|
|||||||
if thread_data.model is not None:
|
if thread_data.model is not None:
|
||||||
print('Unloading models...')
|
print('Unloading models...')
|
||||||
if thread_data.device != 'cpu':
|
if thread_data.device != 'cpu':
|
||||||
thread_data.modelFS.to('cpu')
|
if not thread_data.test_sd2:
|
||||||
thread_data.modelCS.to('cpu')
|
thread_data.modelFS.to('cpu')
|
||||||
thread_data.model.model1.to("cpu")
|
thread_data.modelCS.to('cpu')
|
||||||
thread_data.model.model2.to("cpu")
|
thread_data.model.model1.to("cpu")
|
||||||
|
thread_data.model.model2.to("cpu")
|
||||||
|
|
||||||
del thread_data.model
|
del thread_data.model
|
||||||
del thread_data.modelCS
|
del thread_data.modelCS
|
||||||
@ -337,45 +393,73 @@ def apply_filters(filter_name, image_data, model_path=None):
|
|||||||
|
|
||||||
return image_data
|
return image_data
|
||||||
|
|
||||||
def mk_img(req: Request):
|
def is_model_reload_necessary(req: Request):
|
||||||
|
# custom model support:
|
||||||
|
# the req.use_stable_diffusion_model needs to be a valid path
|
||||||
|
# to the ckpt file (without the extension).
|
||||||
|
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
|
||||||
|
|
||||||
|
needs_model_reload = False
|
||||||
|
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
|
||||||
|
thread_data.ckpt_file = req.use_stable_diffusion_model
|
||||||
|
thread_data.vae_file = req.use_vae_model
|
||||||
|
needs_model_reload = True
|
||||||
|
|
||||||
|
if thread_data.device != 'cpu':
|
||||||
|
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
|
||||||
|
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
|
||||||
|
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
||||||
|
needs_model_reload = True
|
||||||
|
|
||||||
|
return needs_model_reload
|
||||||
|
|
||||||
|
def reload_model():
|
||||||
|
unload_models()
|
||||||
|
unload_filters()
|
||||||
|
load_model_ckpt()
|
||||||
|
|
||||||
|
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
try:
|
try:
|
||||||
yield from do_mk_img(req)
|
return do_mk_img(req, data_queue, task_temp_images, step_callback)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
if thread_data.device != 'cpu':
|
if thread_data.device != 'cpu' and not thread_data.test_sd2:
|
||||||
thread_data.modelFS.to('cpu')
|
thread_data.modelFS.to('cpu')
|
||||||
thread_data.modelCS.to('cpu')
|
thread_data.modelCS.to('cpu')
|
||||||
thread_data.model.model1.to("cpu")
|
thread_data.model.model1.to("cpu")
|
||||||
thread_data.model.model2.to("cpu")
|
thread_data.model.model2.to("cpu")
|
||||||
|
|
||||||
gc() # Release from memory.
|
gc() # Release from memory.
|
||||||
yield json.dumps({
|
data_queue.put(json.dumps({
|
||||||
"status": 'failed',
|
"status": 'failed',
|
||||||
"detail": str(e)
|
"detail": str(e)
|
||||||
})
|
}))
|
||||||
|
raise e
|
||||||
|
|
||||||
def update_temp_img(req, x_samples):
|
def update_temp_img(req, x_samples, task_temp_images: list):
|
||||||
partial_images = []
|
partial_images = []
|
||||||
for i in range(req.num_outputs):
|
for i in range(req.num_outputs):
|
||||||
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
if thread_data.test_sd2:
|
||||||
|
x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||||
|
else:
|
||||||
|
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||||
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
img = Image.fromarray(x_sample)
|
img = Image.fromarray(x_sample)
|
||||||
buf = BytesIO()
|
buf = img_to_buffer(img, output_format='JPEG')
|
||||||
img.save(buf, format='JPEG')
|
|
||||||
buf.seek(0)
|
|
||||||
|
|
||||||
del img, x_sample, x_sample_ddim
|
del img, x_sample, x_sample_ddim
|
||||||
# don't delete x_samples, it is used in the code that called this callback
|
# don't delete x_samples, it is used in the code that called this callback
|
||||||
|
|
||||||
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
|
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
|
||||||
|
task_temp_images[i] = buf
|
||||||
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
||||||
return partial_images
|
return partial_images
|
||||||
|
|
||||||
# Build and return the apropriate generator for do_mk_img
|
# Build and return the apropriate generator for do_mk_img
|
||||||
def get_image_progress_generator(req, extra_props=None):
|
def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
|
||||||
if not req.stream_progress_updates:
|
if not req.stream_progress_updates:
|
||||||
def empty_callback(x_samples, i): return x_samples
|
def empty_callback(x_samples, i): return x_samples
|
||||||
return empty_callback
|
return empty_callback
|
||||||
@ -394,15 +478,17 @@ def get_image_progress_generator(req, extra_props=None):
|
|||||||
progress.update(extra_props)
|
progress.update(extra_props)
|
||||||
|
|
||||||
if req.stream_image_progress and i % 5 == 0:
|
if req.stream_image_progress and i % 5 == 0:
|
||||||
progress['output'] = update_temp_img(req, x_samples)
|
progress['output'] = update_temp_img(req, x_samples, task_temp_images)
|
||||||
|
|
||||||
yield json.dumps(progress)
|
data_queue.put(json.dumps(progress))
|
||||||
|
|
||||||
|
step_callback()
|
||||||
|
|
||||||
if thread_data.stop_processing:
|
if thread_data.stop_processing:
|
||||||
raise UserInitiatedStop("User requested that we stop processing")
|
raise UserInitiatedStop("User requested that we stop processing")
|
||||||
return img_callback
|
return img_callback
|
||||||
|
|
||||||
def do_mk_img(req: Request):
|
def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
thread_data.stop_processing = False
|
thread_data.stop_processing = False
|
||||||
|
|
||||||
res = Response()
|
res = Response()
|
||||||
@ -411,29 +497,7 @@ def do_mk_img(req: Request):
|
|||||||
|
|
||||||
thread_data.temp_images.clear()
|
thread_data.temp_images.clear()
|
||||||
|
|
||||||
# custom model support:
|
if thread_data.turbo != req.turbo and not thread_data.test_sd2:
|
||||||
# the req.use_stable_diffusion_model needs to be a valid path
|
|
||||||
# to the ckpt file (without the extension).
|
|
||||||
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
|
|
||||||
|
|
||||||
needs_model_reload = False
|
|
||||||
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
|
|
||||||
thread_data.ckpt_file = req.use_stable_diffusion_model
|
|
||||||
thread_data.vae_file = req.use_vae_model
|
|
||||||
needs_model_reload = True
|
|
||||||
|
|
||||||
if thread_data.device != 'cpu':
|
|
||||||
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
|
|
||||||
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
|
|
||||||
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
|
||||||
needs_model_reload = True
|
|
||||||
|
|
||||||
if needs_model_reload:
|
|
||||||
unload_models()
|
|
||||||
unload_filters()
|
|
||||||
load_model_ckpt()
|
|
||||||
|
|
||||||
if thread_data.turbo != req.turbo:
|
|
||||||
thread_data.turbo = req.turbo
|
thread_data.turbo = req.turbo
|
||||||
thread_data.model.turbo = req.turbo
|
thread_data.model.turbo = req.turbo
|
||||||
|
|
||||||
@ -478,10 +542,14 @@ def do_mk_img(req: Request):
|
|||||||
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||||
init_image = init_image.half()
|
init_image = init_image.half()
|
||||||
|
|
||||||
thread_data.modelFS.to(thread_data.device)
|
if not thread_data.test_sd2:
|
||||||
|
thread_data.modelFS.to(thread_data.device)
|
||||||
|
|
||||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||||
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
|
if thread_data.test_sd2:
|
||||||
|
init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space
|
||||||
|
else:
|
||||||
|
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
|
||||||
|
|
||||||
if req.mask is not None:
|
if req.mask is not None:
|
||||||
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
|
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
|
||||||
@ -493,7 +561,8 @@ def do_mk_img(req: Request):
|
|||||||
|
|
||||||
# Send to CPU and wait until complete.
|
# Send to CPU and wait until complete.
|
||||||
# wait_model_move_to(thread_data.modelFS, 'cpu')
|
# wait_model_move_to(thread_data.modelFS, 'cpu')
|
||||||
move_to_cpu(thread_data.modelFS)
|
if not thread_data.test_sd2:
|
||||||
|
move_to_cpu(thread_data.modelFS)
|
||||||
|
|
||||||
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||||
t_enc = int(req.prompt_strength * req.num_inference_steps)
|
t_enc = int(req.prompt_strength * req.num_inference_steps)
|
||||||
@ -509,11 +578,14 @@ def do_mk_img(req: Request):
|
|||||||
for prompts in tqdm(data, desc="data"):
|
for prompts in tqdm(data, desc="data"):
|
||||||
|
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
if thread_data.reduced_memory:
|
if thread_data.reduced_memory and not thread_data.test_sd2:
|
||||||
thread_data.modelCS.to(thread_data.device)
|
thread_data.modelCS.to(thread_data.device)
|
||||||
uc = None
|
uc = None
|
||||||
if req.guidance_scale != 1.0:
|
if req.guidance_scale != 1.0:
|
||||||
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
if thread_data.test_sd2:
|
||||||
|
uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt])
|
||||||
|
else:
|
||||||
|
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
|
|
||||||
@ -526,15 +598,21 @@ def do_mk_img(req: Request):
|
|||||||
weight = weights[i]
|
weight = weights[i]
|
||||||
# if not skip_normalize:
|
# if not skip_normalize:
|
||||||
weight = weight / totalWeight
|
weight = weight / totalWeight
|
||||||
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
if thread_data.test_sd2:
|
||||||
|
c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||||
|
else:
|
||||||
|
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||||
else:
|
else:
|
||||||
c = thread_data.modelCS.get_learned_conditioning(prompts)
|
if thread_data.test_sd2:
|
||||||
|
c = thread_data.model.get_learned_conditioning(prompts)
|
||||||
|
else:
|
||||||
|
c = thread_data.modelCS.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
if thread_data.reduced_memory:
|
if thread_data.reduced_memory and not thread_data.test_sd2:
|
||||||
thread_data.modelFS.to(thread_data.device)
|
thread_data.modelFS.to(thread_data.device)
|
||||||
|
|
||||||
n_steps = req.num_inference_steps if req.init_image is None else t_enc
|
n_steps = req.num_inference_steps if req.init_image is None else t_enc
|
||||||
img_callback = get_image_progress_generator(req, {"total_steps": n_steps})
|
img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps})
|
||||||
|
|
||||||
# run the handler
|
# run the handler
|
||||||
try:
|
try:
|
||||||
@ -542,14 +620,7 @@ def do_mk_img(req: Request):
|
|||||||
if handler == _txt2img:
|
if handler == _txt2img:
|
||||||
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
|
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
|
||||||
else:
|
else:
|
||||||
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask)
|
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f)
|
||||||
|
|
||||||
if req.stream_progress_updates:
|
|
||||||
yield from x_samples
|
|
||||||
if hasattr(thread_data, 'partial_x_samples'):
|
|
||||||
if thread_data.partial_x_samples is not None:
|
|
||||||
x_samples = thread_data.partial_x_samples
|
|
||||||
del thread_data.partial_x_samples
|
|
||||||
except UserInitiatedStop:
|
except UserInitiatedStop:
|
||||||
if not hasattr(thread_data, 'partial_x_samples'):
|
if not hasattr(thread_data, 'partial_x_samples'):
|
||||||
continue
|
continue
|
||||||
@ -562,7 +633,10 @@ def do_mk_img(req: Request):
|
|||||||
print("decoding images")
|
print("decoding images")
|
||||||
img_data = [None] * batch_size
|
img_data = [None] * batch_size
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
if thread_data.test_sd2:
|
||||||
|
x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||||
|
else:
|
||||||
|
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
@ -591,9 +665,11 @@ def do_mk_img(req: Request):
|
|||||||
save_metadata(meta_out_path, req, prompts[0], opt_seed)
|
save_metadata(meta_out_path, req, prompts[0], opt_seed)
|
||||||
|
|
||||||
if return_orig_img:
|
if return_orig_img:
|
||||||
img_str = img_to_base64_str(img, req.output_format)
|
img_buffer = img_to_buffer(img, req.output_format)
|
||||||
|
img_str = buffer_to_base64_str(img_buffer, req.output_format)
|
||||||
res_image_orig = ResponseImage(data=img_str, seed=opt_seed)
|
res_image_orig = ResponseImage(data=img_str, seed=opt_seed)
|
||||||
res.images.append(res_image_orig)
|
res.images.append(res_image_orig)
|
||||||
|
task_temp_images[i] = img_buffer
|
||||||
|
|
||||||
if req.save_to_disk_path is not None:
|
if req.save_to_disk_path is not None:
|
||||||
res_image_orig.path_abs = img_out_path
|
res_image_orig.path_abs = img_out_path
|
||||||
@ -609,9 +685,11 @@ def do_mk_img(req: Request):
|
|||||||
filters_applied.append(req.use_upscale)
|
filters_applied.append(req.use_upscale)
|
||||||
if (len(filters_applied) > 0):
|
if (len(filters_applied) > 0):
|
||||||
filtered_image = Image.fromarray(img_data[i])
|
filtered_image = Image.fromarray(img_data[i])
|
||||||
filtered_img_data = img_to_base64_str(filtered_image, req.output_format)
|
filtered_buffer = img_to_buffer(filtered_image, req.output_format)
|
||||||
|
filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format)
|
||||||
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
|
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
|
||||||
res.images.append(response_image)
|
res.images.append(response_image)
|
||||||
|
task_temp_images[i] = filtered_buffer
|
||||||
if req.save_to_disk_path is not None:
|
if req.save_to_disk_path is not None:
|
||||||
filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied))
|
filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied))
|
||||||
save_image(filtered_image, filtered_img_out_path)
|
save_image(filtered_image, filtered_img_out_path)
|
||||||
@ -622,14 +700,18 @@ def do_mk_img(req: Request):
|
|||||||
|
|
||||||
# if thread_data.reduced_memory:
|
# if thread_data.reduced_memory:
|
||||||
# unload_filters()
|
# unload_filters()
|
||||||
move_to_cpu(thread_data.modelFS)
|
if not thread_data.test_sd2:
|
||||||
|
move_to_cpu(thread_data.modelFS)
|
||||||
del img_data
|
del img_data
|
||||||
gc()
|
gc()
|
||||||
if thread_data.device != 'cpu':
|
if thread_data.device != 'cpu':
|
||||||
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb')
|
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb')
|
||||||
|
|
||||||
print('Task completed')
|
print('Task completed')
|
||||||
yield json.dumps(res.json())
|
res = res.json()
|
||||||
|
data_queue.put(json.dumps(res))
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
def save_image(img, img_out_path):
|
def save_image(img, img_out_path):
|
||||||
try:
|
try:
|
||||||
@ -664,51 +746,109 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
|||||||
# Send to CPU and wait until complete.
|
# Send to CPU and wait until complete.
|
||||||
# wait_model_move_to(thread_data.modelCS, 'cpu')
|
# wait_model_move_to(thread_data.modelCS, 'cpu')
|
||||||
|
|
||||||
move_to_cpu(thread_data.modelCS)
|
if not thread_data.test_sd2:
|
||||||
|
move_to_cpu(thread_data.modelCS)
|
||||||
|
|
||||||
if sampler_name == 'ddim':
|
if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'):
|
||||||
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0')
|
||||||
|
|
||||||
samples_ddim = thread_data.model.sample(
|
|
||||||
S=opt_ddim_steps,
|
|
||||||
conditioning=c,
|
|
||||||
seed=opt_seed,
|
|
||||||
shape=shape,
|
|
||||||
verbose=False,
|
|
||||||
unconditional_guidance_scale=opt_scale,
|
|
||||||
unconditional_conditioning=uc,
|
|
||||||
eta=opt_ddim_eta,
|
|
||||||
x_T=start_code,
|
|
||||||
img_callback=img_callback,
|
|
||||||
mask=mask,
|
|
||||||
sampler = sampler_name,
|
|
||||||
)
|
|
||||||
yield from samples_ddim
|
|
||||||
|
|
||||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
|
# samples, _ = sampler.sample(S=opt.steps,
|
||||||
|
# conditioning=c,
|
||||||
|
# batch_size=opt.n_samples,
|
||||||
|
# shape=shape,
|
||||||
|
# verbose=False,
|
||||||
|
# unconditional_guidance_scale=opt.scale,
|
||||||
|
# unconditional_conditioning=uc,
|
||||||
|
# eta=opt.ddim_eta,
|
||||||
|
# x_T=start_code)
|
||||||
|
|
||||||
|
if thread_data.test_sd2:
|
||||||
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
|
|
||||||
|
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||||
|
|
||||||
|
if sampler_name == 'plms':
|
||||||
|
sampler = PLMSSampler(thread_data.model)
|
||||||
|
elif sampler_name == 'ddim':
|
||||||
|
sampler = DDIMSampler(thread_data.model)
|
||||||
|
|
||||||
|
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||||
|
|
||||||
|
|
||||||
|
samples_ddim, intermediates = sampler.sample(
|
||||||
|
S=opt_ddim_steps,
|
||||||
|
conditioning=c,
|
||||||
|
batch_size=opt_n_samples,
|
||||||
|
seed=opt_seed,
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=opt_scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=opt_ddim_eta,
|
||||||
|
x_T=start_code,
|
||||||
|
img_callback=img_callback,
|
||||||
|
mask=mask,
|
||||||
|
sampler = sampler_name,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if sampler_name == 'ddim':
|
||||||
|
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||||
|
|
||||||
|
samples_ddim = thread_data.model.sample(
|
||||||
|
S=opt_ddim_steps,
|
||||||
|
conditioning=c,
|
||||||
|
seed=opt_seed,
|
||||||
|
shape=shape,
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=opt_scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=opt_ddim_eta,
|
||||||
|
x_T=start_code,
|
||||||
|
img_callback=img_callback,
|
||||||
|
mask=mask,
|
||||||
|
sampler = sampler_name,
|
||||||
|
)
|
||||||
|
return samples_ddim
|
||||||
|
|
||||||
|
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1):
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
z_enc = thread_data.model.stochastic_encode(
|
|
||||||
init_latent,
|
|
||||||
torch.tensor([t_enc] * batch_size).to(thread_data.device),
|
|
||||||
opt_seed,
|
|
||||||
opt_ddim_eta,
|
|
||||||
opt_ddim_steps,
|
|
||||||
)
|
|
||||||
x_T = None if mask is None else init_latent
|
x_T = None if mask is None else init_latent
|
||||||
|
|
||||||
# decode it
|
if thread_data.test_sd2:
|
||||||
samples_ddim = thread_data.model.sample(
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
t_enc,
|
|
||||||
c,
|
sampler = DDIMSampler(thread_data.model)
|
||||||
z_enc,
|
|
||||||
unconditional_guidance_scale=opt_scale,
|
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||||
unconditional_conditioning=uc,
|
|
||||||
img_callback=img_callback,
|
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(thread_data.device))
|
||||||
mask=mask,
|
|
||||||
x_T=x_T,
|
samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt_scale,unconditional_conditioning=uc, img_callback=img_callback)
|
||||||
sampler = 'ddim'
|
|
||||||
)
|
else:
|
||||||
yield from samples_ddim
|
z_enc = thread_data.model.stochastic_encode(
|
||||||
|
init_latent,
|
||||||
|
torch.tensor([t_enc] * batch_size).to(thread_data.device),
|
||||||
|
opt_seed,
|
||||||
|
opt_ddim_eta,
|
||||||
|
opt_ddim_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# decode it
|
||||||
|
samples_ddim = thread_data.model.sample(
|
||||||
|
t_enc,
|
||||||
|
c,
|
||||||
|
z_enc,
|
||||||
|
unconditional_guidance_scale=opt_scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
img_callback=img_callback,
|
||||||
|
mask=mask,
|
||||||
|
x_T=x_T,
|
||||||
|
sampler = 'ddim'
|
||||||
|
)
|
||||||
|
return samples_ddim
|
||||||
|
|
||||||
def gc():
|
def gc():
|
||||||
gc_collect()
|
gc_collect()
|
||||||
@ -776,8 +916,16 @@ def load_mask(mask_str, h0, w0, newH, newW, invert=False):
|
|||||||
|
|
||||||
# https://stackoverflow.com/a/61114178
|
# https://stackoverflow.com/a/61114178
|
||||||
def img_to_base64_str(img, output_format="PNG"):
|
def img_to_base64_str(img, output_format="PNG"):
|
||||||
|
buffered = img_to_buffer(img, output_format)
|
||||||
|
return buffer_to_base64_str(buffered, output_format)
|
||||||
|
|
||||||
|
def img_to_buffer(img, output_format="PNG"):
|
||||||
buffered = BytesIO()
|
buffered = BytesIO()
|
||||||
img.save(buffered, format=output_format)
|
img.save(buffered, format=output_format)
|
||||||
|
buffered.seek(0)
|
||||||
|
return buffered
|
||||||
|
|
||||||
|
def buffer_to_base64_str(buffered, output_format="PNG"):
|
||||||
buffered.seek(0)
|
buffered.seek(0)
|
||||||
img_byte = buffered.getvalue()
|
img_byte = buffered.getvalue()
|
||||||
mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg"
|
mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg"
|
||||||
@ -795,3 +943,48 @@ def base64_str_to_img(img_str):
|
|||||||
buffered = base64_str_to_buffer(img_str)
|
buffered = base64_str_to_buffer(img_str)
|
||||||
img = Image.open(buffered)
|
img = Image.open(buffered)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
def split_weighted_subprompts(text):
|
||||||
|
"""
|
||||||
|
grabs all text up to the first occurrence of ':'
|
||||||
|
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||||
|
if ':' has no value defined, defaults to 1.0
|
||||||
|
repeats until no text remaining
|
||||||
|
"""
|
||||||
|
remaining = len(text)
|
||||||
|
prompts = []
|
||||||
|
weights = []
|
||||||
|
while remaining > 0:
|
||||||
|
if ":" in text:
|
||||||
|
idx = text.index(":") # first occurrence from start
|
||||||
|
# grab up to index as sub-prompt
|
||||||
|
prompt = text[:idx]
|
||||||
|
remaining -= idx
|
||||||
|
# remove from main text
|
||||||
|
text = text[idx+1:]
|
||||||
|
# find value for weight
|
||||||
|
if " " in text:
|
||||||
|
idx = text.index(" ") # first occurence
|
||||||
|
else: # no space, read to end
|
||||||
|
idx = len(text)
|
||||||
|
if idx != 0:
|
||||||
|
try:
|
||||||
|
weight = float(text[:idx])
|
||||||
|
except: # couldn't treat as float
|
||||||
|
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
|
||||||
|
weight = 1.0
|
||||||
|
else: # no value found
|
||||||
|
weight = 1.0
|
||||||
|
# remove from main text
|
||||||
|
remaining -= idx
|
||||||
|
text = text[idx+1:]
|
||||||
|
# append the sub-prompt and its weight
|
||||||
|
prompts.append(prompt)
|
||||||
|
weights.append(weight)
|
||||||
|
else: # no : found
|
||||||
|
if len(text) > 0: # there is still text though
|
||||||
|
# take remainder as weight 1
|
||||||
|
prompts.append(text)
|
||||||
|
weights.append(1.0)
|
||||||
|
remaining = 0
|
||||||
|
return prompts, weights
|
||||||
|
@ -283,45 +283,26 @@ def thread_render(device):
|
|||||||
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
|
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
|
||||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||||
try:
|
try:
|
||||||
if runtime.thread_data.device == 'cpu' and is_alive() > 1:
|
if runtime.is_model_reload_necessary(task.request):
|
||||||
# CPU is not the only device. Keep track of active time to unload resources later.
|
|
||||||
runtime.thread_data.lastActive = time.time()
|
|
||||||
# Open data generator.
|
|
||||||
res = runtime.mk_img(task.request)
|
|
||||||
if current_model_path == task.request.use_stable_diffusion_model:
|
|
||||||
current_state = ServerStates.Rendering
|
|
||||||
else:
|
|
||||||
current_state = ServerStates.LoadingModel
|
current_state = ServerStates.LoadingModel
|
||||||
# Start reading from generator.
|
runtime.reload_model()
|
||||||
dataQueue = None
|
current_model_path = task.request.use_stable_diffusion_model
|
||||||
if task.request.stream_progress_updates:
|
current_vae_path = task.request.use_vae_model
|
||||||
dataQueue = task.buffer_queue
|
|
||||||
for result in res:
|
def step_callback():
|
||||||
if current_state == ServerStates.LoadingModel:
|
global current_state_error
|
||||||
current_state = ServerStates.Rendering
|
|
||||||
current_model_path = task.request.use_stable_diffusion_model
|
|
||||||
current_vae_path = task.request.use_vae_model
|
|
||||||
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
||||||
runtime.thread_data.stop_processing = True
|
runtime.thread_data.stop_processing = True
|
||||||
if isinstance(current_state_error, StopAsyncIteration):
|
if isinstance(current_state_error, StopAsyncIteration):
|
||||||
task.error = current_state_error
|
task.error = current_state_error
|
||||||
current_state_error = None
|
current_state_error = None
|
||||||
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
||||||
if dataQueue:
|
|
||||||
dataQueue.put(result)
|
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||||
if isinstance(result, str):
|
|
||||||
result = json.loads(result)
|
current_state = ServerStates.Rendering
|
||||||
task.response = result
|
task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback)
|
||||||
if 'output' in result:
|
|
||||||
for out_obj in result['output']:
|
|
||||||
if 'path' in out_obj:
|
|
||||||
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
|
|
||||||
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
|
|
||||||
elif 'data' in out_obj:
|
|
||||||
buf = runtime.base64_str_to_buffer(out_obj['data'])
|
|
||||||
task.temp_images[result['output'].index(out_obj)] = buf
|
|
||||||
# Before looping back to the generator, mark cache as still alive.
|
|
||||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
task.error = e
|
task.error = e
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
16
ui/server.py
16
ui/server.py
@ -116,6 +116,8 @@ def setConfig(config):
|
|||||||
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
|
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
|
||||||
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
|
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
|
||||||
|
|
||||||
|
config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}")
|
||||||
|
|
||||||
if len(config_bat) > 0:
|
if len(config_bat) > 0:
|
||||||
with open(config_bat_path, 'w', encoding='utf-8') as f:
|
with open(config_bat_path, 'w', encoding='utf-8') as f:
|
||||||
f.write('\r\n'.join(config_bat))
|
f.write('\r\n'.join(config_bat))
|
||||||
@ -133,6 +135,8 @@ def setConfig(config):
|
|||||||
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
|
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
|
||||||
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
|
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
|
||||||
|
|
||||||
|
config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"")
|
||||||
|
|
||||||
if len(config_sh) > 1:
|
if len(config_sh) > 1:
|
||||||
with open(config_sh_path, 'w', encoding='utf-8') as f:
|
with open(config_sh_path, 'w', encoding='utf-8') as f:
|
||||||
f.write('\n'.join(config_sh))
|
f.write('\n'.join(config_sh))
|
||||||
@ -140,12 +144,19 @@ def setConfig(config):
|
|||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
|
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
|
||||||
|
config = getConfig()
|
||||||
|
|
||||||
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
|
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
|
||||||
if not model_name: # When None try user configured model.
|
if not model_name: # When None try user configured model.
|
||||||
config = getConfig()
|
# config = getConfig()
|
||||||
if 'model' in config and model_type in config['model']:
|
if 'model' in config and model_type in config['model']:
|
||||||
model_name = config['model'][model_type]
|
model_name = config['model'][model_type]
|
||||||
if model_name:
|
if model_name:
|
||||||
|
is_sd2 = config.get('test_sd2', False)
|
||||||
|
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
|
||||||
|
print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
|
||||||
|
model_name = 'sd-v1-4'
|
||||||
|
|
||||||
# Check models directory
|
# Check models directory
|
||||||
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
|
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
|
||||||
for model_extension in model_extensions:
|
for model_extension in model_extensions:
|
||||||
@ -188,6 +199,7 @@ class SetAppConfigRequest(BaseModel):
|
|||||||
ui_open_browser_on_start: bool = None
|
ui_open_browser_on_start: bool = None
|
||||||
listen_to_network: bool = None
|
listen_to_network: bool = None
|
||||||
listen_port: int = None
|
listen_port: int = None
|
||||||
|
test_sd2: bool = None
|
||||||
|
|
||||||
@app.post('/app_config')
|
@app.post('/app_config')
|
||||||
async def setAppConfig(req : SetAppConfigRequest):
|
async def setAppConfig(req : SetAppConfigRequest):
|
||||||
@ -208,6 +220,8 @@ async def setAppConfig(req : SetAppConfigRequest):
|
|||||||
if 'net' not in config:
|
if 'net' not in config:
|
||||||
config['net'] = {}
|
config['net'] = {}
|
||||||
config['net']['listen_port'] = int(req.listen_port)
|
config['net']['listen_port'] = int(req.listen_port)
|
||||||
|
if req.test_sd2 is not None:
|
||||||
|
config['test_sd2'] = req.test_sd2
|
||||||
try:
|
try:
|
||||||
setConfig(config)
|
setConfig(config)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user