2022-09-14 13:22:03 +02:00
import json
2022-09-02 10:28:36 +02:00
import os , re
2022-09-05 13:21:43 +02:00
import traceback
2022-09-02 10:28:36 +02:00
import torch
import numpy as np
from omegaconf import OmegaConf
2022-09-15 14:24:03 +02:00
from PIL import Image , ImageOps
2022-09-02 10:28:36 +02:00
from tqdm import tqdm , trange
from itertools import islice
from einops import rearrange
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext
from einops import rearrange , repeat
from ldm . util import instantiate_from_config
from optimizedSD . optimUtils import split_weighted_subprompts
from transformers import logging
2022-09-09 17:35:24 +02:00
from gfpgan import GFPGANer
from basicsr . archs . rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
2022-09-03 20:42:48 +02:00
import uuid
2022-09-02 10:28:36 +02:00
logging . set_verbosity_error ( )
# consts
config_yaml = " optimizedSD/v1-inference.yaml "
2022-09-09 17:35:24 +02:00
filename_regex = re . compile ( ' [^a-zA-Z0-9] ' )
2022-09-02 10:28:36 +02:00
# api stuff
from . import Request , Response , Image as ResponseImage
import base64
from io import BytesIO
2022-09-16 18:02:08 +02:00
#from colorama import Fore
2022-09-02 10:28:36 +02:00
2022-10-17 03:41:39 +02:00
from threading import local as LocalThreadVars
thread_data = LocalThreadVars ( )
2022-09-11 07:04:04 +02:00
2022-10-17 03:41:39 +02:00
def device_would_fail ( device ) :
if device == ' cpu ' : return None
# Returns None when no issues found, otherwise returns the detected error str.
# Memory check
try :
mem_free , mem_total = torch . cuda . mem_get_info ( device )
mem_total / = float ( 10 * * 9 )
if mem_total < 3.0 :
return ' GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion '
except RuntimeError as e :
return str ( e ) # Return cuda errors from mem_get_info as strings
return None
def device_select ( device ) :
if device == ' cpu ' : return True
if not torch . cuda . is_available ( ) : return False
failure_msg = device_would_fail ( device )
if failure_msg :
if ' invalid device ' in failure_msg :
raise NameError ( f ' GPU " { device } " could not be found. Remove this device from config.render_devices or use one of " auto " or " cuda " . ' )
print ( failure_msg )
return False
device_name = torch . cuda . get_device_name ( device )
# otherwise these NVIDIA cards create green images
thread_data . force_full_precision = ( ' nvidia ' in device_name . lower ( ) or ' geforce ' in device_name . lower ( ) ) and ( ' 1660 ' in device_name or ' 1650 ' in device_name )
if thread_data . force_full_precision :
print ( ' forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ' , gpu_name )
2022-09-11 07:04:04 +02:00
2022-10-17 03:41:39 +02:00
thread_data . device = device
thread_data . has_valid_gpu = True
return True
def device_init ( device_selection = None ) :
# Thread bound properties
thread_data . stop_processing = False
thread_data . temp_images = { }
thread_data . ckpt_file = None
thread_data . gfpgan_file = None
thread_data . real_esrgan_file = None
thread_data . model = None
thread_data . modelCS = None
thread_data . modelFS = None
thread_data . model_gfpgan = None
thread_data . model_real_esrgan = None
thread_data . model_is_half = False
thread_data . model_fs_is_half = False
thread_data . device = None
thread_data . unet_bs = 1
thread_data . precision = ' autocast '
thread_data . sampler_plms = None
thread_data . sampler_ddim = None
thread_data . turbo = False
thread_data . has_valid_gpu = False
thread_data . force_full_precision = False
if device_selection . lower ( ) == ' cpu ' :
print ( ' CPU requested, skipping gpu init. ' )
thread_data . device = ' cpu '
return
if not torch . cuda . is_available ( ) :
print ( ' WARNING: torch.cuda is not available. Using the CPU, but this will be very slow! ' )
return
device_count = torch . cuda . device_count ( )
if device_count < = 1 and device_selection == ' auto ' :
device_selection = ' current ' # Use 'auto' only when there is more than one compatible device found.
if device_selection == ' auto ' :
print ( ' Autoselecting GPU. Using most free memory. ' )
max_mem_free = 0
best_device = None
for device in range ( device_count ) :
mem_free , mem_total = torch . cuda . mem_get_info ( device )
mem_free / = float ( 10 * * 9 )
mem_total / = float ( 10 * * 9 )
device_name = torch . cuda . get_device_name ( device )
print ( f ' GPU: { device } detected: { device_name } - Memory: { round ( mem_total - mem_free , 2 ) } Go / { round ( mem_total , 2 ) } Go ' )
if max_mem_free < mem_free :
max_mem_free = mem_free
best_device = device
if best_device and device_select ( device ) :
print ( f ' Setting GPU: { device } as active ' )
torch . cuda . device ( device )
return
if isinstance ( device_selection , str ) :
device_selection = device_selection . lower ( )
if device_selection . startswith ( ' gpu: ' ) :
device_selection = int ( device_selection [ 4 : ] )
2022-10-17 07:05:51 +02:00
if device_selection != ' cuda ' and device_selection != ' current ' and device_selection != ' gpu ' :
2022-10-17 03:41:39 +02:00
if device_select ( device_selection ) :
if isinstance ( device_selection , int ) :
print ( f ' Setting GPU: { device_selection } as active ' )
else :
print ( f ' Setting { device_selection } as active ' )
torch . cuda . device ( device_selection )
return
# By default use current device.
print ( ' Checking current GPU... ' )
device = torch . cuda . current_device ( )
device_name = torch . cuda . get_device_name ( device )
print ( f ' GPU: { device } detected: { device_name } ' )
if device_select ( device ) :
return
2022-09-09 17:35:24 +02:00
print ( ' WARNING: No compatible GPU found. Using the CPU, but this will be very slow! ' )
2022-10-17 03:41:39 +02:00
thread_data . device = ' cpu '
def is_first_cuda_device ( device ) :
2022-10-17 05:06:41 +02:00
if device is None : return False
if device == 0 or device == ' 0 ' : return True
if device == ' cuda ' or device == ' cuda:0 ' : return True
2022-10-17 07:05:51 +02:00
if device == ' gpu ' or device == ' gpu:0 ' : return True
if device == ' current ' : return True
2022-10-17 05:06:41 +02:00
if device == torch . device ( 0 ) : return True
2022-10-17 03:41:39 +02:00
return False
def load_model_ckpt ( ) :
if not thread_data . ckpt_file : raise ValueError ( f ' Thread ckpt_file is undefined. ' )
if not os . path . exists ( thread_data . ckpt_file + ' .ckpt ' ) : raise FileNotFoundError ( f ' Cannot find { thread_data . ckpt_file } .ckpt ' )
if not thread_data . precision :
thread_data . precision = ' full ' if thread_data . force_full_precision else ' autocast '
if not thread_data . unet_bs :
thread_data . unet_bs = 1
2022-09-02 10:28:36 +02:00
2022-10-08 09:49:18 +02:00
unload_model ( )
2022-10-17 03:41:39 +02:00
if thread_data . device == ' cpu ' :
thread_data . precision = ' full '
2022-09-09 17:35:24 +02:00
2022-10-17 03:41:39 +02:00
print ( ' loading ' , thread_data . ckpt_file , ' to ' , thread_data . device , ' using precision ' , thread_data . precision )
sd = load_model_from_config ( thread_data . ckpt_file + ' .ckpt ' )
2022-09-02 10:28:36 +02:00
li , lo = [ ] , [ ]
for key , value in sd . items ( ) :
sp = key . split ( " . " )
if ( sp [ 0 ] ) == " model " :
if " input_blocks " in sp :
li . append ( key )
elif " middle_block " in sp :
li . append ( key )
elif " time_embed " in sp :
li . append ( key )
else :
lo . append ( key )
for key in li :
sd [ " model1. " + key [ 6 : ] ] = sd . pop ( key )
for key in lo :
sd [ " model2. " + key [ 6 : ] ] = sd . pop ( key )
config = OmegaConf . load ( f " { config_yaml } " )
model = instantiate_from_config ( config . modelUNet )
_ , _ = model . load_state_dict ( sd , strict = False )
model . eval ( )
2022-10-17 03:41:39 +02:00
model . cdevice = torch . device ( thread_data . device )
model . unet_bs = thread_data . unet_bs
model . turbo = thread_data . turbo
if thread_data . device != ' cpu ' :
model . to ( thread_data . device )
thread_data . model = model
2022-09-02 10:28:36 +02:00
modelCS = instantiate_from_config ( config . modelCondStage )
_ , _ = modelCS . load_state_dict ( sd , strict = False )
modelCS . eval ( )
2022-10-17 03:41:39 +02:00
modelCS . cond_stage_model . device = torch . device ( thread_data . device )
if thread_data . device != ' cpu ' :
modelCS . to ( thread_data . device )
thread_data . modelCS = modelCS
2022-09-02 10:28:36 +02:00
modelFS = instantiate_from_config ( config . modelFirstStage )
_ , _ = modelFS . load_state_dict ( sd , strict = False )
modelFS . eval ( )
2022-10-17 03:41:39 +02:00
if thread_data . device != ' cpu ' :
modelFS . to ( thread_data . device )
thread_data . modelFS = modelFS
2022-09-02 10:28:36 +02:00
del sd
2022-10-17 03:41:39 +02:00
if thread_data . device != " cpu " and thread_data . precision == " autocast " :
thread_data . model . half ( )
thread_data . modelCS . half ( )
thread_data . modelFS . half ( )
thread_data . model_is_half = True
thread_data . model_fs_is_half = True
2022-09-02 10:28:36 +02:00
else :
2022-10-17 03:41:39 +02:00
thread_data . model_is_half = False
thread_data . model_fs_is_half = False
2022-09-02 10:28:36 +02:00
2022-10-17 03:41:39 +02:00
print ( ' loaded ' , thread_data . ckpt_file , ' as ' , model . device , ' -> ' , modelCS . cond_stage_model . device , ' -> ' , thread_data . modelFS . device , ' using precision ' , thread_data . precision )
2022-09-09 17:35:24 +02:00
2022-10-08 09:49:18 +02:00
def unload_model ( ) :
2022-10-17 03:41:39 +02:00
if thread_data . model is not None :
print ( ' Unloading models... ' )
del thread_data . model
del thread_data . modelCS
del thread_data . modelFS
thread_data . model = None
thread_data . modelCS = None
thread_data . modelFS = None
def load_model_gfpgan ( ) :
if thread_data . gfpgan_file is None :
print ( ' load_model_gfpgan called without setting gfpgan_file ' )
2022-09-09 17:35:24 +02:00
return
2022-10-17 03:41:39 +02:00
if thread_data . device != ' cpu ' and not is_first_cuda_device ( thread_data . device ) :
#TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
raise Exception ( f ' Current device { torch . device ( thread_data . device ) } is not { torch . device ( 0 ) } . ' )
model_path = thread_data . gfpgan_file + " .pth "
thread_data . model_gfpgan = GFPGANer ( device = torch . device ( thread_data . device ) , model_path = model_path , upscale = 1 , arch = ' clean ' , channel_multiplier = 2 , bg_upsampler = None )
print ( ' loaded ' , thread_data . gfpgan_file , ' to ' , thread_data . model_gfpgan . device , ' precision ' , thread_data . precision )
def load_model_real_esrgan ( ) :
if thread_data . real_esrgan_file is None :
print ( ' load_model_real_esrgan called without setting real_esrgan_file ' )
2022-09-09 17:35:24 +02:00
return
2022-10-17 03:41:39 +02:00
model_path = thread_data . real_esrgan_file + " .pth "
2022-09-09 17:35:24 +02:00
RealESRGAN_models = {
' RealESRGAN_x4plus ' : RRDBNet ( num_in_ch = 3 , num_out_ch = 3 , num_feat = 64 , num_block = 23 , num_grow_ch = 32 , scale = 4 ) ,
' RealESRGAN_x4plus_anime_6B ' : RRDBNet ( num_in_ch = 3 , num_out_ch = 3 , num_feat = 64 , num_block = 6 , num_grow_ch = 32 , scale = 4 )
}
2022-10-17 03:41:39 +02:00
model_to_use = RealESRGAN_models [ thread_data . real_esrgan_file ]
2022-09-09 17:35:24 +02:00
2022-10-17 03:41:39 +02:00
if thread_data . device == ' cpu ' :
thread_data . model_real_esrgan = RealESRGANer ( device = torch . device ( thread_data . device ) , scale = 2 , model_path = model_path , model = model_to_use , pre_pad = 0 , half = False ) # cpu does not support half
#thread_data.model_real_esrgan.device = torch.device(thread_data.device)
thread_data . model_real_esrgan . model . to ( ' cpu ' )
2022-09-09 17:35:24 +02:00
else :
2022-10-17 03:41:39 +02:00
thread_data . model_real_esrgan = RealESRGANer ( device = torch . device ( thread_data . device ) , scale = 2 , model_path = model_path , model = model_to_use , pre_pad = 0 , half = thread_data . model_is_half )
2022-09-09 17:35:24 +02:00
2022-10-17 03:41:39 +02:00
thread_data . model_real_esrgan . model . name = thread_data . real_esrgan_file
print ( ' loaded ' , thread_data . real_esrgan_file , ' to ' , thread_data . model_real_esrgan . device , ' precision ' , thread_data . precision )
2022-09-09 17:35:24 +02:00
2022-10-15 05:20:57 +02:00
def get_base_path ( disk_path , session_id , prompt , ext , suffix = None ) :
if disk_path is None : return None
if session_id is None : return None
if ext is None : raise Exception ( ' Missing ext ' )
session_out_path = os . path . join ( disk_path , session_id )
os . makedirs ( session_out_path , exist_ok = True )
prompt_flattened = filename_regex . sub ( ' _ ' , prompt ) [ : 50 ]
img_id = str ( uuid . uuid4 ( ) ) [ - 8 : ]
if suffix is not None :
return os . path . join ( session_out_path , f " { prompt_flattened } _ { img_id } _ { suffix } . { ext } " )
return os . path . join ( session_out_path , f " { prompt_flattened } _ { img_id } . { ext } " )
def apply_filters ( filter_name , image_data ) :
print ( f ' Applying filter { filter_name } ... ' )
2022-10-17 03:41:39 +02:00
if isinstance ( image_data , torch . Tensor ) :
print ( image_data )
image_data . to ( thread_data . device )
2022-10-15 05:20:57 +02:00
gc ( )
if filter_name == ' gfpgan ' :
2022-10-17 03:41:39 +02:00
if thread_data . model_gfpgan is None : raise Exception ( ' Model " gfpgan " not loaded. ' )
print ( ' enhance with ' , thread_data . gfpgan_file , ' on ' , thread_data . model_gfpgan . device , ' precision ' , thread_data . precision )
_ , _ , output = thread_data . model_gfpgan . enhance ( image_data [ : , : , : : - 1 ] , has_aligned = False , only_center_face = False , paste_back = True )
2022-10-15 05:20:57 +02:00
image_data = output [ : , : , : : - 1 ]
if filter_name == ' real_esrgan ' :
2022-10-17 03:41:39 +02:00
if thread_data . model_real_esrgan is None : raise Exception ( ' Model " gfpgan " not loaded. ' )
print ( ' enhance with ' , thread_data . real_esrgan_file , ' on ' , thread_data . model_real_esrgan . device , ' precision ' , thread_data . precision )
output , _ = thread_data . model_real_esrgan . enhance ( image_data [ : , : , : : - 1 ] )
2022-10-15 05:20:57 +02:00
image_data = output [ : , : , : : - 1 ]
return image_data
2022-09-02 10:28:36 +02:00
def mk_img ( req : Request ) :
2022-09-21 18:23:25 +02:00
try :
yield from do_mk_img ( req )
except Exception as e :
2022-09-22 14:34:11 +02:00
print ( traceback . format_exc ( ) )
2022-09-21 18:23:25 +02:00
gc ( )
2022-09-22 14:34:11 +02:00
2022-10-17 03:41:39 +02:00
if thread_data . device != " cpu " :
thread_data . modelFS . to ( " cpu " )
thread_data . modelCS . to ( " cpu " )
2022-09-22 18:44:25 +02:00
2022-10-17 03:41:39 +02:00
thread_data . model . model1 . to ( " cpu " )
thread_data . model . model2 . to ( " cpu " )
2022-09-22 18:44:25 +02:00
gc ( )
2022-09-22 14:34:11 +02:00
yield json . dumps ( {
" status " : ' failed ' ,
" detail " : str ( e )
} )
2022-09-21 18:23:25 +02:00
def do_mk_img ( req : Request ) :
2022-10-17 03:41:39 +02:00
thread_data . stop_processing = False
2022-09-02 10:28:36 +02:00
res = Response ( )
2022-09-14 06:45:35 +02:00
res . request = req
2022-09-02 10:28:36 +02:00
res . images = [ ]
2022-10-17 03:41:39 +02:00
thread_data . temp_images . clear ( )
2022-09-14 18:59:42 +02:00
2022-10-06 10:58:02 +02:00
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
2022-10-17 03:41:39 +02:00
if not os . path . exists ( req . use_stable_diffusion_model + ' .ckpt ' ) : raise FileNotFoundError ( f ' Cannot find { req . use_stable_diffusion_model } .ckpt ' )
2022-10-06 10:58:02 +02:00
needs_model_reload = False
2022-10-17 03:41:39 +02:00
if thread_data . ckpt_file != req . use_stable_diffusion_model :
thread_data . ckpt_file = req . use_stable_diffusion_model
2022-10-06 10:58:02 +02:00
needs_model_reload = True
2022-09-02 10:28:36 +02:00
if req . use_cpu :
2022-10-17 03:41:39 +02:00
if thread_data . device != ' cpu ' :
thread_data . device = ' cpu '
if thread_data . model_is_half :
load_model_ckpt ( )
2022-10-06 10:58:02 +02:00
needs_model_reload = False
2022-10-17 03:41:39 +02:00
load_model_gfpgan ( )
load_model_real_esrgan ( )
2022-09-02 10:28:36 +02:00
else :
2022-10-17 03:41:39 +02:00
if thread_data . has_valid_gpu :
if ( thread_data . precision == ' autocast ' and ( req . use_full_precision or not thread_data . model_is_half ) ) or \
( thread_data . precision == ' full ' and not req . use_full_precision and not thread_data . force_full_precision ) :
thread_data . precision = ' full ' if req . use_full_precision else ' autocast '
load_model_ckpt ( )
load_model_gfpgan ( )
load_model_real_esrgan ( )
2022-10-06 10:58:02 +02:00
needs_model_reload = False
2022-09-02 10:28:36 +02:00
2022-10-06 10:58:02 +02:00
if needs_model_reload :
2022-10-17 03:41:39 +02:00
load_model_ckpt ( )
2022-09-09 17:35:24 +02:00
2022-10-17 03:41:39 +02:00
if req . use_face_correction != thread_data . gfpgan_file :
thread_data . gfpgan_file = req . use_face_correction
load_model_gfpgan ( )
if req . use_upscale != thread_data . real_esrgan_file :
thread_data . real_esrgan_file = req . use_upscale
load_model_real_esrgan ( )
2022-09-02 10:28:36 +02:00
2022-10-17 03:41:39 +02:00
if thread_data . turbo != req . turbo :
thread_data . turbo = req . turbo
thread_data . model . turbo = req . turbo
2022-09-02 10:28:36 +02:00
opt_prompt = req . prompt
opt_seed = req . seed
opt_n_iter = 1
opt_C = 4
opt_f = 8
opt_ddim_eta = 0.0
opt_init_img = req . init_image
2022-10-17 03:41:39 +02:00
print ( req . to_string ( ) , ' \n device ' , thread_data . device )
print ( ' \n \n Using precision: ' , thread_data . precision )
2022-09-07 12:02:34 +02:00
2022-09-02 10:28:36 +02:00
seed_everything ( opt_seed )
2022-10-15 05:20:57 +02:00
batch_size = req . num_outputs
2022-09-02 10:28:36 +02:00
prompt = opt_prompt
assert prompt is not None
data = [ batch_size * [ prompt ] ]
2022-10-17 03:41:39 +02:00
if thread_data . precision == " autocast " and thread_data . device != " cpu " :
2022-09-02 10:28:36 +02:00
precision_scope = autocast
else :
precision_scope = nullcontext
2022-09-15 14:24:03 +02:00
mask = None
2022-09-02 10:28:36 +02:00
if req . init_image is None :
handler = _txt2img
init_latent = None
t_enc = None
else :
handler = _img2img
2022-10-15 05:20:57 +02:00
init_image = load_img ( req . init_image , req . width , req . height )
2022-10-17 03:41:39 +02:00
init_image = init_image . to ( thread_data . device )
2022-09-02 10:28:36 +02:00
2022-10-17 03:41:39 +02:00
if thread_data . device != " cpu " and thread_data . precision == " autocast " :
2022-09-02 10:28:36 +02:00
init_image = init_image . half ( )
2022-10-17 03:41:39 +02:00
thread_data . modelFS . to ( thread_data . device )
2022-09-02 10:28:36 +02:00
init_image = repeat ( init_image , ' 1 ... -> b ... ' , b = batch_size )
2022-10-17 03:41:39 +02:00
init_latent = thread_data . modelFS . get_first_stage_encoding ( thread_data . modelFS . encode_first_stage ( init_image ) ) # move to latent space
2022-09-02 10:28:36 +02:00
2022-09-15 14:24:03 +02:00
if req . mask is not None :
2022-10-17 03:41:39 +02:00
mask = load_mask ( req . mask , req . width , req . height , init_latent . shape [ 2 ] , init_latent . shape [ 3 ] , True ) . to ( thread_data . device )
2022-09-15 14:24:03 +02:00
mask = mask [ 0 ] [ 0 ] . unsqueeze ( 0 ) . repeat ( 4 , 1 , 1 ) . unsqueeze ( 0 )
mask = repeat ( mask , ' 1 ... -> b ... ' , b = batch_size )
2022-10-17 03:41:39 +02:00
if thread_data . device != " cpu " and thread_data . precision == " autocast " :
2022-09-15 14:24:03 +02:00
mask = mask . half ( )
2022-09-22 18:44:25 +02:00
move_fs_to_cpu ( )
2022-09-02 10:28:36 +02:00
2022-10-15 05:20:57 +02:00
assert 0. < = req . prompt_strength < = 1. , ' can only work with strength in [0.0, 1.0] '
t_enc = int ( req . prompt_strength * req . num_inference_steps )
2022-09-02 10:28:36 +02:00
print ( f " target t_enc is { t_enc } steps " )
2022-10-15 05:20:57 +02:00
if req . save_to_disk_path is not None :
session_out_path = os . path . join ( req . save_to_disk_path , req . session_id )
2022-09-03 20:42:48 +02:00
os . makedirs ( session_out_path , exist_ok = True )
else :
session_out_path = None
2022-09-02 10:28:36 +02:00
seeds = " "
with torch . no_grad ( ) :
for n in trange ( opt_n_iter , desc = " Sampling " ) :
for prompts in tqdm ( data , desc = " data " ) :
with precision_scope ( " cuda " ) :
2022-10-17 03:41:39 +02:00
thread_data . modelCS . to ( thread_data . device )
2022-09-02 10:28:36 +02:00
uc = None
2022-10-15 05:20:57 +02:00
if req . guidance_scale != 1.0 :
2022-10-17 03:41:39 +02:00
uc = thread_data . modelCS . get_learned_conditioning ( batch_size * [ req . negative_prompt ] )
2022-09-02 10:28:36 +02:00
if isinstance ( prompts , tuple ) :
prompts = list ( prompts )
subprompts , weights = split_weighted_subprompts ( prompts [ 0 ] )
if len ( subprompts ) > 1 :
c = torch . zeros_like ( uc )
totalWeight = sum ( weights )
# normalize each "sub prompt" and add it
for i in range ( len ( subprompts ) ) :
weight = weights [ i ]
# if not skip_normalize:
weight = weight / totalWeight
2022-10-17 03:41:39 +02:00
c = torch . add ( c , thread_data . modelCS . get_learned_conditioning ( subprompts [ i ] ) , alpha = weight )
2022-09-02 10:28:36 +02:00
else :
2022-10-17 03:41:39 +02:00
c = thread_data . modelCS . get_learned_conditioning ( prompts )
2022-09-02 10:28:36 +02:00
2022-10-17 03:41:39 +02:00
thread_data . modelFS . to ( thread_data . device )
2022-09-14 18:59:42 +02:00
2022-09-13 16:29:41 +02:00
partial_x_samples = None
def img_callback ( x_samples , i ) :
nonlocal partial_x_samples
partial_x_samples = x_samples
2022-09-14 13:22:03 +02:00
if req . stream_progress_updates :
2022-10-15 05:20:57 +02:00
n_steps = req . num_inference_steps if req . init_image is None else t_enc
2022-09-23 10:08:33 +02:00
progress = { " step " : i , " total_steps " : n_steps }
2022-09-14 18:59:42 +02:00
2022-09-22 18:44:25 +02:00
if req . stream_image_progress and i % 5 == 0 :
2022-09-14 18:59:42 +02:00
partial_images = [ ]
for i in range ( batch_size ) :
2022-10-17 03:41:39 +02:00
x_samples_ddim = thread_data . modelFS . decode_first_stage ( x_samples [ i ] . unsqueeze ( 0 ) )
2022-09-14 18:59:42 +02:00
x_sample = torch . clamp ( ( x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
x_sample = 255.0 * rearrange ( x_sample [ 0 ] . cpu ( ) . numpy ( ) , " c h w -> h w c " )
x_sample = x_sample . astype ( np . uint8 )
img = Image . fromarray ( x_sample )
buf = BytesIO ( )
img . save ( buf , format = ' JPEG ' )
buf . seek ( 0 )
2022-09-21 18:23:25 +02:00
del img , x_sample , x_samples_ddim
# don't delete x_samples, it is used in the code that called this callback
2022-10-17 03:41:39 +02:00
thread_data . temp_images [ str ( req . session_id ) + ' / ' + str ( i ) ] = buf
2022-09-14 18:59:42 +02:00
partial_images . append ( { ' path ' : f ' /image/tmp/ { req . session_id } / { i } ' } )
progress [ ' output ' ] = partial_images
yield json . dumps ( progress )
2022-09-14 13:22:03 +02:00
2022-10-17 03:41:39 +02:00
if thread_data . stop_processing :
2022-09-13 16:29:41 +02:00
raise UserInitiatedStop ( " User requested that we stop processing " )
2022-09-02 10:28:36 +02:00
# run the handler
2022-09-13 16:29:41 +02:00
try :
2022-10-17 03:41:39 +02:00
print ( ' Running handler... ' )
2022-09-13 16:29:41 +02:00
if handler == _txt2img :
2022-10-15 05:20:57 +02:00
x_samples = _txt2img ( req . width , req . height , req . num_outputs , req . num_inference_steps , req . guidance_scale , None , opt_C , opt_f , opt_ddim_eta , c , uc , opt_seed , img_callback , mask , req . sampler )
2022-09-13 16:29:41 +02:00
else :
2022-10-15 05:20:57 +02:00
x_samples = _img2img ( init_latent , t_enc , batch_size , req . guidance_scale , c , uc , req . num_inference_steps , opt_ddim_eta , opt_seed , img_callback , mask )
2022-09-14 13:22:03 +02:00
2022-09-22 20:49:05 +02:00
yield from x_samples
2022-09-14 13:22:03 +02:00
x_samples = partial_x_samples
2022-09-13 16:29:41 +02:00
except UserInitiatedStop :
if partial_x_samples is None :
continue
x_samples = partial_x_samples
2022-09-02 10:28:36 +02:00
print ( " saving images " )
for i in range ( batch_size ) :
2022-10-17 03:41:39 +02:00
x_samples_ddim = thread_data . modelFS . decode_first_stage ( x_samples [ i ] . unsqueeze ( 0 ) )
2022-09-02 10:28:36 +02:00
x_sample = torch . clamp ( ( x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
x_sample = 255.0 * rearrange ( x_sample [ 0 ] . cpu ( ) . numpy ( ) , " c h w -> h w c " )
2022-09-09 17:35:24 +02:00
x_sample = x_sample . astype ( np . uint8 )
img = Image . fromarray ( x_sample )
2022-09-08 17:50:27 +02:00
2022-10-15 05:20:57 +02:00
has_filters = ( req . use_face_correction is not None and req . use_face_correction . startswith ( ' GFPGAN ' ) ) or \
( req . use_upscale is not None and req . use_upscale . startswith ( ' RealESRGAN ' ) )
2022-09-21 14:11:42 +02:00
2022-10-15 05:20:57 +02:00
return_orig_img = not has_filters or not req . show_only_filtered_image
2022-09-21 15:59:27 +02:00
2022-10-17 03:41:39 +02:00
if thread_data . stop_processing :
2022-09-21 15:59:27 +02:00
return_orig_img = True
2022-10-15 05:20:57 +02:00
if req . save_to_disk_path is not None :
2022-09-21 15:59:27 +02:00
if return_orig_img :
2022-10-15 05:20:57 +02:00
img_out_path = get_base_path ( req . save_to_disk_path , req . session_id , prompts [ 0 ] , req . output_format )
2022-09-09 17:35:24 +02:00
save_image ( img , img_out_path )
2022-10-15 05:20:57 +02:00
meta_out_path = get_base_path ( req . save_to_disk_path , req . session_id , prompts [ 0 ] , ' txt ' )
save_metadata ( meta_out_path , req , prompts [ 0 ] , opt_seed )
2022-09-09 17:35:24 +02:00
2022-09-21 15:59:27 +02:00
if return_orig_img :
2022-10-15 05:20:57 +02:00
img_data = img_to_base64_str ( img , req . output_format )
2022-09-14 06:45:35 +02:00
res_image_orig = ResponseImage ( data = img_data , seed = opt_seed )
res . images . append ( res_image_orig )
2022-10-15 05:20:57 +02:00
if req . save_to_disk_path is not None :
2022-09-14 06:45:35 +02:00
res_image_orig . path_abs = img_out_path
2022-09-08 19:49:35 +02:00
2022-09-21 18:23:25 +02:00
del img
2022-10-17 03:41:39 +02:00
if has_filters and not thread_data . stop_processing :
2022-09-09 17:35:24 +02:00
filters_applied = [ ]
2022-10-15 05:20:57 +02:00
if req . use_face_correction :
x_sample = apply_filters ( ' gfpgan ' , x_sample )
filters_applied . append ( req . use_face_correction )
if req . use_upscale :
x_sample = apply_filters ( ' real_esrgan ' , x_sample )
filters_applied . append ( req . use_upscale )
if ( len ( filters_applied ) > 0 ) :
filtered_image = Image . fromarray ( x_sample )
filtered_img_data = img_to_base64_str ( filtered_image , req . output_format )
response_image = ResponseImage ( data = filtered_img_data , seed = req . seed )
res . images . append ( response_image )
if req . save_to_disk_path is not None :
filtered_img_out_path = get_base_path ( req . save_to_disk_path , req . session_id , prompts [ 0 ] , req . output_format , " _ " . join ( filters_applied ) )
save_image ( filtered_image , filtered_img_out_path )
response_image . path_abs = filtered_img_out_path
del filtered_image
2022-09-21 18:23:25 +02:00
2022-09-02 10:28:36 +02:00
seeds + = str ( opt_seed ) + " , "
opt_seed + = 1
2022-09-22 18:44:25 +02:00
move_fs_to_cpu ( )
2022-09-21 18:23:25 +02:00
gc ( )
del x_samples , x_samples_ddim , x_sample
2022-10-17 03:41:39 +02:00
print ( f ' memory_final = { round ( torch . cuda . memory_allocated ( thread_data . device ) / 1e6 , 2 ) } Mo ' )
2022-09-02 10:28:36 +02:00
2022-09-16 18:02:08 +02:00
print ( ' Task completed ' )
2022-09-15 14:24:03 +02:00
2022-09-22 20:49:05 +02:00
yield json . dumps ( res . json ( ) )
2022-09-02 10:28:36 +02:00
2022-09-09 17:35:24 +02:00
def save_image ( img , img_out_path ) :
try :
img . save ( img_out_path )
except :
print ( ' could not save the file ' , traceback . format_exc ( ) )
2022-10-15 05:20:57 +02:00
def save_metadata ( meta_out_path , req , prompt , opt_seed ) :
2022-10-17 03:41:39 +02:00
metadata = f ''' { prompt }
2022-10-15 05:20:57 +02:00
Width : { req . width }
Height : { req . height }
Seed : { opt_seed }
Steps : { req . num_inference_steps }
Guidance Scale : { req . guidance_scale }
Prompt Strength : { req . prompt_strength }
Use Face Correction : { req . use_face_correction }
Use Upscaling : { req . use_upscale }
Sampler : { req . sampler }
Negative Prompt : { req . negative_prompt }
2022-10-17 03:41:39 +02:00
Stable Diffusion model : { req . use_stable_diffusion_model + ' .ckpt ' }
'''
2022-09-09 17:35:24 +02:00
try :
with open ( meta_out_path , ' w ' ) as f :
f . write ( metadata )
except :
print ( ' could not save the file ' , traceback . format_exc ( ) )
2022-09-22 20:49:05 +02:00
def _txt2img ( opt_W , opt_H , opt_n_samples , opt_ddim_steps , opt_scale , start_code , opt_C , opt_f , opt_ddim_eta , c , uc , opt_seed , img_callback , mask , sampler_name ) :
2022-09-02 10:28:36 +02:00
shape = [ opt_n_samples , opt_C , opt_H / / opt_f , opt_W / / opt_f ]
2022-10-17 03:41:39 +02:00
if thread_data . device != " cpu " :
mem = torch . cuda . memory_allocated ( thread_data . device ) / 1e6
print ( ' Device: ' , thread_data . device , ' CS_Model, Memory transfer starting. Memory Used: ' , round ( mem , 2 ) , ' Mo ' )
thread_data . modelCS . to ( " cpu " )
while torch . cuda . memory_allocated ( thread_data . device ) / 1e6 > = mem and mem > 0 :
print ( ' Device: ' , thread_data . device , ' Waiting Memory transfer. Memory Used: ' , round ( mem , 2 ) , ' Mo ' )
2022-09-02 10:28:36 +02:00
time . sleep ( 1 )
2022-10-17 03:41:39 +02:00
print ( ' Transfered ' , round ( mem - torch . cuda . memory_allocated ( thread_data . device ) / 1e6 , 2 ) , ' Mo ' )
2022-09-02 10:28:36 +02:00
2022-09-23 17:52:44 +02:00
if sampler_name == ' ddim ' :
2022-10-17 03:41:39 +02:00
thread_data . model . make_schedule ( ddim_num_steps = opt_ddim_steps , ddim_eta = opt_ddim_eta , verbose = False )
2022-09-23 17:52:44 +02:00
2022-10-17 03:41:39 +02:00
samples_ddim = thread_data . model . sample (
2022-09-02 10:28:36 +02:00
S = opt_ddim_steps ,
conditioning = c ,
seed = opt_seed ,
shape = shape ,
verbose = False ,
unconditional_guidance_scale = opt_scale ,
unconditional_conditioning = uc ,
eta = opt_ddim_eta ,
x_T = start_code ,
2022-09-13 16:29:41 +02:00
img_callback = img_callback ,
2022-09-15 14:24:03 +02:00
mask = mask ,
2022-09-22 20:49:05 +02:00
sampler = sampler_name ,
2022-09-02 10:28:36 +02:00
)
2022-09-22 20:49:05 +02:00
yield from samples_ddim
2022-09-02 10:28:36 +02:00
2022-09-22 20:49:05 +02:00
def _img2img ( init_latent , t_enc , batch_size , opt_scale , c , uc , opt_ddim_steps , opt_ddim_eta , opt_seed , img_callback , mask ) :
2022-09-02 10:28:36 +02:00
# encode (scaled latent)
2022-10-17 03:41:39 +02:00
z_enc = thread_data . model . stochastic_encode (
2022-09-02 10:28:36 +02:00
init_latent ,
2022-10-17 03:41:39 +02:00
torch . tensor ( [ t_enc ] * batch_size ) . to ( thread_data . device ) ,
2022-09-02 10:28:36 +02:00
opt_seed ,
opt_ddim_eta ,
opt_ddim_steps ,
)
2022-09-15 14:24:03 +02:00
x_T = None if mask is None else init_latent
2022-09-02 10:28:36 +02:00
# decode it
2022-10-17 03:41:39 +02:00
samples_ddim = thread_data . model . sample (
2022-09-02 10:28:36 +02:00
t_enc ,
c ,
z_enc ,
unconditional_guidance_scale = opt_scale ,
unconditional_conditioning = uc ,
2022-09-13 16:29:41 +02:00
img_callback = img_callback ,
2022-09-15 14:24:03 +02:00
mask = mask ,
x_T = x_T ,
2022-09-02 10:28:36 +02:00
sampler = ' ddim '
)
2022-09-22 20:49:05 +02:00
yield from samples_ddim
2022-09-02 10:28:36 +02:00
2022-09-22 18:44:25 +02:00
def move_fs_to_cpu ( ) :
2022-10-17 03:41:39 +02:00
if thread_data . device != " cpu " :
mem = torch . cuda . memory_allocated ( thread_data . device ) / 1e6
print ( ' Device: ' , thread_data . device , ' FS_Model, Memory transfer starting. Memory Used: ' , round ( mem , 2 ) , ' Mo ' )
thread_data . modelFS . to ( " cpu " )
while torch . cuda . memory_allocated ( thread_data . device ) / 1e6 > = mem and mem > 0 :
print ( ' Device: ' , thread_data . device , ' Waiting for Memory transfer. Memory Used: ' , round ( mem , 2 ) , ' Mo ' )
2022-09-22 18:44:25 +02:00
time . sleep ( 1 )
2022-10-17 03:41:39 +02:00
print ( ' Transfered ' , round ( mem - torch . cuda . memory_allocated ( thread_data . device ) / 1e6 , 2 ) , ' Mo ' )
2022-09-22 18:44:25 +02:00
2022-09-09 17:35:24 +02:00
def gc ( ) :
2022-10-17 03:41:39 +02:00
#gc.collect()
if thread_data . device == ' cpu ' :
2022-09-09 17:35:24 +02:00
return
torch . cuda . empty_cache ( )
torch . cuda . ipc_collect ( )
2022-09-02 10:28:36 +02:00
# internal
def chunk ( it , size ) :
it = iter ( it )
return iter ( lambda : tuple ( islice ( it , size ) ) , ( ) )
def load_model_from_config ( ckpt , verbose = False ) :
print ( f " Loading model from { ckpt } " )
pl_sd = torch . load ( ckpt , map_location = " cpu " )
if " global_step " in pl_sd :
print ( f " Global Step: { pl_sd [ ' global_step ' ] } " )
sd = pl_sd [ " state_dict " ]
return sd
# utils
2022-09-13 16:29:41 +02:00
class UserInitiatedStop ( Exception ) :
pass
2022-09-02 10:28:36 +02:00
2022-09-14 08:06:55 +02:00
def load_img ( img_str , w0 , h0 ) :
2022-09-02 10:28:36 +02:00
image = base64_str_to_img ( img_str ) . convert ( " RGB " )
w , h = image . size
print ( f " loaded input image of size ( { w } , { h } ) from base64 " )
2022-09-14 08:06:55 +02:00
if h0 is not None and w0 is not None :
h , w = h0 , w0
2022-09-02 10:28:36 +02:00
w , h = map ( lambda x : x - x % 64 , ( w , h ) ) # resize to integer multiple of 64
2022-09-14 08:06:55 +02:00
image = image . resize ( ( w , h ) , resample = Image . Resampling . LANCZOS )
2022-09-02 10:28:36 +02:00
image = np . array ( image ) . astype ( np . float32 ) / 255.0
image = image [ None ] . transpose ( 0 , 3 , 1 , 2 )
image = torch . from_numpy ( image )
return 2. * image - 1.
2022-09-15 14:24:03 +02:00
def load_mask ( mask_str , h0 , w0 , newH , newW , invert = False ) :
image = base64_str_to_img ( mask_str ) . convert ( " RGB " )
w , h = image . size
print ( f " loaded input mask of size ( { w } , { h } ) " )
if invert :
print ( " inverted " )
image = ImageOps . invert ( image )
# where_0, where_1 = np.where(image == 0), np.where(image == 255)
# image[where_0], image[where_1] = 255, 0
if h0 is not None and w0 is not None :
h , w = h0 , w0
w , h = map ( lambda x : x - x % 64 , ( w , h ) ) # resize to integer multiple of 64
print ( f " New mask size ( { w } , { h } ) " )
image = image . resize ( ( newW , newH ) , resample = Image . Resampling . LANCZOS )
image = np . array ( image )
image = image . astype ( np . float32 ) / 255.0
image = image [ None ] . transpose ( 0 , 3 , 1 , 2 )
image = torch . from_numpy ( image )
return image
2022-09-02 10:28:36 +02:00
# https://stackoverflow.com/a/61114178
2022-10-06 11:35:34 +02:00
def img_to_base64_str ( img , output_format = " PNG " ) :
2022-09-02 10:28:36 +02:00
buffered = BytesIO ( )
2022-10-06 11:35:34 +02:00
img . save ( buffered , format = output_format )
2022-09-02 10:28:36 +02:00
buffered . seek ( 0 )
img_byte = buffered . getvalue ( )
img_str = " data:image/png;base64, " + base64 . b64encode ( img_byte ) . decode ( )
return img_str
def base64_str_to_img ( img_str ) :
img_str = img_str [ len ( " data:image/png;base64, " ) : ]
data = base64 . b64decode ( img_str )
buffered = BytesIO ( data )
img = Image . open ( buffered )
return img