Files
easydiffusion/ui/sd_internal/runtime.py

410 lines
14 KiB
Python

import sys
import os
import uuid
import re
import torch
import traceback
import numpy as np
from omegaconf import OmegaConf
from pytorch_lightning import logging
from einops import rearrange
from PIL import Image, ImageOps, ImageChops
from ldm.generate import Generate
import transformers
from gfpgan import GFPGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
transformers.logging.set_verbosity_error()
from . import Request, Response, Image as ResponseImage
import base64
import json
from io import BytesIO
filename_regex = re.compile('[^a-zA-Z0-9]')
generator = None
gfpgan_file = None
real_esrgan_file = None
model_gfpgan = None
model_real_esrgan = None
device = None
precision = 'autocast'
has_valid_gpu = False
force_full_precision = False
# local
stop_processing = False
temp_images = {}
try:
gpu = torch.cuda.current_device()
gpu_name = torch.cuda.get_device_name(gpu)
print('GPU detected: ', gpu_name)
force_full_precision = ('nvidia' in gpu_name.lower() or 'geforce' in gpu_name.lower()) and (' 1660' in gpu_name or ' 1650' in gpu_name) # otherwise these NVIDIA cards create green images
if force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
mem_free, mem_total = torch.cuda.mem_get_info(gpu)
mem_total /= float(10**9)
if mem_total < 3.0:
print("GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion")
raise Exception()
has_valid_gpu = True
except:
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
pass
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', precision_to_use='autocast'):
global generator
device = device_to_use if has_valid_gpu else 'cpu'
precision = precision_to_use if not force_full_precision else 'full'
try:
config = 'configs/models.yaml'
model = 'stable-diffusion-1.4'
models = OmegaConf.load(config)
width = models[model].width
height = models[model].height
config = models[model].config
weights = ckpt_to_use + '.ckpt'
except (FileNotFoundError, IOError, KeyError) as e:
print(f'{e}. Aborting.')
sys.exit(-1)
generator = Generate(
width=width,
height=height,
sampler_name='ddim',
weights=weights,
full_precision=(precision == 'full'),
config=config,
grid=False,
# this is solely for recreating the prompt
seamless=False,
embedding_path=None,
device_type=device,
ignore_ctrl_c=True,
)
# gets rid of annoying messages about random seed
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
# preload the model
generator.load_model()
def load_model_gfpgan(gfpgan_to_use):
global gfpgan_file, model_gfpgan
if gfpgan_to_use is None:
return
gfpgan_file = gfpgan_to_use
model_path = gfpgan_to_use + ".pth"
if device == 'cpu':
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
else:
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cuda'))
print('loaded ', gfpgan_to_use, 'to', device, 'precision', precision)
def load_model_real_esrgan(real_esrgan_to_use):
global real_esrgan_file, model_real_esrgan
if real_esrgan_to_use is None:
return
real_esrgan_file = real_esrgan_to_use
model_path = real_esrgan_to_use + ".pth"
RealESRGAN_models = {
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
}
model_to_use = RealESRGAN_models[real_esrgan_to_use]
if device == 'cpu':
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
model_real_esrgan.device = torch.device('cpu')
model_real_esrgan.model.to('cpu')
else:
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=(precision != 'full'))
model_real_esrgan.model.name = real_esrgan_to_use
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
def mk_img(req: Request):
try:
yield from do_mk_img(req)
except Exception as e:
print(traceback.format_exc())
gc()
# if device != "cpu":
# modelFS.to("cpu")
# modelCS.to("cpu")
# model.model1.to("cpu")
# model.model2.to("cpu")
# gc()
yield json.dumps({
"status": 'failed',
"detail": str(e)
})
def do_mk_img(req: Request):
stop_processing = False
if req.use_face_correction != gfpgan_file:
load_model_gfpgan(req.use_face_correction)
if req.use_upscale != real_esrgan_file:
load_model_real_esrgan(req.use_upscale)
init_image = None
init_mask = None
if req.init_image is not None:
image = base64_str_to_img(req.init_image)
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from base64")
if req.width is not None and req.height is not None:
h, w = req.height, req.width
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
init_image = generator._create_init_image(image)
if generator._has_transparency(image) and req.mask is None: # if image has a transparent area and no mask was provided, then try to generate mask
print('>> Initial image has transparent areas. Will inpaint in these regions.')
if generator._check_for_erasure(image):
print(
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
)
init_mask = generator._create_init_mask(image) # this returns a torch tensor
if device != "cpu" and precision != "full":
init_image = init_image.half()
if req.mask is not None:
image = base64_str_to_img(req.mask)
image = ImageChops.invert(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from base64")
if req.width is not None and req.height is not None:
h, w = req.height, req.width
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
init_mask = generator._create_init_mask(image)
if init_mask is not None:
req.sampler = 'plms' # hack to force the underlying implementation to initialize DDIM properly
result = generator.prompt2image(
req.prompt,
iterations = req.num_outputs,
steps = req.num_inference_steps,
seed = req.seed,
cfg_scale = req.guidance_scale,
ddim_eta = 0.0,
skip_normalize = False,
image_callback = None,
step_callback = None,
width = req.width,
height = req.height,
sampler_name = req.sampler,
seamless = False,
log_tokenization= False,
with_variations = None,
variation_amount = 0.0,
# these are specific to img2img and inpaint
init_img = init_image,
init_mask = init_mask,
fit = False,
strength = req.prompt_strength,
init_img_is_path = False,
# these are specific to GFPGAN/ESRGAN
gfpgan_strength= 0,
save_original = False,
upscale = None,
negative_prompt= req.negative_prompt,
)
has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
(req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
print('has filter', has_filters)
return_orig_img = not has_filters or not req.show_only_filtered_image
res = Response()
res.request = req
res.images = []
if req.save_to_disk_path is not None:
session_out_path = os.path.join(req.save_to_disk_path, req.session_id)
os.makedirs(session_out_path, exist_ok=True)
else:
session_out_path = None
for img, seed in result:
if req.save_to_disk_path is not None:
prompt_flattened = filename_regex.sub('_', req.prompt)
prompt_flattened = prompt_flattened[:50]
img_id = str(uuid.uuid4())[-8:]
file_path = f"{prompt_flattened}_{img_id}"
img_out_path = os.path.join(session_out_path, f"{file_path}.{req.output_format}")
meta_out_path = os.path.join(session_out_path, f"{file_path}.txt")
if return_orig_img:
save_image(img, img_out_path)
save_metadata(meta_out_path, req.prompt, seed, req.width, req.height, req.num_inference_steps, req.guidance_scale, req.prompt_strength, req.use_face_correction, req.use_upscale, req.sampler, req.negative_prompt)
if return_orig_img:
img_data = img_to_base64_str(img)
res_image_orig = ResponseImage(data=img_data, seed=seed)
res.images.append(res_image_orig)
if req.save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path
if has_filters and not stop_processing:
print('Applying filters..')
gc()
filters_applied = []
np_img = img.convert('RGB')
np_img = np.array(np_img, dtype=np.uint8)
if req.use_face_correction:
_, _, np_img = model_gfpgan.enhance(np_img, has_aligned=False, only_center_face=False, paste_back=True)
filters_applied.append(req.use_face_correction)
if req.use_upscale:
np_img, _ = model_real_esrgan.enhance(np_img)
filters_applied.append(req.use_upscale)
filtered_image = Image.fromarray(np_img)
filtered_img_data = img_to_base64_str(filtered_image)
res_image_filtered = ResponseImage(data=filtered_img_data, seed=seed)
res.images.append(res_image_filtered)
filters_applied = "_".join(filters_applied)
if req.save_to_disk_path is not None:
filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{req.output_format}")
save_image(filtered_image, filtered_img_out_path)
res_image_filtered.path_abs = filtered_img_out_path
del filtered_image
del img
print('Task completed')
yield json.dumps(res.json())
def save_image(img, img_out_path):
try:
img.save(img_out_path)
except:
print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, prompt, seed, width, height, num_inference_steps, guidance_scale, prompt_strength, use_correct_face, use_upscale, sampler_name, negative_prompt):
metadata = f"{prompt}\nWidth: {width}\nHeight: {height}\nSeed: {seed}\nSteps: {num_inference_steps}\nGuidance Scale: {guidance_scale}\nPrompt Strength: {prompt_strength}\nUse Face Correction: {use_correct_face}\nUse Upscaling: {use_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}"
try:
with open(meta_out_path, 'w') as f:
f.write(metadata)
except:
print('could not save the file', traceback.format_exc())
def gc():
if device == 'cpu':
return
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def load_img(img_str, w0, h0):
image = base64_str_to_img(img_str).convert("RGB")
w, h = image.size
print(f"loaded input image of size ({w}, {h}) from base64")
if h0 is not None and w0 is not None:
h, w = h0, w0
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.*image - 1.
def load_mask(mask_str, h0, w0, newH, newW, invert=False):
image = base64_str_to_img(mask_str).convert("RGB")
w, h = image.size
print(f"loaded input mask of size ({w}, {h})")
if invert:
print("inverted")
image = ImageOps.invert(image)
# where_0, where_1 = np.where(image == 0), np.where(image == 255)
# image[where_0], image[where_1] = 255, 0
if h0 is not None and w0 is not None:
h, w = h0, w0
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
print(f"New mask size ({w}, {h})")
image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS)
image = np.array(image)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
# https://stackoverflow.com/a/61114178
def img_to_base64_str(img):
buffered = BytesIO()
img.save(buffered, format="PNG")
buffered.seek(0)
img_byte = buffered.getvalue()
img_str = "data:image/png;base64," + base64.b64encode(img_byte).decode()
return img_str
def base64_str_to_img(img_str):
img_str = img_str[len("data:image/png;base64,"):]
data = base64.b64decode(img_str)
buffered = BytesIO(data)
img = Image.open(buffered)
return img