mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-08-13 17:57:20 +02:00
Keep v2 files in the repo, for the updater
This commit is contained in:
342
ui/sd_internal/runtime.py
Normal file
342
ui/sd_internal/runtime.py
Normal file
@ -0,0 +1,342 @@
|
||||
import os, re
|
||||
import torch
|
||||
import numpy as np
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from tqdm import tqdm, trange
|
||||
from itertools import islice
|
||||
from einops import rearrange
|
||||
import time
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import autocast
|
||||
from contextlib import nullcontext
|
||||
from einops import rearrange, repeat
|
||||
from ldm.util import instantiate_from_config
|
||||
from optimizedSD.optimUtils import split_weighted_subprompts
|
||||
from transformers import logging
|
||||
|
||||
logging.set_verbosity_error()
|
||||
|
||||
# consts
|
||||
config_yaml = "optimizedSD/v1-inference.yaml"
|
||||
|
||||
# api stuff
|
||||
from . import Request, Response, Image as ResponseImage
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
# local
|
||||
ckpt = None
|
||||
model = None
|
||||
modelCS = None
|
||||
modelFS = None
|
||||
model_is_half = False
|
||||
model_fs_is_half = False
|
||||
device = None
|
||||
unet_bs = 1
|
||||
precision = 'autocast'
|
||||
sampler_plms = None
|
||||
sampler_ddim = None
|
||||
|
||||
# api
|
||||
def load_model(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast', half_model_fs=False):
|
||||
global ckpt, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half
|
||||
|
||||
ckpt = ckpt_to_use
|
||||
device = device_to_use
|
||||
precision = precision_to_use
|
||||
unet_bs = unet_bs_to_use
|
||||
|
||||
sd = load_model_from_config(f"{ckpt}")
|
||||
li, lo = [], []
|
||||
for key, value in sd.items():
|
||||
sp = key.split(".")
|
||||
if (sp[0]) == "model":
|
||||
if "input_blocks" in sp:
|
||||
li.append(key)
|
||||
elif "middle_block" in sp:
|
||||
li.append(key)
|
||||
elif "time_embed" in sp:
|
||||
li.append(key)
|
||||
else:
|
||||
lo.append(key)
|
||||
for key in li:
|
||||
sd["model1." + key[6:]] = sd.pop(key)
|
||||
for key in lo:
|
||||
sd["model2." + key[6:]] = sd.pop(key)
|
||||
|
||||
config = OmegaConf.load(f"{config_yaml}")
|
||||
|
||||
model = instantiate_from_config(config.modelUNet)
|
||||
_, _ = model.load_state_dict(sd, strict=False)
|
||||
model.eval()
|
||||
model.cdevice = device
|
||||
model.unet_bs = unet_bs
|
||||
model.turbo = turbo
|
||||
|
||||
modelCS = instantiate_from_config(config.modelCondStage)
|
||||
_, _ = modelCS.load_state_dict(sd, strict=False)
|
||||
modelCS.eval()
|
||||
modelCS.cond_stage_model.device = device
|
||||
|
||||
modelFS = instantiate_from_config(config.modelFirstStage)
|
||||
_, _ = modelFS.load_state_dict(sd, strict=False)
|
||||
modelFS.eval()
|
||||
del sd
|
||||
|
||||
if device != "cpu" and precision == "autocast":
|
||||
model.half()
|
||||
modelCS.half()
|
||||
model_is_half = True
|
||||
else:
|
||||
model_is_half = False
|
||||
|
||||
if half_model_fs:
|
||||
modelFS.half()
|
||||
model_fs_is_half = True
|
||||
else:
|
||||
model_fs_is_half = False
|
||||
|
||||
def mk_img(req: Request):
|
||||
global modelFS, device
|
||||
|
||||
res = Response()
|
||||
res.images = []
|
||||
|
||||
model.turbo = req.turbo
|
||||
if req.use_cpu:
|
||||
device = 'cpu'
|
||||
|
||||
if model_is_half:
|
||||
print('reloading model for cpu')
|
||||
load_model(ckpt, device)
|
||||
else:
|
||||
device = 'cuda'
|
||||
|
||||
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
|
||||
(precision == 'full' and not req.use_full_precision) or \
|
||||
(req.init_image is None and model_fs_is_half) or \
|
||||
(req.init_image is not None and not model_fs_is_half):
|
||||
|
||||
print('reloading model for cuda')
|
||||
load_model(ckpt, device, model.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'), half_model_fs=(req.init_image is not None and not req.use_full_precision))
|
||||
|
||||
model.cdevice = device
|
||||
modelCS.cond_stage_model.device = device
|
||||
|
||||
opt_prompt = req.prompt
|
||||
opt_seed = req.seed
|
||||
opt_n_samples = req.num_outputs
|
||||
opt_n_iter = 1
|
||||
opt_scale = req.guidance_scale
|
||||
opt_C = 4
|
||||
opt_H = req.height
|
||||
opt_W = req.width
|
||||
opt_f = 8
|
||||
opt_ddim_steps = req.num_inference_steps
|
||||
opt_ddim_eta = 0.0
|
||||
opt_strength = req.prompt_strength
|
||||
opt_save_to_disk_path = req.save_to_disk_path
|
||||
opt_init_img = req.init_image
|
||||
opt_format = 'png'
|
||||
|
||||
print(req.to_string(), 'device', device)
|
||||
|
||||
seed_everything(opt_seed)
|
||||
|
||||
batch_size = opt_n_samples
|
||||
prompt = opt_prompt
|
||||
assert prompt is not None
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
if precision == "autocast" and device != "cpu":
|
||||
precision_scope = autocast
|
||||
else:
|
||||
precision_scope = nullcontext
|
||||
|
||||
if req.init_image is None:
|
||||
handler = _txt2img
|
||||
|
||||
init_latent = None
|
||||
t_enc = None
|
||||
else:
|
||||
handler = _img2img
|
||||
|
||||
init_image = load_img(req.init_image)
|
||||
init_image = init_image.to(device)
|
||||
|
||||
if device != "cpu" and precision == "autocast":
|
||||
init_image = init_image.half()
|
||||
|
||||
modelFS.to(device)
|
||||
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
if device != "cpu":
|
||||
mem = torch.cuda.memory_allocated() / 1e6
|
||||
modelFS.to("cpu")
|
||||
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
|
||||
assert 0. <= opt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(opt_strength * opt_ddim_steps)
|
||||
print(f"target t_enc is {t_enc} steps")
|
||||
|
||||
seeds = ""
|
||||
with torch.no_grad():
|
||||
for n in trange(opt_n_iter, desc="Sampling"):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
|
||||
if opt_save_to_disk_path is not None:
|
||||
sample_path = os.path.join(opt_save_to_disk_path, "_".join(re.split(":| ", prompts[0])))[:150]
|
||||
os.makedirs(sample_path, exist_ok=True)
|
||||
base_count = len(os.listdir(sample_path))
|
||||
|
||||
with precision_scope("cuda"):
|
||||
modelCS.to(device)
|
||||
uc = None
|
||||
if opt_scale != 1.0:
|
||||
uc = modelCS.get_learned_conditioning(batch_size * [""])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
|
||||
subprompts, weights = split_weighted_subprompts(prompts[0])
|
||||
if len(subprompts) > 1:
|
||||
c = torch.zeros_like(uc)
|
||||
totalWeight = sum(weights)
|
||||
# normalize each "sub prompt" and add it
|
||||
for i in range(len(subprompts)):
|
||||
weight = weights[i]
|
||||
# if not skip_normalize:
|
||||
weight = weight / totalWeight
|
||||
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
else:
|
||||
c = modelCS.get_learned_conditioning(prompts)
|
||||
|
||||
# run the handler
|
||||
if handler == _txt2img:
|
||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed)
|
||||
else:
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed)
|
||||
|
||||
modelFS.to(device)
|
||||
|
||||
print("saving images")
|
||||
for i in range(batch_size):
|
||||
|
||||
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
img = Image.fromarray(x_sample.astype(np.uint8))
|
||||
|
||||
img_data = img_to_base64_str(img)
|
||||
res.images.append(ResponseImage(data=img_data, seed=opt_seed))
|
||||
|
||||
if opt_save_to_disk_path is not None:
|
||||
img.save(
|
||||
os.path.join(sample_path, "seed_" + str(opt_seed) + "_" + f"{base_count:05}.{opt_format}")
|
||||
)
|
||||
base_count += 1
|
||||
|
||||
seeds += str(opt_seed) + ","
|
||||
opt_seed += 1
|
||||
|
||||
if device != "cpu":
|
||||
mem = torch.cuda.memory_allocated() / 1e6
|
||||
modelFS.to("cpu")
|
||||
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
del x_samples
|
||||
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
|
||||
|
||||
return res
|
||||
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed):
|
||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
if device != "cpu":
|
||||
mem = torch.cuda.memory_allocated() / 1e6
|
||||
modelCS.to("cpu")
|
||||
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
|
||||
samples_ddim = model.sample(
|
||||
S=opt_ddim_steps,
|
||||
conditioning=c,
|
||||
seed=opt_seed,
|
||||
shape=shape,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=opt_scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=opt_ddim_eta,
|
||||
x_T=start_code,
|
||||
sampler = 'plms',
|
||||
)
|
||||
|
||||
return samples_ddim
|
||||
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed):
|
||||
# encode (scaled latent)
|
||||
z_enc = model.stochastic_encode(
|
||||
init_latent,
|
||||
torch.tensor([t_enc] * batch_size).to(device),
|
||||
opt_seed,
|
||||
opt_ddim_eta,
|
||||
opt_ddim_steps,
|
||||
)
|
||||
# decode it
|
||||
samples_ddim = model.sample(
|
||||
t_enc,
|
||||
c,
|
||||
z_enc,
|
||||
unconditional_guidance_scale=opt_scale,
|
||||
unconditional_conditioning=uc,
|
||||
sampler = 'ddim'
|
||||
)
|
||||
|
||||
return samples_ddim
|
||||
|
||||
# internal
|
||||
|
||||
def chunk(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def load_model_from_config(ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
return sd
|
||||
|
||||
# utils
|
||||
|
||||
def load_img(img_str):
|
||||
image = base64_str_to_img(img_str).convert("RGB")
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h}) from base64")
|
||||
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
image = image.resize((w, h), resample=Image.LANCZOS)
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
return 2.*image - 1.
|
||||
|
||||
# https://stackoverflow.com/a/61114178
|
||||
def img_to_base64_str(img):
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format="PNG")
|
||||
buffered.seek(0)
|
||||
img_byte = buffered.getvalue()
|
||||
img_str = "data:image/png;base64," + base64.b64encode(img_byte).decode()
|
||||
return img_str
|
||||
|
||||
def base64_str_to_img(img_str):
|
||||
img_str = img_str[len("data:image/png;base64,"):]
|
||||
data = base64.b64decode(img_str)
|
||||
buffered = BytesIO(data)
|
||||
img = Image.open(buffered)
|
||||
return img
|
Reference in New Issue
Block a user