2022-10-19 05:58:55 +02:00
""" runtime.py: torch device owned by a thread.
Notes :
Avoid device switching , transfering all models will get too complex .
To use a diffrent device signal the current render device to exit
And then start a new clean thread for the new device .
"""
2022-09-14 13:22:03 +02:00
import json
2022-09-02 10:28:36 +02:00
import os , re
2022-09-05 13:21:43 +02:00
import traceback
2022-09-02 10:28:36 +02:00
import torch
import numpy as np
2022-10-21 09:53:43 +02:00
from gc import collect as gc_collect
2022-09-02 10:28:36 +02:00
from omegaconf import OmegaConf
2022-09-15 14:24:03 +02:00
from PIL import Image , ImageOps
2022-09-02 10:28:36 +02:00
from tqdm import tqdm , trange
from itertools import islice
from einops import rearrange
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext
from einops import rearrange , repeat
from ldm . util import instantiate_from_config
from optimizedSD . optimUtils import split_weighted_subprompts
from transformers import logging
2022-09-09 17:35:24 +02:00
from gfpgan import GFPGANer
from basicsr . archs . rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
2022-09-03 20:42:48 +02:00
import uuid
2022-09-02 10:28:36 +02:00
logging . set_verbosity_error ( )
# consts
config_yaml = " optimizedSD/v1-inference.yaml "
2022-09-09 17:35:24 +02:00
filename_regex = re . compile ( ' [^a-zA-Z0-9] ' )
2022-11-14 07:21:18 +01:00
force_gfpgan_to_cuda0 = True # workaround: gfpgan currently works only on cuda:0
2022-09-02 10:28:36 +02:00
# api stuff
2022-11-14 06:53:22 +01:00
from sd_internal import device_manager
2022-09-02 10:28:36 +02:00
from . import Request , Response , Image as ResponseImage
import base64
from io import BytesIO
2022-09-16 18:02:08 +02:00
#from colorama import Fore
2022-09-02 10:28:36 +02:00
2022-10-17 03:41:39 +02:00
from threading import local as LocalThreadVars
thread_data = LocalThreadVars ( )
2022-09-11 07:04:04 +02:00
2022-11-14 06:53:22 +01:00
def thread_init ( device ) :
2022-10-17 03:41:39 +02:00
# Thread bound properties
thread_data . stop_processing = False
thread_data . temp_images = { }
thread_data . ckpt_file = None
2022-10-28 16:36:44 +02:00
thread_data . vae_file = None
2022-10-17 03:41:39 +02:00
thread_data . gfpgan_file = None
thread_data . real_esrgan_file = None
thread_data . model = None
thread_data . modelCS = None
thread_data . modelFS = None
thread_data . model_gfpgan = None
thread_data . model_real_esrgan = None
thread_data . model_is_half = False
thread_data . model_fs_is_half = False
thread_data . device = None
2022-10-29 23:33:44 +02:00
thread_data . device_name = None
2022-10-17 03:41:39 +02:00
thread_data . unet_bs = 1
thread_data . precision = ' autocast '
thread_data . sampler_plms = None
thread_data . sampler_ddim = None
thread_data . turbo = False
thread_data . force_full_precision = False
2022-10-21 09:53:43 +02:00
thread_data . reduced_memory = True
2022-10-17 03:41:39 +02:00
2022-11-14 06:53:22 +01:00
device_manager . device_init ( thread_data , device )
2022-10-17 03:41:39 +02:00
def load_model_ckpt ( ) :
if not thread_data . ckpt_file : raise ValueError ( f ' Thread ckpt_file is undefined. ' )
if not os . path . exists ( thread_data . ckpt_file + ' .ckpt ' ) : raise FileNotFoundError ( f ' Cannot find { thread_data . ckpt_file } .ckpt ' )
if not thread_data . precision :
thread_data . precision = ' full ' if thread_data . force_full_precision else ' autocast '
2022-10-23 01:02:02 +02:00
2022-10-17 03:41:39 +02:00
if not thread_data . unet_bs :
thread_data . unet_bs = 1
2022-09-02 10:28:36 +02:00
2022-10-17 03:41:39 +02:00
if thread_data . device == ' cpu ' :
thread_data . precision = ' full '
2022-09-09 17:35:24 +02:00
2022-11-10 15:33:11 +01:00
print ( ' loading ' , thread_data . ckpt_file + ' .ckpt ' , ' to device ' , thread_data . device , ' using precision ' , thread_data . precision )
2022-10-17 03:41:39 +02:00
sd = load_model_from_config ( thread_data . ckpt_file + ' .ckpt ' )
2022-09-02 10:28:36 +02:00
li , lo = [ ] , [ ]
for key , value in sd . items ( ) :
sp = key . split ( " . " )
if ( sp [ 0 ] ) == " model " :
if " input_blocks " in sp :
li . append ( key )
elif " middle_block " in sp :
li . append ( key )
elif " time_embed " in sp :
li . append ( key )
else :
lo . append ( key )
for key in li :
sd [ " model1. " + key [ 6 : ] ] = sd . pop ( key )
for key in lo :
sd [ " model2. " + key [ 6 : ] ] = sd . pop ( key )
config = OmegaConf . load ( f " { config_yaml } " )
model = instantiate_from_config ( config . modelUNet )
_ , _ = model . load_state_dict ( sd , strict = False )
model . eval ( )
2022-10-17 03:41:39 +02:00
model . cdevice = torch . device ( thread_data . device )
model . unet_bs = thread_data . unet_bs
model . turbo = thread_data . turbo
2022-11-16 14:59:04 +01:00
# if thread_data.device != 'cpu':
# model.to(thread_data.device)
2022-10-23 01:02:02 +02:00
#if thread_data.reduced_memory:
#model.model1.to("cpu")
#model.model2.to("cpu")
2022-10-17 03:41:39 +02:00
thread_data . model = model
2022-09-02 10:28:36 +02:00
modelCS = instantiate_from_config ( config . modelCondStage )
_ , _ = modelCS . load_state_dict ( sd , strict = False )
modelCS . eval ( )
2022-10-17 03:41:39 +02:00
modelCS . cond_stage_model . device = torch . device ( thread_data . device )
2022-11-16 14:59:04 +01:00
# if thread_data.device != 'cpu':
# if thread_data.reduced_memory:
# modelCS.to('cpu')
# else:
# modelCS.to(thread_data.device) # Preload on device if not already there.
2022-10-17 03:41:39 +02:00
thread_data . modelCS = modelCS
2022-09-02 10:28:36 +02:00
modelFS = instantiate_from_config ( config . modelFirstStage )
_ , _ = modelFS . load_state_dict ( sd , strict = False )
2022-10-28 16:36:44 +02:00
if thread_data . vae_file is not None :
2022-11-18 08:40:56 +01:00
try :
loaded = False
for model_extension in [ ' .ckpt ' , ' .vae.pt ' ] :
if os . path . exists ( thread_data . vae_file + model_extension ) :
print ( f " Loading VAE weights from: { thread_data . vae_file } { model_extension } " )
vae_ckpt = torch . load ( thread_data . vae_file + model_extension , map_location = " cpu " )
vae_dict = { k : v for k , v in vae_ckpt [ " state_dict " ] . items ( ) if k [ 0 : 4 ] != " loss " }
modelFS . first_stage_model . load_state_dict ( vae_dict , strict = False )
loaded = True
break
if not loaded :
print ( f ' Cannot find VAE: { thread_data . vae_file } ' )
thread_data . vae_file = None
except :
print ( traceback . format_exc ( ) )
print ( f ' Could not load VAE: { thread_data . vae_file } ' )
thread_data . vae_file = None
2022-10-28 16:36:44 +02:00
2022-09-02 10:28:36 +02:00
modelFS . eval ( )
2022-11-16 14:59:04 +01:00
# if thread_data.device != 'cpu':
# if thread_data.reduced_memory:
# modelFS.to('cpu')
# else:
# modelFS.to(thread_data.device) # Preload on device if not already there.
2022-10-17 03:41:39 +02:00
thread_data . modelFS = modelFS
2022-09-02 10:28:36 +02:00
del sd
2022-10-17 03:41:39 +02:00
if thread_data . device != " cpu " and thread_data . precision == " autocast " :
thread_data . model . half ( )
thread_data . modelCS . half ( )
thread_data . modelFS . half ( )
thread_data . model_is_half = True
thread_data . model_fs_is_half = True
2022-09-02 10:28:36 +02:00
else :
2022-10-17 03:41:39 +02:00
thread_data . model_is_half = False
thread_data . model_fs_is_half = False
2022-09-02 10:28:36 +02:00
2022-11-07 13:56:10 +01:00
print ( f ''' loaded model
model file : { thread_data . ckpt_file } . ckpt
model . device : { model . device }
modelCS . device : { modelCS . cond_stage_model . device }
modelFS . device : { thread_data . modelFS . device }
using precision : { thread_data . precision } ''' )
2022-09-09 17:35:24 +02:00
2022-10-21 09:53:43 +02:00
def unload_filters ( ) :
if thread_data . model_gfpgan is not None :
2022-11-10 15:33:11 +01:00
if thread_data . device != ' cpu ' : thread_data . model_gfpgan . gfpgan . to ( ' cpu ' )
2022-10-21 09:53:43 +02:00
del thread_data . model_gfpgan
thread_data . model_gfpgan = None
if thread_data . model_real_esrgan is not None :
2022-11-10 15:33:11 +01:00
if thread_data . device != ' cpu ' : thread_data . model_real_esrgan . model . to ( ' cpu ' )
2022-10-21 09:53:43 +02:00
del thread_data . model_real_esrgan
thread_data . model_real_esrgan = None
2022-11-14 06:53:22 +01:00
gc ( )
2022-10-21 09:53:43 +02:00
def unload_models ( ) :
2022-10-17 03:41:39 +02:00
if thread_data . model is not None :
print ( ' Unloading models... ' )
2022-11-10 15:33:11 +01:00
if thread_data . device != ' cpu ' :
thread_data . modelFS . to ( ' cpu ' )
thread_data . modelCS . to ( ' cpu ' )
thread_data . model . model1 . to ( " cpu " )
thread_data . model . model2 . to ( " cpu " )
2022-10-17 03:41:39 +02:00
del thread_data . model
del thread_data . modelCS
del thread_data . modelFS
2022-10-21 09:53:43 +02:00
2022-10-17 03:41:39 +02:00
thread_data . model = None
thread_data . modelCS = None
thread_data . modelFS = None
2022-11-14 06:53:22 +01:00
gc ( )
2022-11-19 07:23:33 +01:00
# def wait_model_move_to(model, target_device): # Send to target_device and wait until complete.
# if thread_data.device == target_device: return
# start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
# if start_mem <= 0: return
# model_name = model.__class__.__name__
# print(f'Device {thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mb')
# start_time = time.time()
# model.to(target_device)
# time_step = start_time
# WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout.
# last_mem = start_mem
# is_transfering = True
# while is_transfering:
# time.sleep(0.5) # 500ms
# mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
# is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time.
# last_mem = mem
# if not is_transfering:
# break;
# if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity.
# print(f'Device {thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mb, Transfered: {round(start_mem - mem)}Mb')
# time_step = time.time()
# print(f'Device {thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mb in {round(time.time() - start_time, 3)} seconds to {target_device}')
def move_to_cpu ( model ) :
if thread_data . device != " cpu " :
2022-11-19 20:13:38 +01:00
d = torch . device ( thread_data . device )
mem = torch . cuda . memory_allocated ( d ) / 1e6
2022-11-19 07:23:33 +01:00
model . to ( " cpu " )
2022-11-19 20:13:38 +01:00
while torch . cuda . memory_allocated ( d ) / 1e6 > = mem :
2022-11-19 07:23:33 +01:00
time . sleep ( 1 )
2022-10-21 09:53:43 +02:00
2022-10-17 03:41:39 +02:00
def load_model_gfpgan ( ) :
2022-10-21 09:53:43 +02:00
if thread_data . gfpgan_file is None : raise ValueError ( f ' Thread gfpgan_file is undefined. ' )
2022-11-17 07:57:06 +01:00
# hack for a bug in facexlib: https://github.com/xinntao/facexlib/pull/19/files
from facexlib . detection import retinaface
retinaface . device = torch . device ( thread_data . device )
print ( ' forced retinaface.device to ' , thread_data . device )
2022-10-17 03:41:39 +02:00
model_path = thread_data . gfpgan_file + " .pth "
2022-11-17 07:57:06 +01:00
thread_data . model_gfpgan = GFPGANer ( device = torch . device ( thread_data . device ) , model_path = model_path , upscale = 1 , arch = ' clean ' , channel_multiplier = 2 , bg_upsampler = None )
2022-10-17 03:41:39 +02:00
print ( ' loaded ' , thread_data . gfpgan_file , ' to ' , thread_data . model_gfpgan . device , ' precision ' , thread_data . precision )
def load_model_real_esrgan ( ) :
2022-10-21 09:53:43 +02:00
if thread_data . real_esrgan_file is None : raise ValueError ( f ' Thread real_esrgan_file is undefined. ' )
2022-10-17 03:41:39 +02:00
model_path = thread_data . real_esrgan_file + " .pth "
2022-09-09 17:35:24 +02:00
RealESRGAN_models = {
' RealESRGAN_x4plus ' : RRDBNet ( num_in_ch = 3 , num_out_ch = 3 , num_feat = 64 , num_block = 23 , num_grow_ch = 32 , scale = 4 ) ,
' RealESRGAN_x4plus_anime_6B ' : RRDBNet ( num_in_ch = 3 , num_out_ch = 3 , num_feat = 64 , num_block = 6 , num_grow_ch = 32 , scale = 4 )
}
2022-10-17 03:41:39 +02:00
model_to_use = RealESRGAN_models [ thread_data . real_esrgan_file ]
2022-09-09 17:35:24 +02:00
2022-10-17 03:41:39 +02:00
if thread_data . device == ' cpu ' :
thread_data . model_real_esrgan = RealESRGANer ( device = torch . device ( thread_data . device ) , scale = 2 , model_path = model_path , model = model_to_use , pre_pad = 0 , half = False ) # cpu does not support half
#thread_data.model_real_esrgan.device = torch.device(thread_data.device)
thread_data . model_real_esrgan . model . to ( ' cpu ' )
2022-09-09 17:35:24 +02:00
else :
2022-10-17 03:41:39 +02:00
thread_data . model_real_esrgan = RealESRGANer ( device = torch . device ( thread_data . device ) , scale = 2 , model_path = model_path , model = model_to_use , pre_pad = 0 , half = thread_data . model_is_half )
2022-09-09 17:35:24 +02:00
2022-10-17 03:41:39 +02:00
thread_data . model_real_esrgan . model . name = thread_data . real_esrgan_file
print ( ' loaded ' , thread_data . real_esrgan_file , ' to ' , thread_data . model_real_esrgan . device , ' precision ' , thread_data . precision )
2022-09-09 17:35:24 +02:00
2022-11-03 00:23:48 +01:00
def get_session_out_path ( disk_path , session_id ) :
if disk_path is None : return None
if session_id is None : return None
session_out_path = os . path . join ( disk_path , filename_regex . sub ( ' _ ' , session_id ) )
os . makedirs ( session_out_path , exist_ok = True )
return session_out_path
2022-10-18 03:27:15 +02:00
def get_base_path ( disk_path , session_id , prompt , img_id , ext , suffix = None ) :
2022-10-15 05:20:57 +02:00
if disk_path is None : return None
if session_id is None : return None
if ext is None : raise Exception ( ' Missing ext ' )
2022-11-03 00:23:48 +01:00
session_out_path = get_session_out_path ( disk_path , session_id )
2022-10-15 05:20:57 +02:00
prompt_flattened = filename_regex . sub ( ' _ ' , prompt ) [ : 50 ]
if suffix is not None :
return os . path . join ( session_out_path , f " { prompt_flattened } _ { img_id } _ { suffix } . { ext } " )
return os . path . join ( session_out_path , f " { prompt_flattened } _ { img_id } . { ext } " )
2022-10-21 09:53:43 +02:00
def apply_filters ( filter_name , image_data , model_path = None ) :
2022-10-15 05:20:57 +02:00
print ( f ' Applying filter { filter_name } ... ' )
2022-10-22 02:56:24 +02:00
gc ( ) # Free space before loading new data.
2022-10-17 03:41:39 +02:00
2022-11-17 07:57:06 +01:00
if isinstance ( image_data , torch . Tensor ) :
image_data . to ( thread_data . device )
2022-11-14 07:21:18 +01:00
2022-11-17 07:57:06 +01:00
if filter_name == ' gfpgan ' :
2022-10-21 09:53:43 +02:00
if model_path is not None and model_path != thread_data . gfpgan_file :
thread_data . gfpgan_file = model_path
load_model_gfpgan ( )
elif not thread_data . model_gfpgan :
load_model_gfpgan ( )
2022-10-17 03:41:39 +02:00
if thread_data . model_gfpgan is None : raise Exception ( ' Model " gfpgan " not loaded. ' )
print ( ' enhance with ' , thread_data . gfpgan_file , ' on ' , thread_data . model_gfpgan . device , ' precision ' , thread_data . precision )
_ , _ , output = thread_data . model_gfpgan . enhance ( image_data [ : , : , : : - 1 ] , has_aligned = False , only_center_face = False , paste_back = True )
2022-10-15 05:20:57 +02:00
image_data = output [ : , : , : : - 1 ]
if filter_name == ' real_esrgan ' :
2022-10-21 09:53:43 +02:00
if model_path is not None and model_path != thread_data . real_esrgan_file :
thread_data . real_esrgan_file = model_path
load_model_real_esrgan ( )
elif not thread_data . model_real_esrgan :
load_model_real_esrgan ( )
2022-10-17 03:41:39 +02:00
if thread_data . model_real_esrgan is None : raise Exception ( ' Model " gfpgan " not loaded. ' )
print ( ' enhance with ' , thread_data . real_esrgan_file , ' on ' , thread_data . model_real_esrgan . device , ' precision ' , thread_data . precision )
output , _ = thread_data . model_real_esrgan . enhance ( image_data [ : , : , : : - 1 ] )
2022-10-15 05:20:57 +02:00
image_data = output [ : , : , : : - 1 ]
return image_data
2022-09-02 10:28:36 +02:00
def mk_img ( req : Request ) :
2022-09-21 18:23:25 +02:00
try :
yield from do_mk_img ( req )
except Exception as e :
2022-09-22 14:34:11 +02:00
print ( traceback . format_exc ( ) )
2022-10-26 22:52:31 +02:00
2022-11-10 15:33:11 +01:00
if thread_data . device != ' cpu ' :
2022-10-26 22:52:31 +02:00
thread_data . modelFS . to ( ' cpu ' )
thread_data . modelCS . to ( ' cpu ' )
thread_data . model . model1 . to ( " cpu " )
thread_data . model . model2 . to ( " cpu " )
2022-10-22 02:56:24 +02:00
gc ( ) # Release from memory.
2022-09-22 14:34:11 +02:00
yield json . dumps ( {
" status " : ' failed ' ,
" detail " : str ( e )
} )
2022-09-21 18:23:25 +02:00
2022-10-21 09:53:43 +02:00
def update_temp_img ( req , x_samples ) :
partial_images = [ ]
for i in range ( req . num_outputs ) :
x_sample_ddim = thread_data . modelFS . decode_first_stage ( x_samples [ i ] . unsqueeze ( 0 ) )
x_sample = torch . clamp ( ( x_sample_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
x_sample = 255.0 * rearrange ( x_sample [ 0 ] . cpu ( ) . numpy ( ) , " c h w -> h w c " )
x_sample = x_sample . astype ( np . uint8 )
img = Image . fromarray ( x_sample )
buf = BytesIO ( )
img . save ( buf , format = ' JPEG ' )
buf . seek ( 0 )
del img , x_sample , x_sample_ddim
# don't delete x_samples, it is used in the code that called this callback
thread_data . temp_images [ str ( req . session_id ) + ' / ' + str ( i ) ] = buf
partial_images . append ( { ' path ' : f ' /image/tmp/ { req . session_id } / { i } ' } )
return partial_images
# Build and return the apropriate generator for do_mk_img
2022-10-22 07:23:39 +02:00
def get_image_progress_generator ( req , extra_props = None ) :
2022-10-21 09:53:43 +02:00
if not req . stream_progress_updates :
def empty_callback ( x_samples , i ) : return x_samples
return empty_callback
thread_data . partial_x_samples = None
last_callback_time = - 1
def img_callback ( x_samples , i ) :
nonlocal last_callback_time
thread_data . partial_x_samples = x_samples
step_time = time . time ( ) - last_callback_time if last_callback_time != - 1 else - 1
last_callback_time = time . time ( )
progress = { " step " : i , " step_time " : step_time }
if extra_props is not None :
progress . update ( extra_props )
if req . stream_image_progress and i % 5 == 0 :
progress [ ' output ' ] = update_temp_img ( req , x_samples )
yield json . dumps ( progress )
if thread_data . stop_processing :
raise UserInitiatedStop ( " User requested that we stop processing " )
return img_callback
2022-09-21 18:23:25 +02:00
def do_mk_img ( req : Request ) :
2022-10-17 03:41:39 +02:00
thread_data . stop_processing = False
2022-09-02 10:28:36 +02:00
res = Response ( )
2022-09-14 06:45:35 +02:00
res . request = req
2022-09-02 10:28:36 +02:00
res . images = [ ]
2022-10-17 03:41:39 +02:00
thread_data . temp_images . clear ( )
2022-09-14 18:59:42 +02:00
2022-10-06 10:58:02 +02:00
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
2022-10-17 03:41:39 +02:00
if not os . path . exists ( req . use_stable_diffusion_model + ' .ckpt ' ) : raise FileNotFoundError ( f ' Cannot find { req . use_stable_diffusion_model } .ckpt ' )
2022-10-06 10:58:02 +02:00
needs_model_reload = False
2022-10-28 16:36:44 +02:00
if not thread_data . model or thread_data . ckpt_file != req . use_stable_diffusion_model or thread_data . vae_file != req . use_vae_model :
2022-10-17 03:41:39 +02:00
thread_data . ckpt_file = req . use_stable_diffusion_model
2022-10-28 16:36:44 +02:00
thread_data . vae_file = req . use_vae_model
2022-10-06 10:58:02 +02:00
needs_model_reload = True
2022-10-29 23:33:44 +02:00
if thread_data . device != ' cpu ' :
2022-10-19 03:08:04 +02:00
if ( thread_data . precision == ' autocast ' and ( req . use_full_precision or not thread_data . model_is_half ) ) or \
( thread_data . precision == ' full ' and not req . use_full_precision and not thread_data . force_full_precision ) :
thread_data . precision = ' full ' if req . use_full_precision else ' autocast '
2022-10-21 09:53:43 +02:00
needs_model_reload = True
2022-09-02 10:28:36 +02:00
2022-10-06 10:58:02 +02:00
if needs_model_reload :
2022-10-21 09:53:43 +02:00
unload_models ( )
unload_filters ( )
2022-10-17 03:41:39 +02:00
load_model_ckpt ( )
2022-09-09 17:35:24 +02:00
2022-10-17 03:41:39 +02:00
if thread_data . turbo != req . turbo :
thread_data . turbo = req . turbo
thread_data . model . turbo = req . turbo
2022-09-02 10:28:36 +02:00
2022-10-23 01:02:02 +02:00
# Start by cleaning memory, loading and unloading things can leave memory allocated.
2022-10-21 09:53:43 +02:00
gc ( )
2022-09-02 10:28:36 +02:00
opt_prompt = req . prompt
opt_seed = req . seed
opt_n_iter = 1
opt_C = 4
opt_f = 8
opt_ddim_eta = 0.0
2022-10-29 23:33:44 +02:00
print ( req , ' \n device ' , torch . device ( thread_data . device ) , " as " , thread_data . device_name )
2022-10-17 03:41:39 +02:00
print ( ' \n \n Using precision: ' , thread_data . precision )
2022-09-07 12:02:34 +02:00
2022-09-02 10:28:36 +02:00
seed_everything ( opt_seed )
2022-10-15 05:20:57 +02:00
batch_size = req . num_outputs
2022-09-02 10:28:36 +02:00
prompt = opt_prompt
assert prompt is not None
data = [ batch_size * [ prompt ] ]
2022-10-17 03:41:39 +02:00
if thread_data . precision == " autocast " and thread_data . device != " cpu " :
2022-09-02 10:28:36 +02:00
precision_scope = autocast
else :
precision_scope = nullcontext
2022-09-15 14:24:03 +02:00
mask = None
2022-09-02 10:28:36 +02:00
if req . init_image is None :
handler = _txt2img
init_latent = None
t_enc = None
else :
handler = _img2img
2022-10-15 05:20:57 +02:00
init_image = load_img ( req . init_image , req . width , req . height )
2022-10-17 03:41:39 +02:00
init_image = init_image . to ( thread_data . device )
2022-09-02 10:28:36 +02:00
2022-10-17 03:41:39 +02:00
if thread_data . device != " cpu " and thread_data . precision == " autocast " :
2022-09-02 10:28:36 +02:00
init_image = init_image . half ( )
2022-10-17 03:41:39 +02:00
thread_data . modelFS . to ( thread_data . device )
2022-09-02 10:28:36 +02:00
init_image = repeat ( init_image , ' 1 ... -> b ... ' , b = batch_size )
2022-10-17 03:41:39 +02:00
init_latent = thread_data . modelFS . get_first_stage_encoding ( thread_data . modelFS . encode_first_stage ( init_image ) ) # move to latent space
2022-09-02 10:28:36 +02:00
2022-09-15 14:24:03 +02:00
if req . mask is not None :
2022-10-17 03:41:39 +02:00
mask = load_mask ( req . mask , req . width , req . height , init_latent . shape [ 2 ] , init_latent . shape [ 3 ] , True ) . to ( thread_data . device )
2022-09-15 14:24:03 +02:00
mask = mask [ 0 ] [ 0 ] . unsqueeze ( 0 ) . repeat ( 4 , 1 , 1 ) . unsqueeze ( 0 )
mask = repeat ( mask , ' 1 ... -> b ... ' , b = batch_size )
2022-10-17 03:41:39 +02:00
if thread_data . device != " cpu " and thread_data . precision == " autocast " :
2022-09-15 14:24:03 +02:00
mask = mask . half ( )
2022-10-22 02:56:24 +02:00
# Send to CPU and wait until complete.
2022-11-19 07:23:33 +01:00
# wait_model_move_to(thread_data.modelFS, 'cpu')
move_to_cpu ( thread_data . modelFS )
2022-09-02 10:28:36 +02:00
2022-10-15 05:20:57 +02:00
assert 0. < = req . prompt_strength < = 1. , ' can only work with strength in [0.0, 1.0] '
t_enc = int ( req . prompt_strength * req . num_inference_steps )
2022-09-02 10:28:36 +02:00
print ( f " target t_enc is { t_enc } steps " )
2022-10-15 05:20:57 +02:00
if req . save_to_disk_path is not None :
2022-11-03 00:23:48 +01:00
session_out_path = get_session_out_path ( req . save_to_disk_path , req . session_id )
2022-09-03 20:42:48 +02:00
else :
session_out_path = None
2022-09-02 10:28:36 +02:00
with torch . no_grad ( ) :
for n in trange ( opt_n_iter , desc = " Sampling " ) :
for prompts in tqdm ( data , desc = " data " ) :
with precision_scope ( " cuda " ) :
2022-10-21 09:53:43 +02:00
if thread_data . reduced_memory :
thread_data . modelCS . to ( thread_data . device )
2022-09-02 10:28:36 +02:00
uc = None
2022-10-15 05:20:57 +02:00
if req . guidance_scale != 1.0 :
2022-10-17 03:41:39 +02:00
uc = thread_data . modelCS . get_learned_conditioning ( batch_size * [ req . negative_prompt ] )
2022-09-02 10:28:36 +02:00
if isinstance ( prompts , tuple ) :
prompts = list ( prompts )
subprompts , weights = split_weighted_subprompts ( prompts [ 0 ] )
if len ( subprompts ) > 1 :
c = torch . zeros_like ( uc )
totalWeight = sum ( weights )
# normalize each "sub prompt" and add it
for i in range ( len ( subprompts ) ) :
weight = weights [ i ]
# if not skip_normalize:
weight = weight / totalWeight
2022-10-17 03:41:39 +02:00
c = torch . add ( c , thread_data . modelCS . get_learned_conditioning ( subprompts [ i ] ) , alpha = weight )
2022-09-02 10:28:36 +02:00
else :
2022-10-17 03:41:39 +02:00
c = thread_data . modelCS . get_learned_conditioning ( prompts )
2022-09-02 10:28:36 +02:00
2022-10-21 09:53:43 +02:00
if thread_data . reduced_memory :
thread_data . modelFS . to ( thread_data . device )
2022-09-14 18:59:42 +02:00
2022-10-21 09:53:43 +02:00
n_steps = req . num_inference_steps if req . init_image is None else t_enc
2022-10-22 07:23:39 +02:00
img_callback = get_image_progress_generator ( req , { " total_steps " : n_steps } )
2022-09-13 16:29:41 +02:00
2022-09-02 10:28:36 +02:00
# run the handler
2022-09-13 16:29:41 +02:00
try :
2022-10-17 03:41:39 +02:00
print ( ' Running handler... ' )
2022-09-13 16:29:41 +02:00
if handler == _txt2img :
2022-10-15 05:20:57 +02:00
x_samples = _txt2img ( req . width , req . height , req . num_outputs , req . num_inference_steps , req . guidance_scale , None , opt_C , opt_f , opt_ddim_eta , c , uc , opt_seed , img_callback , mask , req . sampler )
2022-09-13 16:29:41 +02:00
else :
2022-10-15 05:20:57 +02:00
x_samples = _img2img ( init_latent , t_enc , batch_size , req . guidance_scale , c , uc , req . num_inference_steps , opt_ddim_eta , opt_seed , img_callback , mask )
2022-09-14 13:22:03 +02:00
2022-10-21 09:53:43 +02:00
if req . stream_progress_updates :
yield from x_samples
if hasattr ( thread_data , ' partial_x_samples ' ) :
if thread_data . partial_x_samples is not None :
x_samples = thread_data . partial_x_samples
del thread_data . partial_x_samples
2022-09-13 16:29:41 +02:00
except UserInitiatedStop :
2022-10-21 09:53:43 +02:00
if not hasattr ( thread_data , ' partial_x_samples ' ) :
2022-09-13 16:29:41 +02:00
continue
2022-10-21 09:53:43 +02:00
if thread_data . partial_x_samples is None :
del thread_data . partial_x_samples
continue
x_samples = thread_data . partial_x_samples
del thread_data . partial_x_samples
2022-09-13 16:29:41 +02:00
2022-10-21 09:53:43 +02:00
print ( " decoding images " )
img_data = [ None ] * batch_size
2022-09-02 10:28:36 +02:00
for i in range ( batch_size ) :
2022-10-17 03:41:39 +02:00
x_samples_ddim = thread_data . modelFS . decode_first_stage ( x_samples [ i ] . unsqueeze ( 0 ) )
2022-09-02 10:28:36 +02:00
x_sample = torch . clamp ( ( x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
x_sample = 255.0 * rearrange ( x_sample [ 0 ] . cpu ( ) . numpy ( ) , " c h w -> h w c " )
2022-09-09 17:35:24 +02:00
x_sample = x_sample . astype ( np . uint8 )
2022-10-21 09:53:43 +02:00
img_data [ i ] = x_sample
del x_samples , x_samples_ddim , x_sample
print ( " saving images " )
for i in range ( batch_size ) :
img = Image . fromarray ( img_data [ i ] )
2022-10-25 08:10:52 +02:00
img_id = base64 . b64encode ( int ( time . time ( ) + i ) . to_bytes ( 8 , ' big ' ) ) . decode ( ) # Generate unique ID based on time.
img_id = img_id . translate ( { 43 : None , 47 : None , 61 : None } ) [ - 8 : ] # Remove + / = and keep last 8 chars.
2022-09-08 17:50:27 +02:00
2022-10-15 05:20:57 +02:00
has_filters = ( req . use_face_correction is not None and req . use_face_correction . startswith ( ' GFPGAN ' ) ) or \
( req . use_upscale is not None and req . use_upscale . startswith ( ' RealESRGAN ' ) )
2022-09-21 14:11:42 +02:00
2022-10-15 05:20:57 +02:00
return_orig_img = not has_filters or not req . show_only_filtered_image
2022-09-21 15:59:27 +02:00
2022-10-17 03:41:39 +02:00
if thread_data . stop_processing :
2022-09-21 15:59:27 +02:00
return_orig_img = True
2022-10-15 05:20:57 +02:00
if req . save_to_disk_path is not None :
2022-09-21 15:59:27 +02:00
if return_orig_img :
2022-10-18 03:27:15 +02:00
img_out_path = get_base_path ( req . save_to_disk_path , req . session_id , prompts [ 0 ] , img_id , req . output_format )
2022-09-09 17:35:24 +02:00
save_image ( img , img_out_path )
2022-10-18 03:27:15 +02:00
meta_out_path = get_base_path ( req . save_to_disk_path , req . session_id , prompts [ 0 ] , img_id , ' txt ' )
2022-10-15 05:20:57 +02:00
save_metadata ( meta_out_path , req , prompts [ 0 ] , opt_seed )
2022-09-09 17:35:24 +02:00
2022-09-21 15:59:27 +02:00
if return_orig_img :
2022-10-26 22:14:29 +02:00
img_str = img_to_base64_str ( img , req . output_format )
res_image_orig = ResponseImage ( data = img_str , seed = opt_seed )
2022-09-14 06:45:35 +02:00
res . images . append ( res_image_orig )
2022-10-15 05:20:57 +02:00
if req . save_to_disk_path is not None :
2022-09-14 06:45:35 +02:00
res_image_orig . path_abs = img_out_path
2022-09-21 18:23:25 +02:00
del img
2022-10-17 03:41:39 +02:00
if has_filters and not thread_data . stop_processing :
2022-09-09 17:35:24 +02:00
filters_applied = [ ]
2022-10-15 05:20:57 +02:00
if req . use_face_correction :
2022-10-21 09:53:43 +02:00
img_data [ i ] = apply_filters ( ' gfpgan ' , img_data [ i ] , req . use_face_correction )
2022-10-15 05:20:57 +02:00
filters_applied . append ( req . use_face_correction )
if req . use_upscale :
2022-10-21 09:53:43 +02:00
img_data [ i ] = apply_filters ( ' real_esrgan ' , img_data [ i ] , req . use_upscale )
2022-10-15 05:20:57 +02:00
filters_applied . append ( req . use_upscale )
if ( len ( filters_applied ) > 0 ) :
2022-10-21 09:53:43 +02:00
filtered_image = Image . fromarray ( img_data [ i ] )
2022-10-15 05:20:57 +02:00
filtered_img_data = img_to_base64_str ( filtered_image , req . output_format )
2022-10-19 23:16:51 +02:00
response_image = ResponseImage ( data = filtered_img_data , seed = opt_seed )
2022-10-15 05:20:57 +02:00
res . images . append ( response_image )
if req . save_to_disk_path is not None :
2022-10-18 03:27:15 +02:00
filtered_img_out_path = get_base_path ( req . save_to_disk_path , req . session_id , prompts [ 0 ] , img_id , req . output_format , " _ " . join ( filters_applied ) )
2022-10-15 05:20:57 +02:00
save_image ( filtered_image , filtered_img_out_path )
response_image . path_abs = filtered_img_out_path
del filtered_image
2022-10-23 11:00:21 +02:00
# Filter Applied, move to next seed
2022-09-02 10:28:36 +02:00
opt_seed + = 1
2022-11-10 15:33:11 +01:00
# if thread_data.reduced_memory:
# unload_filters()
2022-11-19 07:23:33 +01:00
move_to_cpu ( thread_data . modelFS )
2022-10-21 09:53:43 +02:00
del img_data
2022-09-21 18:23:25 +02:00
gc ( )
2022-10-19 03:08:04 +02:00
if thread_data . device != ' cpu ' :
2022-11-10 15:33:11 +01:00
print ( f ' memory_final = { round ( torch . cuda . memory_allocated ( thread_data . device ) / 1e6 , 2 ) } Mb ' )
2022-09-02 10:28:36 +02:00
2022-09-16 18:02:08 +02:00
print ( ' Task completed ' )
2022-09-22 20:49:05 +02:00
yield json . dumps ( res . json ( ) )
2022-09-02 10:28:36 +02:00
2022-09-09 17:35:24 +02:00
def save_image ( img , img_out_path ) :
try :
img . save ( img_out_path )
except :
print ( ' could not save the file ' , traceback . format_exc ( ) )
2022-10-15 05:20:57 +02:00
def save_metadata ( meta_out_path , req , prompt , opt_seed ) :
2022-10-17 03:41:39 +02:00
metadata = f ''' { prompt }
2022-10-15 05:20:57 +02:00
Width : { req . width }
Height : { req . height }
Seed : { opt_seed }
Steps : { req . num_inference_steps }
Guidance Scale : { req . guidance_scale }
Prompt Strength : { req . prompt_strength }
Use Face Correction : { req . use_face_correction }
Use Upscaling : { req . use_upscale }
Sampler : { req . sampler }
Negative Prompt : { req . negative_prompt }
2022-10-17 03:41:39 +02:00
Stable Diffusion model : { req . use_stable_diffusion_model + ' .ckpt ' }
2022-11-12 09:01:59 +01:00
VAE model : { req . use_vae_model }
2022-10-17 03:41:39 +02:00
'''
2022-09-09 17:35:24 +02:00
try :
2022-10-18 05:15:36 +02:00
with open ( meta_out_path , ' w ' , encoding = ' utf-8 ' ) as f :
2022-09-09 17:35:24 +02:00
f . write ( metadata )
except :
print ( ' could not save the file ' , traceback . format_exc ( ) )
2022-09-22 20:49:05 +02:00
def _txt2img ( opt_W , opt_H , opt_n_samples , opt_ddim_steps , opt_scale , start_code , opt_C , opt_f , opt_ddim_eta , c , uc , opt_seed , img_callback , mask , sampler_name ) :
2022-09-02 10:28:36 +02:00
shape = [ opt_n_samples , opt_C , opt_H / / opt_f , opt_W / / opt_f ]
2022-10-22 02:56:24 +02:00
# Send to CPU and wait until complete.
2022-11-19 07:23:33 +01:00
# wait_model_move_to(thread_data.modelCS, 'cpu')
move_to_cpu ( thread_data . modelCS )
2022-10-22 03:44:15 +02:00
2022-09-23 17:52:44 +02:00
if sampler_name == ' ddim ' :
2022-10-17 03:41:39 +02:00
thread_data . model . make_schedule ( ddim_num_steps = opt_ddim_steps , ddim_eta = opt_ddim_eta , verbose = False )
2022-09-23 17:52:44 +02:00
2022-10-17 03:41:39 +02:00
samples_ddim = thread_data . model . sample (
2022-09-02 10:28:36 +02:00
S = opt_ddim_steps ,
conditioning = c ,
seed = opt_seed ,
shape = shape ,
verbose = False ,
unconditional_guidance_scale = opt_scale ,
unconditional_conditioning = uc ,
eta = opt_ddim_eta ,
x_T = start_code ,
2022-09-13 16:29:41 +02:00
img_callback = img_callback ,
2022-09-15 14:24:03 +02:00
mask = mask ,
2022-09-22 20:49:05 +02:00
sampler = sampler_name ,
2022-09-02 10:28:36 +02:00
)
2022-09-22 20:49:05 +02:00
yield from samples_ddim
2022-09-02 10:28:36 +02:00
2022-09-22 20:49:05 +02:00
def _img2img ( init_latent , t_enc , batch_size , opt_scale , c , uc , opt_ddim_steps , opt_ddim_eta , opt_seed , img_callback , mask ) :
2022-09-02 10:28:36 +02:00
# encode (scaled latent)
2022-10-17 03:41:39 +02:00
z_enc = thread_data . model . stochastic_encode (
2022-09-02 10:28:36 +02:00
init_latent ,
2022-10-17 03:41:39 +02:00
torch . tensor ( [ t_enc ] * batch_size ) . to ( thread_data . device ) ,
2022-09-02 10:28:36 +02:00
opt_seed ,
opt_ddim_eta ,
opt_ddim_steps ,
)
2022-09-15 14:24:03 +02:00
x_T = None if mask is None else init_latent
2022-09-02 10:28:36 +02:00
# decode it
2022-10-17 03:41:39 +02:00
samples_ddim = thread_data . model . sample (
2022-09-02 10:28:36 +02:00
t_enc ,
c ,
z_enc ,
unconditional_guidance_scale = opt_scale ,
unconditional_conditioning = uc ,
2022-09-13 16:29:41 +02:00
img_callback = img_callback ,
2022-09-15 14:24:03 +02:00
mask = mask ,
x_T = x_T ,
2022-09-02 10:28:36 +02:00
sampler = ' ddim '
)
2022-09-22 20:49:05 +02:00
yield from samples_ddim
2022-09-02 10:28:36 +02:00
2022-09-09 17:35:24 +02:00
def gc ( ) :
2022-10-21 09:53:43 +02:00
gc_collect ( )
2022-10-17 03:41:39 +02:00
if thread_data . device == ' cpu ' :
2022-09-09 17:35:24 +02:00
return
torch . cuda . empty_cache ( )
torch . cuda . ipc_collect ( )
2022-09-02 10:28:36 +02:00
# internal
def chunk ( it , size ) :
it = iter ( it )
return iter ( lambda : tuple ( islice ( it , size ) ) , ( ) )
def load_model_from_config ( ckpt , verbose = False ) :
print ( f " Loading model from { ckpt } " )
pl_sd = torch . load ( ckpt , map_location = " cpu " )
if " global_step " in pl_sd :
print ( f " Global Step: { pl_sd [ ' global_step ' ] } " )
sd = pl_sd [ " state_dict " ]
return sd
# utils
2022-09-13 16:29:41 +02:00
class UserInitiatedStop ( Exception ) :
pass
2022-09-02 10:28:36 +02:00
2022-09-14 08:06:55 +02:00
def load_img ( img_str , w0 , h0 ) :
2022-09-02 10:28:36 +02:00
image = base64_str_to_img ( img_str ) . convert ( " RGB " )
w , h = image . size
print ( f " loaded input image of size ( { w } , { h } ) from base64 " )
2022-09-14 08:06:55 +02:00
if h0 is not None and w0 is not None :
h , w = h0 , w0
2022-09-02 10:28:36 +02:00
w , h = map ( lambda x : x - x % 64 , ( w , h ) ) # resize to integer multiple of 64
2022-09-14 08:06:55 +02:00
image = image . resize ( ( w , h ) , resample = Image . Resampling . LANCZOS )
2022-09-02 10:28:36 +02:00
image = np . array ( image ) . astype ( np . float32 ) / 255.0
image = image [ None ] . transpose ( 0 , 3 , 1 , 2 )
image = torch . from_numpy ( image )
return 2. * image - 1.
2022-09-15 14:24:03 +02:00
def load_mask ( mask_str , h0 , w0 , newH , newW , invert = False ) :
image = base64_str_to_img ( mask_str ) . convert ( " RGB " )
w , h = image . size
print ( f " loaded input mask of size ( { w } , { h } ) " )
if invert :
print ( " inverted " )
image = ImageOps . invert ( image )
# where_0, where_1 = np.where(image == 0), np.where(image == 255)
# image[where_0], image[where_1] = 255, 0
if h0 is not None and w0 is not None :
h , w = h0 , w0
w , h = map ( lambda x : x - x % 64 , ( w , h ) ) # resize to integer multiple of 64
print ( f " New mask size ( { w } , { h } ) " )
image = image . resize ( ( newW , newH ) , resample = Image . Resampling . LANCZOS )
image = np . array ( image )
image = image . astype ( np . float32 ) / 255.0
image = image [ None ] . transpose ( 0 , 3 , 1 , 2 )
image = torch . from_numpy ( image )
return image
2022-09-02 10:28:36 +02:00
# https://stackoverflow.com/a/61114178
2022-10-06 11:35:34 +02:00
def img_to_base64_str ( img , output_format = " PNG " ) :
2022-09-02 10:28:36 +02:00
buffered = BytesIO ( )
2022-10-06 11:35:34 +02:00
img . save ( buffered , format = output_format )
2022-09-02 10:28:36 +02:00
buffered . seek ( 0 )
img_byte = buffered . getvalue ( )
2022-10-31 14:35:57 +01:00
mime_type = " image/png " if output_format . lower ( ) == " png " else " image/jpeg "
img_str = f " data: { mime_type } ;base64, " + base64 . b64encode ( img_byte ) . decode ( )
2022-09-02 10:28:36 +02:00
return img_str
2022-11-01 09:52:42 +01:00
def base64_str_to_buffer ( img_str ) :
2022-10-31 14:35:57 +01:00
mime_type = " image/png " if img_str . startswith ( " data:image/png; " ) else " image/jpeg "
img_str = img_str [ len ( f " data: { mime_type } ;base64, " ) : ]
2022-09-02 10:28:36 +02:00
data = base64 . b64decode ( img_str )
buffered = BytesIO ( data )
2022-11-01 09:52:42 +01:00
return buffered
def base64_str_to_img ( img_str ) :
buffered = base64_str_to_buffer ( img_str )
2022-09-02 10:28:36 +02:00
img = Image . open ( buffered )
return img