diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index 6f702749..3a748431 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from diffusionkit.types import GenerateImageRequest +from sdkit.types import GenerateImageRequest class TaskData(BaseModel): request_id: str = None diff --git a/ui/sd_internal/model_manager.py b/ui/sd_internal/model_manager.py index 30e97053..ec72e6fa 100644 --- a/ui/sd_internal/model_manager.py +++ b/ui/sd_internal/model_manager.py @@ -1,11 +1,10 @@ import os import logging import picklescan.scanner -import rich from sd_internal import app, TaskData, device_manager -from diffusionkit import model_loader -from diffusionkit.types import Context +from sdkit.models import model_loader +from sdkit.types import Context log = logging.getLogger() diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index d9ab3f0d..1d289c99 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -6,8 +6,8 @@ import logging from sd_internal import device_manager, save_utils from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop -from diffusionkit import model_loader, image_generator, image_utils, filters as image_filters, data_utils -from diffusionkit.types import Context, GenerateImageRequest, FilterImageRequest +from sdkit import model_loader, image_generator, image_utils, filters as image_filters +from sdkit.types import Context, GenerateImageRequest, FilterImageRequest log = logging.getLogger() diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py deleted file mode 100644 index 2bf53d0c..00000000 --- a/ui/sd_internal/runtime.py +++ /dev/null @@ -1,1078 +0,0 @@ -"""runtime.py: torch device owned by a thread. -Notes: - Avoid device switching, transfering all models will get too complex. - To use a diffrent device signal the current render device to exit - And then start a new clean thread for the new device. -""" -import json -import os, re -import traceback -import queue -import torch -import numpy as np -from gc import collect as gc_collect -from omegaconf import OmegaConf -from PIL import Image, ImageOps -from tqdm import tqdm, trange -from itertools import islice -from einops import rearrange -import time -from pytorch_lightning import seed_everything -from torch import autocast -from contextlib import nullcontext -from einops import rearrange, repeat -from ldm.util import instantiate_from_config -from transformers import logging - -from gfpgan import GFPGANer -from basicsr.archs.rrdbnet_arch import RRDBNet -from realesrgan import RealESRGANer - -from server import HYPERNETWORK_MODEL_EXTENSIONS# , STABLE_DIFFUSION_MODEL_EXTENSIONS, VAE_MODEL_EXTENSIONS - -from threading import Lock -from safetensors.torch import load_file - -import uuid - -logging.set_verbosity_error() - -# consts -config_yaml = "optimizedSD/v1-inference.yaml" -filename_regex = re.compile('[^a-zA-Z0-9]') -gfpgan_temp_device_lock = Lock() # workaround: gfpgan currently can only start on one device at a time. - -# api stuff -from sd_internal import device_manager -from . import Request, Response, Image as ResponseImage -import base64 -from io import BytesIO -#from colorama import Fore - -from threading import local as LocalThreadVars -thread_data = LocalThreadVars() - -def thread_init(device): - # Thread bound properties - thread_data.stop_processing = False - thread_data.temp_images = {} - - thread_data.ckpt_file = None - thread_data.vae_file = None - thread_data.hypernetwork_file = None - thread_data.gfpgan_file = None - thread_data.real_esrgan_file = None - - thread_data.model = None - thread_data.modelCS = None - thread_data.modelFS = None - thread_data.hypernetwork = None - thread_data.hypernetwork_strength = 1 - thread_data.model_gfpgan = None - thread_data.model_real_esrgan = None - - thread_data.model_is_half = False - thread_data.model_fs_is_half = False - thread_data.device = None - thread_data.device_name = None - thread_data.unet_bs = 1 - thread_data.precision = 'autocast' - thread_data.sampler_plms = None - thread_data.sampler_ddim = None - - thread_data.turbo = False - thread_data.force_full_precision = False - thread_data.reduced_memory = True - - thread_data.test_sd2 = isSD2() - - device_manager.device_init(thread_data, device) - -# temp hack, will remove soon -def isSD2(): - try: - SD_UI_DIR = os.getenv('SD_UI_PATH', None) - CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) - config_json_path = os.path.join(CONFIG_DIR, 'config.json') - if not os.path.exists(config_json_path): - return False - with open(config_json_path, 'r', encoding='utf-8') as f: - config = json.load(f) - return config.get('test_sd2', False) - except Exception as e: - return False - -def load_model_ckpt(): - if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.') - if os.path.exists(thread_data.ckpt_file + '.ckpt'): - thread_data.ckpt_file += '.ckpt' - elif os.path.exists(thread_data.ckpt_file + '.safetensors'): - thread_data.ckpt_file += '.safetensors' - elif not os.path.exists(thread_data.ckpt_file): - raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt or .safetensors') - - if not thread_data.precision: - thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast' - - if not thread_data.unet_bs: - thread_data.unet_bs = 1 - - if thread_data.device == 'cpu': - thread_data.precision = 'full' - - print('loading', thread_data.ckpt_file, 'to device', thread_data.device, 'using precision', thread_data.precision) - - if thread_data.test_sd2: - load_model_ckpt_sd2() - else: - load_model_ckpt_sd1() - -def load_model_ckpt_sd1(): - sd, model_ver = load_model_from_config(thread_data.ckpt_file) - li, lo = [], [] - for key, value in sd.items(): - sp = key.split(".") - if (sp[0]) == "model": - if "input_blocks" in sp: - li.append(key) - elif "middle_block" in sp: - li.append(key) - elif "time_embed" in sp: - li.append(key) - else: - lo.append(key) - for key in li: - sd["model1." + key[6:]] = sd.pop(key) - for key in lo: - sd["model2." + key[6:]] = sd.pop(key) - - config = OmegaConf.load(f"{config_yaml}") - - model = instantiate_from_config(config.modelUNet) - _, _ = model.load_state_dict(sd, strict=False) - model.eval() - model.cdevice = torch.device(thread_data.device) - model.unet_bs = thread_data.unet_bs - model.turbo = thread_data.turbo - # if thread_data.device != 'cpu': - # model.to(thread_data.device) - #if thread_data.reduced_memory: - #model.model1.to("cpu") - #model.model2.to("cpu") - thread_data.model = model - - modelCS = instantiate_from_config(config.modelCondStage) - _, _ = modelCS.load_state_dict(sd, strict=False) - modelCS.eval() - modelCS.cond_stage_model.device = torch.device(thread_data.device) - # if thread_data.device != 'cpu': - # if thread_data.reduced_memory: - # modelCS.to('cpu') - # else: - # modelCS.to(thread_data.device) # Preload on device if not already there. - thread_data.modelCS = modelCS - - modelFS = instantiate_from_config(config.modelFirstStage) - _, _ = modelFS.load_state_dict(sd, strict=False) - - if thread_data.vae_file is not None: - try: - loaded = False - for model_extension in ['.ckpt', '.vae.pt']: - if os.path.exists(thread_data.vae_file + model_extension): - print(f"Loading VAE weights from: {thread_data.vae_file}{model_extension}") - vae_ckpt = torch.load(thread_data.vae_file + model_extension, map_location="cpu") - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} - modelFS.first_stage_model.load_state_dict(vae_dict, strict=False) - loaded = True - break - - if not loaded: - print(f'Cannot find VAE: {thread_data.vae_file}') - thread_data.vae_file = None - except: - print(traceback.format_exc()) - print(f'Could not load VAE: {thread_data.vae_file}') - thread_data.vae_file = None - - modelFS.eval() - # if thread_data.device != 'cpu': - # if thread_data.reduced_memory: - # modelFS.to('cpu') - # else: - # modelFS.to(thread_data.device) # Preload on device if not already there. - thread_data.modelFS = modelFS - del sd - - if thread_data.device != "cpu" and thread_data.precision == "autocast": - thread_data.model.half() - thread_data.modelCS.half() - thread_data.modelFS.half() - thread_data.model_is_half = True - thread_data.model_fs_is_half = True - else: - thread_data.model_is_half = False - thread_data.model_fs_is_half = False - - print(f'''loaded model - model file: {thread_data.ckpt_file} - model.device: {model.device} - modelCS.device: {modelCS.cond_stage_model.device} - modelFS.device: {thread_data.modelFS.device} - using precision: {thread_data.precision}''') - -def load_model_ckpt_sd2(): - sd, model_ver = load_model_from_config(thread_data.ckpt_file) - - config_file = 'configs/stable-diffusion/v2-inference-v.yaml' if model_ver == 'sd2' else "configs/stable-diffusion/v1-inference.yaml" - config = OmegaConf.load(config_file) - verbose = False - - thread_data.model = instantiate_from_config(config.model) - m, u = thread_data.model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) - - thread_data.model.to(thread_data.device) - thread_data.model.eval() - del sd - - thread_data.model.cond_stage_model.device = torch.device(thread_data.device) - - if thread_data.device != "cpu" and thread_data.precision == "autocast": - thread_data.model.half() - thread_data.model_is_half = True - thread_data.model_fs_is_half = True - else: - thread_data.model_is_half = False - thread_data.model_fs_is_half = False - - print(f'''loaded model - model file: {thread_data.ckpt_file} - using precision: {thread_data.precision}''') - -def unload_filters(): - if thread_data.model_gfpgan is not None: - if thread_data.device != 'cpu': thread_data.model_gfpgan.gfpgan.to('cpu') - - del thread_data.model_gfpgan - thread_data.model_gfpgan = None - - if thread_data.model_real_esrgan is not None: - if thread_data.device != 'cpu': thread_data.model_real_esrgan.model.to('cpu') - - del thread_data.model_real_esrgan - thread_data.model_real_esrgan = None - - gc() - -def unload_models(): - if thread_data.model is not None: - print('Unloading models...') - if thread_data.device != 'cpu': - if not thread_data.test_sd2: - thread_data.modelFS.to('cpu') - thread_data.modelCS.to('cpu') - thread_data.model.model1.to("cpu") - thread_data.model.model2.to("cpu") - - del thread_data.model - del thread_data.modelCS - del thread_data.modelFS - - thread_data.model = None - thread_data.modelCS = None - thread_data.modelFS = None - - gc() - -# def wait_model_move_to(model, target_device): # Send to target_device and wait until complete. -# if thread_data.device == target_device: return -# start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 -# if start_mem <= 0: return -# model_name = model.__class__.__name__ -# print(f'Device {thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mb') -# start_time = time.time() -# model.to(target_device) -# time_step = start_time -# WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout. -# last_mem = start_mem -# is_transfering = True -# while is_transfering: -# time.sleep(0.5) # 500ms -# mem = torch.cuda.memory_allocated(thread_data.device) / 1e6 -# is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time. -# last_mem = mem -# if not is_transfering: -# break; -# if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity. -# print(f'Device {thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mb, Transfered: {round(start_mem - mem)}Mb') -# time_step = time.time() -# print(f'Device {thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mb in {round(time.time() - start_time, 3)} seconds to {target_device}') - -def move_to_cpu(model): - if thread_data.device != "cpu": - d = torch.device(thread_data.device) - mem = torch.cuda.memory_allocated(d) / 1e6 - model.to("cpu") - while torch.cuda.memory_allocated(d) / 1e6 >= mem: - time.sleep(1) - -def load_model_gfpgan(): - if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.') - model_path = thread_data.gfpgan_file + ".pth" - thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) - print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision) - -def load_model_real_esrgan(): - if thread_data.real_esrgan_file is None: raise ValueError(f'Thread real_esrgan_file is undefined.') - model_path = thread_data.real_esrgan_file + ".pth" - - RealESRGAN_models = { - 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), - 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) - } - - model_to_use = RealESRGAN_models[thread_data.real_esrgan_file] - - if thread_data.device == 'cpu': - thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half - #thread_data.model_real_esrgan.device = torch.device(thread_data.device) - thread_data.model_real_esrgan.model.to('cpu') - else: - thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half) - - thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file - print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision) - - -def get_session_out_path(disk_path, session_id): - if disk_path is None: return None - if session_id is None: return None - - session_out_path = os.path.join(disk_path, filename_regex.sub('_',session_id)) - os.makedirs(session_out_path, exist_ok=True) - return session_out_path - -def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None): - if disk_path is None: return None - if session_id is None: return None - if ext is None: raise Exception('Missing ext') - - session_out_path = get_session_out_path(disk_path, session_id) - - prompt_flattened = filename_regex.sub('_', prompt)[:50] - - if suffix is not None: - return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}") - return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}") - -def apply_filters(filter_name, image_data, model_path=None): - print(f'Applying filter {filter_name}...') - gc() # Free space before loading new data. - - if isinstance(image_data, torch.Tensor): - image_data.to(thread_data.device) - - if filter_name == 'gfpgan': - # This lock is only ever used here. No need to use timeout for the request. Should never deadlock. - with gfpgan_temp_device_lock: # Wait for any other devices to complete before starting. - # hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files - from facexlib.detection import retinaface - retinaface.device = torch.device(thread_data.device) - print('forced retinaface.device to', thread_data.device) - - if model_path is not None and model_path != thread_data.gfpgan_file: - thread_data.gfpgan_file = model_path - load_model_gfpgan() - elif not thread_data.model_gfpgan: - load_model_gfpgan() - if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.') - - print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision) - _, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) - image_data = output[:,:,::-1] - - if filter_name == 'real_esrgan': - if model_path is not None and model_path != thread_data.real_esrgan_file: - thread_data.real_esrgan_file = model_path - load_model_real_esrgan() - elif not thread_data.model_real_esrgan: - load_model_real_esrgan() - if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.') - print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision) - output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1]) - image_data = output[:,:,::-1] - - return image_data - -def is_model_reload_necessary(req: Request): - # custom model support: - # the req.use_stable_diffusion_model needs to be a valid path - # to the ckpt file (without the extension). - if os.path.exists(req.use_stable_diffusion_model + '.ckpt'): - req.use_stable_diffusion_model += '.ckpt' - elif os.path.exists(req.use_stable_diffusion_model + '.safetensors'): - req.use_stable_diffusion_model += '.safetensors' - elif not os.path.exists(req.use_stable_diffusion_model): - raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt or .safetensors') - - needs_model_reload = False - if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: - thread_data.ckpt_file = req.use_stable_diffusion_model - thread_data.vae_file = req.use_vae_model - needs_model_reload = True - - if thread_data.device != 'cpu': - if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ - (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): - thread_data.precision = 'full' if req.use_full_precision else 'autocast' - needs_model_reload = True - - return needs_model_reload - -def reload_model(): - unload_models() - unload_filters() - load_model_ckpt() - -def is_hypernetwork_reload_necessary(req: Request): - needs_model_reload = False - if thread_data.hypernetwork_file != req.use_hypernetwork_model: - thread_data.hypernetwork_file = req.use_hypernetwork_model - needs_model_reload = True - - return needs_model_reload - -def load_hypernetwork(): - if thread_data.test_sd2: - # Not yet supported in SD2 - return - - from . import hypernetwork - if thread_data.hypernetwork_file is not None: - try: - loaded = False - for model_extension in HYPERNETWORK_MODEL_EXTENSIONS: - if os.path.exists(thread_data.hypernetwork_file + model_extension): - print(f"Loading hypernetwork weights from: {thread_data.hypernetwork_file}{model_extension}") - thread_data.hypernetwork = hypernetwork.load_hypernetwork(thread_data.hypernetwork_file + model_extension) - loaded = True - break - - if not loaded: - print(f'Cannot find hypernetwork: {thread_data.hypernetwork_file}') - thread_data.hypernetwork_file = None - except: - print(traceback.format_exc()) - print(f'Could not load hypernetwork: {thread_data.hypernetwork_file}') - thread_data.hypernetwork_file = None - -def unload_hypernetwork(): - if thread_data.hypernetwork is not None: - print('Unloading hypernetwork...') - if thread_data.device != 'cpu': - for i in thread_data.hypernetwork: - thread_data.hypernetwork[i][0].to('cpu') - thread_data.hypernetwork[i][1].to('cpu') - del thread_data.hypernetwork - thread_data.hypernetwork = None - - gc() - -def reload_hypernetwork(): - unload_hypernetwork() - load_hypernetwork() - -def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): - try: - return do_mk_img(req, data_queue, task_temp_images, step_callback) - except Exception as e: - print(traceback.format_exc()) - - if thread_data.device != 'cpu' and not thread_data.test_sd2: - thread_data.modelFS.to('cpu') - thread_data.modelCS.to('cpu') - thread_data.model.model1.to("cpu") - thread_data.model.model2.to("cpu") - - gc() # Release from memory. - data_queue.put(json.dumps({ - "status": 'failed', - "detail": str(e) - })) - raise e - -def update_temp_img(req, x_samples, task_temp_images: list): - partial_images = [] - for i in range(req.num_outputs): - if thread_data.test_sd2: - x_sample_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0)) - else: - x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) - x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") - x_sample = x_sample.astype(np.uint8) - img = Image.fromarray(x_sample) - buf = img_to_buffer(img, output_format='JPEG') - - del img, x_sample, x_sample_ddim - # don't delete x_samples, it is used in the code that called this callback - - thread_data.temp_images[f'{req.request_id}/{i}'] = buf - task_temp_images[i] = buf - partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'}) - return partial_images - -# Build and return the apropriate generator for do_mk_img -def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None): - if not req.stream_progress_updates: - def empty_callback(x_samples, i): - step_callback() - return empty_callback - - thread_data.partial_x_samples = None - last_callback_time = -1 - def img_callback(x_samples, i): - nonlocal last_callback_time - - thread_data.partial_x_samples = x_samples - step_time = time.time() - last_callback_time if last_callback_time != -1 else -1 - last_callback_time = time.time() - - progress = {"step": i, "step_time": step_time} - if extra_props is not None: - progress.update(extra_props) - - if req.stream_image_progress and i % 5 == 0: - progress['output'] = update_temp_img(req, x_samples, task_temp_images) - - data_queue.put(json.dumps(progress)) - - step_callback() - - if thread_data.stop_processing: - raise UserInitiatedStop("User requested that we stop processing") - return img_callback - -def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): - thread_data.stop_processing = False - - res = Response() - res.request = req - res.images = [] - thread_data.hypernetwork_strength = req.hypernetwork_strength - - thread_data.temp_images.clear() - - if thread_data.turbo != req.turbo and not thread_data.test_sd2: - thread_data.turbo = req.turbo - thread_data.model.turbo = req.turbo - - # Start by cleaning memory, loading and unloading things can leave memory allocated. - gc() - - opt_prompt = req.prompt - opt_seed = req.seed - opt_n_iter = 1 - opt_C = 4 - opt_f = 8 - opt_ddim_eta = 0.0 - - print(req, '\n device', torch.device(thread_data.device), "as", thread_data.device_name) - print('\n\n Using precision:', thread_data.precision) - - seed_everything(opt_seed) - - batch_size = req.num_outputs - prompt = opt_prompt - assert prompt is not None - data = [batch_size * [prompt]] - - if thread_data.precision == "autocast" and thread_data.device != "cpu": - precision_scope = autocast - else: - precision_scope = nullcontext - - mask = None - - if req.init_image is None: - handler = _txt2img - - init_latent = None - t_enc = None - else: - handler = _img2img - - init_image = load_img(req.init_image, req.width, req.height) - init_image = init_image.to(thread_data.device) - - if thread_data.device != "cpu" and thread_data.precision == "autocast": - init_image = init_image.half() - - if not thread_data.test_sd2: - thread_data.modelFS.to(thread_data.device) - - init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) - if thread_data.test_sd2: - init_latent = thread_data.model.get_first_stage_encoding(thread_data.model.encode_first_stage(init_image)) # move to latent space - else: - init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space - - if req.mask is not None: - mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device) - mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0) - mask = repeat(mask, '1 ... -> b ...', b=batch_size) - - if thread_data.device != "cpu" and thread_data.precision == "autocast": - mask = mask.half() - - # Send to CPU and wait until complete. - # wait_model_move_to(thread_data.modelFS, 'cpu') - if not thread_data.test_sd2: - move_to_cpu(thread_data.modelFS) - - assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]' - t_enc = int(req.prompt_strength * req.num_inference_steps) - print(f"target t_enc is {t_enc} steps") - - with torch.no_grad(): - for n in trange(opt_n_iter, desc="Sampling"): - for prompts in tqdm(data, desc="data"): - - with precision_scope("cuda"): - if thread_data.reduced_memory and not thread_data.test_sd2: - thread_data.modelCS.to(thread_data.device) - uc = None - if req.guidance_scale != 1.0: - if thread_data.test_sd2: - uc = thread_data.model.get_learned_conditioning(batch_size * [req.negative_prompt]) - else: - uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) - if isinstance(prompts, tuple): - prompts = list(prompts) - - subprompts, weights = split_weighted_subprompts(prompts[0]) - if len(subprompts) > 1: - c = torch.zeros_like(uc) - totalWeight = sum(weights) - # normalize each "sub prompt" and add it - for i in range(len(subprompts)): - weight = weights[i] - # if not skip_normalize: - weight = weight / totalWeight - if thread_data.test_sd2: - c = torch.add(c, thread_data.model.get_learned_conditioning(subprompts[i]), alpha=weight) - else: - c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) - else: - if thread_data.test_sd2: - c = thread_data.model.get_learned_conditioning(prompts) - else: - c = thread_data.modelCS.get_learned_conditioning(prompts) - - if thread_data.reduced_memory and not thread_data.test_sd2: - thread_data.modelFS.to(thread_data.device) - - n_steps = req.num_inference_steps if req.init_image is None else t_enc - img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps}) - - # run the handler - try: - print('Running handler...') - if handler == _txt2img: - x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler) - else: - x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f) - except UserInitiatedStop: - if not hasattr(thread_data, 'partial_x_samples'): - continue - if thread_data.partial_x_samples is None: - del thread_data.partial_x_samples - continue - x_samples = thread_data.partial_x_samples - del thread_data.partial_x_samples - - print("decoding images") - img_data = [None] * batch_size - for i in range(batch_size): - if thread_data.test_sd2: - x_samples_ddim = thread_data.model.decode_first_stage(x_samples[i].unsqueeze(0)) - else: - x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) - x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) - x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") - x_sample = x_sample.astype(np.uint8) - img_data[i] = x_sample - del x_samples, x_samples_ddim, x_sample - - print("saving images") - for i in range(batch_size): - img = Image.fromarray(img_data[i]) - img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. - img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. - - has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \ - (req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN')) - - return_orig_img = not has_filters or not req.show_only_filtered_image - - if thread_data.stop_processing: - return_orig_img = True - - if req.save_to_disk_path is not None: - if return_orig_img: - img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format) - save_image(img, img_out_path, req.output_format, req.output_quality) - meta_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, 'txt') - save_metadata(meta_out_path, req, prompts[0], opt_seed) - - if return_orig_img: - img_buffer = img_to_buffer(img, req.output_format, req.output_quality) - img_str = buffer_to_base64_str(img_buffer, req.output_format) - res_image_orig = ResponseImage(data=img_str, seed=opt_seed) - res.images.append(res_image_orig) - task_temp_images[i] = img_buffer - - if req.save_to_disk_path is not None: - res_image_orig.path_abs = img_out_path - del img - - if has_filters and not thread_data.stop_processing: - filters_applied = [] - if req.use_face_correction: - img_data[i] = apply_filters('gfpgan', img_data[i], req.use_face_correction) - filters_applied.append(req.use_face_correction) - if req.use_upscale: - img_data[i] = apply_filters('real_esrgan', img_data[i], req.use_upscale) - filters_applied.append(req.use_upscale) - if (len(filters_applied) > 0): - filtered_image = Image.fromarray(img_data[i]) - filtered_buffer = img_to_buffer(filtered_image, req.output_format, req.output_quality) - filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format) - response_image = ResponseImage(data=filtered_img_data, seed=opt_seed) - res.images.append(response_image) - task_temp_images[i] = filtered_buffer - if req.save_to_disk_path is not None: - filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied)) - save_image(filtered_image, filtered_img_out_path, req.output_format, req.output_quality) - response_image.path_abs = filtered_img_out_path - del filtered_image - # Filter Applied, move to next seed - opt_seed += 1 - - # if thread_data.reduced_memory: - # unload_filters() - if not thread_data.test_sd2: - move_to_cpu(thread_data.modelFS) - del img_data - gc() - if thread_data.device != 'cpu': - print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb') - - print('Task completed') - res = res.json() - data_queue.put(json.dumps(res)) - - return res - -def save_image(img, img_out_path, output_format="", output_quality=75): - try: - if output_format.upper() == "JPEG": - img.save(img_out_path, quality=output_quality) - else: - img.save(img_out_path) - except: - print('could not save the file', traceback.format_exc()) - -def save_metadata(meta_out_path, req, prompt, opt_seed): - metadata = f'''{prompt} -Width: {req.width} -Height: {req.height} -Seed: {opt_seed} -Steps: {req.num_inference_steps} -Guidance Scale: {req.guidance_scale} -Prompt Strength: {req.prompt_strength} -Use Face Correction: {req.use_face_correction} -Use Upscaling: {req.use_upscale} -Sampler: {req.sampler} -Negative Prompt: {req.negative_prompt} -Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'} -VAE model: {req.use_vae_model} -Hypernetwork Model: {req.use_hypernetwork_model} -Hypernetwork Strength: {req.hypernetwork_strength} -''' - try: - with open(meta_out_path, 'w', encoding='utf-8') as f: - f.write(metadata) - except: - print('could not save the file', traceback.format_exc()) - -def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name): - shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] - - # Send to CPU and wait until complete. - # wait_model_move_to(thread_data.modelCS, 'cpu') - - if not thread_data.test_sd2: - move_to_cpu(thread_data.modelCS) - - if thread_data.test_sd2 and sampler_name not in ('plms', 'ddim', 'dpm2'): - raise Exception('Only plms, ddim and dpm2 samplers are supported right now, in SD 2.0') - - - # samples, _ = sampler.sample(S=opt.steps, - # conditioning=c, - # batch_size=opt.n_samples, - # shape=shape, - # verbose=False, - # unconditional_guidance_scale=opt.scale, - # unconditional_conditioning=uc, - # eta=opt.ddim_eta, - # x_T=start_code) - - if thread_data.test_sd2: - if sampler_name == 'plms': - from ldm.models.diffusion.plms import PLMSSampler - sampler = PLMSSampler(thread_data.model) - elif sampler_name == 'ddim': - from ldm.models.diffusion.ddim import DDIMSampler - sampler = DDIMSampler(thread_data.model) - sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) - elif sampler_name == 'dpm2': - from ldm.models.diffusion.dpm_solver import DPMSolverSampler - sampler = DPMSolverSampler(thread_data.model) - - shape = [opt_C, opt_H // opt_f, opt_W // opt_f] - - samples_ddim, intermediates = sampler.sample( - S=opt_ddim_steps, - conditioning=c, - batch_size=opt_n_samples, - seed=opt_seed, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt_scale, - unconditional_conditioning=uc, - eta=opt_ddim_eta, - x_T=start_code, - img_callback=img_callback, - mask=mask, - sampler = sampler_name, - ) - else: - if sampler_name == 'ddim': - thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) - - samples_ddim = thread_data.model.sample( - S=opt_ddim_steps, - conditioning=c, - seed=opt_seed, - shape=shape, - verbose=False, - unconditional_guidance_scale=opt_scale, - unconditional_conditioning=uc, - eta=opt_ddim_eta, - x_T=start_code, - img_callback=img_callback, - mask=mask, - sampler = sampler_name, - ) - return samples_ddim - -def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1): - # encode (scaled latent) - x_T = None if mask is None else init_latent - - if thread_data.test_sd2: - from ldm.models.diffusion.ddim import DDIMSampler - - sampler = DDIMSampler(thread_data.model) - - sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) - - z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(thread_data.device)) - - samples_ddim = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt_scale,unconditional_conditioning=uc, img_callback=img_callback) - - else: - z_enc = thread_data.model.stochastic_encode( - init_latent, - torch.tensor([t_enc] * batch_size).to(thread_data.device), - opt_seed, - opt_ddim_eta, - opt_ddim_steps, - ) - - # decode it - samples_ddim = thread_data.model.sample( - t_enc, - c, - z_enc, - unconditional_guidance_scale=opt_scale, - unconditional_conditioning=uc, - img_callback=img_callback, - mask=mask, - x_T=x_T, - sampler = 'ddim' - ) - return samples_ddim - -def gc(): - gc_collect() - if thread_data.device == 'cpu': - return - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - -# internal - -def chunk(it, size): - it = iter(it) - return iter(lambda: tuple(islice(it, size)), ()) - -def load_model_from_config(ckpt, verbose=False): - print(f"Loading model from {ckpt}") - model_ver = 'sd1' - - if ckpt.endswith(".safetensors"): - print("Loading from safetensors") - pl_sd = load_file(ckpt, device="cpu") - else: - pl_sd = torch.load(ckpt, map_location="cpu") - - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - - if "state_dict" in pl_sd: - # check for a key that only seems to be present in SD2 models - if 'cond_stage_model.model.ln_final.bias' in pl_sd['state_dict'].keys(): - model_ver = 'sd2' - - return pl_sd["state_dict"], model_ver - else: - return pl_sd, model_ver - -class UserInitiatedStop(Exception): - pass - -def load_img(img_str, w0, h0): - image = base64_str_to_img(img_str).convert("RGB") - w, h = image.size - print(f"loaded input image of size ({w}, {h}) from base64") - if h0 is not None and w0 is not None: - h, w = h0, w0 - - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 - image = image.resize((w, h), resample=Image.Resampling.LANCZOS) - image = np.array(image).astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return 2.*image - 1. - -def load_mask(mask_str, h0, w0, newH, newW, invert=False): - image = base64_str_to_img(mask_str).convert("RGB") - w, h = image.size - print(f"loaded input mask of size ({w}, {h})") - - if invert: - print("inverted") - image = ImageOps.invert(image) - # where_0, where_1 = np.where(image == 0), np.where(image == 255) - # image[where_0], image[where_1] = 255, 0 - - if h0 is not None and w0 is not None: - h, w = h0, w0 - - w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 - - print(f"New mask size ({w}, {h})") - image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS) - image = np.array(image) - - image = image.astype(np.float32) / 255.0 - image = image[None].transpose(0, 3, 1, 2) - image = torch.from_numpy(image) - return image - -# https://stackoverflow.com/a/61114178 -def img_to_base64_str(img, output_format="PNG", output_quality=75): - buffered = img_to_buffer(img, output_format, quality=output_quality) - return buffer_to_base64_str(buffered, output_format) - -def img_to_buffer(img, output_format="PNG", output_quality=75): - buffered = BytesIO() - if ( output_format.upper() == "JPEG" ): - img.save(buffered, format=output_format, quality=output_quality) - else: - img.save(buffered, format=output_format) - buffered.seek(0) - return buffered - -def buffer_to_base64_str(buffered, output_format="PNG"): - buffered.seek(0) - img_byte = buffered.getvalue() - mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg" - img_str = f"data:{mime_type};base64," + base64.b64encode(img_byte).decode() - return img_str - -def base64_str_to_buffer(img_str): - mime_type = "image/png" if img_str.startswith("data:image/png;") else "image/jpeg" - img_str = img_str[len(f"data:{mime_type};base64,"):] - data = base64.b64decode(img_str) - buffered = BytesIO(data) - return buffered - -def base64_str_to_img(img_str): - buffered = base64_str_to_buffer(img_str) - img = Image.open(buffered) - return img - -def split_weighted_subprompts(text): - """ - grabs all text up to the first occurrence of ':' - uses the grabbed text as a sub-prompt, and takes the value following ':' as weight - if ':' has no value defined, defaults to 1.0 - repeats until no text remaining - """ - remaining = len(text) - prompts = [] - weights = [] - while remaining > 0: - if ":" in text: - idx = text.index(":") # first occurrence from start - # grab up to index as sub-prompt - prompt = text[:idx] - remaining -= idx - # remove from main text - text = text[idx+1:] - # find value for weight - if " " in text: - idx = text.index(" ") # first occurence - else: # no space, read to end - idx = len(text) - if idx != 0: - try: - weight = float(text[:idx]) - except: # couldn't treat as float - print(f"Warning: '{text[:idx]}' is not a value, are you missing a space?") - weight = 1.0 - else: # no value found - weight = 1.0 - # remove from main text - remaining -= idx - text = text[idx+1:] - # append the sub-prompt and its weight - prompts.append(prompt) - weights.append(weight) - else: # no : found - if len(text) > 0: # there is still text though - # take remainder as weight 1 - prompts.append(text) - weights.append(1.0) - remaining = 0 - return prompts, weights diff --git a/ui/sd_internal/save_utils.py b/ui/sd_internal/save_utils.py index 29c4a29c..db9d0671 100644 --- a/ui/sd_internal/save_utils.py +++ b/ui/sd_internal/save_utils.py @@ -3,8 +3,8 @@ import time import base64 import re -from diffusionkit import data_utils -from diffusionkit.types import GenerateImageRequest +from sdkit.utils import save_images, save_dicts +from sdkit.types import GenerateImageRequest from sd_internal import TaskData @@ -33,12 +33,12 @@ def save_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, metadata_entries = get_metadata_entries(req, task_data) if task_data.show_only_filtered_image or filtered_images == images: - data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format) + save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + save_dicts(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format) else: - data_utils.save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) + save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) + save_dicts(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): metadata = get_printable_request(req) diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index 094d853c..f9080a74 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -15,7 +15,7 @@ import queue, threading, time, weakref from typing import Any, Hashable from sd_internal import TaskData, device_manager -from diffusionkit.types import GenerateImageRequest +from sdkit.types import GenerateImageRequest log = logging.getLogger() diff --git a/ui/server.py b/ui/server.py index 284fed1b..43a11347 100644 --- a/ui/server.py +++ b/ui/server.py @@ -15,7 +15,7 @@ from pydantic import BaseModel from sd_internal import app, model_manager, task_manager from sd_internal import TaskData -from diffusionkit.types import GenerateImageRequest +from sdkit.types import GenerateImageRequest log = logging.getLogger()