diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 8b2e3e02..57337107 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -46,7 +46,7 @@ if NOT DEFINED test_sd2 set test_sd2=N @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch ) if "%test_sd2%" == "Y" ( - @call git -c advice.detachedHead=false checkout 7da785cfa0d128368bc1357b54d380ba33dc1138 + @call git -c advice.detachedHead=false checkout 5a14697a8f4a43a56b575a0b1d02a48b37fb9b94 @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback_sd2.patch ) diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 133dac3e..a61c5e1c 100644 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -40,7 +40,7 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed" elif [ "$test_sd2" == "Y" ]; then - git -c advice.detachedHead=false checkout 7da785cfa0d128368bc1357b54d380ba33dc1138 + git -c advice.detachedHead=false checkout 5a14697a8f4a43a56b575a0b1d02a48b37fb9b94 git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed" fi diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index d9521780..b5b61f65 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -21,7 +21,6 @@ from torch import autocast from contextlib import nullcontext from einops import rearrange, repeat from ldm.util import instantiate_from_config -from optimizedSD.optimUtils import split_weighted_subprompts from transformers import logging from gfpgan import GFPGANer @@ -209,7 +208,8 @@ def load_model_ckpt_sd1(): using precision: {thread_data.precision}''') def load_model_ckpt_sd2(): - config = OmegaConf.load('configs/stable-diffusion/v2-inference-v.yaml') + config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if 'sd2_' in thread_data.ckpt_file else "configs/stable-diffusion/v1-inference.yaml" + config = OmegaConf.load(config_file) verbose = False sd = load_model_from_config(thread_data.ckpt_file + '.ckpt') @@ -875,3 +875,48 @@ def base64_str_to_img(img_str): buffered = base64_str_to_buffer(img_str) img = Image.open(buffered) return img + +def split_weighted_subprompts(text): + """ + grabs all text up to the first occurrence of ':' + uses the grabbed text as a sub-prompt, and takes the value following ':' as weight + if ':' has no value defined, defaults to 1.0 + repeats until no text remaining + """ + remaining = len(text) + prompts = [] + weights = [] + while remaining > 0: + if ":" in text: + idx = text.index(":") # first occurrence from start + # grab up to index as sub-prompt + prompt = text[:idx] + remaining -= idx + # remove from main text + text = text[idx+1:] + # find value for weight + if " " in text: + idx = text.index(" ") # first occurence + else: # no space, read to end + idx = len(text) + if idx != 0: + try: + weight = float(text[:idx]) + except: # couldn't treat as float + print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?") + weight = 1.0 + else: # no value found + weight = 1.0 + # remove from main text + remaining -= idx + text = text[idx+1:] + # append the sub-prompt and its weight + prompts.append(prompt) + weights.append(weight) + else: # no : found + if len(text) > 0: # there is still text though + # take remainder as weight 1 + prompts.append(text) + weights.append(1.0) + remaining = 0 + return prompts, weights diff --git a/ui/server.py b/ui/server.py index 8e5343db..fffbb43c 100644 --- a/ui/server.py +++ b/ui/server.py @@ -192,6 +192,7 @@ class SetAppConfigRequest(BaseModel): ui_open_browser_on_start: bool = None listen_to_network: bool = None listen_port: int = None + test_sd2: bool = None @app.post('/app_config') async def setAppConfig(req : SetAppConfigRequest): @@ -212,6 +213,8 @@ async def setAppConfig(req : SetAppConfigRequest): if 'net' not in config: config['net'] = {} config['net']['listen_port'] = int(req.listen_port) + if req.test_sd2 is not None: + config['test_sd2'] = req.test_sd2 try: setConfig(config)