mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-08-10 08:07:47 +02:00
Revert "Revert "Revert "Merge pull request #112 from cmdr2/develop"""
This reverts commit 788dcbf471
.
This commit is contained in:
@ -16,17 +16,12 @@ from ldm.util import instantiate_from_config
|
||||
from optimizedSD.optimUtils import split_weighted_subprompts
|
||||
from transformers import logging
|
||||
|
||||
from gfpgan import GFPGANer
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
import uuid
|
||||
|
||||
logging.set_verbosity_error()
|
||||
|
||||
# consts
|
||||
config_yaml = "optimizedSD/v1-inference.yaml"
|
||||
filename_regex = re.compile('[^a-zA-Z0-9]')
|
||||
|
||||
# api stuff
|
||||
from . import Request, Response, Image as ResponseImage
|
||||
@ -36,16 +31,10 @@ from io import BytesIO
|
||||
# local
|
||||
session_id = str(uuid.uuid4())[-8:]
|
||||
|
||||
ckpt_file = None
|
||||
gfpgan_file = None
|
||||
real_esrgan_file = None
|
||||
|
||||
ckpt = None
|
||||
model = None
|
||||
modelCS = None
|
||||
modelFS = None
|
||||
model_gfpgan = None
|
||||
model_real_esrgan = None
|
||||
|
||||
model_is_half = False
|
||||
model_fs_is_half = False
|
||||
device = None
|
||||
@ -54,30 +43,25 @@ precision = 'autocast'
|
||||
sampler_plms = None
|
||||
sampler_ddim = None
|
||||
|
||||
has_valid_gpu = False
|
||||
force_full_precision = False
|
||||
try:
|
||||
gpu_name = torch.cuda.get_device_name(torch.cuda.current_device())
|
||||
has_valid_gpu = True
|
||||
force_full_precision = ('nvidia' in gpu_name.lower()) and ('1660' in gpu_name or ' 1650' in gpu_name) # otherwise these NVIDIA cards create green images
|
||||
force_full_precision = ('nvidia' in gpu_name.lower()) and (' 1660' in gpu_name or ' 1650' in gpu_name) # otherwise these NVIDIA cards create green images
|
||||
if force_full_precision:
|
||||
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
|
||||
print('forcing full precision on NVIDIA 16xx cards, to avoid green images')
|
||||
except:
|
||||
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
|
||||
pass
|
||||
|
||||
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast', half_model_fs=False):
|
||||
global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half
|
||||
# api
|
||||
def load_model(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast', half_model_fs=False):
|
||||
global ckpt, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half
|
||||
|
||||
ckpt_file = ckpt_to_use
|
||||
device = device_to_use if has_valid_gpu else 'cpu'
|
||||
ckpt = ckpt_to_use
|
||||
device = device_to_use
|
||||
precision = precision_to_use if not force_full_precision else 'full'
|
||||
unet_bs = unet_bs_to_use
|
||||
|
||||
if device == 'cpu':
|
||||
precision = 'full'
|
||||
|
||||
sd = load_model_from_config(f"{ckpt_file}.ckpt")
|
||||
sd = load_model_from_config(f"{ckpt}")
|
||||
li, lo = [], []
|
||||
for key, value in sd.items():
|
||||
sp = key.split(".")
|
||||
@ -127,89 +111,29 @@ def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_u
|
||||
else:
|
||||
model_fs_is_half = False
|
||||
|
||||
print('loaded ', ckpt_file, 'to', device, 'precision', precision)
|
||||
|
||||
def load_model_gfpgan(gfpgan_to_use):
|
||||
global gfpgan_file, model_gfpgan
|
||||
|
||||
if gfpgan_to_use is None:
|
||||
return
|
||||
|
||||
gfpgan_file = gfpgan_to_use
|
||||
model_path = gfpgan_to_use + ".pth"
|
||||
|
||||
if device == 'cpu':
|
||||
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
|
||||
else:
|
||||
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cuda'))
|
||||
|
||||
print('loaded ', gfpgan_to_use, 'to', device, 'precision', precision)
|
||||
|
||||
def load_model_real_esrgan(real_esrgan_to_use):
|
||||
global real_esrgan_file, model_real_esrgan
|
||||
|
||||
if real_esrgan_to_use is None:
|
||||
return
|
||||
|
||||
real_esrgan_file = real_esrgan_to_use
|
||||
model_path = real_esrgan_to_use + ".pth"
|
||||
|
||||
RealESRGAN_models = {
|
||||
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
|
||||
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
}
|
||||
|
||||
model_to_use = RealESRGAN_models[real_esrgan_to_use]
|
||||
|
||||
if device == 'cpu':
|
||||
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
|
||||
model_real_esrgan.device = torch.device('cpu')
|
||||
model_real_esrgan.model.to('cpu')
|
||||
else:
|
||||
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=model_is_half)
|
||||
|
||||
model_real_esrgan.model.name = real_esrgan_to_use
|
||||
|
||||
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
|
||||
|
||||
def mk_img(req: Request):
|
||||
global modelFS, device
|
||||
global model_gfpgan, model_real_esrgan
|
||||
|
||||
res = Response()
|
||||
res.images = []
|
||||
|
||||
model.turbo = req.turbo
|
||||
if req.use_cpu:
|
||||
if device != 'cpu':
|
||||
device = 'cpu'
|
||||
device = 'cpu'
|
||||
|
||||
if model_is_half:
|
||||
load_model_ckpt(ckpt_file, device)
|
||||
|
||||
load_model_gfpgan(gfpgan_file)
|
||||
load_model_real_esrgan(real_esrgan_file)
|
||||
if model_is_half:
|
||||
print('reloading model for cpu')
|
||||
load_model(ckpt, device)
|
||||
else:
|
||||
if has_valid_gpu:
|
||||
prev_device = device
|
||||
device = 'cuda'
|
||||
device = 'cuda'
|
||||
|
||||
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
|
||||
(precision == 'full' and not req.use_full_precision and not force_full_precision) or \
|
||||
(req.init_image is None and model_fs_is_half) or \
|
||||
(req.init_image is not None and not model_fs_is_half and not force_full_precision):
|
||||
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
|
||||
(precision == 'full' and not req.use_full_precision and not force_full_precision) or \
|
||||
(req.init_image is None and model_fs_is_half) or \
|
||||
(req.init_image is not None and not model_fs_is_half and not force_full_precision):
|
||||
|
||||
load_model_ckpt(ckpt_file, device, model.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'), half_model_fs=(req.init_image is not None and not req.use_full_precision))
|
||||
|
||||
if prev_device != device:
|
||||
load_model_gfpgan(gfpgan_file)
|
||||
load_model_real_esrgan(real_esrgan_file)
|
||||
|
||||
if req.use_face_correction != gfpgan_file:
|
||||
load_model_gfpgan(req.use_face_correction)
|
||||
|
||||
if req.use_upscale != real_esrgan_file:
|
||||
load_model_real_esrgan(req.use_upscale)
|
||||
print('reloading model for cuda')
|
||||
load_model(ckpt, device, model.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'), half_model_fs=(req.init_image is not None and not req.use_full_precision))
|
||||
|
||||
model.cdevice = device
|
||||
modelCS.cond_stage_model.device = device
|
||||
@ -228,9 +152,6 @@ def mk_img(req: Request):
|
||||
opt_strength = req.prompt_strength
|
||||
opt_save_to_disk_path = req.save_to_disk_path
|
||||
opt_init_img = req.init_image
|
||||
opt_use_face_correction = req.use_face_correction
|
||||
opt_use_upscale = req.use_upscale
|
||||
opt_show_only_filtered = req.show_only_filtered_image
|
||||
opt_format = 'png'
|
||||
|
||||
print(req.to_string(), '\n device', device)
|
||||
@ -324,54 +245,29 @@ def mk_img(req: Request):
|
||||
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample)
|
||||
img = Image.fromarray(x_sample.astype(np.uint8))
|
||||
|
||||
img_data = img_to_base64_str(img)
|
||||
res.images.append(ResponseImage(data=img_data, seed=opt_seed))
|
||||
|
||||
if opt_save_to_disk_path is not None:
|
||||
prompt_flattened = filename_regex.sub('_', prompts[0])
|
||||
prompt_flattened = prompt_flattened[:50]
|
||||
try:
|
||||
prompt_flattened = "_".join(re.split(":| ", prompts[0]))
|
||||
prompt_flattened = prompt_flattened.replace(',', '')
|
||||
prompt_flattened = prompt_flattened[:50]
|
||||
|
||||
img_id = str(uuid.uuid4())[-8:]
|
||||
img_id = str(uuid.uuid4())[-8:]
|
||||
|
||||
file_path = f"{prompt_flattened}_{img_id}"
|
||||
img_out_path = os.path.join(session_out_path, f"{file_path}.{opt_format}")
|
||||
meta_out_path = os.path.join(session_out_path, f"{file_path}.txt")
|
||||
file_path = f"{prompt_flattened}_{img_id}"
|
||||
img_out_path = os.path.join(session_out_path, f"{file_path}.{opt_format}")
|
||||
meta_out_path = os.path.join(session_out_path, f"{file_path}.txt")
|
||||
|
||||
if not opt_show_only_filtered:
|
||||
save_image(img, img_out_path)
|
||||
|
||||
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale)
|
||||
|
||||
if not opt_show_only_filtered:
|
||||
img_data = img_to_base64_str(img)
|
||||
res.images.append(ResponseImage(data=img_data, seed=opt_seed))
|
||||
|
||||
if (opt_use_face_correction is not None and opt_use_face_correction.startswith('GFPGAN')) or \
|
||||
(opt_use_upscale is not None and opt_use_upscale.startswith('RealESRGAN')):
|
||||
|
||||
gc()
|
||||
filters_applied = []
|
||||
|
||||
if opt_use_face_correction:
|
||||
_, _, output = model_gfpgan.enhance(x_sample[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||
x_sample = output[:,:,::-1]
|
||||
filters_applied.append(opt_use_face_correction)
|
||||
|
||||
if opt_use_upscale:
|
||||
output, _ = model_real_esrgan.enhance(x_sample[:,:,::-1])
|
||||
x_sample = output[:,:,::-1]
|
||||
filters_applied.append(opt_use_upscale)
|
||||
|
||||
filtered_image = Image.fromarray(x_sample)
|
||||
|
||||
filtered_img_data = img_to_base64_str(filtered_image)
|
||||
res.images.append(ResponseImage(data=filtered_img_data, seed=opt_seed))
|
||||
|
||||
filters_applied = "_".join(filters_applied)
|
||||
|
||||
if opt_save_to_disk_path is not None:
|
||||
filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{opt_format}")
|
||||
save_image(filtered_image, filtered_img_out_path)
|
||||
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}"
|
||||
img.save(img_out_path)
|
||||
with open(meta_out_path, 'w') as f:
|
||||
f.write(metadata)
|
||||
except:
|
||||
print('could not save the file', traceback.format_exc())
|
||||
|
||||
seeds += str(opt_seed) + ","
|
||||
opt_seed += 1
|
||||
@ -386,21 +282,6 @@ def mk_img(req: Request):
|
||||
|
||||
return res
|
||||
|
||||
def save_image(img, img_out_path):
|
||||
try:
|
||||
img.save(img_out_path)
|
||||
except:
|
||||
print('could not save the file', traceback.format_exc())
|
||||
|
||||
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale):
|
||||
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}"
|
||||
|
||||
try:
|
||||
with open(meta_out_path, 'w') as f:
|
||||
f.write(metadata)
|
||||
except:
|
||||
print('could not save the file', traceback.format_exc())
|
||||
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed):
|
||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
@ -446,13 +327,6 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
|
||||
return samples_ddim
|
||||
|
||||
def gc():
|
||||
if device == 'cpu':
|
||||
return
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
# internal
|
||||
|
||||
def chunk(it, size):
|
||||
|
Reference in New Issue
Block a user