diff --git a/CHANGES.md b/CHANGES.md index db274ede..baafcb18 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -21,6 +21,7 @@ - A `What's New?` tab in the UI ### Detailed changelog +* 2.4.15 - 25 Nov 2022 - Experimental support for SD 2.0. Uses lots of memory, not optimized, probably GPU-only. * 2.4.14 - 22 Nov 2022 - Change the backend to a custom fork of Stable Diffusion * 2.4.13 - 21 Nov 2022 - Change the modifier weight via mouse wheel, drag to reorder selected modifiers, and some more modifier-related fixes. Thanks @patriceac * 2.4.12 - 21 Nov 2022 - Another fix for improving how long images take to generate. Reduces the time taken for an enqueued task to start processing. diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 6e4ffd36..8b2e3e02 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -27,6 +27,8 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd" @call python -c "import os; import shutil; frm = 'sd-ui-files\\ui\\hotfix\\9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'; dst = os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'transformers', '9c24e6cd9f499d02c4f21a033736dabd365962dc80fe3aeb57a8f85ea45a20a3.26fead7ea4f0f843f6eb4055dfd25693f1a71f3c6871b184042d4b126244e142'); shutil.copyfile(frm, dst) if os.path.exists(dst) else print(''); print('Hotfixed broken JSON file from OpenAI');" +if NOT DEFINED test_sd2 set test_sd2=N + @>nul findstr /m "sd_git_cloned" scripts\install_status.txt @if "%ERRORLEVEL%" EQU "0" ( @echo "Stable Diffusion's git repository was already installed. Updating.." @@ -37,9 +39,17 @@ if exist "Open Developer Console.cmd" del "Open Developer Console.cmd" @call git reset --hard @call git pull - @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch + if "%test_sd2%" == "N" ( + @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a + + @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch + ) + if "%test_sd2%" == "Y" ( + @call git -c advice.detachedHead=false checkout 7da785cfa0d128368bc1357b54d380ba33dc1138 + + @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback_sd2.patch + ) @cd .. ) else ( @@ -346,7 +356,9 @@ echo. > "..\models\vae\Put your VAE files here.txt" ) ) - +if "%test_sd2%" == "Y" ( + @call pip install open_clip_torch==2.0.2 +) @>nul findstr /m "sd_install_complete" ..\scripts\install_status.txt @if "%ERRORLEVEL%" NEQ "0" ( diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index ff6a04d4..133dac3e 100644 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -21,6 +21,10 @@ python -c "import os; import shutil; frm = 'sd-ui-files/ui/hotfix/9c24e6cd9f499d # Caution, this file will make your eyes and brain bleed. It's such an unholy mess. # Note to self: Please rewrite this in Python. For the sake of your own sanity. +if [ "$test_sd2" == "" ]; then + export test_sd2="N" +fi + if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/install_status.txt` -gt "0" ]; then echo "Stable Diffusion's git repository was already installed. Updating.." @@ -30,9 +34,16 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta git reset --hard git pull - git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed" + if [ "$test_sd2" == "N" ]; then + git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a + + git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed" + elif [ "$test_sd2" == "Y" ]; then + git -c advice.detachedHead=false checkout 7da785cfa0d128368bc1357b54d380ba33dc1138 + + git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed" + fi cd .. else @@ -291,6 +302,9 @@ if [ ! -f "../models/vae/vae-ft-mse-840000-ema-pruned.ckpt" ]; then fi fi +if [ "$test_sd2" == "Y" ]; then + pip install open_clip_torch==2.0.2 +fi if [ `grep -c sd_install_complete ../scripts/install_status.txt` -gt "0" ]; then echo sd_weights_downloaded >> ../scripts/install_status.txt diff --git a/ui/index.html b/ui/index.html index 3e85b254..edbccf97 100644 --- a/ui/index.html +++ b/ui/index.html @@ -22,7 +22,7 @@
diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index 2e5bc75c..dac008ef 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -132,6 +132,14 @@ var PARAMETERS = [ return `` } }, + { + id: "test_sd2", + type: ParameterType.checkbox, + label: "Test SD 2.0", + note: "Experimental! High memory usage! GPU-only! Please restart the program after changing this.", + icon: "fa-fire", + default: false, + }, { id: "use_beta_channel", type: ParameterType.checkbox, @@ -196,6 +204,7 @@ let saveToDiskField = document.querySelector('#save_to_disk') let diskPathField = document.querySelector('#diskPath') let listenToNetworkField = document.querySelector("#listen_to_network") let listenPortField = document.querySelector("#listen_port") +let testSD2Field = document.querySelector("#test_sd2") let useBetaChannelField = document.querySelector("#use_beta_channel") let uiOpenBrowserOnStartField = document.querySelector("#ui_open_browser_on_start") @@ -230,6 +239,9 @@ async function getAppConfig() { if (config.ui && config.ui.open_browser_on_start === false) { uiOpenBrowserOnStartField.checked = false } + if ('test_sd2' in config) { + testSD2Field.checked = config['test_sd2'] + } if (config.net && config.net.listen_to_network === false) { listenToNetworkField.checked = false } @@ -372,7 +384,8 @@ saveSettingsBtn.addEventListener('click', function() { 'update_branch': updateBranch, 'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked, 'listen_to_network': listenToNetworkField.checked, - 'listen_port': listenPortField.value + 'listen_port': listenPortField.value, + 'test_sd2': testSD2Field.checked }) } diff --git a/ui/sd_internal/ddim_callback_sd2.patch b/ui/sd_internal/ddim_callback_sd2.patch new file mode 100644 index 00000000..00700d00 --- /dev/null +++ b/ui/sd_internal/ddim_callback_sd2.patch @@ -0,0 +1,20 @@ +diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py +index 1bbdd02..cd00cc3 100644 +--- a/optimizedSD/ddpm.py ++++ b/optimizedSD/ddpm.py +@@ -348,6 +348,7 @@ class DDPM(pl.LightningModule): + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels ++ print('sampler 2') + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + +@@ -1090,6 +1091,7 @@ class LatentDiffusion(DDPM): + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None, **kwargs): ++ print('sampler 1') + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index e217965d..d9521780 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -76,8 +76,24 @@ def thread_init(device): thread_data.force_full_precision = False thread_data.reduced_memory = True + thread_data.test_sd2 = isSD2() + device_manager.device_init(thread_data, device) +# temp hack, will remove soon +def isSD2(): + try: + SD_UI_DIR = os.getenv('SD_UI_PATH', None) + CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) + config_json_path = os.path.join(CONFIG_DIR, 'config.json') + if not os.path.exists(config_json_path): + return False + with open(config_json_path, 'r', encoding='utf-8') as f: + config = json.load(f) + return config.get('test_sd2', False) + except Exception as e: + return False + def load_model_ckpt(): if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.') if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt') @@ -92,6 +108,13 @@ def load_model_ckpt(): thread_data.precision = 'full' print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision) + + if thread_data.test_sd2: + load_model_ckpt_sd2() + else: + load_model_ckpt_sd1() + +def load_model_ckpt_sd1(): sd = load_model_from_config(thread_data.ckpt_file + '.ckpt') li, lo = [], [] for key, value in sd.items(): @@ -185,6 +208,37 @@ def load_model_ckpt(): modelFS.device: {thread_data.modelFS.device} using precision: {thread_data.precision}''') +def load_model_ckpt_sd2(): + config = OmegaConf.load('configs/stable-diffusion/v2-inference-v.yaml') + verbose = False + + sd = load_model_from_config(thread_data.ckpt_file + '.ckpt') + + thread_data.model = instantiate_from_config(config.model) + m, u = thread_data.model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + thread_data.model.to(thread_data.device) + thread_data.model.eval() + del sd + + if thread_data.device != "cpu" and thread_data.precision == "autocast": + thread_data.model.half() + thread_data.model_is_half = True + thread_data.model_fs_is_half = True + else: + thread_data.model_is_half = False + thread_data.model_fs_is_half = False + + print(f'''loaded model + model file: {thread_data.ckpt_file}.ckpt + using precision: {thread_data.precision}''') + def unload_filters(): if thread_data.model_gfpgan is not None: if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu') @@ -204,10 +258,11 @@ def unload_models(): if thread_data.model is not None: print('Unloading models...') if thread_data.device != 'cpu': - thread_data.modelFS.to('cpu') - thread_data.modelCS.to('cpu') - thread_data.model.model1.to("cpu") - thread_data.model.model2.to("cpu") + if not thread_data.test_sd2: + thread_data.modelFS.to('cpu') + thread_data.modelCS.to('cpu') + thread_data.model.model1.to("cpu") + thread_data.model.model2.to("cpu") del thread_data.model del thread_data.modelCS @@ -343,7 +398,7 @@ def mk_img(req: Request): except Exception as e: print(traceback.format_exc()) - if thread_data.device != 'cpu': + if thread_data.device != 'cpu' and not thread_data.test_sd2: thread_data.modelFS.to('cpu') thread_data.modelCS.to('cpu') thread_data.model.model1.to("cpu") @@ -358,7 +413,10 @@ def mk_img(req: Request): def update_temp_img(req, x_samples): partial_images = [] for i in range(req.num_outputs): - x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) + if thread_data.test_sd2: + x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0)) + else: + x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = x_sample.astype(np.uint8) @@ -433,7 +491,7 @@ def do_mk_img(req: Request): unload_filters() load_model_ckpt() - if thread_data.turbo != req.turbo: + if thread_data.turbo != req.turbo and not thread_data.test_sd2: thread_data.turbo = req.turbo thread_data.model.turbo = req.turbo @@ -478,10 +536,14 @@ def do_mk_img(req: Request): if thread_data.device != "cpu" and thread_data.precision == "autocast": init_image = init_image.half() - thread_data.modelFS.to(thread_data.device) + if not thread_data.test_sd2: + thread_data.modelFS.to(thread_data.device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) - init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space + if thread_data.test_sd2: + init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space + else: + init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space if req.mask is not None: mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device) @@ -493,7 +555,8 @@ def do_mk_img(req: Request): # Send to CPU and wait until complete. # wait_model_move_to(thread_data.modelFS, 'cpu') - move_to_cpu(thread_data.modelFS) + if not thread_data.test_sd2: + move_to_cpu(thread_data.modelFS) assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]' t_enc = int(req.prompt_strength * req.num_inference_steps) @@ -509,11 +572,14 @@ def do_mk_img(req: Request): for prompts in tqdm(data, desc="data"): with precision_scope("cuda"): - if thread_data.reduced_memory: + if thread_data.reduced_memory and not thread_data.test_sd2: thread_data.modelCS.to(thread_data.device) uc = None if req.guidance_scale != 1.0: - uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) + if thread_data.test_sd2: + uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt]) + else: + uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) if isinstance(prompts, tuple): prompts = list(prompts) @@ -526,11 +592,17 @@ def do_mk_img(req: Request): weight = weights[i] # if not skip_normalize: weight = weight / totalWeight - c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) + if thread_data.test_sd2: + c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight) + else: + c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) else: - c = thread_data.modelCS.get_learned_conditioning(prompts) + if thread_data.test_sd2: + c = thread_data.model.get_learned_conditioning(prompts) + else: + c = thread_data.modelCS.get_learned_conditioning(prompts) - if thread_data.reduced_memory: + if thread_data.reduced_memory and not thread_data.test_sd2: thread_data.modelFS.to(thread_data.device) n_steps = req.num_inference_steps if req.init_image is None else t_enc @@ -562,7 +634,10 @@ def do_mk_img(req: Request): print("decoding images") img_data = [None] * batch_size for i in range(batch_size): - x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) + if thread_data.test_sd2: + x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0)) + else: + x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = x_sample.astype(np.uint8) @@ -622,7 +697,8 @@ def do_mk_img(req: Request): # if thread_data.reduced_memory: # unload_filters() - move_to_cpu(thread_data.modelFS) + if not thread_data.test_sd2: + move_to_cpu(thread_data.modelFS) del img_data gc() if thread_data.device != 'cpu': @@ -664,7 +740,11 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, # Send to CPU and wait until complete. # wait_model_move_to(thread_data.modelCS, 'cpu') - move_to_cpu(thread_data.modelCS) + if not thread_data.test_sd2: + move_to_cpu(thread_data.modelCS) + + if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim'): + raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0') if sampler_name == 'ddim': thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) diff --git a/ui/server.py b/ui/server.py index 61635f18..8e5343db 100644 --- a/ui/server.py +++ b/ui/server.py @@ -116,6 +116,8 @@ def setConfig(config): bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") + config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}") + if len(config_bat) > 0: with open(config_bat_path, 'w', encoding='utf-8') as f: f.write('\r\n'.join(config_bat)) @@ -133,6 +135,8 @@ def setConfig(config): bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") + config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"") + if len(config_sh) > 1: with open(config_sh_path, 'w', encoding='utf-8') as f: f.write('\n'.join(config_sh))