2022-09-14 13:22:03 +02:00
import json
2022-09-02 10:28:36 +02:00
import os , re
2022-09-05 13:21:43 +02:00
import traceback
2022-09-02 10:28:36 +02:00
import torch
import numpy as np
from omegaconf import OmegaConf
2022-09-15 14:24:03 +02:00
from PIL import Image , ImageOps
2022-09-02 10:28:36 +02:00
from tqdm import tqdm , trange
from itertools import islice
from einops import rearrange
import time
from pytorch_lightning import seed_everything
from torch import autocast
from contextlib import nullcontext
from einops import rearrange , repeat
from ldm . util import instantiate_from_config
from optimizedSD . optimUtils import split_weighted_subprompts
from transformers import logging
2022-09-09 17:35:24 +02:00
from gfpgan import GFPGANer
from basicsr . archs . rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
2022-09-03 20:42:48 +02:00
import uuid
2022-09-02 10:28:36 +02:00
logging . set_verbosity_error ( )
# consts
config_yaml = " optimizedSD/v1-inference.yaml "
2022-09-09 17:35:24 +02:00
filename_regex = re . compile ( ' [^a-zA-Z0-9] ' )
2022-09-02 10:28:36 +02:00
# api stuff
from . import Request , Response , Image as ResponseImage
import base64
from io import BytesIO
2022-09-16 18:02:08 +02:00
#from colorama import Fore
2022-09-02 10:28:36 +02:00
# local
2022-09-13 16:29:41 +02:00
stop_processing = False
2022-09-14 18:59:42 +02:00
temp_images = { }
2022-09-03 20:42:48 +02:00
2022-09-09 17:35:24 +02:00
ckpt_file = None
gfpgan_file = None
real_esrgan_file = None
2022-09-02 10:28:36 +02:00
model = None
modelCS = None
modelFS = None
2022-09-09 17:35:24 +02:00
model_gfpgan = None
model_real_esrgan = None
2022-09-02 10:28:36 +02:00
model_is_half = False
model_fs_is_half = False
device = None
unet_bs = 1
precision = ' autocast '
sampler_plms = None
sampler_ddim = None
2022-09-09 17:35:24 +02:00
has_valid_gpu = False
2022-09-07 12:02:34 +02:00
force_full_precision = False
try :
2022-09-11 07:04:04 +02:00
gpu = torch . cuda . current_device ( )
gpu_name = torch . cuda . get_device_name ( gpu )
2022-09-12 07:42:41 +02:00
print ( ' GPU detected: ' , gpu_name )
2022-09-11 07:46:05 +02:00
force_full_precision = ( ' nvidia ' in gpu_name . lower ( ) or ' geforce ' in gpu_name . lower ( ) ) and ( ' 1660 ' in gpu_name or ' 1650 ' in gpu_name ) # otherwise these NVIDIA cards create green images
2022-09-07 12:02:34 +02:00
if force_full_precision :
2022-09-09 17:35:24 +02:00
print ( ' forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ' , gpu_name )
2022-09-11 07:04:04 +02:00
mem_free , mem_total = torch . cuda . mem_get_info ( gpu )
mem_total / = float ( 10 * * 9 )
if mem_total < 3.0 :
print ( " GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion " )
raise Exception ( )
has_valid_gpu = True
2022-09-07 12:02:34 +02:00
except :
2022-09-09 17:35:24 +02:00
print ( ' WARNING: No compatible GPU found. Using the CPU, but this will be very slow! ' )
2022-09-07 12:02:34 +02:00
pass
2022-09-09 17:35:24 +02:00
def load_model_ckpt ( ckpt_to_use , device_to_use = ' cuda ' , turbo = False , unet_bs_to_use = 1 , precision_to_use = ' autocast ' , half_model_fs = False ) :
global ckpt_file , model , modelCS , modelFS , model_is_half , device , unet_bs , precision , model_fs_is_half
2022-09-02 10:28:36 +02:00
2022-09-09 17:35:24 +02:00
ckpt_file = ckpt_to_use
device = device_to_use if has_valid_gpu else ' cpu '
2022-09-07 12:02:34 +02:00
precision = precision_to_use if not force_full_precision else ' full '
2022-09-02 10:28:36 +02:00
unet_bs = unet_bs_to_use
2022-09-09 17:35:24 +02:00
if device == ' cpu ' :
precision = ' full '
sd = load_model_from_config ( f " { ckpt_file } .ckpt " )
2022-09-02 10:28:36 +02:00
li , lo = [ ] , [ ]
for key , value in sd . items ( ) :
sp = key . split ( " . " )
if ( sp [ 0 ] ) == " model " :
if " input_blocks " in sp :
li . append ( key )
elif " middle_block " in sp :
li . append ( key )
elif " time_embed " in sp :
li . append ( key )
else :
lo . append ( key )
for key in li :
sd [ " model1. " + key [ 6 : ] ] = sd . pop ( key )
for key in lo :
sd [ " model2. " + key [ 6 : ] ] = sd . pop ( key )
config = OmegaConf . load ( f " { config_yaml } " )
model = instantiate_from_config ( config . modelUNet )
_ , _ = model . load_state_dict ( sd , strict = False )
model . eval ( )
model . cdevice = device
model . unet_bs = unet_bs
model . turbo = turbo
modelCS = instantiate_from_config ( config . modelCondStage )
_ , _ = modelCS . load_state_dict ( sd , strict = False )
modelCS . eval ( )
modelCS . cond_stage_model . device = device
modelFS = instantiate_from_config ( config . modelFirstStage )
_ , _ = modelFS . load_state_dict ( sd , strict = False )
modelFS . eval ( )
del sd
if device != " cpu " and precision == " autocast " :
model . half ( )
modelCS . half ( )
model_is_half = True
else :
model_is_half = False
if half_model_fs :
modelFS . half ( )
model_fs_is_half = True
else :
model_fs_is_half = False
2022-09-09 17:35:24 +02:00
print ( ' loaded ' , ckpt_file , ' to ' , device , ' precision ' , precision )
def load_model_gfpgan ( gfpgan_to_use ) :
global gfpgan_file , model_gfpgan
if gfpgan_to_use is None :
return
gfpgan_file = gfpgan_to_use
model_path = gfpgan_to_use + " .pth "
if device == ' cpu ' :
model_gfpgan = GFPGANer ( model_path = model_path , upscale = 1 , arch = ' clean ' , channel_multiplier = 2 , bg_upsampler = None , device = torch . device ( ' cpu ' ) )
else :
model_gfpgan = GFPGANer ( model_path = model_path , upscale = 1 , arch = ' clean ' , channel_multiplier = 2 , bg_upsampler = None , device = torch . device ( ' cuda ' ) )
print ( ' loaded ' , gfpgan_to_use , ' to ' , device , ' precision ' , precision )
def load_model_real_esrgan ( real_esrgan_to_use ) :
global real_esrgan_file , model_real_esrgan
if real_esrgan_to_use is None :
return
real_esrgan_file = real_esrgan_to_use
model_path = real_esrgan_to_use + " .pth "
RealESRGAN_models = {
' RealESRGAN_x4plus ' : RRDBNet ( num_in_ch = 3 , num_out_ch = 3 , num_feat = 64 , num_block = 23 , num_grow_ch = 32 , scale = 4 ) ,
' RealESRGAN_x4plus_anime_6B ' : RRDBNet ( num_in_ch = 3 , num_out_ch = 3 , num_feat = 64 , num_block = 6 , num_grow_ch = 32 , scale = 4 )
}
model_to_use = RealESRGAN_models [ real_esrgan_to_use ]
if device == ' cpu ' :
model_real_esrgan = RealESRGANer ( scale = 2 , model_path = model_path , model = model_to_use , pre_pad = 0 , half = False ) # cpu does not support half
model_real_esrgan . device = torch . device ( ' cpu ' )
model_real_esrgan . model . to ( ' cpu ' )
else :
model_real_esrgan = RealESRGANer ( scale = 2 , model_path = model_path , model = model_to_use , pre_pad = 0 , half = model_is_half )
model_real_esrgan . model . name = real_esrgan_to_use
print ( ' loaded ' , real_esrgan_to_use , ' to ' , device , ' precision ' , precision )
2022-09-02 10:28:36 +02:00
def mk_img ( req : Request ) :
2022-09-21 18:23:25 +02:00
try :
yield from do_mk_img ( req )
except Exception as e :
2022-09-22 14:34:11 +02:00
print ( traceback . format_exc ( ) )
2022-09-21 18:23:25 +02:00
gc ( )
2022-09-22 14:34:11 +02:00
2022-09-22 18:44:25 +02:00
if device != " cpu " :
modelFS . to ( " cpu " )
modelCS . to ( " cpu " )
model . model1 . to ( " cpu " )
model . model2 . to ( " cpu " )
gc ( )
2022-09-22 14:34:11 +02:00
yield json . dumps ( {
" status " : ' failed ' ,
" detail " : str ( e )
} )
2022-09-21 18:23:25 +02:00
def do_mk_img ( req : Request ) :
global model , modelCS , modelFS , device
2022-09-09 17:35:24 +02:00
global model_gfpgan , model_real_esrgan
2022-09-13 16:29:41 +02:00
global stop_processing
stop_processing = False
2022-09-02 10:28:36 +02:00
res = Response ( )
2022-09-14 06:45:35 +02:00
res . request = req
2022-09-02 10:28:36 +02:00
res . images = [ ]
2022-09-14 18:59:42 +02:00
temp_images . clear ( )
2022-09-02 10:28:36 +02:00
model . turbo = req . turbo
if req . use_cpu :
2022-09-09 17:35:24 +02:00
if device != ' cpu ' :
device = ' cpu '
2022-09-02 10:28:36 +02:00
2022-09-09 17:35:24 +02:00
if model_is_half :
2022-09-21 18:23:25 +02:00
del model , modelCS , modelFS
2022-09-09 17:35:24 +02:00
load_model_ckpt ( ckpt_file , device )
load_model_gfpgan ( gfpgan_file )
load_model_real_esrgan ( real_esrgan_file )
2022-09-02 10:28:36 +02:00
else :
2022-09-09 17:35:24 +02:00
if has_valid_gpu :
prev_device = device
device = ' cuda '
if ( precision == ' autocast ' and ( req . use_full_precision or not model_is_half ) ) or \
( precision == ' full ' and not req . use_full_precision and not force_full_precision ) or \
( req . init_image is None and model_fs_is_half ) or \
( req . init_image is not None and not model_fs_is_half and not force_full_precision ) :
2022-09-21 18:23:25 +02:00
del model , modelCS , modelFS
load_model_ckpt ( ckpt_file , device , req . turbo , unet_bs , ( ' full ' if req . use_full_precision else ' autocast ' ) , half_model_fs = ( req . init_image is not None and not req . use_full_precision ) )
2022-09-02 10:28:36 +02:00
2022-09-09 17:35:24 +02:00
if prev_device != device :
load_model_gfpgan ( gfpgan_file )
load_model_real_esrgan ( real_esrgan_file )
2022-09-02 10:28:36 +02:00
2022-09-09 17:35:24 +02:00
if req . use_face_correction != gfpgan_file :
load_model_gfpgan ( req . use_face_correction )
if req . use_upscale != real_esrgan_file :
load_model_real_esrgan ( req . use_upscale )
2022-09-02 10:28:36 +02:00
model . cdevice = device
modelCS . cond_stage_model . device = device
opt_prompt = req . prompt
opt_seed = req . seed
opt_n_samples = req . num_outputs
opt_n_iter = 1
opt_scale = req . guidance_scale
opt_C = 4
opt_H = req . height
opt_W = req . width
opt_f = 8
opt_ddim_steps = req . num_inference_steps
opt_ddim_eta = 0.0
opt_strength = req . prompt_strength
opt_save_to_disk_path = req . save_to_disk_path
opt_init_img = req . init_image
2022-09-09 17:35:24 +02:00
opt_use_face_correction = req . use_face_correction
opt_use_upscale = req . use_upscale
opt_show_only_filtered = req . show_only_filtered_image
2022-09-02 10:28:36 +02:00
opt_format = ' png '
2022-09-22 20:49:05 +02:00
opt_sampler_name = req . sampler
2022-09-02 10:28:36 +02:00
2022-09-03 20:42:48 +02:00
print ( req . to_string ( ) , ' \n device ' , device )
2022-09-02 10:28:36 +02:00
2022-09-07 12:02:34 +02:00
print ( ' \n \n Using precision: ' , precision )
2022-09-02 10:28:36 +02:00
seed_everything ( opt_seed )
batch_size = opt_n_samples
prompt = opt_prompt
assert prompt is not None
data = [ batch_size * [ prompt ] ]
if precision == " autocast " and device != " cpu " :
precision_scope = autocast
else :
precision_scope = nullcontext
2022-09-15 14:24:03 +02:00
mask = None
2022-09-02 10:28:36 +02:00
if req . init_image is None :
handler = _txt2img
init_latent = None
t_enc = None
else :
handler = _img2img
2022-09-14 08:06:55 +02:00
init_image = load_img ( req . init_image , opt_W , opt_H )
2022-09-02 10:28:36 +02:00
init_image = init_image . to ( device )
if device != " cpu " and precision == " autocast " :
init_image = init_image . half ( )
modelFS . to ( device )
init_image = repeat ( init_image , ' 1 ... -> b ... ' , b = batch_size )
init_latent = modelFS . get_first_stage_encoding ( modelFS . encode_first_stage ( init_image ) ) # move to latent space
2022-09-15 14:24:03 +02:00
if req . mask is not None :
mask = load_mask ( req . mask , opt_W , opt_H , init_latent . shape [ 2 ] , init_latent . shape [ 3 ] , True ) . to ( device )
mask = mask [ 0 ] [ 0 ] . unsqueeze ( 0 ) . repeat ( 4 , 1 , 1 ) . unsqueeze ( 0 )
mask = repeat ( mask , ' 1 ... -> b ... ' , b = batch_size )
if device != " cpu " and precision == " autocast " :
mask = mask . half ( )
2022-09-22 18:44:25 +02:00
move_fs_to_cpu ( )
2022-09-02 10:28:36 +02:00
assert 0. < = opt_strength < = 1. , ' can only work with strength in [0.0, 1.0] '
t_enc = int ( opt_strength * opt_ddim_steps )
print ( f " target t_enc is { t_enc } steps " )
2022-09-03 20:42:48 +02:00
if opt_save_to_disk_path is not None :
2022-09-14 18:59:42 +02:00
session_out_path = os . path . join ( opt_save_to_disk_path , req . session_id )
2022-09-03 20:42:48 +02:00
os . makedirs ( session_out_path , exist_ok = True )
else :
session_out_path = None
2022-09-02 10:28:36 +02:00
seeds = " "
with torch . no_grad ( ) :
for n in trange ( opt_n_iter , desc = " Sampling " ) :
for prompts in tqdm ( data , desc = " data " ) :
with precision_scope ( " cuda " ) :
modelCS . to ( device )
uc = None
if opt_scale != 1.0 :
uc = modelCS . get_learned_conditioning ( batch_size * [ " " ] )
if isinstance ( prompts , tuple ) :
prompts = list ( prompts )
subprompts , weights = split_weighted_subprompts ( prompts [ 0 ] )
if len ( subprompts ) > 1 :
c = torch . zeros_like ( uc )
totalWeight = sum ( weights )
# normalize each "sub prompt" and add it
for i in range ( len ( subprompts ) ) :
weight = weights [ i ]
# if not skip_normalize:
weight = weight / totalWeight
c = torch . add ( c , modelCS . get_learned_conditioning ( subprompts [ i ] ) , alpha = weight )
else :
c = modelCS . get_learned_conditioning ( prompts )
2022-09-14 18:59:42 +02:00
modelFS . to ( device )
2022-09-13 16:29:41 +02:00
partial_x_samples = None
def img_callback ( x_samples , i ) :
nonlocal partial_x_samples
partial_x_samples = x_samples
2022-09-14 13:22:03 +02:00
if req . stream_progress_updates :
2022-09-23 10:08:33 +02:00
n_steps = opt_ddim_steps if req . init_image is None else t_enc
2022-09-24 19:47:30 +02:00
progress = {
" status " : " progress " ,
#"progress": (i + 1) / n_steps,
" progress " : {
" step " : i , " total_steps " : n_steps
}
}
2022-09-14 18:59:42 +02:00
2022-09-22 18:44:25 +02:00
if req . stream_image_progress and i % 5 == 0 :
2022-09-14 18:59:42 +02:00
partial_images = [ ]
for i in range ( batch_size ) :
x_samples_ddim = modelFS . decode_first_stage ( x_samples [ i ] . unsqueeze ( 0 ) )
x_sample = torch . clamp ( ( x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
x_sample = 255.0 * rearrange ( x_sample [ 0 ] . cpu ( ) . numpy ( ) , " c h w -> h w c " )
x_sample = x_sample . astype ( np . uint8 )
img = Image . fromarray ( x_sample )
buf = BytesIO ( )
img . save ( buf , format = ' JPEG ' )
buf . seek ( 0 )
2022-09-21 18:23:25 +02:00
del img , x_sample , x_samples_ddim
# don't delete x_samples, it is used in the code that called this callback
2022-09-14 18:59:42 +02:00
temp_images [ str ( req . session_id ) + ' / ' + str ( i ) ] = buf
partial_images . append ( { ' path ' : f ' /image/tmp/ { req . session_id } / { i } ' } )
progress [ ' output ' ] = partial_images
yield json . dumps ( progress )
2022-09-14 13:22:03 +02:00
2022-09-13 16:29:41 +02:00
if stop_processing :
raise UserInitiatedStop ( " User requested that we stop processing " )
2022-09-02 10:28:36 +02:00
# run the handler
2022-09-13 16:29:41 +02:00
try :
if handler == _txt2img :
2022-09-22 20:49:05 +02:00
x_samples = _txt2img ( opt_W , opt_H , opt_n_samples , opt_ddim_steps , opt_scale , None , opt_C , opt_f , opt_ddim_eta , c , uc , opt_seed , img_callback , mask , opt_sampler_name )
2022-09-13 16:29:41 +02:00
else :
2022-09-22 20:49:05 +02:00
x_samples = _img2img ( init_latent , t_enc , batch_size , opt_scale , c , uc , opt_ddim_steps , opt_ddim_eta , opt_seed , img_callback , mask )
2022-09-14 13:22:03 +02:00
2022-09-22 20:49:05 +02:00
yield from x_samples
2022-09-14 13:22:03 +02:00
x_samples = partial_x_samples
2022-09-13 16:29:41 +02:00
except UserInitiatedStop :
if partial_x_samples is None :
continue
x_samples = partial_x_samples
2022-09-02 10:28:36 +02:00
print ( " saving images " )
for i in range ( batch_size ) :
x_samples_ddim = modelFS . decode_first_stage ( x_samples [ i ] . unsqueeze ( 0 ) )
x_sample = torch . clamp ( ( x_samples_ddim + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
x_sample = 255.0 * rearrange ( x_sample [ 0 ] . cpu ( ) . numpy ( ) , " c h w -> h w c " )
2022-09-09 17:35:24 +02:00
x_sample = x_sample . astype ( np . uint8 )
img = Image . fromarray ( x_sample )
2022-09-08 17:50:27 +02:00
2022-09-21 14:11:42 +02:00
has_filters = ( opt_use_face_correction is not None and opt_use_face_correction . startswith ( ' GFPGAN ' ) ) or \
( opt_use_upscale is not None and opt_use_upscale . startswith ( ' RealESRGAN ' ) )
2022-09-21 15:59:27 +02:00
return_orig_img = not has_filters or not opt_show_only_filtered
if stop_processing :
return_orig_img = True
2022-09-08 20:15:21 +02:00
if opt_save_to_disk_path is not None :
2022-09-09 17:35:24 +02:00
prompt_flattened = filename_regex . sub ( ' _ ' , prompts [ 0 ] )
prompt_flattened = prompt_flattened [ : 50 ]
img_id = str ( uuid . uuid4 ( ) ) [ - 8 : ]
file_path = f " { prompt_flattened } _ { img_id } "
img_out_path = os . path . join ( session_out_path , f " { file_path } . { opt_format } " )
meta_out_path = os . path . join ( session_out_path , f " { file_path } .txt " )
2022-09-21 15:59:27 +02:00
if return_orig_img :
2022-09-09 17:35:24 +02:00
save_image ( img , img_out_path )
2022-09-22 20:49:05 +02:00
save_metadata ( meta_out_path , prompts , opt_seed , opt_W , opt_H , opt_ddim_steps , opt_scale , opt_strength , opt_use_face_correction , opt_use_upscale , opt_sampler_name )
2022-09-09 17:35:24 +02:00
2022-09-21 15:59:27 +02:00
if return_orig_img :
2022-09-09 17:35:24 +02:00
img_data = img_to_base64_str ( img )
2022-09-14 06:45:35 +02:00
res_image_orig = ResponseImage ( data = img_data , seed = opt_seed )
res . images . append ( res_image_orig )
if opt_save_to_disk_path is not None :
res_image_orig . path_abs = img_out_path
2022-09-08 19:49:35 +02:00
2022-09-21 18:23:25 +02:00
del img
2022-09-21 15:59:27 +02:00
if has_filters and not stop_processing :
2022-09-15 14:24:03 +02:00
print ( ' Applying filters.. ' )
2022-09-09 17:35:24 +02:00
gc ( )
filters_applied = [ ]
2022-09-08 19:49:35 +02:00
2022-09-09 17:35:24 +02:00
if opt_use_face_correction :
_ , _ , output = model_gfpgan . enhance ( x_sample [ : , : , : : - 1 ] , has_aligned = False , only_center_face = False , paste_back = True )
x_sample = output [ : , : , : : - 1 ]
filters_applied . append ( opt_use_face_correction )
if opt_use_upscale :
output , _ = model_real_esrgan . enhance ( x_sample [ : , : , : : - 1 ] )
x_sample = output [ : , : , : : - 1 ]
filters_applied . append ( opt_use_upscale )
filtered_image = Image . fromarray ( x_sample )
filtered_img_data = img_to_base64_str ( filtered_image )
2022-09-14 06:45:35 +02:00
res_image_filtered = ResponseImage ( data = filtered_img_data , seed = opt_seed )
res . images . append ( res_image_filtered )
2022-09-09 17:35:24 +02:00
filters_applied = " _ " . join ( filters_applied )
if opt_save_to_disk_path is not None :
filtered_img_out_path = os . path . join ( session_out_path , f " { file_path } _ { filters_applied } . { opt_format } " )
save_image ( filtered_image , filtered_img_out_path )
2022-09-14 06:45:35 +02:00
res_image_filtered . path_abs = filtered_img_out_path
2022-09-02 10:28:36 +02:00
2022-09-21 18:23:25 +02:00
del filtered_image
2022-09-02 10:28:36 +02:00
seeds + = str ( opt_seed ) + " , "
opt_seed + = 1
2022-09-22 18:44:25 +02:00
move_fs_to_cpu ( )
2022-09-21 18:23:25 +02:00
gc ( )
del x_samples , x_samples_ddim , x_sample
2022-09-02 10:28:36 +02:00
print ( " memory_final = " , torch . cuda . memory_allocated ( ) / 1e6 )
2022-09-16 18:02:08 +02:00
print ( ' Task completed ' )
2022-09-15 14:24:03 +02:00
2022-09-22 20:49:05 +02:00
yield json . dumps ( res . json ( ) )
2022-09-02 10:28:36 +02:00
2022-09-09 17:35:24 +02:00
def save_image ( img , img_out_path ) :
try :
img . save ( img_out_path )
except :
print ( ' could not save the file ' , traceback . format_exc ( ) )
2022-09-22 20:49:05 +02:00
def save_metadata ( meta_out_path , prompts , opt_seed , opt_W , opt_H , opt_ddim_steps , opt_scale , opt_prompt_strength , opt_correct_face , opt_upscale , sampler_name ) :
metadata = f " { prompts [ 0 ] } \n Width: { opt_W } \n Height: { opt_H } \n Seed: { opt_seed } \n Steps: { opt_ddim_steps } \n Guidance Scale: { opt_scale } \n Prompt Strength: { opt_prompt_strength } \n Use Face Correction: { opt_correct_face } \n Use Upscaling: { opt_upscale } \n Sampler: { sampler_name } "
2022-09-09 17:35:24 +02:00
try :
with open ( meta_out_path , ' w ' ) as f :
f . write ( metadata )
except :
print ( ' could not save the file ' , traceback . format_exc ( ) )
2022-09-22 20:49:05 +02:00
def _txt2img ( opt_W , opt_H , opt_n_samples , opt_ddim_steps , opt_scale , start_code , opt_C , opt_f , opt_ddim_eta , c , uc , opt_seed , img_callback , mask , sampler_name ) :
2022-09-02 10:28:36 +02:00
shape = [ opt_n_samples , opt_C , opt_H / / opt_f , opt_W / / opt_f ]
if device != " cpu " :
mem = torch . cuda . memory_allocated ( ) / 1e6
modelCS . to ( " cpu " )
while torch . cuda . memory_allocated ( ) / 1e6 > = mem :
time . sleep ( 1 )
2022-09-23 17:52:44 +02:00
if sampler_name == ' ddim ' :
model . make_schedule ( ddim_num_steps = opt_ddim_steps , ddim_eta = opt_ddim_eta , verbose = False )
2022-09-02 10:28:36 +02:00
samples_ddim = model . sample (
S = opt_ddim_steps ,
conditioning = c ,
seed = opt_seed ,
shape = shape ,
verbose = False ,
unconditional_guidance_scale = opt_scale ,
unconditional_conditioning = uc ,
eta = opt_ddim_eta ,
x_T = start_code ,
2022-09-13 16:29:41 +02:00
img_callback = img_callback ,
2022-09-15 14:24:03 +02:00
mask = mask ,
2022-09-22 20:49:05 +02:00
sampler = sampler_name ,
2022-09-02 10:28:36 +02:00
)
2022-09-22 20:49:05 +02:00
yield from samples_ddim
2022-09-02 10:28:36 +02:00
2022-09-22 20:49:05 +02:00
def _img2img ( init_latent , t_enc , batch_size , opt_scale , c , uc , opt_ddim_steps , opt_ddim_eta , opt_seed , img_callback , mask ) :
2022-09-02 10:28:36 +02:00
# encode (scaled latent)
z_enc = model . stochastic_encode (
init_latent ,
torch . tensor ( [ t_enc ] * batch_size ) . to ( device ) ,
opt_seed ,
opt_ddim_eta ,
opt_ddim_steps ,
)
2022-09-15 14:24:03 +02:00
x_T = None if mask is None else init_latent
2022-09-02 10:28:36 +02:00
# decode it
samples_ddim = model . sample (
t_enc ,
c ,
z_enc ,
unconditional_guidance_scale = opt_scale ,
unconditional_conditioning = uc ,
2022-09-13 16:29:41 +02:00
img_callback = img_callback ,
2022-09-15 14:24:03 +02:00
mask = mask ,
x_T = x_T ,
2022-09-02 10:28:36 +02:00
sampler = ' ddim '
)
2022-09-22 20:49:05 +02:00
yield from samples_ddim
2022-09-02 10:28:36 +02:00
2022-09-22 18:44:25 +02:00
def move_fs_to_cpu ( ) :
if device != " cpu " :
mem = torch . cuda . memory_allocated ( ) / 1e6
modelFS . to ( " cpu " )
while torch . cuda . memory_allocated ( ) / 1e6 > = mem :
time . sleep ( 1 )
2022-09-09 17:35:24 +02:00
def gc ( ) :
if device == ' cpu ' :
return
torch . cuda . empty_cache ( )
torch . cuda . ipc_collect ( )
2022-09-02 10:28:36 +02:00
# internal
def chunk ( it , size ) :
it = iter ( it )
return iter ( lambda : tuple ( islice ( it , size ) ) , ( ) )
def load_model_from_config ( ckpt , verbose = False ) :
print ( f " Loading model from { ckpt } " )
pl_sd = torch . load ( ckpt , map_location = " cpu " )
if " global_step " in pl_sd :
print ( f " Global Step: { pl_sd [ ' global_step ' ] } " )
sd = pl_sd [ " state_dict " ]
return sd
# utils
2022-09-13 16:29:41 +02:00
class UserInitiatedStop ( Exception ) :
pass
2022-09-02 10:28:36 +02:00
2022-09-14 08:06:55 +02:00
def load_img ( img_str , w0 , h0 ) :
2022-09-02 10:28:36 +02:00
image = base64_str_to_img ( img_str ) . convert ( " RGB " )
w , h = image . size
print ( f " loaded input image of size ( { w } , { h } ) from base64 " )
2022-09-14 08:06:55 +02:00
if h0 is not None and w0 is not None :
h , w = h0 , w0
2022-09-02 10:28:36 +02:00
w , h = map ( lambda x : x - x % 64 , ( w , h ) ) # resize to integer multiple of 64
2022-09-14 08:06:55 +02:00
image = image . resize ( ( w , h ) , resample = Image . Resampling . LANCZOS )
2022-09-02 10:28:36 +02:00
image = np . array ( image ) . astype ( np . float32 ) / 255.0
image = image [ None ] . transpose ( 0 , 3 , 1 , 2 )
image = torch . from_numpy ( image )
return 2. * image - 1.
2022-09-15 14:24:03 +02:00
def load_mask ( mask_str , h0 , w0 , newH , newW , invert = False ) :
image = base64_str_to_img ( mask_str ) . convert ( " RGB " )
w , h = image . size
print ( f " loaded input mask of size ( { w } , { h } ) " )
if invert :
print ( " inverted " )
image = ImageOps . invert ( image )
# where_0, where_1 = np.where(image == 0), np.where(image == 255)
# image[where_0], image[where_1] = 255, 0
if h0 is not None and w0 is not None :
h , w = h0 , w0
w , h = map ( lambda x : x - x % 64 , ( w , h ) ) # resize to integer multiple of 64
print ( f " New mask size ( { w } , { h } ) " )
image = image . resize ( ( newW , newH ) , resample = Image . Resampling . LANCZOS )
image = np . array ( image )
image = image . astype ( np . float32 ) / 255.0
image = image [ None ] . transpose ( 0 , 3 , 1 , 2 )
image = torch . from_numpy ( image )
return image
2022-09-02 10:28:36 +02:00
# https://stackoverflow.com/a/61114178
def img_to_base64_str ( img ) :
buffered = BytesIO ( )
img . save ( buffered , format = " PNG " )
buffered . seek ( 0 )
img_byte = buffered . getvalue ( )
img_str = " data:image/png;base64, " + base64 . b64encode ( img_byte ) . decode ( )
return img_str
def base64_str_to_img ( img_str ) :
img_str = img_str [ len ( " data:image/png;base64, " ) : ]
data = base64 . b64decode ( img_str )
buffered = BytesIO ( data )
img = Image . open ( buffered )
return img