Placeholder changes for SD 2.0 support, haven't tested yet

This commit is contained in:
cmdr2 2022-11-25 12:17:44 +05:30
parent b70235ff92
commit ea7b28c9d5
8 changed files with 169 additions and 25 deletions

View File

@ -21,6 +21,7 @@
- A `What's New?` tab in the UI
### Detailed changelog
* 2.4.15 - 25 Nov 2022 - Experimental support for SD 2.0. Uses lots of memory, not optimized, probably GPU-only.
* 2.4.14 - 22 Nov 2022 - Change the backend to a custom fork of Stable Diffusion
* 2.4.13 - 21 Nov 2022 - Change the modifier weight via mouse wheel, drag to reorder selected modifiers, and some more modifier-related fixes. Thanks @patriceac
* 2.4.12 - 21 Nov 2022 - Another fix for improving how long images take to generate. Reduces the time taken for an enqueued task to start processing.

View File

@ -27,6 +27,8 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');"
if NOT DEFINED test_sd2 set test_sd2=N
@>nul findstr /m "sd_git_cloned" scripts\install_status.txt
@if "%ERRORLEVEL%" EQU "0" (
@echo "Stable Diffusion's git repository was already installed. Updating.."
@ -37,9 +39,17 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
@call git reset --hard
@call git pull
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
if "%test_sd2%" == "N" (
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
)
if "%test_sd2%" == "Y" (
@call git -c advice.detachedHead=false checkout 7da785cfa0d128368bc1357b54d380ba33dc1138
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback_sd2.patch
)
@cd ..
) else (
@ -346,7 +356,9 @@ echo. > "..\models\vae\Put your VAE files here.txt"
)
)
if "%test_sd2%" == "Y" (
@call pip install open_clip_torch==2.0.2
)
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
@if "%ERRORLEVEL%" NEQ "0" (

View File

@ -21,6 +21,10 @@ python -c "import os; import shutil; frm = 'sd-ui-files/ui/hotfix/9c24e6cd9f499d
# Caution, this file will make your eyes and brain bleed. It's such an unholy mess.
# Note to self: Please rewrite this in Python. For the sake of your own sanity.
if [ "$test_sd2" == "" ]; then
export test_sd2="N"
fi
if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then
echo "Stable Diffusion's git repository was already installed. Updating.."
@ -30,9 +34,16 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
git reset --hard
git pull
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
if [ "$test_sd2" == "N" ]; then
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
elif [ "$test_sd2" == "Y" ]; then
git -c advice.detachedHead=false checkout 7da785cfa0d128368bc1357b54d380ba33dc1138
git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed"
fi
cd ..
else
@ -291,6 +302,9 @@ if [ ! -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then
fi
fi
if [ "$test_sd2" == "Y" ]; then
pip install open_clip_torch==2.0.2
fi
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
echo sd_weights_downloaded >> ../scripts/install_status.txt

View File

@ -22,7 +22,7 @@
<div id="logo">
<h1>
Stable Diffusion UI
<small>v2.4.14 <span id="updateBranchLabel"></span></small>
<small>v2.4.15 <span id="updateBranchLabel"></span></small>
</h1>
</div>
<div id="server-status">

View File

@ -132,6 +132,14 @@ var PARAMETERS = [
return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">`
}
},
{
id: "test_sd2",
type: ParameterType.checkbox,
label: "Test SD 2.0",
note: "Experimental! High memory usage! GPU-only! Please restart the program after changing this.",
icon: "fa-fire",
default: false,
},
{
id: "use_beta_channel",
type: ParameterType.checkbox,
@ -196,6 +204,7 @@ let saveToDiskField = document.querySelector('#save_to_disk')
let diskPathField = document.querySelector('#diskPath')
let listenToNetworkField = document.querySelector("#listen_to_network")
let listenPortField = document.querySelector("#listen_port")
let testSD2Field = document.querySelector("#test_sd2")
let useBetaChannelField = document.querySelector("#use_beta_channel")
let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start")
@ -230,6 +239,9 @@ async function getAppConfig() {
if (config.ui && config.ui.open_browser_on_start === false) {
uiOpenBrowserOnStartField.checked = false
}
if ('test_sd2' in config) {
testSD2Field.checked = config['test_sd2']
}
if (config.net && config.net.listen_to_network === false) {
listenToNetworkField.checked = false
}
@ -372,7 +384,8 @@ saveSettingsBtn.addEventListener('click', function() {
'update_branch': updateBranch,
'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked,
'listen_to_network': listenToNetworkField.checked,
'listen_port': listenPortField.value
'listen_port': listenPortField.value,
'test_sd2': testSD2Field.checked
})
}

View File

@ -0,0 +1,20 @@
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
index 1bbdd02..cd00cc3 100644
--- a/optimizedSD/ddpm.py
+++ b/optimizedSD/ddpm.py
@@ -348,6 +348,7 @@ class DDPM(pl.LightningModule):
def sample(self, batch_size=16, return_intermediates=False):
image_size = self.image_size
channels = self.channels
+ print('sampler 2')
return self.p_sample_loop((batch_size, channels, image_size, image_size),
return_intermediates=return_intermediates)
@@ -1090,6 +1091,7 @@ class LatentDiffusion(DDPM):
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
verbose=True, timesteps=None, quantize_denoised=False,
mask=None, x0=None, shape=None, **kwargs):
+ print('sampler 1')
if shape is None:
shape = (batch_size, self.channels, self.image_size, self.image_size)
if cond is not None:

View File

@ -76,8 +76,24 @@ def thread_init(device):
thread_data.force_full_precision = False
thread_data.reduced_memory = True
thread_data.test_sd2 = isSD2()
device_manager.device_init(thread_data, device)
# temp hack, will remove soon
def isSD2():
try:
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return False
with open(config_json_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config.get('test_sd2', False)
except Exception as e:
return False
def load_model_ckpt():
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
@ -92,6 +108,13 @@ def load_model_ckpt():
thread_data.precision = 'full'
print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision)
if thread_data.test_sd2:
load_model_ckpt_sd2()
else:
load_model_ckpt_sd1()
def load_model_ckpt_sd1():
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
li, lo = [], []
for key, value in sd.items():
@ -185,6 +208,37 @@ def load_model_ckpt():
modelFS.device: {thread_data.modelFS.device}
using precision: {thread_data.precision}''')
def load_model_ckpt_sd2():
config = OmegaConf.load('configs/stable-diffusion/v2-inference-v.yaml')
verbose = False
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
thread_data.model = instantiate_from_config(config.model)
m, u = thread_data.model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
thread_data.model.to(thread_data.device)
thread_data.model.eval()
del sd
if thread_data.device != "cpu" and thread_data.precision == "autocast":
thread_data.model.half()
thread_data.model_is_half = True
thread_data.model_fs_is_half = True
else:
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
print(f'''loaded model
model file: {thread_data.ckpt_file}.ckpt
using precision: {thread_data.precision}''')
def unload_filters():
if thread_data.model_gfpgan is not None:
if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu')
@ -204,10 +258,11 @@ def unload_models():
if thread_data.model is not None:
print('Unloading models...')
if thread_data.device != 'cpu':
thread_data.modelFS.to('cpu')
thread_data.modelCS.to('cpu')
thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu")
if not thread_data.test_sd2:
thread_data.modelFS.to('cpu')
thread_data.modelCS.to('cpu')
thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu")
del thread_data.model
del thread_data.modelCS
@ -343,7 +398,7 @@ def mk_img(req: Request):
except Exception as e:
print(traceback.format_exc())
if thread_data.device != 'cpu':
if thread_data.device != 'cpu' and not thread_data.test_sd2:
thread_data.modelFS.to('cpu')
thread_data.modelCS.to('cpu')
thread_data.model.model1.to("cpu")
@ -358,7 +413,10 @@ def mk_img(req: Request):
def update_temp_img(req, x_samples):
partial_images = []
for i in range(req.num_outputs):
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
if thread_data.test_sd2:
x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
else:
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
@ -433,7 +491,7 @@ def do_mk_img(req: Request):
unload_filters()
load_model_ckpt()
if thread_data.turbo != req.turbo:
if thread_data.turbo != req.turbo and not thread_data.test_sd2:
thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo
@ -478,10 +536,14 @@ def do_mk_img(req: Request):
if thread_data.device != "cpu" and thread_data.precision == "autocast":
init_image = init_image.half()
thread_data.modelFS.to(thread_data.device)
if not thread_data.test_sd2:
thread_data.modelFS.to(thread_data.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
if thread_data.test_sd2:
init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space
else:
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None:
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
@ -493,7 +555,8 @@ def do_mk_img(req: Request):
# Send to CPU and wait until complete.
# wait_model_move_to(thread_data.modelFS, 'cpu')
move_to_cpu(thread_data.modelFS)
if not thread_data.test_sd2:
move_to_cpu(thread_data.modelFS)
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(req.prompt_strength * req.num_inference_steps)
@ -509,11 +572,14 @@ def do_mk_img(req: Request):
for prompts in tqdm(data, desc="data"):
with precision_scope("cuda"):
if thread_data.reduced_memory:
if thread_data.reduced_memory and not thread_data.test_sd2:
thread_data.modelCS.to(thread_data.device)
uc = None
if req.guidance_scale != 1.0:
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if thread_data.test_sd2:
uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt])
else:
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
@ -526,11 +592,17 @@ def do_mk_img(req: Request):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
if thread_data.test_sd2:
c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = thread_data.modelCS.get_learned_conditioning(prompts)
if thread_data.test_sd2:
c = thread_data.model.get_learned_conditioning(prompts)
else:
c = thread_data.modelCS.get_learned_conditioning(prompts)
if thread_data.reduced_memory:
if thread_data.reduced_memory and not thread_data.test_sd2:
thread_data.modelFS.to(thread_data.device)
n_steps = req.num_inference_steps if req.init_image is None else t_enc
@ -562,7 +634,10 @@ def do_mk_img(req: Request):
print("decoding images")
img_data = [None] * batch_size
for i in range(batch_size):
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
if thread_data.test_sd2:
x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
else:
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
@ -622,7 +697,8 @@ def do_mk_img(req: Request):
# if thread_data.reduced_memory:
# unload_filters()
move_to_cpu(thread_data.modelFS)
if not thread_data.test_sd2:
move_to_cpu(thread_data.modelFS)
del img_data
gc()
if thread_data.device != 'cpu':
@ -664,7 +740,11 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
# Send to CPU and wait until complete.
# wait_model_move_to(thread_data.modelCS, 'cpu')
move_to_cpu(thread_data.modelCS)
if not thread_data.test_sd2:
move_to_cpu(thread_data.modelCS)
if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'):
raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0')
if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)

View File

@ -116,6 +116,8 @@ def setConfig(config):
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}")
if len(config_bat) > 0:
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat))
@ -133,6 +135,8 @@ def setConfig(config):
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"")
if len(config_sh) > 1:
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh))