forked from extern/easydiffusion
Placeholder changes for SD 2.0 support, haven't tested yet
This commit is contained in:
parent
b70235ff92
commit
ea7b28c9d5
@ -21,6 +21,7 @@
|
||||
- A `What's New?` tab in the UI
|
||||
|
||||
### Detailed changelog
|
||||
* 2.4.15 - 25 Nov 2022 - Experimental support for SD 2.0. Uses lots of memory, not optimized, probably GPU-only.
|
||||
* 2.4.14 - 22 Nov 2022 - Change the backend to a custom fork of Stable Diffusion
|
||||
* 2.4.13 - 21 Nov 2022 - Change the modifier weight via mouse wheel, drag to reorder selected modifiers, and some more modifier-related fixes. Thanks @patriceac
|
||||
* 2.4.12 - 21 Nov 2022 - Another fix for improving how long images take to generate. Reduces the time taken for an enqueued task to start processing.
|
||||
|
@ -27,6 +27,8 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
|
||||
|
||||
@call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');"
|
||||
|
||||
if NOT DEFINED test_sd2 set test_sd2=N
|
||||
|
||||
@>nul findstr /m "sd_git_cloned" scripts\install_status.txt
|
||||
@if "%ERRORLEVEL%" EQU "0" (
|
||||
@echo "Stable Diffusion's git repository was already installed. Updating.."
|
||||
@ -37,9 +39,17 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd"
|
||||
|
||||
@call git reset --hard
|
||||
@call git pull
|
||||
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
||||
|
||||
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
|
||||
if "%test_sd2%" == "N" (
|
||||
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
||||
|
||||
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
|
||||
)
|
||||
if "%test_sd2%" == "Y" (
|
||||
@call git -c advice.detachedHead=false checkout 7da785cfa0d128368bc1357b54d380ba33dc1138
|
||||
|
||||
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback_sd2.patch
|
||||
)
|
||||
|
||||
@cd ..
|
||||
) else (
|
||||
@ -346,7 +356,9 @@ echo. > "..\models\vae\Put your VAE files here.txt"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if "%test_sd2%" == "Y" (
|
||||
@call pip install open_clip_torch==2.0.2
|
||||
)
|
||||
|
||||
@>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt
|
||||
@if "%ERRORLEVEL%" NEQ "0" (
|
||||
|
@ -21,6 +21,10 @@ python -c "import os; import shutil; frm = 'sd-ui-files/ui/hotfix/9c24e6cd9f499d
|
||||
# Caution, this file will make your eyes and brain bleed. It's such an unholy mess.
|
||||
# Note to self: Please rewrite this in Python. For the sake of your own sanity.
|
||||
|
||||
if [ "$test_sd2" == "" ]; then
|
||||
export test_sd2="N"
|
||||
fi
|
||||
|
||||
if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then
|
||||
echo "Stable Diffusion's git repository was already installed. Updating.."
|
||||
|
||||
@ -30,9 +34,16 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
|
||||
|
||||
git reset --hard
|
||||
git pull
|
||||
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
||||
|
||||
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
|
||||
if [ "$test_sd2" == "N" ]; then
|
||||
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
||||
|
||||
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
|
||||
elif [ "$test_sd2" == "Y" ]; then
|
||||
git -c advice.detachedHead=false checkout 7da785cfa0d128368bc1357b54d380ba33dc1138
|
||||
|
||||
git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed"
|
||||
fi
|
||||
|
||||
cd ..
|
||||
else
|
||||
@ -291,6 +302,9 @@ if [ ! -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ "$test_sd2" == "Y" ]; then
|
||||
pip install open_clip_torch==2.0.2
|
||||
fi
|
||||
|
||||
if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then
|
||||
echo sd_weights_downloaded >> ../scripts/install_status.txt
|
||||
|
@ -22,7 +22,7 @@
|
||||
<div id="logo">
|
||||
<h1>
|
||||
Stable Diffusion UI
|
||||
<small>v2.4.14 <span id="updateBranchLabel"></span></small>
|
||||
<small>v2.4.15 <span id="updateBranchLabel"></span></small>
|
||||
</h1>
|
||||
</div>
|
||||
<div id="server-status">
|
||||
|
@ -132,6 +132,14 @@ var PARAMETERS = [
|
||||
return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">`
|
||||
}
|
||||
},
|
||||
{
|
||||
id: "test_sd2",
|
||||
type: ParameterType.checkbox,
|
||||
label: "Test SD 2.0",
|
||||
note: "Experimental! High memory usage! GPU-only! Please restart the program after changing this.",
|
||||
icon: "fa-fire",
|
||||
default: false,
|
||||
},
|
||||
{
|
||||
id: "use_beta_channel",
|
||||
type: ParameterType.checkbox,
|
||||
@ -196,6 +204,7 @@ let saveToDiskField = document.querySelector('#save_to_disk')
|
||||
let diskPathField = document.querySelector('#diskPath')
|
||||
let listenToNetworkField = document.querySelector("#listen_to_network")
|
||||
let listenPortField = document.querySelector("#listen_port")
|
||||
let testSD2Field = document.querySelector("#test_sd2")
|
||||
let useBetaChannelField = document.querySelector("#use_beta_channel")
|
||||
let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start")
|
||||
|
||||
@ -230,6 +239,9 @@ async function getAppConfig() {
|
||||
if (config.ui && config.ui.open_browser_on_start === false) {
|
||||
uiOpenBrowserOnStartField.checked = false
|
||||
}
|
||||
if ('test_sd2' in config) {
|
||||
testSD2Field.checked = config['test_sd2']
|
||||
}
|
||||
if (config.net && config.net.listen_to_network === false) {
|
||||
listenToNetworkField.checked = false
|
||||
}
|
||||
@ -372,7 +384,8 @@ saveSettingsBtn.addEventListener('click', function() {
|
||||
'update_branch': updateBranch,
|
||||
'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked,
|
||||
'listen_to_network': listenToNetworkField.checked,
|
||||
'listen_port': listenPortField.value
|
||||
'listen_port': listenPortField.value,
|
||||
'test_sd2': testSD2Field.checked
|
||||
})
|
||||
}
|
||||
|
||||
|
20
ui/sd_internal/ddim_callback_sd2.patch
Normal file
20
ui/sd_internal/ddim_callback_sd2.patch
Normal file
@ -0,0 +1,20 @@
|
||||
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
|
||||
index 1bbdd02..cd00cc3 100644
|
||||
--- a/optimizedSD/ddpm.py
|
||||
+++ b/optimizedSD/ddpm.py
|
||||
@@ -348,6 +348,7 @@ class DDPM(pl.LightningModule):
|
||||
def sample(self, batch_size=16, return_intermediates=False):
|
||||
image_size = self.image_size
|
||||
channels = self.channels
|
||||
+ print('sampler 2')
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size),
|
||||
return_intermediates=return_intermediates)
|
||||
|
||||
@@ -1090,6 +1091,7 @@ class LatentDiffusion(DDPM):
|
||||
def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
|
||||
verbose=True, timesteps=None, quantize_denoised=False,
|
||||
mask=None, x0=None, shape=None, **kwargs):
|
||||
+ print('sampler 1')
|
||||
if shape is None:
|
||||
shape = (batch_size, self.channels, self.image_size, self.image_size)
|
||||
if cond is not None:
|
@ -76,8 +76,24 @@ def thread_init(device):
|
||||
thread_data.force_full_precision = False
|
||||
thread_data.reduced_memory = True
|
||||
|
||||
thread_data.test_sd2 = isSD2()
|
||||
|
||||
device_manager.device_init(thread_data, device)
|
||||
|
||||
# temp hack, will remove soon
|
||||
def isSD2():
|
||||
try:
|
||||
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
|
||||
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
if not os.path.exists(config_json_path):
|
||||
return False
|
||||
with open(config_json_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
return config.get('test_sd2', False)
|
||||
except Exception as e:
|
||||
return False
|
||||
|
||||
def load_model_ckpt():
|
||||
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
|
||||
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
|
||||
@ -92,6 +108,13 @@ def load_model_ckpt():
|
||||
thread_data.precision = 'full'
|
||||
|
||||
print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision)
|
||||
|
||||
if thread_data.test_sd2:
|
||||
load_model_ckpt_sd2()
|
||||
else:
|
||||
load_model_ckpt_sd1()
|
||||
|
||||
def load_model_ckpt_sd1():
|
||||
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
||||
li, lo = [], []
|
||||
for key, value in sd.items():
|
||||
@ -185,6 +208,37 @@ def load_model_ckpt():
|
||||
modelFS.device: {thread_data.modelFS.device}
|
||||
using precision: {thread_data.precision}''')
|
||||
|
||||
def load_model_ckpt_sd2():
|
||||
config = OmegaConf.load('configs/stable-diffusion/v2-inference-v.yaml')
|
||||
verbose = False
|
||||
|
||||
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
||||
|
||||
thread_data.model = instantiate_from_config(config.model)
|
||||
m, u = thread_data.model.load_state_dict(sd, strict=False)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
thread_data.model.to(thread_data.device)
|
||||
thread_data.model.eval()
|
||||
del sd
|
||||
|
||||
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||
thread_data.model.half()
|
||||
thread_data.model_is_half = True
|
||||
thread_data.model_fs_is_half = True
|
||||
else:
|
||||
thread_data.model_is_half = False
|
||||
thread_data.model_fs_is_half = False
|
||||
|
||||
print(f'''loaded model
|
||||
model file: {thread_data.ckpt_file}.ckpt
|
||||
using precision: {thread_data.precision}''')
|
||||
|
||||
def unload_filters():
|
||||
if thread_data.model_gfpgan is not None:
|
||||
if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu')
|
||||
@ -204,10 +258,11 @@ def unload_models():
|
||||
if thread_data.model is not None:
|
||||
print('Unloading models...')
|
||||
if thread_data.device != 'cpu':
|
||||
thread_data.modelFS.to('cpu')
|
||||
thread_data.modelCS.to('cpu')
|
||||
thread_data.model.model1.to("cpu")
|
||||
thread_data.model.model2.to("cpu")
|
||||
if not thread_data.test_sd2:
|
||||
thread_data.modelFS.to('cpu')
|
||||
thread_data.modelCS.to('cpu')
|
||||
thread_data.model.model1.to("cpu")
|
||||
thread_data.model.model2.to("cpu")
|
||||
|
||||
del thread_data.model
|
||||
del thread_data.modelCS
|
||||
@ -343,7 +398,7 @@ def mk_img(req: Request):
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
|
||||
if thread_data.device != 'cpu':
|
||||
if thread_data.device != 'cpu' and not thread_data.test_sd2:
|
||||
thread_data.modelFS.to('cpu')
|
||||
thread_data.modelCS.to('cpu')
|
||||
thread_data.model.model1.to("cpu")
|
||||
@ -358,7 +413,10 @@ def mk_img(req: Request):
|
||||
def update_temp_img(req, x_samples):
|
||||
partial_images = []
|
||||
for i in range(req.num_outputs):
|
||||
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
if thread_data.test_sd2:
|
||||
x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
else:
|
||||
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
@ -433,7 +491,7 @@ def do_mk_img(req: Request):
|
||||
unload_filters()
|
||||
load_model_ckpt()
|
||||
|
||||
if thread_data.turbo != req.turbo:
|
||||
if thread_data.turbo != req.turbo and not thread_data.test_sd2:
|
||||
thread_data.turbo = req.turbo
|
||||
thread_data.model.turbo = req.turbo
|
||||
|
||||
@ -478,10 +536,14 @@ def do_mk_img(req: Request):
|
||||
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||
init_image = init_image.half()
|
||||
|
||||
thread_data.modelFS.to(thread_data.device)
|
||||
if not thread_data.test_sd2:
|
||||
thread_data.modelFS.to(thread_data.device)
|
||||
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
|
||||
if thread_data.test_sd2:
|
||||
init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space
|
||||
else:
|
||||
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
if req.mask is not None:
|
||||
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
|
||||
@ -493,7 +555,8 @@ def do_mk_img(req: Request):
|
||||
|
||||
# Send to CPU and wait until complete.
|
||||
# wait_model_move_to(thread_data.modelFS, 'cpu')
|
||||
move_to_cpu(thread_data.modelFS)
|
||||
if not thread_data.test_sd2:
|
||||
move_to_cpu(thread_data.modelFS)
|
||||
|
||||
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(req.prompt_strength * req.num_inference_steps)
|
||||
@ -509,11 +572,14 @@ def do_mk_img(req: Request):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
|
||||
with precision_scope("cuda"):
|
||||
if thread_data.reduced_memory:
|
||||
if thread_data.reduced_memory and not thread_data.test_sd2:
|
||||
thread_data.modelCS.to(thread_data.device)
|
||||
uc = None
|
||||
if req.guidance_scale != 1.0:
|
||||
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
||||
if thread_data.test_sd2:
|
||||
uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt])
|
||||
else:
|
||||
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
|
||||
@ -526,11 +592,17 @@ def do_mk_img(req: Request):
|
||||
weight = weights[i]
|
||||
# if not skip_normalize:
|
||||
weight = weight / totalWeight
|
||||
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
if thread_data.test_sd2:
|
||||
c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
else:
|
||||
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
else:
|
||||
c = thread_data.modelCS.get_learned_conditioning(prompts)
|
||||
if thread_data.test_sd2:
|
||||
c = thread_data.model.get_learned_conditioning(prompts)
|
||||
else:
|
||||
c = thread_data.modelCS.get_learned_conditioning(prompts)
|
||||
|
||||
if thread_data.reduced_memory:
|
||||
if thread_data.reduced_memory and not thread_data.test_sd2:
|
||||
thread_data.modelFS.to(thread_data.device)
|
||||
|
||||
n_steps = req.num_inference_steps if req.init_image is None else t_enc
|
||||
@ -562,7 +634,10 @@ def do_mk_img(req: Request):
|
||||
print("decoding images")
|
||||
img_data = [None] * batch_size
|
||||
for i in range(batch_size):
|
||||
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
if thread_data.test_sd2:
|
||||
x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
else:
|
||||
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
@ -622,7 +697,8 @@ def do_mk_img(req: Request):
|
||||
|
||||
# if thread_data.reduced_memory:
|
||||
# unload_filters()
|
||||
move_to_cpu(thread_data.modelFS)
|
||||
if not thread_data.test_sd2:
|
||||
move_to_cpu(thread_data.modelFS)
|
||||
del img_data
|
||||
gc()
|
||||
if thread_data.device != 'cpu':
|
||||
@ -664,7 +740,11 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
# Send to CPU and wait until complete.
|
||||
# wait_model_move_to(thread_data.modelCS, 'cpu')
|
||||
|
||||
move_to_cpu(thread_data.modelCS)
|
||||
if not thread_data.test_sd2:
|
||||
move_to_cpu(thread_data.modelCS)
|
||||
|
||||
if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'):
|
||||
raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0')
|
||||
|
||||
if sampler_name == 'ddim':
|
||||
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||
|
@ -116,6 +116,8 @@ def setConfig(config):
|
||||
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
|
||||
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
|
||||
|
||||
config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}")
|
||||
|
||||
if len(config_bat) > 0:
|
||||
with open(config_bat_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\r\n'.join(config_bat))
|
||||
@ -133,6 +135,8 @@ def setConfig(config):
|
||||
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
|
||||
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
|
||||
|
||||
config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"")
|
||||
|
||||
if len(config_sh) > 1:
|
||||
with open(config_sh_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(config_sh))
|
||||
|
Loading…
Reference in New Issue
Block a user