diff --git a/CHANGES.md b/CHANGES.md
index db274ede..baafcb18 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -21,6 +21,7 @@
- A `What's New?` tab in the UI
### Detailed changelog
+* 2.4.15 - 25 Nov 2022 - Experimental support for SD 2.0. Uses lots of memory, not optimized, probably GPU-only.
* 2.4.14 - 22 Nov 2022 - Change the backend to a custom fork of Stable Diffusion
* 2.4.13 - 21 Nov 2022 - Change the modifier weight via mouse wheel, drag to reorder selected modifiers, and some more modifier-related fixes. Thanks @patriceac
* 2.4.12 - 21 Nov 2022 - Another fix for improving how long images take to generate. Reduces the time taken for an enqueued task to start processing.
diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat
index 6e4ffd36..8b2e3e02 100644
--- a/scripts/on_sd_start.bat
+++ b/scripts/on_sd_start.bat
@@ -27,6 +27,8 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');"
+if NOT DEFINED test_sd2 set test_sd2=N
+
@>nul findstr /m "sd_git_cloned" scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
@echo "Stable Diffusion's git repository was already installed. Updating.."
@@ -37,9 +39,17 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@call git reset --hard
@call git pull
- @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
- @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
+ if "%test_sd2%" == "N" (
+ @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
+
+ @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
+ )
+ if "%test_sd2%" == "Y" (
+ @call git -c advice.detachedHead=false checkout 7da785cfa0d128368bc1357b54d380ba33dc1138
+
+ @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback_sd2.patch
+ )
@cd ..
) else (
@@ -346,7 +356,9 @@ echo. > "..\models\vae\Put your VAE files here.txt"
)
)
-
+if "%test_sd2%" == "Y" (
+ @call pip install open_clip_torch==2.0.2
+)
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" NEQ "0" (
diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh
index ff6a04d4..133dac3e 100644
--- a/scripts/on_sd_start.sh
+++ b/scripts/on_sd_start.sh
@@ -21,6 +21,10 @@ python -c "import os; import shutil; frm = 'sd-ui-files/ui/hotfix/9c24e6cd9f499d
# Caution, this file will make your eyes and brain bleed. It's such an unholy mess.
# Note to self: Please rewrite this in Python. For the sake of your own sanity.
+if [ "$test_sd2" == "" ]; then
+ export test_sd2="N"
+fi
+
if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then
echo "Stable Diffusion's git repository was already installed. Updating.."
@@ -30,9 +34,16 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
git reset --hard
git pull
- git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
- git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
+ if [ "$test_sd2" == "N" ]; then
+ git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
+
+ git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
+ elif [ "$test_sd2" == "Y" ]; then
+ git -c advice.detachedHead=false checkout 7da785cfa0d128368bc1357b54d380ba33dc1138
+
+ git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed"
+ fi
cd ..
else
@@ -291,6 +302,9 @@ if [ ! -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then
fi
fi
+if [ "$test_sd2" == "Y" ]; then
+ pip install open_clip_torch==2.0.2
+fi
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
echo sd_weights_downloaded >> ../scripts/install_status.txt
diff --git a/ui/index.html b/ui/index.html
index 3e85b254..edbccf97 100644
--- a/ui/index.html
+++ b/ui/index.html
@@ -22,7 +22,7 @@
Stable Diffusion UI
- v2.4.14
+ v2.4.15
diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js
index 2e5bc75c..dac008ef 100644
--- a/ui/media/js/parameters.js
+++ b/ui/media/js/parameters.js
@@ -132,6 +132,14 @@ var PARAMETERS = [
return ``
}
},
+ {
+ id: "test_sd2",
+ type: ParameterType.checkbox,
+ label: "Test SD 2.0",
+ note: "Experimental! High memory usage! GPU-only! Please restart the program after changing this.",
+ icon: "fa-fire",
+ default: false,
+ },
{
id: "use_beta_channel",
type: ParameterType.checkbox,
@@ -196,6 +204,7 @@ let saveToDiskField = document.querySelector('#save_to_disk')
let diskPathField = document.querySelector('#diskPath')
let listenToNetworkField = document.querySelector("#listen_to_network")
let listenPortField = document.querySelector("#listen_port")
+let testSD2Field = document.querySelector("#test_sd2")
let useBetaChannelField = document.querySelector("#use_beta_channel")
let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start")
@@ -230,6 +239,9 @@ async function getAppConfig() {
if (config.ui && config.ui.open_browser_on_start === false) {
uiOpenBrowserOnStartField.checked = false
}
+ if ('test_sd2' in config) {
+ testSD2Field.checked = config['test_sd2']
+ }
if (config.net && config.net.listen_to_network === false) {
listenToNetworkField.checked = false
}
@@ -372,7 +384,8 @@ saveSettingsBtn.addEventListener('click', function() {
'update_branch': updateBranch,
'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked,
'listen_to_network': listenToNetworkField.checked,
- 'listen_port': listenPortField.value
+ 'listen_port': listenPortField.value,
+ 'test_sd2': testSD2Field.checked
})
}
diff --git a/ui/sd_internal/ddim_callback_sd2.patch b/ui/sd_internal/ddim_callback_sd2.patch
new file mode 100644
index 00000000..00700d00
--- /dev/null
+++ b/ui/sd_internal/ddim_callback_sd2.patch
@@ -0,0 +1,20 @@
+diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
+index 1bbdd02..cd00cc3 100644
+--- a/optimizedSD/ddpm.py
++++ b/optimizedSD/ddpm.py
+@@ -348,6 +348,7 @@ class DDPM(pl.LightningModule):
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
++ print('sampler 2')
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+
+@@ -1090,6 +1091,7 @@ class LatentDiffusion(DDPM):
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None, **kwargs):
++ print('sampler 1')
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
+ if cond is not None:
diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py
index e217965d..d9521780 100644
--- a/ui/sd_internal/runtime.py
+++ b/ui/sd_internal/runtime.py
@@ -76,8 +76,24 @@ def thread_init(device):
thread_data.force_full_precision = False
thread_data.reduced_memory = True
+ thread_data.test_sd2 = isSD2()
+
device_manager.device_init(thread_data, device)
+# temp hack, will remove soon
+def isSD2():
+ try:
+ SD_UI_DIR = os.getenv('SD_UI_PATH', None)
+ CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
+ config_json_path = os.path.join(CONFIG_DIR, 'config.json')
+ if not os.path.exists(config_json_path):
+ return False
+ with open(config_json_path, 'r', encoding='utf-8') as f:
+ config = json.load(f)
+ return config.get('test_sd2', False)
+ except Exception as e:
+ return False
+
def load_model_ckpt():
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
@@ -92,6 +108,13 @@ def load_model_ckpt():
thread_data.precision = 'full'
print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision)
+
+ if thread_data.test_sd2:
+ load_model_ckpt_sd2()
+ else:
+ load_model_ckpt_sd1()
+
+def load_model_ckpt_sd1():
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
li, lo = [], []
for key, value in sd.items():
@@ -185,6 +208,37 @@ def load_model_ckpt():
modelFS.device: {thread_data.modelFS.device}
using precision: {thread_data.precision}''')
+def load_model_ckpt_sd2():
+ config = OmegaConf.load('configs/stable-diffusion/v2-inference-v.yaml')
+ verbose = False
+
+ sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
+
+ thread_data.model = instantiate_from_config(config.model)
+ m, u = thread_data.model.load_state_dict(sd, strict=False)
+ if len(m) > 0 and verbose:
+ print("missing keys:")
+ print(m)
+ if len(u) > 0 and verbose:
+ print("unexpected keys:")
+ print(u)
+
+ thread_data.model.to(thread_data.device)
+ thread_data.model.eval()
+ del sd
+
+ if thread_data.device != "cpu" and thread_data.precision == "autocast":
+ thread_data.model.half()
+ thread_data.model_is_half = True
+ thread_data.model_fs_is_half = True
+ else:
+ thread_data.model_is_half = False
+ thread_data.model_fs_is_half = False
+
+ print(f'''loaded model
+ model file: {thread_data.ckpt_file}.ckpt
+ using precision: {thread_data.precision}''')
+
def unload_filters():
if thread_data.model_gfpgan is not None:
if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu')
@@ -204,10 +258,11 @@ def unload_models():
if thread_data.model is not None:
print('Unloading models...')
if thread_data.device != 'cpu':
- thread_data.modelFS.to('cpu')
- thread_data.modelCS.to('cpu')
- thread_data.model.model1.to("cpu")
- thread_data.model.model2.to("cpu")
+ if not thread_data.test_sd2:
+ thread_data.modelFS.to('cpu')
+ thread_data.modelCS.to('cpu')
+ thread_data.model.model1.to("cpu")
+ thread_data.model.model2.to("cpu")
del thread_data.model
del thread_data.modelCS
@@ -343,7 +398,7 @@ def mk_img(req: Request):
except Exception as e:
print(traceback.format_exc())
- if thread_data.device != 'cpu':
+ if thread_data.device != 'cpu' and not thread_data.test_sd2:
thread_data.modelFS.to('cpu')
thread_data.modelCS.to('cpu')
thread_data.model.model1.to("cpu")
@@ -358,7 +413,10 @@ def mk_img(req: Request):
def update_temp_img(req, x_samples):
partial_images = []
for i in range(req.num_outputs):
- x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
+ if thread_data.test_sd2:
+ x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
+ else:
+ x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
@@ -433,7 +491,7 @@ def do_mk_img(req: Request):
unload_filters()
load_model_ckpt()
- if thread_data.turbo != req.turbo:
+ if thread_data.turbo != req.turbo and not thread_data.test_sd2:
thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo
@@ -478,10 +536,14 @@ def do_mk_img(req: Request):
if thread_data.device != "cpu" and thread_data.precision == "autocast":
init_image = init_image.half()
- thread_data.modelFS.to(thread_data.device)
+ if not thread_data.test_sd2:
+ thread_data.modelFS.to(thread_data.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
- init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
+ if thread_data.test_sd2:
+ init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space
+ else:
+ init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None:
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
@@ -493,7 +555,8 @@ def do_mk_img(req: Request):
# Send to CPU and wait until complete.
# wait_model_move_to(thread_data.modelFS, 'cpu')
- move_to_cpu(thread_data.modelFS)
+ if not thread_data.test_sd2:
+ move_to_cpu(thread_data.modelFS)
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(req.prompt_strength * req.num_inference_steps)
@@ -509,11 +572,14 @@ def do_mk_img(req: Request):
for prompts in tqdm(data, desc="data"):
with precision_scope("cuda"):
- if thread_data.reduced_memory:
+ if thread_data.reduced_memory and not thread_data.test_sd2:
thread_data.modelCS.to(thread_data.device)
uc = None
if req.guidance_scale != 1.0:
- uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
+ if thread_data.test_sd2:
+ uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt])
+ else:
+ uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
@@ -526,11 +592,17 @@ def do_mk_img(req: Request):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
- c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
+ if thread_data.test_sd2:
+ c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight)
+ else:
+ c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
- c = thread_data.modelCS.get_learned_conditioning(prompts)
+ if thread_data.test_sd2:
+ c = thread_data.model.get_learned_conditioning(prompts)
+ else:
+ c = thread_data.modelCS.get_learned_conditioning(prompts)
- if thread_data.reduced_memory:
+ if thread_data.reduced_memory and not thread_data.test_sd2:
thread_data.modelFS.to(thread_data.device)
n_steps = req.num_inference_steps if req.init_image is None else t_enc
@@ -562,7 +634,10 @@ def do_mk_img(req: Request):
print("decoding images")
img_data = [None] * batch_size
for i in range(batch_size):
- x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
+ if thread_data.test_sd2:
+ x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
+ else:
+ x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
@@ -622,7 +697,8 @@ def do_mk_img(req: Request):
# if thread_data.reduced_memory:
# unload_filters()
- move_to_cpu(thread_data.modelFS)
+ if not thread_data.test_sd2:
+ move_to_cpu(thread_data.modelFS)
del img_data
gc()
if thread_data.device != 'cpu':
@@ -664,7 +740,11 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
# Send to CPU and wait until complete.
# wait_model_move_to(thread_data.modelCS, 'cpu')
- move_to_cpu(thread_data.modelCS)
+ if not thread_data.test_sd2:
+ move_to_cpu(thread_data.modelCS)
+
+ if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'):
+ raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0')
if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
diff --git a/ui/server.py b/ui/server.py
index 61635f18..8e5343db 100644
--- a/ui/server.py
+++ b/ui/server.py
@@ -116,6 +116,8 @@ def setConfig(config):
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
+ config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}")
+
if len(config_bat) > 0:
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat))
@@ -133,6 +135,8 @@ def setConfig(config):
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
+ config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"")
+
if len(config_sh) > 1:
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh))