import os, re import traceback import torch import numpy as np from omegaconf import OmegaConf from PIL import Image from tqdm import tqdm, trange from itertools import islice from einops import rearrange import time from pytorch_lightning import seed_everything from torch import autocast from contextlib import nullcontext from einops import rearrange, repeat from ldm.util import instantiate_from_config from optimizedSD.optimUtils import split_weighted_subprompts from transformers import logging import uuid logging.set_verbosity_error() # consts config_yaml = "optimizedSD/v1-inference.yaml" # api stuff from . import Request, Response, Image as ResponseImage import base64 from io import BytesIO # local session_id = str(uuid.uuid4())[-8:] ckpt = None model = None modelCS = None modelFS = None model_is_half = False model_fs_is_half = False device = None unet_bs = 1 precision = 'autocast' sampler_plms = None sampler_ddim = None # api def load_model(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast', half_model_fs=False): global ckpt, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half ckpt = ckpt_to_use device = device_to_use precision = precision_to_use unet_bs = unet_bs_to_use sd = load_model_from_config(f"{ckpt}") li, lo = [], [] for key, value in sd.items(): sp = key.split(".") if (sp[0]) == "model": if "input_blocks" in sp: li.append(key) elif "middle_block" in sp: li.append(key) elif "time_embed" in sp: li.append(key) else: lo.append(key) for key in li: sd["model1." + key[6:]] = sd.pop(key) for key in lo: sd["model2." + key[6:]] = sd.pop(key) config = OmegaConf.load(f"{config_yaml}") model = instantiate_from_config(config.modelUNet) _, _ = model.load_state_dict(sd, strict=False) model.eval() model.cdevice = device model.unet_bs = unet_bs model.turbo = turbo modelCS = instantiate_from_config(config.modelCondStage) _, _ = modelCS.load_state_dict(sd, strict=False) modelCS.eval() modelCS.cond_stage_model.device = device modelFS = instantiate_from_config(config.modelFirstStage) _, _ = modelFS.load_state_dict(sd, strict=False) modelFS.eval() del sd if device != "cpu" and precision == "autocast": model.half() modelCS.half() model_is_half = True else: model_is_half = False if half_model_fs: modelFS.half() model_fs_is_half = True else: model_fs_is_half = False def mk_img(req: Request): global modelFS, device res = Response() res.images = [] model.turbo = req.turbo if req.use_cpu: device = 'cpu' if model_is_half: print('reloading model for cpu') load_model(ckpt, device) else: device = 'cuda' if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \ (precision == 'full' and not req.use_full_precision) or \ (req.init_image is None and model_fs_is_half) or \ (req.init_image is not None and not model_fs_is_half): print('reloading model for cuda') load_model(ckpt, device, model.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'), half_model_fs=(req.init_image is not None and not req.use_full_precision)) model.cdevice = device modelCS.cond_stage_model.device = device opt_prompt = req.prompt opt_seed = req.seed opt_n_samples = req.num_outputs opt_n_iter = 1 opt_scale = req.guidance_scale opt_C = 4 opt_H = req.height opt_W = req.width opt_f = 8 opt_ddim_steps = req.num_inference_steps opt_ddim_eta = 0.0 opt_strength = req.prompt_strength opt_save_to_disk_path = req.save_to_disk_path opt_init_img = req.init_image opt_format = 'png' print(req.to_string(), '\n device', device) seed_everything(opt_seed) batch_size = opt_n_samples prompt = opt_prompt assert prompt is not None data = [batch_size * [prompt]] if precision == "autocast" and device != "cpu": precision_scope = autocast else: precision_scope = nullcontext if req.init_image is None: handler = _txt2img init_latent = None t_enc = None else: handler = _img2img init_image = load_img(req.init_image) init_image = init_image.to(device) if device != "cpu" and precision == "autocast": init_image = init_image.half() modelFS.to(device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space if device != "cpu": mem = torch.cuda.memory_allocated() / 1e6 modelFS.to("cpu") while torch.cuda.memory_allocated() / 1e6 >= mem: time.sleep(1) assert 0. <= opt_strength <= 1., 'can only work with strength in [0.0, 1.0]' t_enc = int(opt_strength * opt_ddim_steps) print(f"target t_enc is {t_enc} steps") if opt_save_to_disk_path is not None: session_out_path = os.path.join(opt_save_to_disk_path, session_id) os.makedirs(session_out_path, exist_ok=True) else: session_out_path = None seeds = "" with torch.no_grad(): for n in trange(opt_n_iter, desc="Sampling"): for prompts in tqdm(data, desc="data"): with precision_scope("cuda"): modelCS.to(device) uc = None if opt_scale != 1.0: uc = modelCS.get_learned_conditioning(batch_size * [""]) if isinstance(prompts, tuple): prompts = list(prompts) subprompts, weights = split_weighted_subprompts(prompts[0]) if len(subprompts) > 1: c = torch.zeros_like(uc) totalWeight = sum(weights) # normalize each "sub prompt" and add it for i in range(len(subprompts)): weight = weights[i] # if not skip_normalize: weight = weight / totalWeight c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) else: c = modelCS.get_learned_conditioning(prompts) # run the handler if handler == _txt2img: x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed) else: x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed) modelFS.to(device) print("saving images") for i in range(batch_size): x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") img = Image.fromarray(x_sample.astype(np.uint8)) img_data = img_to_base64_str(img) res.images.append(ResponseImage(data=img_data, seed=opt_seed)) if opt_save_to_disk_path is not None: try: prompt_flattened = "_".join(re.split(":| ", prompts[0])) prompt_flattened = prompt_flattened.replace(',', '') prompt_flattened = prompt_flattened[:50] img_id = str(uuid.uuid4())[-8:] file_path = f"{prompt_flattened}_{img_id}" img_out_path = os.path.join(session_out_path, f"{file_path}.{opt_format}") meta_out_path = os.path.join(session_out_path, f"{file_path}.txt") metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}" img.save(img_out_path) with open(meta_out_path, 'w') as f: f.write(metadata) except: print('could not save the file', traceback.format_exc()) seeds += str(opt_seed) + "," opt_seed += 1 if device != "cpu": mem = torch.cuda.memory_allocated() / 1e6 modelFS.to("cpu") while torch.cuda.memory_allocated() / 1e6 >= mem: time.sleep(1) del x_samples print("memory_final = ", torch.cuda.memory_allocated() / 1e6) return res def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed): shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] if device != "cpu": mem = torch.cuda.memory_allocated() / 1e6 modelCS.to("cpu") while torch.cuda.memory_allocated() / 1e6 >= mem: time.sleep(1) samples_ddim = model.sample( S=opt_ddim_steps, conditioning=c, seed=opt_seed, shape=shape, verbose=False, unconditional_guidance_scale=opt_scale, unconditional_conditioning=uc, eta=opt_ddim_eta, x_T=start_code, sampler = 'plms', ) return samples_ddim def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed): # encode (scaled latent) z_enc = model.stochastic_encode( init_latent, torch.tensor([t_enc] * batch_size).to(device), opt_seed, opt_ddim_eta, opt_ddim_steps, ) # decode it samples_ddim = model.sample( t_enc, c, z_enc, unconditional_guidance_scale=opt_scale, unconditional_conditioning=uc, sampler = 'ddim' ) return samples_ddim # internal def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) def load_model_from_config(ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] return sd # utils def load_img(img_str): image = base64_str_to_img(img_str).convert("RGB") w, h = image.size print(f"loaded input image of size ({w}, {h}) from base64") w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 image = image.resize((w, h), resample=Image.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.*image - 1. # https://stackoverflow.com/a/61114178 def img_to_base64_str(img): buffered = BytesIO() img.save(buffered, format="PNG") buffered.seek(0) img_byte = buffered.getvalue() img_str = "data:image/png;base64," + base64.b64encode(img_byte).decode() return img_str def base64_str_to_img(img_str): img_str = img_str[len("data:image/png;base64,"):] data = base64.b64decode(img_str) buffered = BytesIO(data) img = Image.open(buffered) return img