mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-22 00:03:20 +01:00
Tweaks to load sd1 models in sd2 code, typos
This commit is contained in:
parent
ea7b28c9d5
commit
02dd3e457d
@ -46,7 +46,7 @@ if NOT DEFINED test_sd2 set test_sd2=N
|
||||
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
|
||||
)
|
||||
if "%test_sd2%" == "Y" (
|
||||
@call git -c advice.detachedHead=false checkout 7da785cfa0d128368bc1357b54d380ba33dc1138
|
||||
@call git -c advice.detachedHead=false checkout 5a14697a8f4a43a56b575a0b1d02a48b37fb9b94
|
||||
|
||||
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback_sd2.patch
|
||||
)
|
||||
|
@ -40,7 +40,7 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
|
||||
|
||||
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
|
||||
elif [ "$test_sd2" == "Y" ]; then
|
||||
git -c advice.detachedHead=false checkout 7da785cfa0d128368bc1357b54d380ba33dc1138
|
||||
git -c advice.detachedHead=false checkout 5a14697a8f4a43a56b575a0b1d02a48b37fb9b94
|
||||
|
||||
git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed"
|
||||
fi
|
||||
|
@ -21,7 +21,6 @@ from torch import autocast
|
||||
from contextlib import nullcontext
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import instantiate_from_config
|
||||
from optimizedSD.optimUtils import split_weighted_subprompts
|
||||
from transformers import logging
|
||||
|
||||
from gfpgan import GFPGANer
|
||||
@ -209,7 +208,8 @@ def load_model_ckpt_sd1():
|
||||
using precision: {thread_data.precision}''')
|
||||
|
||||
def load_model_ckpt_sd2():
|
||||
config = OmegaConf.load('configs/stable-diffusion/v2-inference-v.yaml')
|
||||
config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if 'sd2_' in thread_data.ckpt_file else "configs/stable-diffusion/v1-inference.yaml"
|
||||
config = OmegaConf.load(config_file)
|
||||
verbose = False
|
||||
|
||||
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
||||
@ -875,3 +875,48 @@ def base64_str_to_img(img_str):
|
||||
buffered = base64_str_to_buffer(img_str)
|
||||
img = Image.open(buffered)
|
||||
return img
|
||||
|
||||
def split_weighted_subprompts(text):
|
||||
"""
|
||||
grabs all text up to the first occurrence of ':'
|
||||
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
|
||||
if ':' has no value defined, defaults to 1.0
|
||||
repeats until no text remaining
|
||||
"""
|
||||
remaining = len(text)
|
||||
prompts = []
|
||||
weights = []
|
||||
while remaining > 0:
|
||||
if ":" in text:
|
||||
idx = text.index(":") # first occurrence from start
|
||||
# grab up to index as sub-prompt
|
||||
prompt = text[:idx]
|
||||
remaining -= idx
|
||||
# remove from main text
|
||||
text = text[idx+1:]
|
||||
# find value for weight
|
||||
if " " in text:
|
||||
idx = text.index(" ") # first occurence
|
||||
else: # no space, read to end
|
||||
idx = len(text)
|
||||
if idx != 0:
|
||||
try:
|
||||
weight = float(text[:idx])
|
||||
except: # couldn't treat as float
|
||||
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
|
||||
weight = 1.0
|
||||
else: # no value found
|
||||
weight = 1.0
|
||||
# remove from main text
|
||||
remaining -= idx
|
||||
text = text[idx+1:]
|
||||
# append the sub-prompt and its weight
|
||||
prompts.append(prompt)
|
||||
weights.append(weight)
|
||||
else: # no : found
|
||||
if len(text) > 0: # there is still text though
|
||||
# take remainder as weight 1
|
||||
prompts.append(text)
|
||||
weights.append(1.0)
|
||||
remaining = 0
|
||||
return prompts, weights
|
||||
|
@ -192,6 +192,7 @@ class SetAppConfigRequest(BaseModel):
|
||||
ui_open_browser_on_start: bool = None
|
||||
listen_to_network: bool = None
|
||||
listen_port: int = None
|
||||
test_sd2: bool = None
|
||||
|
||||
@app.post('/app_config')
|
||||
async def setAppConfig(req : SetAppConfigRequest):
|
||||
@ -212,6 +213,8 @@ async def setAppConfig(req : SetAppConfigRequest):
|
||||
if 'net' not in config:
|
||||
config['net'] = {}
|
||||
config['net']['listen_port'] = int(req.listen_port)
|
||||
if req.test_sd2 is not None:
|
||||
config['test_sd2'] = req.test_sd2
|
||||
try:
|
||||
setConfig(config)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user