mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-08-02 21:28:59 +02:00
410 lines
14 KiB
Python
410 lines
14 KiB
Python
import sys
|
|
import os
|
|
import uuid
|
|
import re
|
|
import torch
|
|
import traceback
|
|
import numpy as np
|
|
from omegaconf import OmegaConf
|
|
from pytorch_lightning import logging
|
|
from einops import rearrange
|
|
from PIL import Image, ImageOps, ImageChops
|
|
from ldm.generate import Generate
|
|
import transformers
|
|
|
|
from gfpgan import GFPGANer
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
from realesrgan import RealESRGANer
|
|
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
from . import Request, Response, Image as ResponseImage
|
|
import base64
|
|
import json
|
|
from io import BytesIO
|
|
|
|
filename_regex = re.compile('[^a-zA-Z0-9]')
|
|
|
|
generator = None
|
|
|
|
gfpgan_file = None
|
|
real_esrgan_file = None
|
|
model_gfpgan = None
|
|
model_real_esrgan = None
|
|
|
|
device = None
|
|
precision = 'autocast'
|
|
|
|
has_valid_gpu = False
|
|
force_full_precision = False
|
|
|
|
# local
|
|
stop_processing = False
|
|
temp_images = {}
|
|
|
|
try:
|
|
gpu = torch.cuda.current_device()
|
|
gpu_name = torch.cuda.get_device_name(gpu)
|
|
print('GPU detected: ', gpu_name)
|
|
|
|
force_full_precision = ('nvidia' in gpu_name.lower() or 'geforce' in gpu_name.lower()) and (' 1660' in gpu_name or ' 1650' in gpu_name) # otherwise these NVIDIA cards create green images
|
|
if force_full_precision:
|
|
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
|
|
|
|
mem_free, mem_total = torch.cuda.mem_get_info(gpu)
|
|
mem_total /= float(10**9)
|
|
if mem_total < 3.0:
|
|
print("GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion")
|
|
raise Exception()
|
|
|
|
has_valid_gpu = True
|
|
except:
|
|
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
|
|
pass
|
|
|
|
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', precision_to_use='autocast'):
|
|
global generator
|
|
|
|
device = device_to_use if has_valid_gpu else 'cpu'
|
|
precision = precision_to_use if not force_full_precision else 'full'
|
|
|
|
try:
|
|
config = 'configs/models.yaml'
|
|
model = 'stable-diffusion-1.4'
|
|
|
|
models = OmegaConf.load(config)
|
|
width = models[model].width
|
|
height = models[model].height
|
|
config = models[model].config
|
|
weights = ckpt_to_use + '.ckpt'
|
|
except (FileNotFoundError, IOError, KeyError) as e:
|
|
print(f'{e}. Aborting.')
|
|
sys.exit(-1)
|
|
|
|
generator = Generate(
|
|
width=width,
|
|
height=height,
|
|
sampler_name='ddim',
|
|
weights=weights,
|
|
full_precision=(precision == 'full'),
|
|
config=config,
|
|
grid=False,
|
|
# this is solely for recreating the prompt
|
|
seamless=False,
|
|
embedding_path=None,
|
|
device_type=device,
|
|
ignore_ctrl_c=True,
|
|
)
|
|
|
|
# gets rid of annoying messages about random seed
|
|
logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)
|
|
|
|
# preload the model
|
|
generator.load_model()
|
|
|
|
def load_model_gfpgan(gfpgan_to_use):
|
|
global gfpgan_file, model_gfpgan
|
|
|
|
if gfpgan_to_use is None:
|
|
return
|
|
|
|
gfpgan_file = gfpgan_to_use
|
|
model_path = gfpgan_to_use + ".pth"
|
|
|
|
if device == 'cpu':
|
|
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
|
|
else:
|
|
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cuda'))
|
|
|
|
print('loaded ', gfpgan_to_use, 'to', device, 'precision', precision)
|
|
|
|
def load_model_real_esrgan(real_esrgan_to_use):
|
|
global real_esrgan_file, model_real_esrgan
|
|
|
|
if real_esrgan_to_use is None:
|
|
return
|
|
|
|
real_esrgan_file = real_esrgan_to_use
|
|
model_path = real_esrgan_to_use + ".pth"
|
|
|
|
RealESRGAN_models = {
|
|
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
|
|
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
|
}
|
|
|
|
model_to_use = RealESRGAN_models[real_esrgan_to_use]
|
|
|
|
if device == 'cpu':
|
|
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
|
|
model_real_esrgan.device = torch.device('cpu')
|
|
model_real_esrgan.model.to('cpu')
|
|
else:
|
|
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=(precision != 'full'))
|
|
|
|
model_real_esrgan.model.name = real_esrgan_to_use
|
|
|
|
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
|
|
|
|
def mk_img(req: Request):
|
|
try:
|
|
yield from do_mk_img(req)
|
|
except Exception as e:
|
|
print(traceback.format_exc())
|
|
|
|
gc()
|
|
|
|
# if device != "cpu":
|
|
# modelFS.to("cpu")
|
|
# modelCS.to("cpu")
|
|
|
|
# model.model1.to("cpu")
|
|
# model.model2.to("cpu")
|
|
|
|
# gc()
|
|
|
|
yield json.dumps({
|
|
"status": 'failed',
|
|
"detail": str(e)
|
|
})
|
|
|
|
def do_mk_img(req: Request):
|
|
stop_processing = False
|
|
|
|
if req.use_face_correction != gfpgan_file:
|
|
load_model_gfpgan(req.use_face_correction)
|
|
|
|
if req.use_upscale != real_esrgan_file:
|
|
load_model_real_esrgan(req.use_upscale)
|
|
|
|
init_image = None
|
|
init_mask = None
|
|
|
|
if req.init_image is not None:
|
|
image = base64_str_to_img(req.init_image)
|
|
|
|
w, h = image.size
|
|
print(f"loaded input image of size ({w}, {h}) from base64")
|
|
if req.width is not None and req.height is not None:
|
|
h, w = req.height, req.width
|
|
|
|
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
|
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
|
|
init_image = generator._create_init_image(image)
|
|
|
|
if generator._has_transparency(image) and req.mask is None: # if image has a transparent area and no mask was provided, then try to generate mask
|
|
print('>> Initial image has transparent areas. Will inpaint in these regions.')
|
|
if generator._check_for_erasure(image):
|
|
print(
|
|
'>> WARNING: Colors underneath the transparent region seem to have been erased.\n',
|
|
'>> Inpainting will be suboptimal. Please preserve the colors when making\n',
|
|
'>> a transparency mask, or provide mask explicitly using --init_mask (-M).'
|
|
)
|
|
init_mask = generator._create_init_mask(image) # this returns a torch tensor
|
|
|
|
if device != "cpu" and precision != "full":
|
|
init_image = init_image.half()
|
|
|
|
if req.mask is not None:
|
|
image = base64_str_to_img(req.mask)
|
|
|
|
image = ImageChops.invert(image)
|
|
|
|
w, h = image.size
|
|
print(f"loaded input image of size ({w}, {h}) from base64")
|
|
if req.width is not None and req.height is not None:
|
|
h, w = req.height, req.width
|
|
|
|
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
|
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
|
|
|
|
init_mask = generator._create_init_mask(image)
|
|
|
|
if init_mask is not None:
|
|
req.sampler = 'plms' # hack to force the underlying implementation to initialize DDIM properly
|
|
|
|
result = generator.prompt2image(
|
|
req.prompt,
|
|
iterations = req.num_outputs,
|
|
steps = req.num_inference_steps,
|
|
seed = req.seed,
|
|
cfg_scale = req.guidance_scale,
|
|
ddim_eta = 0.0,
|
|
skip_normalize = False,
|
|
image_callback = None,
|
|
step_callback = None,
|
|
width = req.width,
|
|
height = req.height,
|
|
sampler_name = req.sampler,
|
|
seamless = False,
|
|
log_tokenization= False,
|
|
with_variations = None,
|
|
variation_amount = 0.0,
|
|
# these are specific to img2img and inpaint
|
|
init_img = init_image,
|
|
init_mask = init_mask,
|
|
fit = False,
|
|
strength = req.prompt_strength,
|
|
init_img_is_path = False,
|
|
# these are specific to GFPGAN/ESRGAN
|
|
gfpgan_strength= 0,
|
|
save_original = False,
|
|
upscale = None,
|
|
negative_prompt= req.negative_prompt,
|
|
)
|
|
|
|
has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
|
|
(req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
|
|
|
|
print('has filter', has_filters)
|
|
|
|
return_orig_img = not has_filters or not req.show_only_filtered_image
|
|
|
|
res = Response()
|
|
res.request = req
|
|
res.images = []
|
|
|
|
if req.save_to_disk_path is not None:
|
|
session_out_path = os.path.join(req.save_to_disk_path, req.session_id)
|
|
os.makedirs(session_out_path, exist_ok=True)
|
|
else:
|
|
session_out_path = None
|
|
|
|
for img, seed in result:
|
|
if req.save_to_disk_path is not None:
|
|
prompt_flattened = filename_regex.sub('_', req.prompt)
|
|
prompt_flattened = prompt_flattened[:50]
|
|
|
|
img_id = str(uuid.uuid4())[-8:]
|
|
|
|
file_path = f"{prompt_flattened}_{img_id}"
|
|
img_out_path = os.path.join(session_out_path, f"{file_path}.{req.output_format}")
|
|
meta_out_path = os.path.join(session_out_path, f"{file_path}.txt")
|
|
|
|
if return_orig_img:
|
|
save_image(img, img_out_path)
|
|
|
|
save_metadata(meta_out_path, req.prompt, seed, req.width, req.height, req.num_inference_steps, req.guidance_scale, req.prompt_strength, req.use_face_correction, req.use_upscale, req.sampler, req.negative_prompt)
|
|
|
|
if return_orig_img:
|
|
img_data = img_to_base64_str(img)
|
|
res_image_orig = ResponseImage(data=img_data, seed=seed)
|
|
res.images.append(res_image_orig)
|
|
|
|
if req.save_to_disk_path is not None:
|
|
res_image_orig.path_abs = img_out_path
|
|
|
|
if has_filters and not stop_processing:
|
|
print('Applying filters..')
|
|
|
|
gc()
|
|
filters_applied = []
|
|
|
|
np_img = img.convert('RGB')
|
|
np_img = np.array(np_img, dtype=np.uint8)
|
|
|
|
if req.use_face_correction:
|
|
_, _, np_img = model_gfpgan.enhance(np_img, has_aligned=False, only_center_face=False, paste_back=True)
|
|
filters_applied.append(req.use_face_correction)
|
|
|
|
if req.use_upscale:
|
|
np_img, _ = model_real_esrgan.enhance(np_img)
|
|
filters_applied.append(req.use_upscale)
|
|
|
|
filtered_image = Image.fromarray(np_img)
|
|
|
|
filtered_img_data = img_to_base64_str(filtered_image)
|
|
res_image_filtered = ResponseImage(data=filtered_img_data, seed=seed)
|
|
res.images.append(res_image_filtered)
|
|
|
|
filters_applied = "_".join(filters_applied)
|
|
|
|
if req.save_to_disk_path is not None:
|
|
filtered_img_out_path = os.path.join(session_out_path, f"{file_path}_{filters_applied}.{req.output_format}")
|
|
save_image(filtered_image, filtered_img_out_path)
|
|
res_image_filtered.path_abs = filtered_img_out_path
|
|
|
|
del filtered_image
|
|
|
|
del img
|
|
|
|
print('Task completed')
|
|
|
|
yield json.dumps(res.json())
|
|
|
|
def save_image(img, img_out_path):
|
|
try:
|
|
img.save(img_out_path)
|
|
except:
|
|
print('could not save the file', traceback.format_exc())
|
|
|
|
def save_metadata(meta_out_path, prompt, seed, width, height, num_inference_steps, guidance_scale, prompt_strength, use_correct_face, use_upscale, sampler_name, negative_prompt):
|
|
metadata = f"{prompt}\nWidth: {width}\nHeight: {height}\nSeed: {seed}\nSteps: {num_inference_steps}\nGuidance Scale: {guidance_scale}\nPrompt Strength: {prompt_strength}\nUse Face Correction: {use_correct_face}\nUse Upscaling: {use_upscale}\nSampler: {sampler_name}\nNegative Prompt: {negative_prompt}"
|
|
|
|
try:
|
|
with open(meta_out_path, 'w') as f:
|
|
f.write(metadata)
|
|
except:
|
|
print('could not save the file', traceback.format_exc())
|
|
|
|
def gc():
|
|
if device == 'cpu':
|
|
return
|
|
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.ipc_collect()
|
|
|
|
def load_img(img_str, w0, h0):
|
|
image = base64_str_to_img(img_str).convert("RGB")
|
|
w, h = image.size
|
|
print(f"loaded input image of size ({w}, {h}) from base64")
|
|
if h0 is not None and w0 is not None:
|
|
h, w = h0, w0
|
|
|
|
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
|
image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
|
|
image = np.array(image).astype(np.float32) / 255.0
|
|
image = image[None].transpose(0, 3, 1, 2)
|
|
image = torch.from_numpy(image)
|
|
return 2.*image - 1.
|
|
|
|
def load_mask(mask_str, h0, w0, newH, newW, invert=False):
|
|
image = base64_str_to_img(mask_str).convert("RGB")
|
|
w, h = image.size
|
|
print(f"loaded input mask of size ({w}, {h})")
|
|
|
|
if invert:
|
|
print("inverted")
|
|
image = ImageOps.invert(image)
|
|
# where_0, where_1 = np.where(image == 0), np.where(image == 255)
|
|
# image[where_0], image[where_1] = 255, 0
|
|
|
|
if h0 is not None and w0 is not None:
|
|
h, w = h0, w0
|
|
|
|
w, h = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
|
|
|
print(f"New mask size ({w}, {h})")
|
|
image = image.resize((newW, newH), resample=Image.Resampling.LANCZOS)
|
|
image = np.array(image)
|
|
|
|
image = image.astype(np.float32) / 255.0
|
|
image = image[None].transpose(0, 3, 1, 2)
|
|
image = torch.from_numpy(image)
|
|
return image
|
|
|
|
# https://stackoverflow.com/a/61114178
|
|
def img_to_base64_str(img):
|
|
buffered = BytesIO()
|
|
img.save(buffered, format="PNG")
|
|
buffered.seek(0)
|
|
img_byte = buffered.getvalue()
|
|
img_str = "data:image/png;base64," + base64.b64encode(img_byte).decode()
|
|
return img_str
|
|
|
|
def base64_str_to_img(img_str):
|
|
img_str = img_str[len("data:image/png;base64,"):]
|
|
data = base64.b64decode(img_str)
|
|
buffered = BytesIO(data)
|
|
img = Image.open(buffered)
|
|
return img
|