easydiffusion/ui/sd_internal/runtime.py

1078 lines
43 KiB
Python
Raw Normal View History

"""runtime.py: torch device owned by a thread.
Notes:
Avoid device switching, transfering all models will get too complex.
To use a diffrent device signal the current render device to exit
And then start a new clean thread for the new device.
"""
import json
import os, re
import traceback
import queue
import torch
import numpy as np
2022-10-21 09:53:43 +02:00
from gc import collect as gc_collect
from omegaconf import OmegaConf
2022-09-15 14:24:03 +02:00
from PIL import Image, ImageOps
from tqdm import tqdm, trange
from itertools import islice
from einops import rearrange
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext
from einops import rearrange, repeat
from ldm.util import instantiate_from_config
from transformers import logging
from gfpgan import GFPGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from server import HYPERNETWORK_MODEL_EXTENSIONS# , STABLE_DIFFUSION_MODEL_EXTENSIONS, VAE_MODEL_EXTENSIONS
from threading import Lock
from safetensors.torch import load_file
import uuid
logging.set_verbosity_error()
# consts
config_yaml = "optimizedSD/v1-inference.yaml"
filename_regex = re.compile('[^a-zA-Z0-9]')
gfpgan_temp_device_lock = Lock() # workaround: gfpgan currently can only start on one device at a time.
# api stuff
from sd_internal import device_manager
from . import Request, Response, Image as ResponseImage
import base64
from io import BytesIO
2022-09-16 18:02:08 +02:00
#from colorama import Fore
2022-10-17 03:41:39 +02:00
from threading import local as LocalThreadVars
thread_data = LocalThreadVars()
def thread_init(device):
2022-10-17 03:41:39 +02:00
# Thread bound properties
thread_data.stop_processing = False
thread_data.temp_images = {}
thread_data.ckpt_file = None
thread_data.vae_file = None
thread_data.hypernetwork_file = None
2022-10-17 03:41:39 +02:00
thread_data.gfpgan_file = None
thread_data.real_esrgan_file = None
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
thread_data.hypernetwork = None
thread_data.hypernetwork_strength = 1
2022-10-17 03:41:39 +02:00
thread_data.model_gfpgan = None
thread_data.model_real_esrgan = None
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
thread_data.device = None
thread_data.device_name = None
2022-10-17 03:41:39 +02:00
thread_data.unet_bs = 1
thread_data.precision = 'autocast'
thread_data.sampler_plms = None
thread_data.sampler_ddim = None
2022-10-17 03:41:39 +02:00
thread_data.turbo = False
thread_data.force_full_precision = False
2022-10-21 09:53:43 +02:00
thread_data.reduced_memory = True
2022-10-17 03:41:39 +02:00
thread_data.test_sd2 = isSD2()
device_manager.device_init(thread_data, device)
2022-10-17 03:41:39 +02:00
# temp hack, will remove soon
def isSD2():
try:
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return False
with open(config_json_path, 'r', encoding='utf-8') as f:
config = json.load(f)
return config.get('test_sd2', False)
except Exception as e:
return False
2022-10-17 03:41:39 +02:00
def load_model_ckpt():
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
if os.path.exists(thread_data.ckpt_file + '.ckpt'):
thread_data.ckpt_file += '.ckpt'
elif os.path.exists(thread_data.ckpt_file + '.safetensors'):
thread_data.ckpt_file += '.safetensors'
elif not os.path.exists(thread_data.ckpt_file):
raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt or .safetensors')
2022-10-17 03:41:39 +02:00
if not thread_data.precision:
thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast'
2022-10-23 01:02:02 +02:00
2022-10-17 03:41:39 +02:00
if not thread_data.unet_bs:
thread_data.unet_bs = 1
2022-10-17 03:41:39 +02:00
if thread_data.device == 'cpu':
thread_data.precision = 'full'
print('loading', thread_data.ckpt_file, 'to device', thread_data.device, 'using precision', thread_data.precision)
if thread_data.test_sd2:
load_model_ckpt_sd2()
else:
load_model_ckpt_sd1()
def load_model_ckpt_sd1():
sd = load_model_from_config(thread_data.ckpt_file)
li, lo = [], []
for key, value in sd.items():
sp = key.split(".")
if (sp[0]) == "model":
if "input_blocks" in sp:
li.append(key)
elif "middle_block" in sp:
li.append(key)
elif "time_embed" in sp:
li.append(key)
else:
lo.append(key)
for key in li:
sd["model1." + key[6:]] = sd.pop(key)
for key in lo:
sd["model2." + key[6:]] = sd.pop(key)
config = OmegaConf.load(f"{config_yaml}")
model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()
2022-10-17 03:41:39 +02:00
model.cdevice = torch.device(thread_data.device)
model.unet_bs = thread_data.unet_bs
model.turbo = thread_data.turbo
# if thread_data.device != 'cpu':
# model.to(thread_data.device)
2022-10-23 01:02:02 +02:00
#if thread_data.reduced_memory:
#model.model1.to("cpu")
#model.model2.to("cpu")
2022-10-17 03:41:39 +02:00
thread_data.model = model
modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()
2022-10-17 03:41:39 +02:00
modelCS.cond_stage_model.device = torch.device(thread_data.device)
# if thread_data.device != 'cpu':
# if thread_data.reduced_memory:
# modelCS.to('cpu')
# else:
# modelCS.to(thread_data.device) # Preload on device if not already there.
2022-10-17 03:41:39 +02:00
thread_data.modelCS = modelCS
modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
if thread_data.vae_file is not None:
try:
loaded = False
for model_extension in ['.ckpt', '.vae.pt']:
if os.path.exists(thread_data.vae_file + model_extension):
print(f"Loading VAE weights from: {thread_data.vae_file}{model_extension}")
vae_ckpt = torch.load(thread_data.vae_file + model_extension, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
modelFS.first_stage_model.load_state_dict(vae_dict, strict=False)
loaded = True
break
if not loaded:
print(f'Cannot find VAE: {thread_data.vae_file}')
thread_data.vae_file = None
except:
print(traceback.format_exc())
print(f'Could not load VAE: {thread_data.vae_file}')
thread_data.vae_file = None
modelFS.eval()
# if thread_data.device != 'cpu':
# if thread_data.reduced_memory:
# modelFS.to('cpu')
# else:
# modelFS.to(thread_data.device) # Preload on device if not already there.
2022-10-17 03:41:39 +02:00
thread_data.modelFS = modelFS
del sd
2022-10-17 03:41:39 +02:00
if thread_data.device != "cpu" and thread_data.precision == "autocast":
thread_data.model.half()
thread_data.modelCS.half()
thread_data.modelFS.half()
thread_data.model_is_half = True
thread_data.model_fs_is_half = True
else:
2022-10-17 03:41:39 +02:00
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
print(f'''loaded model
model file: {thread_data.ckpt_file}
model.device: {model.device}
modelCS.device: {modelCS.cond_stage_model.device}
modelFS.device: {thread_data.modelFS.device}
using precision: {thread_data.precision}''')
def load_model_ckpt_sd2():
config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if 'sd2_' in thread_data.ckpt_file else "configs/stable-diffusion/v1-inference.yaml"
config = OmegaConf.load(config_file)
verbose = False
sd = load_model_from_config(thread_data.ckpt_file)
thread_data.model = instantiate_from_config(config.model)
m, u = thread_data.model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
thread_data.model.to(thread_data.device)
thread_data.model.eval()
del sd
thread_data.model.cond_stage_model.device = torch.device(thread_data.device)
if thread_data.device != "cpu" and thread_data.precision == "autocast":
thread_data.model.half()
thread_data.model_is_half = True
thread_data.model_fs_is_half = True
else:
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
print(f'''loaded model
model file: {thread_data.ckpt_file}
using precision: {thread_data.precision}''')
2022-10-21 09:53:43 +02:00
def unload_filters():
if thread_data.model_gfpgan is not None:
if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu')
2022-10-21 09:53:43 +02:00
del thread_data.model_gfpgan
thread_data.model_gfpgan = None
if thread_data.model_real_esrgan is not None:
if thread_data.device != 'cpu': thread_data.model_real_esrgan.model.to('cpu')
2022-10-21 09:53:43 +02:00
del thread_data.model_real_esrgan
thread_data.model_real_esrgan = None
gc()
2022-10-21 09:53:43 +02:00
def unload_models():
2022-10-17 03:41:39 +02:00
if thread_data.model is not None:
print('Unloading models...')
if thread_data.device != 'cpu':
if not thread_data.test_sd2:
thread_data.modelFS.to('cpu')
thread_data.modelCS.to('cpu')
thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu")
2022-10-17 03:41:39 +02:00
del thread_data.model
del thread_data.modelCS
del thread_data.modelFS
2022-10-21 09:53:43 +02:00
2022-10-17 03:41:39 +02:00
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
gc()
# def wait_model_move_to(model, target_device): # Send to target_device and wait until complete.
# if thread_data.device == target_device: return
# start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
# if start_mem <= 0: return
# model_name = model.__class__.__name__
# print(f'Device {thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mb')
# start_time = time.time()
# model.to(target_device)
# time_step = start_time
# WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout.
# last_mem = start_mem
# is_transfering = True
# while is_transfering:
# time.sleep(0.5) # 500ms
# mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
# is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time.
# last_mem = mem
# if not is_transfering:
# break;
# if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity.
# print(f'Device {thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mb, Transfered: {round(start_mem - mem)}Mb')
# time_step = time.time()
# print(f'Device {thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mb in {round(time.time() - start_time, 3)} seconds to {target_device}')
def move_to_cpu(model):
if thread_data.device != "cpu":
d = torch.device(thread_data.device)
mem = torch.cuda.memory_allocated(d) / 1e6
model.to("cpu")
while torch.cuda.memory_allocated(d) / 1e6 >= mem:
time.sleep(1)
2022-10-21 09:53:43 +02:00
2022-10-17 03:41:39 +02:00
def load_model_gfpgan():
2022-10-21 09:53:43 +02:00
if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.')
2022-10-17 03:41:39 +02:00
model_path = thread_data.gfpgan_file + ".pth"
thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
2022-10-17 03:41:39 +02:00
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
def load_model_real_esrgan():
2022-10-21 09:53:43 +02:00
if thread_data.real_esrgan_file is None: raise ValueError(f'Thread real_esrgan_file is undefined.')
2022-10-17 03:41:39 +02:00
model_path = thread_data.real_esrgan_file + ".pth"
RealESRGAN_models = {
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
}
2022-10-17 03:41:39 +02:00
model_to_use = RealESRGAN_models[thread_data.real_esrgan_file]
2022-10-17 03:41:39 +02:00
if thread_data.device == 'cpu':
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
#thread_data.model_real_esrgan.device = torch.device(thread_data.device)
thread_data.model_real_esrgan.model.to('cpu')
else:
2022-10-17 03:41:39 +02:00
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half)
2022-10-17 03:41:39 +02:00
thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file
print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
def get_session_out_path(disk_path, session_id):
if disk_path is None: return None
if session_id is None: return None
session_out_path = os.path.join(disk_path, filename_regex.sub('_',session_id))
os.makedirs(session_out_path, exist_ok=True)
return session_out_path
def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None):
if disk_path is None: return None
if session_id is None: return None
if ext is None: raise Exception('Missing ext')
session_out_path = get_session_out_path(disk_path, session_id)
prompt_flattened = filename_regex.sub('_', prompt)[:50]
if suffix is not None:
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}")
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}")
2022-10-21 09:53:43 +02:00
def apply_filters(filter_name, image_data, model_path=None):
print(f'Applying filter {filter_name}...')
2022-10-22 02:56:24 +02:00
gc() # Free space before loading new data.
2022-10-17 03:41:39 +02:00
if isinstance(image_data, torch.Tensor):
image_data.to(thread_data.device)
if filter_name == 'gfpgan':
# This lock is only ever used here. No need to use timeout for the request. Should never deadlock.
with gfpgan_temp_device_lock: # Wait for any other devices to complete before starting.
# hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files
from facexlib.detection import retinaface
retinaface.device = torch.device(thread_data.device)
print('forced retinaface.device to', thread_data.device)
if model_path is not None and model_path != thread_data.gfpgan_file:
thread_data.gfpgan_file = model_path
load_model_gfpgan()
elif not thread_data.model_gfpgan:
load_model_gfpgan()
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
image_data = output[:,:,::-1]
if filter_name == 'real_esrgan':
2022-10-21 09:53:43 +02:00
if model_path is not None and model_path != thread_data.real_esrgan_file:
thread_data.real_esrgan_file = model_path
load_model_real_esrgan()
elif not thread_data.model_real_esrgan:
load_model_real_esrgan()
2022-10-17 03:41:39 +02:00
if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1])
image_data = output[:,:,::-1]
return image_data
def is_model_reload_necessary(req: Request):
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if os.path.exists(req.use_stable_diffusion_model + '.ckpt'):
req.use_stable_diffusion_model += '.ckpt'
elif os.path.exists(req.use_stable_diffusion_model + '.safetensors'):
req.use_stable_diffusion_model += '.safetensors'
elif not os.path.exists(req.use_stable_diffusion_model):
raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt or .safetensors')
needs_model_reload = False
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
return needs_model_reload
def reload_model():
unload_models()
unload_filters()
load_model_ckpt()
def is_hypernetwork_reload_necessary(req: Request):
needs_model_reload = False
if thread_data.hypernetwork_file != req.use_hypernetwork_model:
thread_data.hypernetwork_file = req.use_hypernetwork_model
needs_model_reload = True
return needs_model_reload
def load_hypernetwork():
if thread_data.test_sd2:
# Not yet supported in SD2
return
from . import hypernetwork
if thread_data.hypernetwork_file is not None:
try:
loaded = False
for model_extension in HYPERNETWORK_MODEL_EXTENSIONS:
if os.path.exists(thread_data.hypernetwork_file + model_extension):
print(f"Loading hypernetwork weights from: {thread_data.hypernetwork_file}{model_extension}")
thread_data.hypernetwork = hypernetwork.load_hypernetwork(thread_data.hypernetwork_file + model_extension)
loaded = True
break
if not loaded:
print(f'Cannot find hypernetwork: {thread_data.hypernetwork_file}')
thread_data.hypernetwork_file = None
except:
print(traceback.format_exc())
print(f'Could not load hypernetwork: {thread_data.hypernetwork_file}')
thread_data.hypernetwork_file = None
def unload_hypernetwork():
if thread_data.hypernetwork is not None:
print('Unloading hypernetwork...')
if thread_data.device != 'cpu':
for i in thread_data.hypernetwork:
thread_data.hypernetwork[i][0].to('cpu')
thread_data.hypernetwork[i][1].to('cpu')
del thread_data.hypernetwork
thread_data.hypernetwork = None
gc()
def reload_hypernetwork():
unload_hypernetwork()
load_hypernetwork()
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
2022-09-21 18:23:25 +02:00
try:
return do_mk_img(req, data_queue, task_temp_images, step_callback)
2022-09-21 18:23:25 +02:00
except Exception as e:
2022-09-22 14:34:11 +02:00
print(traceback.format_exc())
if thread_data.device != 'cpu' and not thread_data.test_sd2:
thread_data.modelFS.to('cpu')
thread_data.modelCS.to('cpu')
thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu")
2022-10-22 02:56:24 +02:00
gc() # Release from memory.
data_queue.put(json.dumps({
2022-09-22 14:34:11 +02:00
"status": 'failed',
"detail": str(e)
}))
raise e
2022-09-21 18:23:25 +02:00
def update_temp_img(req, x_samples, task_temp_images: list):
2022-10-21 09:53:43 +02:00
partial_images = []
for i in range(req.num_outputs):
if thread_data.test_sd2:
x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
else:
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
2022-10-21 09:53:43 +02:00
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = img_to_buffer(img, output_format='JPEG')
2022-10-21 09:53:43 +02:00
del img, x_sample, x_sample_ddim
# don't delete x_samples, it is used in the code that called this callback
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
task_temp_images[i] = buf
2022-10-21 09:53:43 +02:00
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
return partial_images
# Build and return the apropriate generator for do_mk_img
def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
2022-10-21 09:53:43 +02:00
if not req.stream_progress_updates:
def empty_callback(x_samples, i): return x_samples
return empty_callback
thread_data.partial_x_samples = None
last_callback_time = -1
def img_callback(x_samples, i):
nonlocal last_callback_time
thread_data.partial_x_samples = x_samples
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
progress = {"step": i, "step_time": step_time}
if extra_props is not None:
progress.update(extra_props)
if req.stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(req, x_samples, task_temp_images)
data_queue.put(json.dumps(progress))
2022-10-21 09:53:43 +02:00
step_callback()
2022-10-21 09:53:43 +02:00
if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return img_callback
def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
2022-10-17 03:41:39 +02:00
thread_data.stop_processing = False
res = Response()
res.request = req
res.images = []
thread_data.hypernetwork_strength = req.hypernetwork_strength
2022-10-17 03:41:39 +02:00
thread_data.temp_images.clear()
2022-09-14 18:59:42 +02:00
if thread_data.turbo != req.turbo and not thread_data.test_sd2:
2022-10-17 03:41:39 +02:00
thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo
2022-10-23 01:02:02 +02:00
# Start by cleaning memory, loading and unloading things can leave memory allocated.
2022-10-21 09:53:43 +02:00
gc()
opt_prompt = req.prompt
opt_seed = req.seed
opt_n_iter = 1
opt_C = 4
opt_f = 8
opt_ddim_eta = 0.0
print(req, '\n device', torch.device(thread_data.device), "as", thread_data.device_name)
2022-10-17 03:41:39 +02:00
print('\n\n Using precision:', thread_data.precision)
seed_everything(opt_seed)
batch_size = req.num_outputs
prompt = opt_prompt
assert prompt is not None
data = [batch_size * [prompt]]
2022-10-17 03:41:39 +02:00
if thread_data.precision == "autocast" and thread_data.device != "cpu":
precision_scope = autocast
else:
precision_scope = nullcontext
2022-09-15 14:24:03 +02:00
mask = None
if req.init_image is None:
handler = _txt2img
init_latent = None
t_enc = None
else:
handler = _img2img
init_image = load_img(req.init_image, req.width, req.height)
2022-10-17 03:41:39 +02:00
init_image = init_image.to(thread_data.device)
2022-10-17 03:41:39 +02:00
if thread_data.device != "cpu" and thread_data.precision == "autocast":
init_image = init_image.half()
if not thread_data.test_sd2:
thread_data.modelFS.to(thread_data.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
if thread_data.test_sd2:
init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space
else:
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
2022-09-15 14:24:03 +02:00
if req.mask is not None:
2022-10-17 03:41:39 +02:00
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
2022-09-15 14:24:03 +02:00
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
2022-10-17 03:41:39 +02:00
if thread_data.device != "cpu" and thread_data.precision == "autocast":
2022-09-15 14:24:03 +02:00
mask = mask.half()
2022-10-22 02:56:24 +02:00
# Send to CPU and wait until complete.
# wait_model_move_to(thread_data.modelFS, 'cpu')
if not thread_data.test_sd2:
move_to_cpu(thread_data.modelFS)
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(req.prompt_strength * req.num_inference_steps)
print(f"target t_enc is {t_enc} steps")
if req.save_to_disk_path is not None:
session_out_path = get_session_out_path(req.save_to_disk_path, req.session_id)
else:
session_out_path = None
with torch.no_grad():
for n in trange(opt_n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
with precision_scope("cuda"):
if thread_data.reduced_memory and not thread_data.test_sd2:
2022-10-21 09:53:43 +02:00
thread_data.modelCS.to(thread_data.device)
uc = None
if req.guidance_scale != 1.0:
if thread_data.test_sd2:
uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt])
else:
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
subprompts, weights = split_weighted_subprompts(prompts[0])
if len(subprompts) > 1:
c = torch.zeros_like(uc)
totalWeight = sum(weights)
# normalize each "sub prompt" and add it
for i in range(len(subprompts)):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
if thread_data.test_sd2:
c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
if thread_data.test_sd2:
c = thread_data.model.get_learned_conditioning(prompts)
else:
c = thread_data.modelCS.get_learned_conditioning(prompts)
if thread_data.reduced_memory and not thread_data.test_sd2:
2022-10-21 09:53:43 +02:00
thread_data.modelFS.to(thread_data.device)
2022-09-14 18:59:42 +02:00
2022-10-21 09:53:43 +02:00
n_steps = req.num_inference_steps if req.init_image is None else t_enc
img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps})
# run the handler
try:
2022-10-17 03:41:39 +02:00
print('Running handler...')
if handler == _txt2img:
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
else:
2022-11-25 19:24:08 +01:00
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f)
except UserInitiatedStop:
2022-10-21 09:53:43 +02:00
if not hasattr(thread_data, 'partial_x_samples'):
continue
2022-10-21 09:53:43 +02:00
if thread_data.partial_x_samples is None:
del thread_data.partial_x_samples
continue
x_samples = thread_data.partial_x_samples
del thread_data.partial_x_samples
2022-10-21 09:53:43 +02:00
print("decoding images")
img_data = [None] * batch_size
for i in range(batch_size):
if thread_data.test_sd2:
x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0))
else:
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
2022-10-21 09:53:43 +02:00
img_data[i] = x_sample
del x_samples, x_samples_ddim, x_sample
print("saving images")
for i in range(batch_size):
img = Image.fromarray(img_data[i])
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
(req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
return_orig_img = not has_filters or not req.show_only_filtered_image
2022-10-17 03:41:39 +02:00
if thread_data.stop_processing:
return_orig_img = True
if req.save_to_disk_path is not None:
if return_orig_img:
img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format)
save_image(img, img_out_path, req.output_format, req.output_quality)
meta_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, 'txt')
save_metadata(meta_out_path, req, prompts[0], opt_seed)
if return_orig_img:
img_buffer = img_to_buffer(img, req.output_format, req.output_quality)
img_str = buffer_to_base64_str(img_buffer, req.output_format)
res_image_orig = ResponseImage(data=img_str, seed=opt_seed)
res.images.append(res_image_orig)
task_temp_images[i] = img_buffer
if req.save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path
2022-09-21 18:23:25 +02:00
del img
2022-10-17 03:41:39 +02:00
if has_filters and not thread_data.stop_processing:
filters_applied = []
if req.use_face_correction:
2022-10-21 09:53:43 +02:00
img_data[i] = apply_filters('gfpgan', img_data[i], req.use_face_correction)
filters_applied.append(req.use_face_correction)
if req.use_upscale:
2022-10-21 09:53:43 +02:00
img_data[i] = apply_filters('real_esrgan', img_data[i], req.use_upscale)
filters_applied.append(req.use_upscale)
if (len(filters_applied) > 0):
2022-10-21 09:53:43 +02:00
filtered_image = Image.fromarray(img_data[i])
filtered_buffer = img_to_buffer(filtered_image, req.output_format, req.output_quality)
filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format)
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(response_image)
task_temp_images[i] = filtered_buffer
if req.save_to_disk_path is not None:
filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied))
save_image(filtered_image, filtered_img_out_path, req.output_format, req.output_quality)
response_image.path_abs = filtered_img_out_path
del filtered_image
2022-10-23 11:00:21 +02:00
# Filter Applied, move to next seed
opt_seed += 1
# if thread_data.reduced_memory:
# unload_filters()
if not thread_data.test_sd2:
move_to_cpu(thread_data.modelFS)
2022-10-21 09:53:43 +02:00
del img_data
2022-09-21 18:23:25 +02:00
gc()
2022-10-19 03:08:04 +02:00
if thread_data.device != 'cpu':
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb')
2022-09-16 18:02:08 +02:00
print('Task completed')
res = res.json()
data_queue.put(json.dumps(res))
return res
def save_image(img, img_out_path, output_format="", output_quality=75):
try:
if output_format.upper() == "JPEG":
img.save(img_out_path, quality=output_quality)
else:
img.save(img_out_path)
except:
print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, req, prompt, opt_seed):
2022-10-17 03:41:39 +02:00
metadata = f'''{prompt}
Width: {req.width}
Height: {req.height}
Seed: {opt_seed}
Steps: {req.num_inference_steps}
Guidance Scale: {req.guidance_scale}
Prompt Strength: {req.prompt_strength}
Use Face Correction: {req.use_face_correction}
Use Upscaling: {req.use_upscale}
Sampler: {req.sampler}
Negative Prompt: {req.negative_prompt}
2022-10-17 03:41:39 +02:00
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
2022-11-12 09:01:59 +01:00
VAE model: {req.use_vae_model}
Hypernetwork Model: {req.use_hypernetwork_model}
Hypernetwork Strength: {req.hypernetwork_strength}
2022-10-17 03:41:39 +02:00
'''
try:
with open(meta_out_path, 'w', encoding='utf-8') as f:
f.write(metadata)
except:
print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
2022-10-22 02:56:24 +02:00
# Send to CPU and wait until complete.
# wait_model_move_to(thread_data.modelCS, 'cpu')
if not thread_data.test_sd2:
move_to_cpu(thread_data.modelCS)
if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim', 'dpm2'):
raise Exception('Only plms and ddim samplers are supported right now, in SD 2.0')
2022-10-22 03:44:15 +02:00
2022-11-25 09:59:24 +01:00
# samples, _ = sampler.sample(S=opt.steps,
# conditioning=c,
# batch_size=opt.n_samples,
# shape=shape,
# verbose=False,
# unconditional_guidance_scale=opt.scale,
# unconditional_conditioning=uc,
# eta=opt.ddim_eta,
# x_T=start_code)
if thread_data.test_sd2:
if sampler_name == 'plms':
from ldm.models.diffusion.plms import PLMSSampler
2022-11-25 09:59:24 +01:00
sampler = PLMSSampler(thread_data.model)
elif sampler_name == 'ddim':
from ldm.models.diffusion.ddim import DDIMSampler
2022-11-25 09:59:24 +01:00
sampler = DDIMSampler(thread_data.model)
2022-11-25 18:45:22 +01:00
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
elif sampler_name == 'dpm2':
from ldm.models.diffusion.dpm_solver import DPMSolverSampler
sampler = DPMSolverSampler(thread_data.model)
2022-11-25 18:45:22 +01:00
shape = [opt_C, opt_H // opt_f, opt_W // opt_f]
2022-11-25 18:45:22 +01:00
samples_ddim, intermediates = sampler.sample(
2022-11-25 09:59:24 +01:00
S=opt_ddim_steps,
conditioning=c,
batch_size=opt_n_samples,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
else:
2022-11-25 18:45:22 +01:00
if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
2022-11-29 09:00:08 +01:00
samples_ddim = thread_data.model.sample(
2022-11-25 09:59:24 +01:00
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
mask=mask,
sampler = sampler_name,
)
return samples_ddim
2022-11-25 19:24:08 +01:00
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1):
# encode (scaled latent)
2022-09-15 14:24:03 +02:00
x_T = None if mask is None else init_latent
2022-11-25 18:06:02 +01:00
if thread_data.test_sd2:
from ldm.models.diffusion.ddim import DDIMSampler
sampler = DDIMSampler(thread_data.model)
2022-11-25 18:58:31 +01:00
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(thread_data.device))
2022-11-25 20:25:39 +01:00
samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt_scale,unconditional_conditioning=uc, img_callback=img_callback)
2022-11-25 18:06:02 +01:00
else:
2022-11-25 18:58:31 +01:00
z_enc = thread_data.model.stochastic_encode(
init_latent,
torch.tensor([t_enc] * batch_size).to(thread_data.device),
opt_seed,
opt_ddim_eta,
opt_ddim_steps,
)
2022-11-25 18:06:02 +01:00
# decode it
2022-11-29 09:00:08 +01:00
samples_ddim = thread_data.model.sample(
2022-11-25 18:06:02 +01:00
t_enc,
c,
z_enc,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
img_callback=img_callback,
mask=mask,
x_T=x_T,
sampler = 'ddim'
)
return samples_ddim
def gc():
2022-10-21 09:53:43 +02:00
gc_collect()
2022-10-17 03:41:39 +02:00
if thread_data.device == 'cpu':
return
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# internal
def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}")
if ckpt.endswith(".safetensors"):
print("Loading from safetensors")
pl_sd = load_file(ckpt, device="cpu")
else:
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
return pl_sd["state_dict"]
else:
return pl_sd
class UserInitiatedStop(Exception):
pass
def load_img(img_str, w0, h0):
image = base64_str_to_img(img_str).convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from base64")
if h0 is not None and w0 is not None:
h, w = h0, w0
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.*image - 1.
2022-09-15 14:24:03 +02:00
def load_mask(mask_str, h0, w0, newH, newW, invert=False):
image = base64_str_to_img(mask_str).convert("RGB")
w, h = image.size
print(f"loaded input mask of size ({w}, {h})")
if invert:
print("inverted")
image = ImageOps.invert(image)
# where_0, where_1 = np.where(image == 0), np.where(image == 255)
# image[where_0], image[where_1] = 255, 0
if h0 is not None and w0 is not None:
h, w = h0, w0
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
print(f"New mask size ({w}, {h})")
image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
# https://stackoverflow.com/a/61114178
def img_to_base64_str(img, output_format="PNG", output_quality=75):
buffered = img_to_buffer(img, output_format, quality=output_quality)
return buffer_to_base64_str(buffered, output_format)
def img_to_buffer(img, output_format="PNG", output_quality=75):
buffered = BytesIO()
if ( output_format.upper() == "JPEG" ):
img.save(buffered, format=output_format, quality=output_quality)
else:
img.save(buffered, format=output_format)
buffered.seek(0)
return buffered
def buffer_to_base64_str(buffered, output_format="PNG"):
buffered.seek(0)
img_byte = buffered.getvalue()
mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg"
img_str = f"data:{mime_type};base64," + base64.b64encode(img_byte).decode()
return img_str
def base64_str_to_buffer(img_str):
mime_type = "image/png" if img_str.startswith("data:image/png;") else "image/jpeg"
img_str = img_str[len(f"data:{mime_type};base64,"):]
data = base64.b64decode(img_str)
buffered = BytesIO(data)
return buffered
def base64_str_to_img(img_str):
buffered = base64_str_to_buffer(img_str)
img = Image.open(buffered)
return img
def split_weighted_subprompts(text):
"""
grabs all text up to the first occurrence of ':'
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
if ':' has no value defined, defaults to 1.0
repeats until no text remaining
"""
remaining = len(text)
prompts = []
weights = []
while remaining > 0:
if ":" in text:
idx = text.index(":") # first occurrence from start
# grab up to index as sub-prompt
prompt = text[:idx]
remaining -= idx
# remove from main text
text = text[idx+1:]
# find value for weight
if " " in text:
idx = text.index(" ") # first occurence
else: # no space, read to end
idx = len(text)
if idx != 0:
try:
weight = float(text[:idx])
except: # couldn't treat as float
print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?")
weight = 1.0
else: # no value found
weight = 1.0
# remove from main text
remaining -= idx
text = text[idx+1:]
# append the sub-prompt and its weight
prompts.append(prompt)
weights.append(weight)
else: # no : found
if len(text) > 0: # there is still text though
# take remainder as weight 1
prompts.append(text)
weights.append(1.0)
remaining = 0
return prompts, weights