"""runtime.py: torch device owned by a thread. Notes: Avoid device switching, transfering all models will get too complex. To use a diffrent device signal the current render device to exit And then start a new clean thread for the new device. """ import json import os, re import traceback import torch import numpy as np from gc import collect as gc_collect from omegaconf import OmegaConf from PIL import Image, ImageOps from tqdm import tqdm, trange from itertools import islice from einops import rearrange import time from pytorch_lightning import seed_everything from torch import autocast from contextlib import nullcontext from einops import rearrange, repeat from ldm.util import instantiate_from_config from optimizedSD.optimUtils import split_weighted_subprompts from transformers import logging from gfpgan import GFPGANer from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer import uuid logging.set_verbosity_error() # consts config_yaml = "optimizedSD/v1-inference.yaml" filename_regex = re.compile('[^a-zA-Z0-9]') force_gfpgan_to_cuda0 = True # workaround: gfpgan currently works only on cuda:0 # api stuff from sd_internal import device_manager from . import Request, Response, Image as ResponseImage import base64 from io import BytesIO #from colorama import Fore from threading import local as LocalThreadVars thread_data = LocalThreadVars() def thread_init(device): # Thread bound properties thread_data.stop_processing = False thread_data.temp_images = {} thread_data.ckpt_file = None thread_data.vae_file = None thread_data.gfpgan_file = None thread_data.real_esrgan_file = None thread_data.model = None thread_data.modelCS = None thread_data.modelFS = None thread_data.model_gfpgan = None thread_data.model_real_esrgan = None thread_data.model_is_half = False thread_data.model_fs_is_half = False thread_data.device = None thread_data.device_name = None thread_data.unet_bs = 1 thread_data.precision = 'autocast' thread_data.sampler_plms = None thread_data.sampler_ddim = None thread_data.turbo = False thread_data.force_full_precision = False thread_data.reduced_memory = True device_manager.device_init(thread_data, device) def load_model_ckpt(): if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.') if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt') if not thread_data.precision: thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast' if not thread_data.unet_bs: thread_data.unet_bs = 1 if thread_data.device == 'cpu': thread_data.precision = 'full' print('loading', thread_data.ckpt_file + '.ckpt', 'to device', thread_data.device, 'using precision', thread_data.precision) sd = load_model_from_config(thread_data.ckpt_file + '.ckpt') li, lo = [], [] for key, value in sd.items(): sp = key.split(".") if (sp[0]) == "model": if "input_blocks" in sp: li.append(key) elif "middle_block" in sp: li.append(key) elif "time_embed" in sp: li.append(key) else: lo.append(key) for key in li: sd["model1." + key[6:]] = sd.pop(key) for key in lo: sd["model2." + key[6:]] = sd.pop(key) config = OmegaConf.load(f"{config_yaml}") model = instantiate_from_config(config.modelUNet) _, _ = model.load_state_dict(sd, strict=False) model.eval() model.cdevice = torch.device(thread_data.device) model.unet_bs = thread_data.unet_bs model.turbo = thread_data.turbo # if thread_data.device != 'cpu': # model.to(thread_data.device) #if thread_data.reduced_memory: #model.model1.to("cpu") #model.model2.to("cpu") thread_data.model = model modelCS = instantiate_from_config(config.modelCondStage) _, _ = modelCS.load_state_dict(sd, strict=False) modelCS.eval() modelCS.cond_stage_model.device = torch.device(thread_data.device) # if thread_data.device != 'cpu': # if thread_data.reduced_memory: # modelCS.to('cpu') # else: # modelCS.to(thread_data.device) # Preload on device if not already there. thread_data.modelCS = modelCS modelFS = instantiate_from_config(config.modelFirstStage) _, _ = modelFS.load_state_dict(sd, strict=False) if thread_data.vae_file is not None: try: loaded = False for model_extension in ['.ckpt', '.vae.pt']: if os.path.exists(thread_data.vae_file + model_extension): print(f"Loading VAE weights from: {thread_data.vae_file}{model_extension}") vae_ckpt = torch.load(thread_data.vae_file + model_extension, map_location="cpu") vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} modelFS.first_stage_model.load_state_dict(vae_dict, strict=False) loaded = True break if not loaded: print(f'Cannot find VAE: {thread_data.vae_file}') thread_data.vae_file = None except: print(traceback.format_exc()) print(f'Could not load VAE: {thread_data.vae_file}') thread_data.vae_file = None modelFS.eval() # if thread_data.device != 'cpu': # if thread_data.reduced_memory: # modelFS.to('cpu') # else: # modelFS.to(thread_data.device) # Preload on device if not already there. thread_data.modelFS = modelFS del sd if thread_data.device != "cpu" and thread_data.precision == "autocast": thread_data.model.half() thread_data.modelCS.half() thread_data.modelFS.half() thread_data.model_is_half = True thread_data.model_fs_is_half = True else: thread_data.model_is_half = False thread_data.model_fs_is_half = False print(f'''loaded model model file: {thread_data.ckpt_file}.ckpt model.device: {model.device} modelCS.device: {modelCS.cond_stage_model.device} modelFS.device: {thread_data.modelFS.device} using precision: {thread_data.precision}''') def unload_filters(): if thread_data.model_gfpgan is not None: if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu') del thread_data.model_gfpgan thread_data.model_gfpgan = None if thread_data.model_real_esrgan is not None: if thread_data.device != 'cpu': thread_data.model_real_esrgan.model.to('cpu') del thread_data.model_real_esrgan thread_data.model_real_esrgan = None gc() def unload_models(): if thread_data.model is not None: print('Unloading models...') if thread_data.device != 'cpu': thread_data.modelFS.to('cpu') thread_data.modelCS.to('cpu') thread_data.model.model1.to("cpu") thread_data.model.model2.to("cpu") del thread_data.model del thread_data.modelCS del thread_data.modelFS thread_data.model = None thread_data.modelCS = None thread_data.modelFS = None gc() # def wait_model_move_to(model, target_device): # Send to target_device and wait until complete. # if thread_data.device == target_device: return # start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 # if start_mem <= 0: return # model_name = model.__class__.__name__ # print(f'Device {thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mb') # start_time = time.time() # model.to(target_device) # time_step = start_time # WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout. # last_mem = start_mem # is_transfering = True # while is_transfering: # time.sleep(0.5) # 500ms # mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 # is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time. # last_mem = mem # if not is_transfering: # break; # if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity. # print(f'Device {thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mb, Transfered: {round(start_mem - mem)}Mb') # time_step = time.time() # print(f'Device {thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mb in {round(time.time() - start_time, 3)} seconds to {target_device}') def move_to_cpu(model): if thread_data.device != "cpu": d = torch.device(thread_data.device) mem = torch.cuda.memory_allocated(d) / 1e6 model.to("cpu") while torch.cuda.memory_allocated(d) / 1e6 >= mem: time.sleep(1) def load_model_gfpgan(): if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.') # hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files from facexlib.detection import retinaface retinaface.device = torch.device(thread_data.device) print('forced retinaface.device to', thread_data.device) model_path = thread_data.gfpgan_file + ".pth" thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision) def load_model_real_esrgan(): if thread_data.real_esrgan_file is None: raise ValueError(f'Thread real_esrgan_file is undefined.') model_path = thread_data.real_esrgan_file + ".pth" RealESRGAN_models = { 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) } model_to_use = RealESRGAN_models[thread_data.real_esrgan_file] if thread_data.device == 'cpu': thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half #thread_data.model_real_esrgan.device = torch.device(thread_data.device) thread_data.model_real_esrgan.model.to('cpu') else: thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half) thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision) def get_session_out_path(disk_path, session_id): if disk_path is None: return None if session_id is None: return None session_out_path = os.path.join(disk_path, filename_regex.sub('_',session_id)) os.makedirs(session_out_path, exist_ok=True) return session_out_path def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None): if disk_path is None: return None if session_id is None: return None if ext is None: raise Exception('Missing ext') session_out_path = get_session_out_path(disk_path, session_id) prompt_flattened = filename_regex.sub('_', prompt)[:50] if suffix is not None: return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}") return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}") def apply_filters(filter_name, image_data, model_path=None): print(f'Applying filter {filter_name}...') gc() # Free space before loading new data. if isinstance(image_data, torch.Tensor): image_data.to(thread_data.device) if filter_name == 'gfpgan': if model_path is not None and model_path != thread_data.gfpgan_file: thread_data.gfpgan_file = model_path load_model_gfpgan() elif not thread_data.model_gfpgan: load_model_gfpgan() if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.') print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision) _, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) image_data = output[:,:,::-1] if filter_name == 'real_esrgan': if model_path is not None and model_path != thread_data.real_esrgan_file: thread_data.real_esrgan_file = model_path load_model_real_esrgan() elif not thread_data.model_real_esrgan: load_model_real_esrgan() if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.') print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision) output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1]) image_data = output[:,:,::-1] return image_data def mk_img(req: Request): try: yield from do_mk_img(req) except Exception as e: print(traceback.format_exc()) if thread_data.device != 'cpu': thread_data.modelFS.to('cpu') thread_data.modelCS.to('cpu') thread_data.model.model1.to("cpu") thread_data.model.model2.to("cpu") gc() # Release from memory. yield json.dumps({ "status": 'failed', "detail": str(e) }) def update_temp_img(req, x_samples): partial_images = [] for i in range(req.num_outputs): x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = x_sample.astype(np.uint8) img = Image.fromarray(x_sample) buf = BytesIO() img.save(buf, format='JPEG') buf.seek(0) del img, x_sample, x_sample_ddim # don't delete x_samples, it is used in the code that called this callback thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'}) return partial_images # Build and return the apropriate generator for do_mk_img def get_image_progress_generator(req, extra_props=None): if not req.stream_progress_updates: def empty_callback(x_samples, i): return x_samples return empty_callback thread_data.partial_x_samples = None last_callback_time = -1 def img_callback(x_samples, i): nonlocal last_callback_time thread_data.partial_x_samples = x_samples step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 last_callback_time = time.time() progress = {"step": i, "step_time": step_time} if extra_props is not None: progress.update(extra_props) if req.stream_image_progress and i % 5 == 0: progress['output'] = update_temp_img(req, x_samples) yield json.dumps(progress) if thread_data.stop_processing: raise UserInitiatedStop("User requested that we stop processing") return img_callback def do_mk_img(req: Request): thread_data.stop_processing = False res = Response() res.request = req res.images = [] thread_data.temp_images.clear() # custom model support: # the req.use_stable_diffusion_model needs to be a valid path # to the ckpt file (without the extension). if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt') needs_model_reload = False if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: thread_data.ckpt_file = req.use_stable_diffusion_model thread_data.vae_file = req.use_vae_model needs_model_reload = True if thread_data.device != 'cpu': if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): thread_data.precision = 'full' if req.use_full_precision else 'autocast' needs_model_reload = True if needs_model_reload: unload_models() unload_filters() load_model_ckpt() if thread_data.turbo != req.turbo: thread_data.turbo = req.turbo thread_data.model.turbo = req.turbo # Start by cleaning memory, loading and unloading things can leave memory allocated. gc() opt_prompt = req.prompt opt_seed = req.seed opt_n_iter = 1 opt_C = 4 opt_f = 8 opt_ddim_eta = 0.0 print(req, '\n device', torch.device(thread_data.device), "as", thread_data.device_name) print('\n\n Using precision:', thread_data.precision) seed_everything(opt_seed) batch_size = req.num_outputs prompt = opt_prompt assert prompt is not None data = [batch_size * [prompt]] if thread_data.precision == "autocast" and thread_data.device != "cpu": precision_scope = autocast else: precision_scope = nullcontext mask = None if req.init_image is None: handler = _txt2img init_latent = None t_enc = None else: handler = _img2img init_image = load_img(req.init_image, req.width, req.height) init_image = init_image.to(thread_data.device) if thread_data.device != "cpu" and thread_data.precision == "autocast": init_image = init_image.half() thread_data.modelFS.to(thread_data.device) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space if req.mask is not None: mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device) mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0) mask = repeat(mask, '1 ... -> b ...', b=batch_size) if thread_data.device != "cpu" and thread_data.precision == "autocast": mask = mask.half() # Send to CPU and wait until complete. # wait_model_move_to(thread_data.modelFS, 'cpu') move_to_cpu(thread_data.modelFS) assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]' t_enc = int(req.prompt_strength * req.num_inference_steps) print(f"target t_enc is {t_enc} steps") if req.save_to_disk_path is not None: session_out_path = get_session_out_path(req.save_to_disk_path, req.session_id) else: session_out_path = None with torch.no_grad(): for n in trange(opt_n_iter, desc="Sampling"): for prompts in tqdm(data, desc="data"): with precision_scope("cuda"): if thread_data.reduced_memory: thread_data.modelCS.to(thread_data.device) uc = None if req.guidance_scale != 1.0: uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) if isinstance(prompts, tuple): prompts = list(prompts) subprompts, weights = split_weighted_subprompts(prompts[0]) if len(subprompts) > 1: c = torch.zeros_like(uc) totalWeight = sum(weights) # normalize each "sub prompt" and add it for i in range(len(subprompts)): weight = weights[i] # if not skip_normalize: weight = weight / totalWeight c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) else: c = thread_data.modelCS.get_learned_conditioning(prompts) if thread_data.reduced_memory: thread_data.modelFS.to(thread_data.device) n_steps = req.num_inference_steps if req.init_image is None else t_enc img_callback = get_image_progress_generator(req, {"total_steps": n_steps}) # run the handler try: print('Running handler...') if handler == _txt2img: x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler) else: x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask) if req.stream_progress_updates: yield from x_samples if hasattr(thread_data, 'partial_x_samples'): if thread_data.partial_x_samples is not None: x_samples = thread_data.partial_x_samples del thread_data.partial_x_samples except UserInitiatedStop: if not hasattr(thread_data, 'partial_x_samples'): continue if thread_data.partial_x_samples is None: del thread_data.partial_x_samples continue x_samples = thread_data.partial_x_samples del thread_data.partial_x_samples print("decoding images") img_data = [None] * batch_size for i in range(batch_size): x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = x_sample.astype(np.uint8) img_data[i] = x_sample del x_samples, x_samples_ddim, x_sample print("saving images") for i in range(batch_size): img = Image.fromarray(img_data[i]) img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \ (req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN')) return_orig_img = not has_filters or not req.show_only_filtered_image if thread_data.stop_processing: return_orig_img = True if req.save_to_disk_path is not None: if return_orig_img: img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format) save_image(img, img_out_path) meta_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, 'txt') save_metadata(meta_out_path, req, prompts[0], opt_seed) if return_orig_img: img_str = img_to_base64_str(img, req.output_format) res_image_orig = ResponseImage(data=img_str, seed=opt_seed) res.images.append(res_image_orig) if req.save_to_disk_path is not None: res_image_orig.path_abs = img_out_path del img if has_filters and not thread_data.stop_processing: filters_applied = [] if req.use_face_correction: img_data[i] = apply_filters('gfpgan', img_data[i], req.use_face_correction) filters_applied.append(req.use_face_correction) if req.use_upscale: img_data[i] = apply_filters('real_esrgan', img_data[i], req.use_upscale) filters_applied.append(req.use_upscale) if (len(filters_applied) > 0): filtered_image = Image.fromarray(img_data[i]) filtered_img_data = img_to_base64_str(filtered_image, req.output_format) response_image = ResponseImage(data=filtered_img_data, seed=opt_seed) res.images.append(response_image) if req.save_to_disk_path is not None: filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied)) save_image(filtered_image, filtered_img_out_path) response_image.path_abs = filtered_img_out_path del filtered_image # Filter Applied, move to next seed opt_seed += 1 # if thread_data.reduced_memory: # unload_filters() move_to_cpu(thread_data.modelFS) del img_data gc() if thread_data.device != 'cpu': print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb') print('Task completed') yield json.dumps(res.json()) def save_image(img, img_out_path): try: img.save(img_out_path) except: print('could not save the file', traceback.format_exc()) def save_metadata(meta_out_path, req, prompt, opt_seed): metadata = f'''{prompt} Width: {req.width} Height: {req.height} Seed: {opt_seed} Steps: {req.num_inference_steps} Guidance Scale: {req.guidance_scale} Prompt Strength: {req.prompt_strength} Use Face Correction: {req.use_face_correction} Use Upscaling: {req.use_upscale} Sampler: {req.sampler} Negative Prompt: {req.negative_prompt} Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'} VAE model: {req.use_vae_model} ''' try: with open(meta_out_path, 'w', encoding='utf-8') as f: f.write(metadata) except: print('could not save the file', traceback.format_exc()) def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name): shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] # Send to CPU and wait until complete. # wait_model_move_to(thread_data.modelCS, 'cpu') move_to_cpu(thread_data.modelCS) if sampler_name == 'ddim': thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) samples_ddim = thread_data.model.sample( S=opt_ddim_steps, conditioning=c, seed=opt_seed, shape=shape, verbose=False, unconditional_guidance_scale=opt_scale, unconditional_conditioning=uc, eta=opt_ddim_eta, x_T=start_code, img_callback=img_callback, mask=mask, sampler = sampler_name, ) yield from samples_ddim def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask): # encode (scaled latent) z_enc = thread_data.model.stochastic_encode( init_latent, torch.tensor([t_enc] * batch_size).to(thread_data.device), opt_seed, opt_ddim_eta, opt_ddim_steps, ) x_T = None if mask is None else init_latent # decode it samples_ddim = thread_data.model.sample( t_enc, c, z_enc, unconditional_guidance_scale=opt_scale, unconditional_conditioning=uc, img_callback=img_callback, mask=mask, x_T=x_T, sampler = 'ddim' ) yield from samples_ddim def gc(): gc_collect() if thread_data.device == 'cpu': return torch.cuda.empty_cache() torch.cuda.ipc_collect() # internal def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) def load_model_from_config(ckpt, verbose=False): print(f"Loading model from {ckpt}") pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] return sd # utils class UserInitiatedStop(Exception): pass def load_img(img_str, w0, h0): image = base64_str_to_img(img_str).convert("RGB") w, h = image.size print(f"loaded input image of size ({w}, {h}) from base64") if h0 is not None and w0 is not None: h, w = h0, w0 w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 image = image.resize((w, h), resample=Image.Resampling.LANCZOS) image = np.array(image).astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return 2.*image - 1. def load_mask(mask_str, h0, w0, newH, newW, invert=False): image = base64_str_to_img(mask_str).convert("RGB") w, h = image.size print(f"loaded input mask of size ({w}, {h})") if invert: print("inverted") image = ImageOps.invert(image) # where_0, where_1 = np.where(image == 0), np.where(image == 255) # image[where_0], image[where_1] = 255, 0 if h0 is not None and w0 is not None: h, w = h0, w0 w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 print(f"New mask size ({w}, {h})") image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS) image = np.array(image) image = image.astype(np.float32) / 255.0 image = image[None].transpose(0, 3, 1, 2) image = torch.from_numpy(image) return image # https://stackoverflow.com/a/61114178 def img_to_base64_str(img, output_format="PNG"): buffered = BytesIO() img.save(buffered, format=output_format) buffered.seek(0) img_byte = buffered.getvalue() mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg" img_str = f"data:{mime_type};base64," + base64.b64encode(img_byte).decode() return img_str def base64_str_to_buffer(img_str): mime_type = "image/png" if img_str.startswith("data:image/png;") else "image/jpeg" img_str = img_str[len(f"data:{mime_type};base64,"):] data = base64.b64decode(img_str) buffered = BytesIO(data) return buffered def base64_str_to_img(img_str): buffered = base64_str_to_buffer(img_str) img = Image.open(buffered) return img