Merge pull request #576 from cmdr2/beta

v2.4.16 - Remove the use of git-apply
This commit is contained in:
cmdr2 2022-11-30 12:12:17 +05:30 committed by GitHub
commit 2e69ffcb5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 480 additions and 151 deletions

View File

@ -21,6 +21,8 @@
- A `What's New?` tab in the UI - A `What's New?` tab in the UI
### Detailed changelog ### Detailed changelog
* 2.4.16 - 29 Nov 2022 - Bug fixes for SD 2.0 - remove the need for patching, default to SD 1.4 model if trying to load an SD2 model in SD1.4.
* 2.4.15 - 25 Nov 2022 - Experimental support for SD 2.0. Uses lots of memory, not optimized, probably GPU-only.
* 2.4.14 - 22 Nov 2022 - Change the backend to a custom fork of Stable Diffusion * 2.4.14 - 22 Nov 2022 - Change the backend to a custom fork of Stable Diffusion
* 2.4.13 - 21 Nov 2022 - Change the modifier weight via mouse wheel, drag to reorder selected modifiers, and some more modifier-related fixes. Thanks @patriceac * 2.4.13 - 21 Nov 2022 - Change the modifier weight via mouse wheel, drag to reorder selected modifiers, and some more modifier-related fixes. Thanks @patriceac
* 2.4.12 - 21 Nov 2022 - Another fix for improving how long images take to generate. Reduces the time taken for an enqueued task to start processing. * 2.4.12 - 21 Nov 2022 - Another fix for improving how long images take to generate. Reduces the time taken for an enqueued task to start processing.

View File

@ -29,6 +29,18 @@ call conda activate .\stable-diffusion\env
call where python call where python
call python --version call python --version
@rem set the PYTHONPATH
cd stable-diffusion
set SD_DIR=%cd%
cd env\lib\site-packages
set PYTHONPATH=%SD_DIR%;%cd%
cd ..\..\..
echo PYTHONPATH=%PYTHONPATH%
cd ..
@rem done
echo. echo.
cmd /k cmd /k

View File

@ -35,6 +35,15 @@ if [ "$0" == "bash" ]; then
which python which python
python --version python --version
# set the PYTHONPATH
cd stable-diffusion
SD_PATH=`pwd`
export PYTHONPATH="$SD_PATH:$SD_PATH/env/lib/python3.8/site-packages"
echo "PYTHONPATH=$PYTHONPATH"
cd ..
# done
echo "" echo ""
else else
file_name=$(basename "${BASH_SOURCE[0]}") file_name=$(basename "${BASH_SOURCE[0]}")

View File

@ -27,6 +27,8 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');" @call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');"
if NOT DEFINED test_sd2 set test_sd2=N
@>nul findstr /m "sd_git_cloned" scripts\install_status.txt @>nul findstr /m "sd_git_cloned" scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" ( @if "%ERRORLEVEL%" EQU "0" (
@echo "Stable Diffusion's git repository was already installed. Updating.." @echo "Stable Diffusion's git repository was already installed. Updating.."
@ -37,9 +39,13 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@call git reset --hard @call git reset --hard
@call git pull @call git pull
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch if "%test_sd2%" == "N" (
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
)
if "%test_sd2%" == "Y" (
@call git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9
)
@cd .. @cd ..
) else ( ) else (
@ -56,8 +62,6 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@cd stable-diffusion @cd stable-diffusion
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
@cd .. @cd ..
) )
@ -346,7 +350,9 @@ echo. > "..\models\vae\Put your VAE files here.txt"
) )
) )
if "%test_sd2%" == "Y" (
@call pip install open_clip_torch==2.0.2
)
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt @>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" NEQ "0" ( @if "%ERRORLEVEL%" NEQ "0" (

View File

@ -21,6 +21,10 @@ python -c "import os; import shutil; frm = 'sd-ui-files/ui/hotfix/9c24e6cd9f499d
# Caution, this file will make your eyes and brain bleed. It's such an unholy mess. # Caution, this file will make your eyes and brain bleed. It's such an unholy mess.
# Note to self: Please rewrite this in Python. For the sake of your own sanity. # Note to self: Please rewrite this in Python. For the sake of your own sanity.
if [ "$test_sd2" == "" ]; then
export test_sd2="N"
fi
if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then
echo "Stable Diffusion's git repository was already installed. Updating.." echo "Stable Diffusion's git repository was already installed. Updating.."
@ -30,9 +34,12 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
git reset --hard git reset --hard
git pull git pull
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed" if [ "$test_sd2" == "N" ]; then
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
elif [ "$test_sd2" == "Y" ]; then
git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9
fi
cd .. cd ..
else else
@ -47,8 +54,6 @@ else
cd stable-diffusion cd stable-diffusion
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
cd .. cd ..
fi fi
@ -291,6 +296,9 @@ if [ ! -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then
fi fi
fi fi
if [ "$test_sd2" == "Y" ]; then
pip install open_clip_torch==2.0.2
fi
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
echo sd_weights_downloaded >> ../scripts/install_status.txt echo sd_weights_downloaded >> ../scripts/install_status.txt

View File

@ -22,7 +22,7 @@
<div id="logo"> <div id="logo">
<h1> <h1>
Stable Diffusion UI Stable Diffusion UI
<small>v2.4.14 <span id="updateBranchLabel"></span></small> <small>v2.4.16 <span id="updateBranchLabel"></span></small>
</h1> </h1>
</div> </div>
<div id="server-status"> <div id="server-status">

View File

@ -210,7 +210,7 @@ code {
} }
.collapsible-content { .collapsible-content {
display: block; display: block;
padding-left: 15px; padding-left: 10px;
} }
.collapsible-content h5 { .collapsible-content h5 {
padding: 5pt 0pt; padding: 5pt 0pt;
@ -658,11 +658,15 @@ input::file-selector-button {
opacity: 1; opacity: 1;
} }
/* MOBILE SUPPORT */ /* Small screens */
@media screen and (max-width: 700px) { @media screen and (max-width: 1265px) {
#top-nav { #top-nav {
flex-direction: column; flex-direction: column;
} }
}
/* MOBILE SUPPORT */
@media screen and (max-width: 700px) {
body { body {
margin: 0px; margin: 0px;
} }
@ -712,7 +716,7 @@ input::file-selector-button {
padding-right: 0px; padding-right: 0px;
} }
#server-status { #server-status {
display: none; top: 75%;
} }
.popup > div { .popup > div {
padding-left: 5px !important; padding-left: 5px !important;

View File

@ -132,6 +132,14 @@ var PARAMETERS = [
return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">` return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">`
} }
}, },
{
id: "test_sd2",
type: ParameterType.checkbox,
label: "Test SD 2.0",
note: "Experimental! High memory usage! GPU-only! Not the final version! Please restart the program after changing this.",
icon: "fa-fire",
default: false,
},
{ {
id: "use_beta_channel", id: "use_beta_channel",
type: ParameterType.checkbox, type: ParameterType.checkbox,
@ -196,6 +204,7 @@ let saveToDiskField = document.querySelector('#save_to_disk')
let diskPathField = document.querySelector('#diskPath') let diskPathField = document.querySelector('#diskPath')
let listenToNetworkField = document.querySelector("#listen_to_network") let listenToNetworkField = document.querySelector("#listen_to_network")
let listenPortField = document.querySelector("#listen_port") let listenPortField = document.querySelector("#listen_port")
let testSD2Field = document.querySelector("#test_sd2")
let useBetaChannelField = document.querySelector("#use_beta_channel") let useBetaChannelField = document.querySelector("#use_beta_channel")
let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start") let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start")
@ -230,6 +239,12 @@ async function getAppConfig() {
if (config.ui && config.ui.open_browser_on_start === false) { if (config.ui && config.ui.open_browser_on_start === false) {
uiOpenBrowserOnStartField.checked = false uiOpenBrowserOnStartField.checked = false
} }
if ('test_sd2' in config) {
testSD2Field.checked = config['test_sd2']
}
let testSD2SettingEntry = getParameterSettingsEntry('test_sd2')
testSD2SettingEntry.style.display = (config.update_branch === 'beta' ? '' : 'none')
if (config.net && config.net.listen_to_network === false) { if (config.net && config.net.listen_to_network === false) {
listenToNetworkField.checked = false listenToNetworkField.checked = false
} }
@ -372,7 +387,8 @@ saveSettingsBtn.addEventListener('click', function() {
'update_branch': updateBranch, 'update_branch': updateBranch,
'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked, 'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked,
'listen_to_network': listenToNetworkField.checked, 'listen_to_network': listenToNetworkField.checked,
'listen_port': listenPortField.value 'listen_port': listenPortField.value,
'test_sd2': testSD2Field.checked
}) })
} }

View File

@ -0,0 +1,84 @@
diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py
index 27ead0e..6215939 100644
--- a/ldm/models/diffusion/ddim.py
+++ b/ldm/models/diffusion/ddim.py
@@ -100,7 +100,7 @@ class DDIMSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
- samples, intermediates = self.ddim_sampling(conditioning, size,
+ samples = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
@@ -117,7 +117,8 @@ class DDIMSampler(object):
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule
)
- return samples, intermediates
+ # return samples, intermediates
+ yield from samples
@torch.no_grad()
def ddim_sampling(self, cond, shape,
@@ -168,14 +169,15 @@ class DDIMSampler(object):
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
- return img, intermediates
+ # return img, intermediates
+ yield from img_callback(pred_x0, len(iterator)-1)
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py
index 7002a36..0951f39 100644
--- a/ldm/models/diffusion/plms.py
+++ b/ldm/models/diffusion/plms.py
@@ -96,7 +96,7 @@ class PLMSSampler(object):
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
- samples, intermediates = self.plms_sampling(conditioning, size,
+ samples = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
@@ -112,7 +112,8 @@ class PLMSSampler(object):
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
- return samples, intermediates
+ #return samples, intermediates
+ yield from samples
@torch.no_grad()
def plms_sampling(self, cond, shape,
@@ -165,14 +166,15 @@ class PLMSSampler(object):
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
- if callback: callback(i)
- if img_callback: img_callback(pred_x0, i)
+ if callback: yield from callback(i)
+ if img_callback: yield from img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
- return img, intermediates
+ # return img, intermediates
+ yield from img_callback(pred_x0, len(iterator)-1)
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,

View File

@ -7,6 +7,7 @@ Notes:
import json import json
import os, re import os, re
import traceback import traceback
import queue
import torch import torch
import numpy as np import numpy as np
from gc import collect as gc_collect from gc import collect as gc_collect
@ -21,7 +22,6 @@ from torch import autocast
from contextlib import nullcontext from contextlib import nullcontext
from einops import rearrange, repeat from einops import rearrange, repeat
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from optimizedSD.optimUtils import split_weighted_subprompts
from transformers import logging from transformers import logging
from gfpgan import GFPGANer from gfpgan import GFPGANer
@ -76,8 +76,24 @@ def thread_init(device):
thread_data.force_full_precision = False thread_data.force_full_precision = False
thread_data.reduced_memory = True thread_data.reduced_memory = True
thread_data.test_sd2 = isSD2()
device_manager.device_init(thread_data, device) device_manager.device_init(thread_data, device)
# temp hack, will remove soon
def isSD2():
try:
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return False
with open(config_json_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config.get('test_sd2', False)
except Exception as e:
return False
def load_model_ckpt(): def load_model_ckpt():
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.') if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt') if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
@ -92,6 +108,13 @@ def load_model_ckpt():
thread_data.precision = 'full' thread_data.precision = 'full'
print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision) print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision)
if thread_data.test_sd2:
load_model_ckpt_sd2()
else:
load_model_ckpt_sd1()
def load_model_ckpt_sd1():
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt') sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
li, lo = [], [] li, lo = [], []
for key, value in sd.items(): for key, value in sd.items():
@ -185,6 +208,38 @@ def load_model_ckpt():
modelFS.device: {thread_data.modelFS.device} modelFS.device: {thread_data.modelFS.device}
using precision: {thread_data.precision}''') using precision: {thread_data.precision}''')
def load_model_ckpt_sd2():
config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if 'sd2_' in thread_data.ckpt_file else "configs/stable-diffusion/v1-inference.yaml"
config = OmegaConf.load(config_file)
verbose = False
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
thread_data.model = instantiate_from_config(config.model)
m, u = thread_data.model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
thread_data.model.to(thread_data.device)
thread_data.model.eval()
del sd
if thread_data.device != "cpu" and thread_data.precision == "autocast":
thread_data.model.half()
thread_data.model_is_half = True
thread_data.model_fs_is_half = True
else:
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
print(f'''loaded model
model file: {thread_data.ckpt_file}.ckpt
using precision: {thread_data.precision}''')
def unload_filters(): def unload_filters():
if thread_data.model_gfpgan is not None: if thread_data.model_gfpgan is not None:
if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu') if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu')
@ -204,10 +259,11 @@ def unload_models():
if thread_data.model is not None: if thread_data.model is not None:
print('Unloading models...') print('Unloading models...')
if thread_data.device != 'cpu': if thread_data.device != 'cpu':
thread_data.modelFS.to('cpu') if not thread_data.test_sd2:
thread_data.modelCS.to('cpu') thread_data.modelFS.to('cpu')
thread_data.model.model1.to("cpu") thread_data.modelCS.to('cpu')
thread_data.model.model2.to("cpu") thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu")
del thread_data.model del thread_data.model
del thread_data.modelCS del thread_data.modelCS
@ -337,45 +393,73 @@ def apply_filters(filter_name, image_data, model_path=None):
return image_data return image_data
def mk_img(req: Request): def is_model_reload_necessary(req: Request):
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
return needs_model_reload
def reload_model():
unload_models()
unload_filters()
load_model_ckpt()
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
try: try:
yield from do_mk_img(req) return do_mk_img(req, data_queue, task_temp_images, step_callback)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
if thread_data.device != 'cpu': if thread_data.device != 'cpu' and not thread_data.test_sd2:
thread_data.modelFS.to('cpu') thread_data.modelFS.to('cpu')
thread_data.modelCS.to('cpu') thread_data.modelCS.to('cpu')
thread_data.model.model1.to("cpu") thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu") thread_data.model.model2.to("cpu")
gc() # Release from memory. gc() # Release from memory.
yield json.dumps({ data_queue.put(json.dumps({
"status": 'failed', "status": 'failed',
"detail": str(e) "detail": str(e)
}) }))
raise e
def update_temp_img(req, x_samples): def update_temp_img(req, x_samples, task_temp_images: list):
partial_images = [] partial_images = []
for i in range(req.num_outputs): for i in range(req.num_outputs):
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) if thread_data.test_sd2:
x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
else:
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample) img = Image.fromarray(x_sample)
buf = BytesIO() buf = img_to_buffer(img, output_format='JPEG')
img.save(buf, format='JPEG')
buf.seek(0)
del img, x_sample, x_sample_ddim del img, x_sample, x_sample_ddim
# don't delete x_samples, it is used in the code that called this callback # don't delete x_samples, it is used in the code that called this callback
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
task_temp_images[i] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'}) partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
return partial_images return partial_images
# Build and return the apropriate generator for do_mk_img # Build and return the apropriate generator for do_mk_img
def get_image_progress_generator(req, extra_props=None): def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
if not req.stream_progress_updates: if not req.stream_progress_updates:
def empty_callback(x_samples, i): return x_samples def empty_callback(x_samples, i): return x_samples
return empty_callback return empty_callback
@ -394,15 +478,17 @@ def get_image_progress_generator(req, extra_props=None):
progress.update(extra_props) progress.update(extra_props)
if req.stream_image_progress and i % 5 == 0: if req.stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(req, x_samples) progress['output'] = update_temp_img(req, x_samples, task_temp_images)
yield json.dumps(progress) data_queue.put(json.dumps(progress))
step_callback()
if thread_data.stop_processing: if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing") raise UserInitiatedStop("User requested that we stop processing")
return img_callback return img_callback
def do_mk_img(req: Request): def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
thread_data.stop_processing = False thread_data.stop_processing = False
res = Response() res = Response()
@ -411,29 +497,7 @@ def do_mk_img(req: Request):
thread_data.temp_images.clear() thread_data.temp_images.clear()
# custom model support: if thread_data.turbo != req.turbo and not thread_data.test_sd2:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
if needs_model_reload:
unload_models()
unload_filters()
load_model_ckpt()
if thread_data.turbo != req.turbo:
thread_data.turbo = req.turbo thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo thread_data.model.turbo = req.turbo
@ -478,10 +542,14 @@ def do_mk_img(req: Request):
if thread_data.device != "cpu" and thread_data.precision == "autocast": if thread_data.device != "cpu" and thread_data.precision == "autocast":
init_image = init_image.half() init_image = init_image.half()
thread_data.modelFS.to(thread_data.device) if not thread_data.test_sd2:
thread_data.modelFS.to(thread_data.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space if thread_data.test_sd2:
init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space
else:
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None: if req.mask is not None:
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device) mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
@ -493,7 +561,8 @@ def do_mk_img(req: Request):
# Send to CPU and wait until complete. # Send to CPU and wait until complete.
# wait_model_move_to(thread_data.modelFS, 'cpu') # wait_model_move_to(thread_data.modelFS, 'cpu')
move_to_cpu(thread_data.modelFS) if not thread_data.test_sd2:
move_to_cpu(thread_data.modelFS)
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(req.prompt_strength * req.num_inference_steps) t_enc = int(req.prompt_strength * req.num_inference_steps)
@ -509,11 +578,14 @@ def do_mk_img(req: Request):
for prompts in tqdm(data, desc="data"): for prompts in tqdm(data, desc="data"):
with precision_scope("cuda"): with precision_scope("cuda"):
if thread_data.reduced_memory: if thread_data.reduced_memory and not thread_data.test_sd2:
thread_data.modelCS.to(thread_data.device) thread_data.modelCS.to(thread_data.device)
uc = None uc = None
if req.guidance_scale != 1.0: if req.guidance_scale != 1.0:
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) if thread_data.test_sd2:
uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt])
else:
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
@ -526,15 +598,21 @@ def do_mk_img(req: Request):
weight = weights[i] weight = weights[i]
# if not skip_normalize: # if not skip_normalize:
weight = weight / totalWeight weight = weight / totalWeight
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) if thread_data.test_sd2:
c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else: else:
c = thread_data.modelCS.get_learned_conditioning(prompts) if thread_data.test_sd2:
c = thread_data.model.get_learned_conditioning(prompts)
else:
c = thread_data.modelCS.get_learned_conditioning(prompts)
if thread_data.reduced_memory: if thread_data.reduced_memory and not thread_data.test_sd2:
thread_data.modelFS.to(thread_data.device) thread_data.modelFS.to(thread_data.device)
n_steps = req.num_inference_steps if req.init_image is None else t_enc n_steps = req.num_inference_steps if req.init_image is None else t_enc
img_callback = get_image_progress_generator(req, {"total_steps": n_steps}) img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps})
# run the handler # run the handler
try: try:
@ -542,14 +620,7 @@ def do_mk_img(req: Request):
if handler == _txt2img: if handler == _txt2img:
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler) x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
else: else:
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask) x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f)
if req.stream_progress_updates:
yield from x_samples
if hasattr(thread_data, 'partial_x_samples'):
if thread_data.partial_x_samples is not None:
x_samples = thread_data.partial_x_samples
del thread_data.partial_x_samples
except UserInitiatedStop: except UserInitiatedStop:
if not hasattr(thread_data, 'partial_x_samples'): if not hasattr(thread_data, 'partial_x_samples'):
continue continue
@ -562,7 +633,10 @@ def do_mk_img(req: Request):
print("decoding images") print("decoding images")
img_data = [None] * batch_size img_data = [None] * batch_size
for i in range(batch_size): for i in range(batch_size):
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) if thread_data.test_sd2:
x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
else:
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
@ -591,9 +665,11 @@ def do_mk_img(req: Request):
save_metadata(meta_out_path, req, prompts[0], opt_seed) save_metadata(meta_out_path, req, prompts[0], opt_seed)
if return_orig_img: if return_orig_img:
img_str = img_to_base64_str(img, req.output_format) img_buffer = img_to_buffer(img, req.output_format)
img_str = buffer_to_base64_str(img_buffer, req.output_format)
res_image_orig = ResponseImage(data=img_str, seed=opt_seed) res_image_orig = ResponseImage(data=img_str, seed=opt_seed)
res.images.append(res_image_orig) res.images.append(res_image_orig)
task_temp_images[i] = img_buffer
if req.save_to_disk_path is not None: if req.save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path res_image_orig.path_abs = img_out_path
@ -609,9 +685,11 @@ def do_mk_img(req: Request):
filters_applied.append(req.use_upscale) filters_applied.append(req.use_upscale)
if (len(filters_applied) > 0): if (len(filters_applied) > 0):
filtered_image = Image.fromarray(img_data[i]) filtered_image = Image.fromarray(img_data[i])
filtered_img_data = img_to_base64_str(filtered_image, req.output_format) filtered_buffer = img_to_buffer(filtered_image, req.output_format)
filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format)
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed) response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(response_image) res.images.append(response_image)
task_temp_images[i] = filtered_buffer
if req.save_to_disk_path is not None: if req.save_to_disk_path is not None:
filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied)) filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied))
save_image(filtered_image, filtered_img_out_path) save_image(filtered_image, filtered_img_out_path)
@ -622,14 +700,18 @@ def do_mk_img(req: Request):
# if thread_data.reduced_memory: # if thread_data.reduced_memory:
# unload_filters() # unload_filters()
move_to_cpu(thread_data.modelFS) if not thread_data.test_sd2:
move_to_cpu(thread_data.modelFS)
del img_data del img_data
gc() gc()
if thread_data.device != 'cpu': if thread_data.device != 'cpu':
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb') print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb')
print('Task completed') print('Task completed')
yield json.dumps(res.json()) res = res.json()
data_queue.put(json.dumps(res))
return res
def save_image(img, img_out_path): def save_image(img, img_out_path):
try: try:
@ -664,51 +746,109 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
# Send to CPU and wait until complete. # Send to CPU and wait until complete.
# wait_model_move_to(thread_data.modelCS, 'cpu') # wait_model_move_to(thread_data.modelCS, 'cpu')
move_to_cpu(thread_data.modelCS) if not thread_data.test_sd2:
move_to_cpu(thread_data.modelCS)
if sampler_name == 'ddim': if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'):
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0')
samples_ddim = thread_data.model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
yield from samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask): # samples, _ = sampler.sample(S=opt.steps,
# conditioning=c,
# batch_size=opt.n_samples,
# shape=shape,
# verbose=False,
# unconditional_guidance_scale=opt.scale,
# unconditional_conditioning=uc,
# eta=opt.ddim_eta,
# x_T=start_code)
if thread_data.test_sd2:
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
if sampler_name == 'plms':
sampler = PLMSSampler(thread_data.model)
elif sampler_name == 'ddim':
sampler = DDIMSampler(thread_data.model)
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim, intermediates = sampler.sample(
S=opt_ddim_steps,
conditioning=c,
batch_size=opt_n_samples,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
else:
if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = thread_data.model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1):
# encode (scaled latent) # encode (scaled latent)
z_enc = thread_data.model.stochastic_encode(
init_latent,
torch.tensor([t_enc] * batch_size).to(thread_data.device),
opt_seed,
opt_ddim_eta,
opt_ddim_steps,
)
x_T = None if mask is None else init_latent x_T = None if mask is None else init_latent
# decode it if thread_data.test_sd2:
samples_ddim = thread_data.model.sample( from ldm.models.diffusion.ddim import DDIMSampler
t_enc,
c, sampler = DDIMSampler(thread_data.model)
z_enc,
unconditional_guidance_scale=opt_scale, sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
unconditional_conditioning=uc,
img_callback=img_callback, z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(thread_data.device))
mask=mask,
x_T=x_T, samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt_scale,unconditional_conditioning=uc, img_callback=img_callback)
sampler = 'ddim'
) else:
yield from samples_ddim z_enc = thread_data.model.stochastic_encode(
init_latent,
torch.tensor([t_enc] * batch_size).to(thread_data.device),
opt_seed,
opt_ddim_eta,
opt_ddim_steps,
)
# decode it
samples_ddim = thread_data.model.sample(
t_enc,
c,
z_enc,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
img_callback=img_callback,
mask=mask,
x_T=x_T,
sampler = 'ddim'
)
return samples_ddim
def gc(): def gc():
gc_collect() gc_collect()
@ -776,8 +916,16 @@ def load_mask(mask_str, h0, w0, newH, newW, invert=False):
# https://stackoverflow.com/a/61114178 # https://stackoverflow.com/a/61114178
def img_to_base64_str(img, output_format="PNG"): def img_to_base64_str(img, output_format="PNG"):
buffered = img_to_buffer(img, output_format)
return buffer_to_base64_str(buffered, output_format)
def img_to_buffer(img, output_format="PNG"):
buffered = BytesIO() buffered = BytesIO()
img.save(buffered, format=output_format) img.save(buffered, format=output_format)
buffered.seek(0)
return buffered
def buffer_to_base64_str(buffered, output_format="PNG"):
buffered.seek(0) buffered.seek(0)
img_byte = buffered.getvalue() img_byte = buffered.getvalue()
mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg" mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg"
@ -795,3 +943,48 @@ def base64_str_to_img(img_str):
buffered = base64_str_to_buffer(img_str) buffered = base64_str_to_buffer(img_str)
img = Image.open(buffered) img = Image.open(buffered)
return img return img
def split_weighted_subprompts(text):
"""
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
remaining = len(text)
prompts = []
weights = []
while remaining > 0:
if ":" in text:
idx = text.index(":") # first occurrence from start
# grab up to index as sub-prompt
prompt = text[:idx]
remaining -= idx
# remove from main text
text = text[idx+1:]
# find value for weight
if " " in text:
idx = text.index(" ") # first occurence
else: # no space, read to end
idx = len(text)
if idx != 0:
try:
weight = float(text[:idx])
except: # couldn't treat as float
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
weight = 1.0
else: # no value found
weight = 1.0
# remove from main text
remaining -= idx
text = text[idx+1:]
# append the sub-prompt and its weight
prompts.append(prompt)
weights.append(weight)
else: # no : found
if len(text) > 0: # there is still text though
# take remainder as weight 1
prompts.append(text)
weights.append(1.0)
remaining = 0
return prompts, weights

View File

@ -283,45 +283,26 @@ def thread_render(device):
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}') print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try: try:
if runtime.thread_data.device == 'cpu' and is_alive() > 1: if runtime.is_model_reload_necessary(task.request):
# CPU is not the only device. Keep track of active time to unload resources later.
runtime.thread_data.lastActive = time.time()
# Open data generator.
res = runtime.mk_img(task.request)
if current_model_path == task.request.use_stable_diffusion_model:
current_state = ServerStates.Rendering
else:
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
# Start reading from generator. runtime.reload_model()
dataQueue = None current_model_path = task.request.use_stable_diffusion_model
if task.request.stream_progress_updates: current_vae_path = task.request.use_vae_model
dataQueue = task.buffer_queue
for result in res: def step_callback():
if current_state == ServerStates.LoadingModel: global current_state_error
current_state = ServerStates.Rendering
current_model_path = task.request.use_stable_diffusion_model
current_vae_path = task.request.use_vae_model
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.thread_data.stop_processing = True runtime.thread_data.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration): if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error task.error = current_state_error
current_state_error = None current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result) task_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(result, str):
result = json.loads(result) current_state = ServerStates.Rendering
task.response = result task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback)
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
buf = runtime.base64_str_to_buffer(out_obj['data'])
task.temp_images[result['output'].index(out_obj)] = buf
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(task.request.session_id, TASK_TTL)
except Exception as e: except Exception as e:
task.error = e task.error = e
print(traceback.format_exc()) print(traceback.format_exc())

View File

@ -116,6 +116,8 @@ def setConfig(config):
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}")
if len(config_bat) > 0: if len(config_bat) > 0:
with open(config_bat_path, 'w', encoding='utf-8') as f: with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat)) f.write('\r\n'.join(config_bat))
@ -133,6 +135,8 @@ def setConfig(config):
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"")
if len(config_sh) > 1: if len(config_sh) > 1:
with open(config_sh_path, 'w', encoding='utf-8') as f: with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh)) f.write('\n'.join(config_sh))
@ -140,12 +144,19 @@ def setConfig(config):
print(traceback.format_exc()) print(traceback.format_exc())
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]): def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
config = getConfig()
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR] model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
if not model_name: # When None try user configured model. if not model_name: # When None try user configured model.
config = getConfig() # config = getConfig()
if 'model' in config and model_type in config['model']: if 'model' in config and model_type in config['model']:
model_name = config['model'][model_type] model_name = config['model'][model_type]
if model_name: if model_name:
is_sd2 = config.get('test_sd2', False)
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
model_name = 'sd-v1-4'
# Check models directory # Check models directory
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name) models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
for model_extension in model_extensions: for model_extension in model_extensions:
@ -188,6 +199,7 @@ class SetAppConfigRequest(BaseModel):
ui_open_browser_on_start: bool = None ui_open_browser_on_start: bool = None
listen_to_network: bool = None listen_to_network: bool = None
listen_port: int = None listen_port: int = None
test_sd2: bool = None
@app.post('/app_config') @app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest): async def setAppConfig(req : SetAppConfigRequest):
@ -208,6 +220,8 @@ async def setAppConfig(req : SetAppConfigRequest):
if 'net' not in config: if 'net' not in config:
config['net'] = {} config['net'] = {}
config['net']['listen_port'] = int(req.listen_port) config['net']['listen_port'] = int(req.listen_port)
if req.test_sd2 is not None:
config['test_sd2'] = req.test_sd2
try: try:
setConfig(config) setConfig(config)