forked from extern/easydiffusion
Merge pull request #404 from cmdr2/multi-gpu
Support for multiple GPUs, and improvements to RAM and VRAM usage
This commit is contained in:
commit
5d8bda1178
@ -18,7 +18,7 @@
|
||||
<div id="container">
|
||||
<div id="top-nav">
|
||||
<div id="logo">
|
||||
<h1>Stable Diffusion UI <small>v2.3.5 <span id="updateBranchLabel"></span></small></h1>
|
||||
<h1>Stable Diffusion UI <small>v2.3.6 <span id="updateBranchLabel"></span></small></h1>
|
||||
</div>
|
||||
<ul id="top-nav-items">
|
||||
<li class="dropdown">
|
||||
|
@ -1,8 +1,15 @@
|
||||
"""runtime.py: torch device owned by a thread.
|
||||
Notes:
|
||||
Avoid device switching, transfering all models will get too complex.
|
||||
To use a diffrent device signal the current render device to exit
|
||||
And then start a new clean thread for the new device.
|
||||
"""
|
||||
import json
|
||||
import os, re
|
||||
import traceback
|
||||
import torch
|
||||
import numpy as np
|
||||
from gc import collect as gc_collect
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image, ImageOps
|
||||
from tqdm import tqdm, trange
|
||||
@ -35,63 +42,145 @@ import base64
|
||||
from io import BytesIO
|
||||
#from colorama import Fore
|
||||
|
||||
# local
|
||||
stop_processing = False
|
||||
temp_images = {}
|
||||
from threading import local as LocalThreadVars
|
||||
thread_data = LocalThreadVars()
|
||||
|
||||
ckpt_file = None
|
||||
gfpgan_file = None
|
||||
real_esrgan_file = None
|
||||
def device_would_fail(device):
|
||||
if device == 'cpu': return None
|
||||
# Returns None when no issues found, otherwise returns the detected error str.
|
||||
# Memory check
|
||||
try:
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_total /= float(10**9)
|
||||
if mem_total < 3.0:
|
||||
return 'GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion'
|
||||
except RuntimeError as e:
|
||||
return str(e) # Return cuda errors from mem_get_info as strings
|
||||
return None
|
||||
|
||||
model = None
|
||||
modelCS = None
|
||||
modelFS = None
|
||||
model_gfpgan = None
|
||||
model_real_esrgan = None
|
||||
def device_select(device):
|
||||
if device == 'cpu': return True
|
||||
if not torch.cuda.is_available(): return False
|
||||
failure_msg = device_would_fail(device)
|
||||
if failure_msg:
|
||||
if 'invalid device' in failure_msg:
|
||||
raise NameError(f'GPU "{device}" could not be found. Remove this device from config.render_devices or use one of "auto" or "cuda".')
|
||||
print(failure_msg)
|
||||
return False
|
||||
|
||||
model_is_half = False
|
||||
model_fs_is_half = False
|
||||
device = None
|
||||
unet_bs = 1
|
||||
precision = 'autocast'
|
||||
sampler_plms = None
|
||||
sampler_ddim = None
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
|
||||
has_valid_gpu = False
|
||||
force_full_precision = False
|
||||
try:
|
||||
gpu = torch.cuda.current_device()
|
||||
gpu_name = torch.cuda.get_device_name(gpu)
|
||||
print('GPU detected: ', gpu_name)
|
||||
# otherwise these NVIDIA cards create green images
|
||||
thread_data.force_full_precision = ('nvidia' in device_name.lower() or 'geforce' in device_name.lower()) and (' 1660' in device_name or ' 1650' in device_name)
|
||||
if thread_data.force_full_precision:
|
||||
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', device_name)
|
||||
|
||||
force_full_precision = ('nvidia' in gpu_name.lower() or 'geforce' in gpu_name.lower()) and (' 1660' in gpu_name or ' 1650' in gpu_name) # otherwise these NVIDIA cards create green images
|
||||
if force_full_precision:
|
||||
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
|
||||
thread_data.device = device
|
||||
thread_data.has_valid_gpu = True
|
||||
return True
|
||||
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(gpu)
|
||||
mem_total /= float(10**9)
|
||||
if mem_total < 3.0:
|
||||
print("GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion")
|
||||
raise Exception()
|
||||
def device_init(device_selection=None):
|
||||
# Thread bound properties
|
||||
thread_data.stop_processing = False
|
||||
thread_data.temp_images = {}
|
||||
|
||||
has_valid_gpu = True
|
||||
except:
|
||||
thread_data.ckpt_file = None
|
||||
thread_data.gfpgan_file = None
|
||||
thread_data.real_esrgan_file = None
|
||||
|
||||
thread_data.model = None
|
||||
thread_data.modelCS = None
|
||||
thread_data.modelFS = None
|
||||
thread_data.model_gfpgan = None
|
||||
thread_data.model_real_esrgan = None
|
||||
|
||||
thread_data.model_is_half = False
|
||||
thread_data.model_fs_is_half = False
|
||||
thread_data.device = None
|
||||
thread_data.unet_bs = 1
|
||||
thread_data.precision = 'autocast'
|
||||
thread_data.sampler_plms = None
|
||||
thread_data.sampler_ddim = None
|
||||
|
||||
thread_data.turbo = False
|
||||
thread_data.has_valid_gpu = False
|
||||
thread_data.force_full_precision = False
|
||||
thread_data.reduced_memory = True
|
||||
|
||||
if device_selection.lower() == 'cpu':
|
||||
print('CPU requested, skipping gpu init.')
|
||||
thread_data.device = 'cpu'
|
||||
return
|
||||
if not torch.cuda.is_available():
|
||||
print('WARNING: torch.cuda is not available. Using the CPU, but this will be very slow!')
|
||||
thread_data.device = 'cpu'
|
||||
return
|
||||
device_count = torch.cuda.device_count()
|
||||
if device_count <= 1 and device_selection == 'auto':
|
||||
device_selection = 'current' # Use 'auto' only when there is more than one compatible device found.
|
||||
if device_selection == 'auto':
|
||||
print('Autoselecting GPU. Using most free memory.')
|
||||
max_mem_free = 0
|
||||
best_device = None
|
||||
for device in range(device_count):
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_free /= float(10**9)
|
||||
mem_total /= float(10**9)
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
print(f'GPU:{device} detected: {device_name} - Memory: {round(mem_total - mem_free, 2)}Go / {round(mem_total, 2)}Go')
|
||||
if max_mem_free < mem_free:
|
||||
max_mem_free = mem_free
|
||||
best_device = device
|
||||
if best_device and device_select(device):
|
||||
print(f'Setting GPU:{device} as active')
|
||||
torch.cuda.device(device)
|
||||
return
|
||||
if isinstance(device_selection, str):
|
||||
device_selection = device_selection.lower()
|
||||
if device_selection.startswith('gpu:'):
|
||||
device_selection = int(device_selection[4:])
|
||||
if device_selection != 'cuda' and device_selection != 'current' and device_selection != 'gpu':
|
||||
if device_select(device_selection):
|
||||
if isinstance(device_selection, int):
|
||||
print(f'Setting GPU:{device_selection} as active')
|
||||
else:
|
||||
print(f'Setting {device_selection} as active')
|
||||
torch.cuda.device(device_selection)
|
||||
return
|
||||
# By default use current device.
|
||||
print('Checking current GPU...')
|
||||
device = torch.cuda.current_device()
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
print(f'GPU:{device} detected: {device_name}')
|
||||
if device_select(device):
|
||||
return
|
||||
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
|
||||
pass
|
||||
thread_data.device = 'cpu'
|
||||
|
||||
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast'):
|
||||
global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half
|
||||
def is_first_cuda_device(device):
|
||||
if device is None: return False
|
||||
if device == 0 or device == '0': return True
|
||||
if device == 'cuda' or device == 'cuda:0': return True
|
||||
if device == 'gpu' or device == 'gpu:0': return True
|
||||
if device == 'current': return True
|
||||
if device == torch.device(0): return True
|
||||
return False
|
||||
|
||||
device = device_to_use if has_valid_gpu else 'cpu'
|
||||
precision = precision_to_use if not force_full_precision else 'full'
|
||||
unet_bs = unet_bs_to_use
|
||||
def load_model_ckpt():
|
||||
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
|
||||
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
|
||||
|
||||
unload_model()
|
||||
if not thread_data.precision:
|
||||
thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast'
|
||||
|
||||
if device == 'cpu':
|
||||
precision = 'full'
|
||||
if not thread_data.unet_bs:
|
||||
thread_data.unet_bs = 1
|
||||
|
||||
sd = load_model_from_config(f"{ckpt_to_use}.ckpt")
|
||||
if thread_data.device == 'cpu':
|
||||
thread_data.precision = 'full'
|
||||
|
||||
print('loading', thread_data.ckpt_file, 'to', thread_data.device, 'using precision', thread_data.precision)
|
||||
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
||||
li, lo = [], []
|
||||
for key, value in sd.items():
|
||||
sp = key.split(".")
|
||||
@ -114,88 +203,127 @@ def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_u
|
||||
model = instantiate_from_config(config.modelUNet)
|
||||
_, _ = model.load_state_dict(sd, strict=False)
|
||||
model.eval()
|
||||
model.cdevice = device
|
||||
model.unet_bs = unet_bs
|
||||
model.turbo = turbo
|
||||
model.cdevice = torch.device(thread_data.device)
|
||||
model.unet_bs = thread_data.unet_bs
|
||||
model.turbo = thread_data.turbo
|
||||
if thread_data.device != 'cpu':
|
||||
model.to(thread_data.device)
|
||||
#if thread_data.reduced_memory:
|
||||
#model.model1.to("cpu")
|
||||
#model.model2.to("cpu")
|
||||
thread_data.model = model
|
||||
|
||||
modelCS = instantiate_from_config(config.modelCondStage)
|
||||
_, _ = modelCS.load_state_dict(sd, strict=False)
|
||||
modelCS.eval()
|
||||
modelCS.cond_stage_model.device = device
|
||||
modelCS.cond_stage_model.device = torch.device(thread_data.device)
|
||||
if thread_data.device != 'cpu':
|
||||
if thread_data.reduced_memory:
|
||||
modelCS.to('cpu')
|
||||
else:
|
||||
modelCS.to(thread_data.device) # Preload on device if not already there.
|
||||
thread_data.modelCS = modelCS
|
||||
|
||||
modelFS = instantiate_from_config(config.modelFirstStage)
|
||||
_, _ = modelFS.load_state_dict(sd, strict=False)
|
||||
modelFS.eval()
|
||||
if thread_data.device != 'cpu':
|
||||
if thread_data.reduced_memory:
|
||||
modelFS.to('cpu')
|
||||
else:
|
||||
modelFS.to(thread_data.device) # Preload on device if not already there.
|
||||
thread_data.modelFS = modelFS
|
||||
del sd
|
||||
|
||||
if device != "cpu" and precision == "autocast":
|
||||
model.half()
|
||||
modelCS.half()
|
||||
modelFS.half()
|
||||
model_is_half = True
|
||||
model_fs_is_half = True
|
||||
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||
thread_data.model.half()
|
||||
thread_data.modelCS.half()
|
||||
thread_data.modelFS.half()
|
||||
thread_data.model_is_half = True
|
||||
thread_data.model_fs_is_half = True
|
||||
else:
|
||||
model_is_half = False
|
||||
model_fs_is_half = False
|
||||
thread_data.model_is_half = False
|
||||
thread_data.model_fs_is_half = False
|
||||
|
||||
ckpt_file = ckpt_to_use
|
||||
print('loaded', thread_data.ckpt_file, 'as', model.device, '->', modelCS.cond_stage_model.device, '->', thread_data.modelFS.device, 'using precision', thread_data.precision)
|
||||
|
||||
print('loaded ', ckpt_file, 'to', device, 'precision', precision)
|
||||
def unload_filters():
|
||||
if thread_data.model_gfpgan is not None:
|
||||
del thread_data.model_gfpgan
|
||||
thread_data.model_gfpgan = None
|
||||
|
||||
def unload_model():
|
||||
global model, modelCS, modelFS
|
||||
if thread_data.model_real_esrgan is not None:
|
||||
del thread_data.model_real_esrgan
|
||||
thread_data.model_real_esrgan = None
|
||||
|
||||
if model is not None:
|
||||
del model
|
||||
del modelCS
|
||||
del modelFS
|
||||
def unload_models():
|
||||
if thread_data.model is not None:
|
||||
print('Unloading models...')
|
||||
del thread_data.model
|
||||
del thread_data.modelCS
|
||||
del thread_data.modelFS
|
||||
|
||||
model = None
|
||||
modelCS = None
|
||||
modelFS = None
|
||||
thread_data.model = None
|
||||
thread_data.modelCS = None
|
||||
thread_data.modelFS = None
|
||||
|
||||
def load_model_gfpgan(gfpgan_to_use):
|
||||
global gfpgan_file, model_gfpgan
|
||||
def wait_model_move_to(model, target_device): # Send to target_device and wait until complete.
|
||||
if thread_data.device == target_device: return
|
||||
start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
||||
if start_mem <= 0: return
|
||||
model_name = model.__class__.__name__
|
||||
print(f'Device:{thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mo')
|
||||
start_time = time.time()
|
||||
model.to(target_device)
|
||||
time_step = start_time
|
||||
WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout.
|
||||
last_mem = start_mem
|
||||
is_transfering = True
|
||||
while is_transfering:
|
||||
time.sleep(0.5) # 500ms
|
||||
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
||||
is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time.
|
||||
last_mem = mem
|
||||
if not is_transfering:
|
||||
break;
|
||||
if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity.
|
||||
print(f'Device:{thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mo, Transfered: {round(start_mem - mem)}Mo')
|
||||
time_step = time.time()
|
||||
print(f'Device:{thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mo in {round(time.time() - start_time, 3)} seconds to {target_device}')
|
||||
|
||||
if gfpgan_to_use is None:
|
||||
return
|
||||
def load_model_gfpgan():
|
||||
if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.')
|
||||
#print('load_model_gfpgan called without setting gfpgan_file')
|
||||
#return
|
||||
if not is_first_cuda_device(thread_data.device):
|
||||
#TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
|
||||
raise Exception(f'Current device {torch.device(thread_data.device)} is not {torch.device(0)}. Cannot run GFPGANer.')
|
||||
model_path = thread_data.gfpgan_file + ".pth"
|
||||
thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
|
||||
|
||||
gfpgan_file = gfpgan_to_use
|
||||
model_path = gfpgan_to_use + ".pth"
|
||||
|
||||
if device == 'cpu':
|
||||
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
|
||||
else:
|
||||
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cuda'))
|
||||
|
||||
print('loaded ', gfpgan_to_use, 'to', device, 'precision', precision)
|
||||
|
||||
def load_model_real_esrgan(real_esrgan_to_use):
|
||||
global real_esrgan_file, model_real_esrgan
|
||||
|
||||
if real_esrgan_to_use is None:
|
||||
return
|
||||
|
||||
real_esrgan_file = real_esrgan_to_use
|
||||
model_path = real_esrgan_to_use + ".pth"
|
||||
def load_model_real_esrgan():
|
||||
if thread_data.real_esrgan_file is None: raise ValueError(f'Thread real_esrgan_file is undefined.')
|
||||
#print('load_model_real_esrgan called without setting real_esrgan_file')
|
||||
#return
|
||||
model_path = thread_data.real_esrgan_file + ".pth"
|
||||
|
||||
RealESRGAN_models = {
|
||||
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
|
||||
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
}
|
||||
|
||||
model_to_use = RealESRGAN_models[real_esrgan_to_use]
|
||||
model_to_use = RealESRGAN_models[thread_data.real_esrgan_file]
|
||||
|
||||
if device == 'cpu':
|
||||
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
|
||||
model_real_esrgan.device = torch.device('cpu')
|
||||
model_real_esrgan.model.to('cpu')
|
||||
if thread_data.device == 'cpu':
|
||||
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
|
||||
#thread_data.model_real_esrgan.device = torch.device(thread_data.device)
|
||||
thread_data.model_real_esrgan.model.to('cpu')
|
||||
else:
|
||||
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=model_is_half)
|
||||
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half)
|
||||
|
||||
model_real_esrgan.model.name = real_esrgan_to_use
|
||||
|
||||
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
|
||||
thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file
|
||||
print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
|
||||
|
||||
def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None):
|
||||
if disk_path is None: return None
|
||||
@ -206,22 +334,38 @@ def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None):
|
||||
os.makedirs(session_out_path, exist_ok=True)
|
||||
|
||||
prompt_flattened = filename_regex.sub('_', prompt)[:50]
|
||||
|
||||
|
||||
if suffix is not None:
|
||||
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}")
|
||||
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}")
|
||||
|
||||
def apply_filters(filter_name, image_data):
|
||||
def apply_filters(filter_name, image_data, model_path=None):
|
||||
print(f'Applying filter {filter_name}...')
|
||||
gc()
|
||||
gc() # Free space before loading new data.
|
||||
if isinstance(image_data, torch.Tensor):
|
||||
print(image_data)
|
||||
image_data.to(thread_data.device)
|
||||
|
||||
if filter_name == 'gfpgan':
|
||||
_, _, output = model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||
if model_path is not None and model_path != thread_data.gfpgan_file:
|
||||
thread_data.gfpgan_file = model_path
|
||||
load_model_gfpgan()
|
||||
elif not thread_data.model_gfpgan:
|
||||
load_model_gfpgan()
|
||||
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
|
||||
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
|
||||
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||
image_data = output[:,:,::-1]
|
||||
|
||||
if filter_name == 'real_esrgan':
|
||||
output, _ = model_real_esrgan.enhance(image_data[:,:,::-1])
|
||||
if model_path is not None and model_path != thread_data.real_esrgan_file:
|
||||
thread_data.real_esrgan_file = model_path
|
||||
load_model_real_esrgan()
|
||||
elif not thread_data.model_real_esrgan:
|
||||
load_model_real_esrgan()
|
||||
if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.')
|
||||
print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
|
||||
output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1])
|
||||
image_data = output[:,:,::-1]
|
||||
|
||||
return image_data
|
||||
@ -232,83 +376,105 @@ def mk_img(req: Request):
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
|
||||
gc()
|
||||
|
||||
if device != "cpu":
|
||||
modelFS.to("cpu")
|
||||
modelCS.to("cpu")
|
||||
|
||||
model.model1.to("cpu")
|
||||
model.model2.to("cpu")
|
||||
|
||||
gc()
|
||||
if thread_data.reduced_memory:
|
||||
thread_data.modelFS.to('cpu')
|
||||
thread_data.modelCS.to('cpu')
|
||||
thread_data.model.model1.to("cpu")
|
||||
thread_data.model.model2.to("cpu")
|
||||
else:
|
||||
# Model crashed, release all resources in unknown state.
|
||||
unload_models()
|
||||
unload_filters()
|
||||
|
||||
gc() # Release from memory.
|
||||
yield json.dumps({
|
||||
"status": 'failed',
|
||||
"detail": str(e)
|
||||
})
|
||||
|
||||
def do_mk_img(req: Request):
|
||||
global ckpt_file
|
||||
global model, modelCS, modelFS, device
|
||||
global model_gfpgan, model_real_esrgan
|
||||
global stop_processing
|
||||
def update_temp_img(req, x_samples):
|
||||
partial_images = []
|
||||
for i in range(req.num_outputs):
|
||||
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format='JPEG')
|
||||
buf.seek(0)
|
||||
|
||||
stop_processing = False
|
||||
del img, x_sample, x_sample_ddim
|
||||
# don't delete x_samples, it is used in the code that called this callback
|
||||
|
||||
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
|
||||
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
||||
return partial_images
|
||||
|
||||
# Build and return the apropriate generator for do_mk_img
|
||||
def get_image_progress_generator(req, extra_props=None):
|
||||
if not req.stream_progress_updates:
|
||||
def empty_callback(x_samples, i): return x_samples
|
||||
return empty_callback
|
||||
|
||||
thread_data.partial_x_samples = None
|
||||
last_callback_time = -1
|
||||
def img_callback(x_samples, i):
|
||||
nonlocal last_callback_time
|
||||
|
||||
thread_data.partial_x_samples = x_samples
|
||||
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
|
||||
last_callback_time = time.time()
|
||||
|
||||
progress = {"step": i, "step_time": step_time}
|
||||
if extra_props is not None:
|
||||
progress.update(extra_props)
|
||||
|
||||
if req.stream_image_progress and i % 5 == 0:
|
||||
progress['output'] = update_temp_img(req, x_samples)
|
||||
|
||||
yield json.dumps(progress)
|
||||
|
||||
if thread_data.stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
return img_callback
|
||||
|
||||
def do_mk_img(req: Request):
|
||||
thread_data.stop_processing = False
|
||||
|
||||
res = Response()
|
||||
res.request = req
|
||||
res.images = []
|
||||
|
||||
temp_images.clear()
|
||||
thread_data.temp_images.clear()
|
||||
|
||||
# custom model support:
|
||||
# the req.use_stable_diffusion_model needs to be a valid path
|
||||
# to the ckpt file (without the extension).
|
||||
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
|
||||
|
||||
needs_model_reload = False
|
||||
ckpt_to_use = ckpt_file
|
||||
if ckpt_to_use != req.use_stable_diffusion_model:
|
||||
ckpt_to_use = req.use_stable_diffusion_model
|
||||
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model:
|
||||
thread_data.ckpt_file = req.use_stable_diffusion_model
|
||||
needs_model_reload = True
|
||||
|
||||
model.turbo = req.turbo
|
||||
if req.use_cpu:
|
||||
if device != 'cpu':
|
||||
device = 'cpu'
|
||||
|
||||
if model_is_half:
|
||||
load_model_ckpt(ckpt_to_use, device)
|
||||
needs_model_reload = False
|
||||
|
||||
load_model_gfpgan(gfpgan_file)
|
||||
load_model_real_esrgan(real_esrgan_file)
|
||||
else:
|
||||
if has_valid_gpu:
|
||||
prev_device = device
|
||||
device = 'cuda'
|
||||
|
||||
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
|
||||
(precision == 'full' and not req.use_full_precision and not force_full_precision):
|
||||
|
||||
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'))
|
||||
needs_model_reload = False
|
||||
|
||||
if prev_device != device:
|
||||
load_model_gfpgan(gfpgan_file)
|
||||
load_model_real_esrgan(real_esrgan_file)
|
||||
if thread_data.has_valid_gpu:
|
||||
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
|
||||
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
|
||||
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
||||
needs_model_reload = True
|
||||
|
||||
if needs_model_reload:
|
||||
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, precision)
|
||||
unload_models()
|
||||
unload_filters()
|
||||
load_model_ckpt()
|
||||
|
||||
if req.use_face_correction != gfpgan_file:
|
||||
load_model_gfpgan(req.use_face_correction)
|
||||
if thread_data.turbo != req.turbo:
|
||||
thread_data.turbo = req.turbo
|
||||
thread_data.model.turbo = req.turbo
|
||||
|
||||
if req.use_upscale != real_esrgan_file:
|
||||
load_model_real_esrgan(req.use_upscale)
|
||||
|
||||
model.cdevice = device
|
||||
modelCS.cond_stage_model.device = device
|
||||
# Start by cleaning memory, loading and unloading things can leave memory allocated.
|
||||
gc()
|
||||
|
||||
opt_prompt = req.prompt
|
||||
opt_seed = req.seed
|
||||
@ -316,11 +482,9 @@ def do_mk_img(req: Request):
|
||||
opt_C = 4
|
||||
opt_f = 8
|
||||
opt_ddim_eta = 0.0
|
||||
opt_init_img = req.init_image
|
||||
|
||||
print(req.to_string(), '\n device', device)
|
||||
|
||||
print('\n\n Using precision:', precision)
|
||||
print(req.to_string(), '\n device', thread_data.device)
|
||||
print('\n\n Using precision:', thread_data.precision)
|
||||
|
||||
seed_everything(opt_seed)
|
||||
|
||||
@ -329,7 +493,7 @@ def do_mk_img(req: Request):
|
||||
assert prompt is not None
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
if precision == "autocast" and device != "cpu":
|
||||
if thread_data.precision == "autocast" and thread_data.device != "cpu":
|
||||
precision_scope = autocast
|
||||
else:
|
||||
precision_scope = nullcontext
|
||||
@ -345,25 +509,26 @@ def do_mk_img(req: Request):
|
||||
handler = _img2img
|
||||
|
||||
init_image = load_img(req.init_image, req.width, req.height)
|
||||
init_image = init_image.to(device)
|
||||
init_image = init_image.to(thread_data.device)
|
||||
|
||||
if device != "cpu" and precision == "autocast":
|
||||
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||
init_image = init_image.half()
|
||||
|
||||
modelFS.to(device)
|
||||
thread_data.modelFS.to(thread_data.device)
|
||||
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
||||
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
if req.mask is not None:
|
||||
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(device)
|
||||
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
|
||||
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
|
||||
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
||||
|
||||
if device != "cpu" and precision == "autocast":
|
||||
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||
mask = mask.half()
|
||||
|
||||
move_fs_to_cpu()
|
||||
# Send to CPU and wait until complete.
|
||||
wait_model_move_to(thread_data.modelFS, 'cpu')
|
||||
|
||||
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
|
||||
t_enc = int(req.prompt_strength * req.num_inference_steps)
|
||||
@ -375,16 +540,16 @@ def do_mk_img(req: Request):
|
||||
else:
|
||||
session_out_path = None
|
||||
|
||||
seeds = ""
|
||||
with torch.no_grad():
|
||||
for n in trange(opt_n_iter, desc="Sampling"):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
|
||||
with precision_scope("cuda"):
|
||||
modelCS.to(device)
|
||||
if thread_data.reduced_memory:
|
||||
thread_data.modelCS.to(thread_data.device)
|
||||
uc = None
|
||||
if req.guidance_scale != 1.0:
|
||||
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
||||
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
|
||||
@ -397,85 +562,65 @@ def do_mk_img(req: Request):
|
||||
weight = weights[i]
|
||||
# if not skip_normalize:
|
||||
weight = weight / totalWeight
|
||||
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
else:
|
||||
c = modelCS.get_learned_conditioning(prompts)
|
||||
c = thread_data.modelCS.get_learned_conditioning(prompts)
|
||||
|
||||
modelFS.to(device)
|
||||
if thread_data.reduced_memory:
|
||||
thread_data.modelFS.to(thread_data.device)
|
||||
|
||||
partial_x_samples = None
|
||||
last_callback_time = -1
|
||||
def img_callback(x_samples, i):
|
||||
nonlocal partial_x_samples, last_callback_time
|
||||
|
||||
partial_x_samples = x_samples
|
||||
|
||||
if req.stream_progress_updates:
|
||||
n_steps = req.num_inference_steps if req.init_image is None else t_enc
|
||||
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
|
||||
last_callback_time = time.time()
|
||||
|
||||
progress = {"step": i, "total_steps": n_steps, "step_time": step_time}
|
||||
|
||||
if req.stream_image_progress and i % 5 == 0:
|
||||
partial_images = []
|
||||
|
||||
for i in range(batch_size):
|
||||
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format='JPEG')
|
||||
buf.seek(0)
|
||||
|
||||
del img, x_sample, x_samples_ddim
|
||||
# don't delete x_samples, it is used in the code that called this callback
|
||||
|
||||
temp_images[str(req.session_id) + '/' + str(i)] = buf
|
||||
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
||||
|
||||
progress['output'] = partial_images
|
||||
|
||||
yield json.dumps(progress)
|
||||
|
||||
if stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
n_steps = req.num_inference_steps if req.init_image is None else t_enc
|
||||
img_callback = get_image_progress_generator(req, {"total_steps": n_steps})
|
||||
|
||||
# run the handler
|
||||
try:
|
||||
print('Running handler...')
|
||||
if handler == _txt2img:
|
||||
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
|
||||
else:
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask)
|
||||
|
||||
yield from x_samples
|
||||
|
||||
x_samples = partial_x_samples
|
||||
if req.stream_progress_updates:
|
||||
yield from x_samples
|
||||
if hasattr(thread_data, 'partial_x_samples'):
|
||||
if thread_data.partial_x_samples is not None:
|
||||
x_samples = thread_data.partial_x_samples
|
||||
del thread_data.partial_x_samples
|
||||
except UserInitiatedStop:
|
||||
if partial_x_samples is None:
|
||||
if not hasattr(thread_data, 'partial_x_samples'):
|
||||
continue
|
||||
if thread_data.partial_x_samples is None:
|
||||
del thread_data.partial_x_samples
|
||||
continue
|
||||
x_samples = thread_data.partial_x_samples
|
||||
del thread_data.partial_x_samples
|
||||
|
||||
x_samples = partial_x_samples
|
||||
|
||||
print("saving images")
|
||||
print("decoding images")
|
||||
img_data = [None] * batch_size
|
||||
for i in range(batch_size):
|
||||
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
||||
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
||||
|
||||
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample)
|
||||
img_data[i] = x_sample
|
||||
del x_samples, x_samples_ddim, x_sample
|
||||
|
||||
if thread_data.reduced_memory:
|
||||
# Send to CPU and wait until complete.
|
||||
wait_model_move_to(thread_data.modelFS, 'cpu')
|
||||
|
||||
print("saving images")
|
||||
for i in range(batch_size):
|
||||
img = Image.fromarray(img_data[i])
|
||||
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
||||
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
||||
|
||||
has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
|
||||
(req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
|
||||
|
||||
return_orig_img = not has_filters or not req.show_only_filtered_image
|
||||
|
||||
if stop_processing:
|
||||
if thread_data.stop_processing:
|
||||
return_orig_img = True
|
||||
|
||||
if req.save_to_disk_path is not None:
|
||||
@ -486,25 +631,24 @@ def do_mk_img(req: Request):
|
||||
save_metadata(meta_out_path, req, prompts[0], opt_seed)
|
||||
|
||||
if return_orig_img:
|
||||
img_data = img_to_base64_str(img, req.output_format)
|
||||
res_image_orig = ResponseImage(data=img_data, seed=opt_seed)
|
||||
img_str = img_to_base64_str(img, req.output_format)
|
||||
res_image_orig = ResponseImage(data=img_str, seed=opt_seed)
|
||||
res.images.append(res_image_orig)
|
||||
|
||||
if req.save_to_disk_path is not None:
|
||||
res_image_orig.path_abs = img_out_path
|
||||
|
||||
del img
|
||||
|
||||
if has_filters and not stop_processing:
|
||||
if has_filters and not thread_data.stop_processing:
|
||||
filters_applied = []
|
||||
if req.use_face_correction:
|
||||
x_sample = apply_filters('gfpgan', x_sample)
|
||||
img_data[i] = apply_filters('gfpgan', img_data[i], req.use_face_correction)
|
||||
filters_applied.append(req.use_face_correction)
|
||||
if req.use_upscale:
|
||||
x_sample = apply_filters('real_esrgan', x_sample)
|
||||
img_data[i] = apply_filters('real_esrgan', img_data[i], req.use_upscale)
|
||||
filters_applied.append(req.use_upscale)
|
||||
if (len(filters_applied) > 0):
|
||||
filtered_image = Image.fromarray(x_sample)
|
||||
filtered_image = Image.fromarray(img_data[i])
|
||||
filtered_img_data = img_to_base64_str(filtered_image, req.output_format)
|
||||
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
|
||||
res.images.append(response_image)
|
||||
@ -513,17 +657,17 @@ def do_mk_img(req: Request):
|
||||
save_image(filtered_image, filtered_img_out_path)
|
||||
response_image.path_abs = filtered_img_out_path
|
||||
del filtered_image
|
||||
|
||||
seeds += str(opt_seed) + ","
|
||||
# Filter Applied, move to next seed
|
||||
opt_seed += 1
|
||||
|
||||
move_fs_to_cpu()
|
||||
if thread_data.reduced_memory:
|
||||
unload_filters()
|
||||
del img_data
|
||||
gc()
|
||||
del x_samples, x_samples_ddim, x_sample
|
||||
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
|
||||
if thread_data.device != 'cpu':
|
||||
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mo')
|
||||
|
||||
print('Task completed')
|
||||
|
||||
yield json.dumps(res.json())
|
||||
|
||||
def save_image(img, img_out_path):
|
||||
@ -533,7 +677,7 @@ def save_image(img, img_out_path):
|
||||
print('could not save the file', traceback.format_exc())
|
||||
|
||||
def save_metadata(meta_out_path, req, prompt, opt_seed):
|
||||
metadata = f"""{prompt}
|
||||
metadata = f'''{prompt}
|
||||
Width: {req.width}
|
||||
Height: {req.height}
|
||||
Seed: {opt_seed}
|
||||
@ -544,8 +688,8 @@ Use Face Correction: {req.use_face_correction}
|
||||
Use Upscaling: {req.use_upscale}
|
||||
Sampler: {req.sampler}
|
||||
Negative Prompt: {req.negative_prompt}
|
||||
Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
|
||||
"""
|
||||
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
|
||||
'''
|
||||
try:
|
||||
with open(meta_out_path, 'w', encoding='utf-8') as f:
|
||||
f.write(metadata)
|
||||
@ -555,16 +699,13 @@ Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
|
||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
if device != "cpu":
|
||||
mem = torch.cuda.memory_allocated() / 1e6
|
||||
modelCS.to("cpu")
|
||||
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
# Send to CPU and wait until complete.
|
||||
wait_model_move_to(thread_data.modelCS, 'cpu')
|
||||
|
||||
if sampler_name == 'ddim':
|
||||
model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||
|
||||
samples_ddim = model.sample(
|
||||
samples_ddim = thread_data.model.sample(
|
||||
S=opt_ddim_steps,
|
||||
conditioning=c,
|
||||
seed=opt_seed,
|
||||
@ -578,14 +719,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
mask=mask,
|
||||
sampler = sampler_name,
|
||||
)
|
||||
|
||||
yield from samples_ddim
|
||||
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
|
||||
# encode (scaled latent)
|
||||
z_enc = model.stochastic_encode(
|
||||
z_enc = thread_data.model.stochastic_encode(
|
||||
init_latent,
|
||||
torch.tensor([t_enc] * batch_size).to(device),
|
||||
torch.tensor([t_enc] * batch_size).to(thread_data.device),
|
||||
opt_seed,
|
||||
opt_ddim_eta,
|
||||
opt_ddim_steps,
|
||||
@ -593,7 +733,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
x_T = None if mask is None else init_latent
|
||||
|
||||
# decode it
|
||||
samples_ddim = model.sample(
|
||||
samples_ddim = thread_data.model.sample(
|
||||
t_enc,
|
||||
c,
|
||||
z_enc,
|
||||
@ -604,20 +744,12 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
x_T=x_T,
|
||||
sampler = 'ddim'
|
||||
)
|
||||
|
||||
yield from samples_ddim
|
||||
|
||||
def move_fs_to_cpu():
|
||||
if device != "cpu":
|
||||
mem = torch.cuda.memory_allocated() / 1e6
|
||||
modelFS.to("cpu")
|
||||
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
||||
time.sleep(1)
|
||||
|
||||
def gc():
|
||||
if device == 'cpu':
|
||||
gc_collect()
|
||||
if thread_data.device == 'cpu':
|
||||
return
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
@ -627,7 +759,6 @@ def chunk(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def load_model_from_config(ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
|
@ -1,14 +1,25 @@
|
||||
"""task_manager.py: manage tasks dispatching and render threads.
|
||||
Notes:
|
||||
render_threads should be the only hard reference held by the manager to the threads.
|
||||
Use weak_thread_data to store all other data using weak keys.
|
||||
This will allow for garbage collection after the thread dies.
|
||||
"""
|
||||
import json
|
||||
import traceback
|
||||
|
||||
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
||||
|
||||
import queue, threading, time
|
||||
import queue, threading, time, weakref
|
||||
from typing import Any, Generator, Hashable, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sd_internal import Request, Response
|
||||
|
||||
THREAD_NAME_PREFIX = 'Runtime-Render/'
|
||||
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
||||
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
||||
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
|
||||
|
||||
class SymbolClass(type): # Print nicely formatted Symbol names.
|
||||
def __repr__(self): return self.__qualname__
|
||||
def __str__(self): return self.__name__
|
||||
@ -25,7 +36,7 @@ class RenderTask(): # Task with output queue and completion lock.
|
||||
def __init__(self, req: Request):
|
||||
self.request: Request = req # Initial Request
|
||||
self.response: Any = None # Copy of the last reponse
|
||||
self.temp_images:[] = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
|
||||
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
|
||||
self.error: Exception = None
|
||||
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
|
||||
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
|
||||
@ -66,17 +77,30 @@ class ImageRequest(BaseModel):
|
||||
stream_progress_updates: bool = False
|
||||
stream_image_progress: bool = False
|
||||
|
||||
class FilterRequest(BaseModel):
|
||||
session_id: str = "session"
|
||||
model: str = None
|
||||
name: str = ""
|
||||
init_image: str = None # base64
|
||||
width: int = 512
|
||||
height: int = 512
|
||||
save_to_disk_path: str = None
|
||||
turbo: bool = True
|
||||
use_cpu: bool = False
|
||||
use_full_precision: bool = False
|
||||
output_format: str = "jpeg" # or "png"
|
||||
|
||||
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||
class TaskCache():
|
||||
def __init__(self):
|
||||
self._base = dict()
|
||||
self._lock: threading.Lock = threading.RLock()
|
||||
self._lock: threading.Lock = threading.Lock()
|
||||
def _get_ttl_time(self, ttl: int) -> int:
|
||||
return int(time.time()) + ttl
|
||||
def _is_expired(self, timestamp: int) -> bool:
|
||||
return int(time.time()) >= timestamp
|
||||
def clean(self) -> None:
|
||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clean failed to acquire lock within timeout.')
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clean' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
# Create a list of expired keys to delete
|
||||
to_delete = []
|
||||
@ -91,11 +115,11 @@ class TaskCache():
|
||||
finally:
|
||||
self._lock.release()
|
||||
def clear(self) -> None:
|
||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clear failed to acquire lock within timeout.')
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clear' + ERR_LOCK_FAILED)
|
||||
try: self._base.clear()
|
||||
finally: self._lock.release()
|
||||
def delete(self, key: Hashable) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.delete failed to acquire lock within timeout.')
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.delete' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
if key not in self._base:
|
||||
return False
|
||||
@ -104,7 +128,7 @@ class TaskCache():
|
||||
finally:
|
||||
self._lock.release()
|
||||
def keep(self, key: Hashable, ttl: int) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.keep failed to acquire lock within timeout.')
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.keep' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
if key in self._base:
|
||||
_, value = self._base.get(key)
|
||||
@ -114,7 +138,7 @@ class TaskCache():
|
||||
finally:
|
||||
self._lock.release()
|
||||
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.put failed to acquire lock within timeout.')
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.put' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
self._base[key] = (
|
||||
self._get_ttl_time(ttl), value
|
||||
@ -128,23 +152,26 @@ class TaskCache():
|
||||
finally:
|
||||
self._lock.release()
|
||||
def tryGet(self, key: Hashable) -> Any:
|
||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.tryGet failed to acquire lock within timeout.')
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.tryGet' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
ttl, value = self._base.get(key, (None, None))
|
||||
if ttl is not None and self._is_expired(ttl):
|
||||
print(f'Session {key} expired. Discarding data.')
|
||||
self.delete(key)
|
||||
del self._base[key]
|
||||
return None
|
||||
return value
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
manager_lock = threading.RLock()
|
||||
render_threads = []
|
||||
current_state = ServerStates.Init
|
||||
current_state_error:Exception = None
|
||||
current_model_path = None
|
||||
tasks_queue = queue.Queue()
|
||||
tasks_queue = []
|
||||
task_cache = TaskCache()
|
||||
default_model_to_load = None
|
||||
weak_thread_data = weakref.WeakKeyDictionary()
|
||||
|
||||
def preload_model(file_path=None):
|
||||
global current_state, current_state_error, current_model_path
|
||||
@ -155,7 +182,8 @@ def preload_model(file_path=None):
|
||||
current_state = ServerStates.LoadingModel
|
||||
try:
|
||||
from . import runtime
|
||||
runtime.load_model_ckpt(ckpt_to_use=file_path)
|
||||
runtime.thread_data.ckpt_file = file_path
|
||||
runtime.load_model_ckpt()
|
||||
current_model_path = file_path
|
||||
current_state_error = None
|
||||
current_state = ServerStates.Online
|
||||
@ -165,72 +193,129 @@ def preload_model(file_path=None):
|
||||
current_state = ServerStates.Unavailable
|
||||
print(traceback.format_exc())
|
||||
|
||||
def thread_render():
|
||||
def thread_get_next_task():
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.')
|
||||
return None
|
||||
if len(tasks_queue) <= 0:
|
||||
manager_lock.release()
|
||||
return None
|
||||
from . import runtime
|
||||
task = None
|
||||
try: # Select a render task.
|
||||
for queued_task in tasks_queue:
|
||||
if queued_task.request.use_face_correction: # TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
|
||||
if is_alive(0) <= 0: # Allows GFPGANer only on cuda:0.
|
||||
queued_task.error = Exception('cuda:0 is not available with the current config. Remove GFPGANer filter to run task.')
|
||||
task = queued_task
|
||||
break
|
||||
if queued_task.request.use_cpu:
|
||||
queued_task.error = Exception('Cpu cannot be used to run this task. Remove GFPGANer filter to run task.')
|
||||
task = queued_task
|
||||
break
|
||||
if not runtime.is_first_cuda_device(runtime.thread_data.device):
|
||||
continue # Wait for cuda:0
|
||||
if queued_task.request.use_cpu and runtime.thread_data.device != 'cpu':
|
||||
if is_alive('cpu') > 0:
|
||||
continue # CPU Tasks, Skip GPU device
|
||||
else:
|
||||
queued_task.error = Exception('Cpu is not enabled in render_devices.')
|
||||
task = queued_task
|
||||
break
|
||||
if not queued_task.request.use_cpu and runtime.thread_data.device == 'cpu':
|
||||
if is_alive() > 1: # cpu is alive, so need more than one.
|
||||
continue # GPU Tasks, don't run on CPU unless there is nothing else.
|
||||
else:
|
||||
queued_task.error = Exception('No active gpu found. Please check the error message in the command-line window at startup.')
|
||||
task = queued_task
|
||||
break
|
||||
task = queued_task
|
||||
break
|
||||
if task is not None:
|
||||
del tasks_queue[tasks_queue.index(task)]
|
||||
return task
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
def thread_render(device):
|
||||
global current_state, current_state_error, current_model_path
|
||||
from . import runtime
|
||||
current_state = ServerStates.Online
|
||||
weak_thread_data[threading.current_thread()] = {
|
||||
'device': device
|
||||
}
|
||||
try:
|
||||
runtime.device_init(device)
|
||||
except:
|
||||
print(traceback.format_exc())
|
||||
return
|
||||
weak_thread_data[threading.current_thread()] = {
|
||||
'device': runtime.thread_data.device
|
||||
}
|
||||
preload_model()
|
||||
current_state = ServerStates.Online
|
||||
while True:
|
||||
task_cache.clean()
|
||||
if isinstance(current_state_error, SystemExit):
|
||||
current_state = ServerStates.Unavailable
|
||||
return
|
||||
task = None
|
||||
try:
|
||||
task = tasks_queue.get(timeout=1)
|
||||
except queue.Empty as e:
|
||||
if isinstance(current_state_error, SystemExit):
|
||||
current_state = ServerStates.Unavailable
|
||||
return
|
||||
else: continue
|
||||
#if current_model_path != task.request.use_stable_diffusion_model:
|
||||
# preload_model(task.request.use_stable_diffusion_model)
|
||||
task = thread_get_next_task()
|
||||
if task is None:
|
||||
time.sleep(1)
|
||||
continue
|
||||
if task.error is not None:
|
||||
print(task.error)
|
||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
continue
|
||||
if current_state_error:
|
||||
task.error = current_state_error
|
||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
continue
|
||||
print(f'Session {task.request.session_id} starting task {id(task)}')
|
||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||
try:
|
||||
task.lock.acquire(blocking=False)
|
||||
# Open data generator.
|
||||
res = runtime.mk_img(task.request)
|
||||
if current_model_path == task.request.use_stable_diffusion_model:
|
||||
current_state = ServerStates.Rendering
|
||||
else:
|
||||
current_state = ServerStates.LoadingModel
|
||||
# Start reading from generator.
|
||||
dataQueue = None
|
||||
if task.request.stream_progress_updates:
|
||||
dataQueue = task.buffer_queue
|
||||
for result in res:
|
||||
if current_state == ServerStates.LoadingModel:
|
||||
current_state = ServerStates.Rendering
|
||||
current_model_path = task.request.use_stable_diffusion_model
|
||||
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
||||
runtime.thread_data.stop_processing = True
|
||||
if isinstance(current_state_error, StopAsyncIteration):
|
||||
task.error = current_state_error
|
||||
current_state_error = None
|
||||
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
||||
if dataQueue:
|
||||
dataQueue.put(result)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
task.response = result
|
||||
if 'output' in result:
|
||||
for out_obj in result['output']:
|
||||
if 'path' in out_obj:
|
||||
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
|
||||
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
|
||||
elif 'data' in out_obj:
|
||||
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||
except Exception as e:
|
||||
task.error = e
|
||||
task.lock.release()
|
||||
tasks_queue.task_done()
|
||||
print(traceback.format_exc())
|
||||
continue
|
||||
dataQueue = None
|
||||
if task.request.stream_progress_updates:
|
||||
dataQueue = task.buffer_queue
|
||||
for result in res:
|
||||
if current_state == ServerStates.LoadingModel:
|
||||
current_state = ServerStates.Rendering
|
||||
current_model_path = task.request.use_stable_diffusion_model
|
||||
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
||||
runtime.stop_processing = True
|
||||
if isinstance(current_state_error, StopAsyncIteration):
|
||||
task.error = current_state_error
|
||||
current_state_error = None
|
||||
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
||||
if dataQueue:
|
||||
dataQueue.put(result)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
task.response = result
|
||||
if 'output' in result:
|
||||
for out_obj in result['output']:
|
||||
if 'path' in out_obj:
|
||||
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
|
||||
task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]]
|
||||
elif 'data' in out_obj:
|
||||
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
|
||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||
# Task completed
|
||||
task.lock.release()
|
||||
tasks_queue.task_done()
|
||||
finally:
|
||||
# Task completed
|
||||
task.lock.release()
|
||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||
if isinstance(task.error, StopAsyncIteration):
|
||||
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
|
||||
@ -240,19 +325,62 @@ def thread_render():
|
||||
print(f'Session {task.request.session_id} task {id(task)} completed.')
|
||||
current_state = ServerStates.Online
|
||||
|
||||
render_thread = threading.Thread(target=thread_render)
|
||||
def get_cached_task(session_id:str, update_ttl:bool=False):
|
||||
# By calling keep before tryGet, wont discard if was expired.
|
||||
if update_ttl and not task_cache.keep(session_id, TASK_TTL):
|
||||
# Failed to keep task, already gone.
|
||||
return None
|
||||
return task_cache.tryGet(session_id)
|
||||
|
||||
def start_render_thread():
|
||||
# Start Rendering Thread
|
||||
render_thread.daemon = True
|
||||
render_thread.start()
|
||||
def is_first_cuda_device(device):
|
||||
from . import runtime # When calling runtime from outside thread_render DO NOT USE thread specific attributes or functions.
|
||||
return runtime.is_first_cuda_device(device)
|
||||
|
||||
def is_alive(name=None):
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED)
|
||||
nbr_alive = 0
|
||||
try:
|
||||
for rthread in render_threads:
|
||||
if name is not None:
|
||||
weak_data = weak_thread_data.get(rthread)
|
||||
if weak_data is None or weak_data['device'] is None:
|
||||
print('The thread', rthread.name, 'is registered but has no data store in the task manager.')
|
||||
continue
|
||||
thread_name = str(weak_data['device']).lower()
|
||||
if is_first_cuda_device(name):
|
||||
if not is_first_cuda_device(thread_name):
|
||||
continue
|
||||
elif thread_name != name:
|
||||
continue
|
||||
if rthread.is_alive():
|
||||
nbr_alive += 1
|
||||
return nbr_alive
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
def start_render_thread(device='auto'):
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_threads' + ERR_LOCK_FAILED)
|
||||
print('Start new Rendering Thread on device', device)
|
||||
try:
|
||||
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
|
||||
rthread.daemon = True
|
||||
rthread.name = THREAD_NAME_PREFIX + device
|
||||
rthread.start()
|
||||
timeout = LOCK_TIMEOUT
|
||||
while not rthread.is_alive():
|
||||
if timeout <= 0: raise Exception('render_thread', rthread.name, 'failed to start before timeout or has crashed.')
|
||||
timeout -= 1
|
||||
time.sleep(1)
|
||||
render_threads.append(rthread)
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
global current_state_error
|
||||
current_state_error = SystemExit('Application shutting down.')
|
||||
|
||||
def render(req : ImageRequest):
|
||||
if not render_thread.is_alive(): # Render thread is dead
|
||||
if is_alive() <= 0: # Render thread is dead
|
||||
raise ChildProcessError('Rendering thread has died.')
|
||||
# Alive, check if task in cache
|
||||
task = task_cache.tryGet(req.session_id)
|
||||
@ -294,6 +422,12 @@ def render(req : ImageRequest):
|
||||
|
||||
new_task = RenderTask(r)
|
||||
if task_cache.put(r.session_id, new_task, TASK_TTL):
|
||||
tasks_queue.put(new_task, block=True, timeout=30)
|
||||
return new_task
|
||||
# Use twice the normal timeout for adding user requests.
|
||||
# Tries to force task_cache.put to fail before tasks_queue.put would.
|
||||
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
|
||||
try:
|
||||
tasks_queue.append(new_task)
|
||||
return new_task
|
||||
finally:
|
||||
manager_lock.release()
|
||||
raise RuntimeError('Failed to add task to cache.')
|
||||
|
462
ui/server.py
462
ui/server.py
@ -1,3 +1,7 @@
|
||||
"""server.py: FastAPI SD-UI Web Host.
|
||||
Notes:
|
||||
async endpoints always run on the main thread. Without they run on the thread pool.
|
||||
"""
|
||||
import json
|
||||
import traceback
|
||||
|
||||
@ -16,17 +20,29 @@ UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'))
|
||||
|
||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||
APP_CONFIG_DEFAULTS = {
|
||||
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
|
||||
'render_devices': ['auto'], # ['cuda'] or ['CPU', 'GPU:0', 'GPU:1', ...] or ['cpu']
|
||||
'update_branch': 'main',
|
||||
}
|
||||
APP_CONFIG_DEFAULT_MODELS = [
|
||||
# needed to support the legacy installations
|
||||
'custom-model', # Check if user has a custom model, use it first.
|
||||
'sd-v1-4', # Default fallback.
|
||||
]
|
||||
|
||||
import asyncio
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
import queue, threading, time
|
||||
from typing import Any, Generator, Hashable, Optional, Union
|
||||
#import queue, threading, time
|
||||
from typing import Any, Generator, Hashable, List, Optional, Union
|
||||
|
||||
from sd_internal import Request, Response, task_manager
|
||||
|
||||
LOOP = asyncio.get_event_loop()
|
||||
app = FastAPI()
|
||||
|
||||
modifiers_cache = None
|
||||
@ -42,191 +58,137 @@ NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma
|
||||
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media')), name="media")
|
||||
app.mount('/plugins', StaticFiles(directory=UI_PLUGINS_DIR), name="plugins")
|
||||
|
||||
class SetAppConfigRequest(BaseModel):
|
||||
update_branch: str = "main"
|
||||
|
||||
# needs to support the legacy installations
|
||||
def get_initial_model_to_load():
|
||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
||||
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
|
||||
|
||||
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use)
|
||||
|
||||
config = getConfig()
|
||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
||||
model_name = config['model']['stable-diffusion']
|
||||
model_path = resolve_model_to_use(model_name)
|
||||
|
||||
if os.path.exists(model_path + '.ckpt'):
|
||||
ckpt_to_use = model_path
|
||||
else:
|
||||
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
|
||||
return ckpt_to_use
|
||||
|
||||
def resolve_model_to_use(model_name):
|
||||
if model_name in ('sd-v1-4', 'custom-model'):
|
||||
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
||||
|
||||
legacy_model_path = os.path.join(SD_DIR, model_name)
|
||||
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'):
|
||||
model_path = legacy_model_path
|
||||
else:
|
||||
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
||||
return model_path
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
task_manager.current_state_error = SystemExit('Application shutting down.')
|
||||
|
||||
@app.get('/')
|
||||
def read_root():
|
||||
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
|
||||
|
||||
@app.get('/ping') # Get server and optionally session status.
|
||||
def ping(session_id:str=None):
|
||||
if not task_manager.render_thread.is_alive(): # Render thread is dead.
|
||||
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
raise HTTPException(status_code=500, detail='Render thread is dead.')
|
||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
# Alive
|
||||
response = {'status': str(task_manager.current_state)}
|
||||
if session_id:
|
||||
task = task_manager.task_cache.tryGet(session_id)
|
||||
if task:
|
||||
response['task'] = id(task)
|
||||
if task.lock.locked():
|
||||
response['session'] = 'running'
|
||||
elif isinstance(task.error, StopAsyncIteration):
|
||||
response['session'] = 'stopped'
|
||||
elif task.error:
|
||||
response['session'] = 'error'
|
||||
elif not task.buffer_queue.empty():
|
||||
response['session'] = 'buffer'
|
||||
elif task.response:
|
||||
response['session'] = 'completed'
|
||||
else:
|
||||
response['session'] = 'pending'
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
|
||||
def save_model_to_config(model_name):
|
||||
config = getConfig()
|
||||
if 'model' not in config:
|
||||
config['model'] = {}
|
||||
|
||||
config['model']['stable-diffusion'] = model_name
|
||||
setConfig(config)
|
||||
|
||||
@app.post('/render')
|
||||
def render(req : task_manager.ImageRequest):
|
||||
try:
|
||||
save_model_to_config(req.use_stable_diffusion_model)
|
||||
req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
|
||||
new_task = task_manager.render(req)
|
||||
response = {
|
||||
'status': str(task_manager.current_state),
|
||||
'queue': task_manager.tasks_queue.qsize(),
|
||||
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
|
||||
'task': id(new_task)
|
||||
}
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
except ChildProcessError as e: # Render thread is dead
|
||||
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
|
||||
except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one.
|
||||
raise HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get('/image/stream/{session_id:str}/{task_id:int}')
|
||||
def stream(session_id:str, task_id:int):
|
||||
#TODO Move to WebSockets ??
|
||||
task = task_manager.task_cache.tryGet(session_id)
|
||||
if not task: raise HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
|
||||
if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
||||
if task.buffer_queue.empty() and not task.lock.locked():
|
||||
if task.response:
|
||||
#print(f'Session {session_id} sending cached response')
|
||||
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
|
||||
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
|
||||
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
|
||||
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
|
||||
|
||||
@app.get('/image/stop')
|
||||
def stop(session_id:str=None):
|
||||
if not session_id:
|
||||
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
|
||||
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
|
||||
task_manager.current_state_error = StopAsyncIteration('')
|
||||
return {'OK'}
|
||||
task = task_manager.task_cache.tryGet(session_id)
|
||||
if not task: raise HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
|
||||
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
|
||||
task.error = StopAsyncIteration('')
|
||||
return {'OK'}
|
||||
|
||||
@app.get('/image/tmp/{session_id}/{img_id:int}')
|
||||
def get_image(session_id, img_id):
|
||||
task = task_manager.task_cache.tryGet(session_id)
|
||||
if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
|
||||
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
|
||||
try:
|
||||
img_data = task.temp_images[img_id]
|
||||
if isinstance(img_data, str):
|
||||
return img_data
|
||||
img_data.seek(0)
|
||||
return StreamingResponse(img_data, media_type='image/jpeg')
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post('/app_config')
|
||||
async def setAppConfig(req : SetAppConfigRequest):
|
||||
try:
|
||||
config = {
|
||||
'update_branch': req.update_branch
|
||||
}
|
||||
|
||||
config_json_str = json.dumps(config)
|
||||
config_bat_str = f'@set update_branch={req.update_branch}'
|
||||
config_sh_str = f'export update_branch={req.update_branch}'
|
||||
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
|
||||
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
|
||||
|
||||
with open(config_json_path, 'w', encoding='utf-8') as f:
|
||||
f.write(config_json_str)
|
||||
|
||||
with open(config_bat_path, 'w', encoding='utf-8') as f:
|
||||
f.write(config_bat_str)
|
||||
|
||||
with open(config_sh_path, 'w', encoding='utf-8') as f:
|
||||
f.write(config_sh_str)
|
||||
|
||||
return {'OK'}
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def getConfig(default_val={}):
|
||||
config_cached = None
|
||||
config_last_mod_time = 0
|
||||
def getConfig(default_val=APP_CONFIG_DEFAULTS):
|
||||
global config_cached, config_last_mod_time
|
||||
try:
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
if not os.path.exists(config_json_path):
|
||||
return default_val
|
||||
if config_last_mod_time > 0 and config_cached is not None:
|
||||
# Don't read if file was not modified
|
||||
mtime = os.path.getmtime(config_json_path)
|
||||
if mtime <= config_last_mod_time:
|
||||
return config_cached
|
||||
with open(config_json_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
config_cached = json.load(f)
|
||||
config_last_mod_time = os.path.getmtime(config_json_path)
|
||||
return config_cached
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
print(traceback.format_exc())
|
||||
return default_val
|
||||
|
||||
def setConfig(config):
|
||||
try:
|
||||
try: # config.json
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
with open(config_json_path, 'w', encoding='utf-8') as f:
|
||||
return json.dump(config, f)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
except:
|
||||
print(traceback.format_exc())
|
||||
|
||||
if 'render_devices' in config:
|
||||
gpu_devices = filter(lambda dev: dev.lower().startswith('gpu') or dev.lower().startswith('cuda'), config['render_devices'])
|
||||
else:
|
||||
gpu_devices = []
|
||||
|
||||
has_first_cuda_device = False
|
||||
for device in gpu_devices:
|
||||
if not task_manager.is_first_cuda_device(device): continue
|
||||
has_first_cuda_device = True
|
||||
break
|
||||
if len(gpu_devices) > 0 and not has_first_cuda_device:
|
||||
print('WARNING: GFPGANer only works on GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.')
|
||||
print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer')
|
||||
|
||||
try: # config.bat
|
||||
config_bat = [
|
||||
f"@set update_branch={config['update_branch']}"
|
||||
]
|
||||
if len(gpu_devices) > 0 and not has_first_cuda_device:
|
||||
config_bat.append('::Set the devices visible inside SD-UI here')
|
||||
config_bat.append(f"::@set CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}") # Needs better detection for edge cases, add as a comment for now.
|
||||
print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat')
|
||||
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
|
||||
with open(config_bat_path, 'w', encoding='utf-8') as f:
|
||||
f.write(f.write('\r\n'.join(config_bat)))
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
|
||||
try: # config.sh
|
||||
config_sh = [
|
||||
'#!/bin/bash'
|
||||
f"export update_branch={config['update_branch']}"
|
||||
]
|
||||
if len(gpu_devices) > 0 and not has_first_cuda_device:
|
||||
config_sh.append('#Set the devices visible inside SD-UI here')
|
||||
config_sh.append(f"#CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}") # Needs better detection for edge cases, add as a comment for now.
|
||||
print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh')
|
||||
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
|
||||
with open(config_sh_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(config_sh))
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
|
||||
def resolve_model_to_use(model_name:str=None):
|
||||
if not model_name: # When None try user configured model.
|
||||
config = getConfig()
|
||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
||||
model_name = config['model']['stable-diffusion']
|
||||
if model_name:
|
||||
# Check models directory
|
||||
models_dir_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
||||
if os.path.exists(models_dir_path + '.ckpt'):
|
||||
return models_dir_path
|
||||
if os.path.exists(model_name + '.ckpt'):
|
||||
# Direct Path to file
|
||||
model_name = os.path.abspath(model_name)
|
||||
return model_name
|
||||
# Default locations
|
||||
if model_name in APP_CONFIG_DEFAULT_MODELS:
|
||||
default_model_path = os.path.join(SD_DIR, model_name)
|
||||
if os.path.exists(default_model_path + '.ckpt'):
|
||||
return default_model_path
|
||||
# Can't find requested model, check the default paths.
|
||||
for default_model in APP_CONFIG_DEFAULT_MODELS:
|
||||
default_model_path = os.path.join(SD_DIR, default_model)
|
||||
if os.path.exists(default_model_path + '.ckpt'):
|
||||
if model_name is not None:
|
||||
print(f'Could not find the configured custom model {model_name}.ckpt. Using the default one: {default_model_path}.ckpt')
|
||||
return default_model_path
|
||||
raise Exception('No valid models found.')
|
||||
|
||||
class SetAppConfigRequest(BaseModel):
|
||||
update_branch: str = None
|
||||
render_devices: Union[List[str], List[int], str, int] = None
|
||||
|
||||
@app.post('/app_config')
|
||||
async def setAppConfig(req : SetAppConfigRequest):
|
||||
config = getConfig()
|
||||
if req.update_branch:
|
||||
config['update_branch'] = req.update_branch
|
||||
if req.render_devices and hasattr(req.render_devices, "__len__"): # strings, array of strings or numbers.
|
||||
render_devices = []
|
||||
if isinstance(req.render_devices, str):
|
||||
req.render_devices = req.render_devices.split(',')
|
||||
if isinstance(req.render_devices, list):
|
||||
for gpu in req.render_devices:
|
||||
if isinstance(req.render_devices, int):
|
||||
render_devices.append('GPU:' + gpu)
|
||||
else:
|
||||
render_devices.append(gpu)
|
||||
if isinstance(req.render_devices, int):
|
||||
render_devices.append('GPU:' + req.render_devices)
|
||||
if len(render_devices) > 0:
|
||||
config['render_devices'] = render_devices
|
||||
try:
|
||||
setConfig(config)
|
||||
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def getModels():
|
||||
models = {
|
||||
'active': {
|
||||
@ -282,6 +244,112 @@ def read_web_data(key:str=None):
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
|
||||
|
||||
@app.get('/ping') # Get server and optionally session status.
|
||||
def ping(session_id:str=None):
|
||||
if task_manager.is_alive() <= 0: # Check that render threads are alive.
|
||||
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
raise HTTPException(status_code=500, detail='Render thread is dead.')
|
||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
# Alive
|
||||
response = {'status': str(task_manager.current_state)}
|
||||
if session_id:
|
||||
task = task_manager.get_cached_task(session_id, update_ttl=True)
|
||||
if task:
|
||||
response['task'] = id(task)
|
||||
if task.lock.locked():
|
||||
response['session'] = 'running'
|
||||
elif isinstance(task.error, StopAsyncIteration):
|
||||
response['session'] = 'stopped'
|
||||
elif task.error:
|
||||
response['session'] = 'error'
|
||||
elif not task.buffer_queue.empty():
|
||||
response['session'] = 'buffer'
|
||||
elif task.response:
|
||||
response['session'] = 'completed'
|
||||
else:
|
||||
response['session'] = 'pending'
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
|
||||
def save_model_to_config(model_name):
|
||||
config = getConfig()
|
||||
if 'model' not in config:
|
||||
config['model'] = {}
|
||||
|
||||
config['model']['stable-diffusion'] = model_name
|
||||
setConfig(config)
|
||||
|
||||
@app.post('/render')
|
||||
def render(req : task_manager.ImageRequest):
|
||||
if req.use_cpu and task_manager.is_alive('cpu') <= 0: raise HTTPException(status_code=403, detail=f'CPU rendering is not enabled in config.json or the thread has died...') # HTTP403 Forbidden
|
||||
if req.use_face_correction and task_manager.is_alive(0) <= 0: #TODO Remove when GFPGANer is fixed upstream.
|
||||
raise HTTPException(status_code=412, detail=f'GFPGANer only works GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.') # HTTP412 Precondition Failed
|
||||
try:
|
||||
save_model_to_config(req.use_stable_diffusion_model)
|
||||
req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
|
||||
new_task = task_manager.render(req)
|
||||
response = {
|
||||
'status': str(task_manager.current_state),
|
||||
'queue': len(task_manager.tasks_queue),
|
||||
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
|
||||
'task': id(new_task)
|
||||
}
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
except ChildProcessError as e: # Render thread is dead
|
||||
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
|
||||
except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one.
|
||||
raise HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get('/image/stream/{session_id:str}/{task_id:int}')
|
||||
def stream(session_id:str, task_id:int):
|
||||
#TODO Move to WebSockets ??
|
||||
task = task_manager.get_cached_task(session_id, update_ttl=True)
|
||||
if not task: raise HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
|
||||
if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
||||
if task.buffer_queue.empty() and not task.lock.locked():
|
||||
if task.response:
|
||||
#print(f'Session {session_id} sending cached response')
|
||||
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
|
||||
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
|
||||
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
|
||||
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
|
||||
|
||||
@app.get('/image/stop')
|
||||
def stop(session_id:str=None):
|
||||
if not session_id:
|
||||
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
|
||||
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
|
||||
task_manager.current_state_error = StopAsyncIteration('')
|
||||
return {'OK'}
|
||||
task = task_manager.get_cached_task(session_id, update_ttl=False)
|
||||
if not task: raise HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
|
||||
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
|
||||
task.error = StopAsyncIteration('')
|
||||
return {'OK'}
|
||||
|
||||
@app.get('/image/tmp/{session_id}/{img_id:int}')
|
||||
def get_image(session_id, img_id):
|
||||
task = task_manager.get_cached_task(session_id, update_ttl=True)
|
||||
if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
|
||||
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
|
||||
try:
|
||||
img_data = task.temp_images[img_id]
|
||||
if isinstance(img_data, str):
|
||||
return img_data
|
||||
img_data.seek(0)
|
||||
return StreamingResponse(img_data, media_type='image/jpeg')
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get('/')
|
||||
def read_root():
|
||||
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
task_manager.current_state_error = SystemExit('Application shutting down.')
|
||||
|
||||
# don't log certain requests
|
||||
class LogSuppressFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
@ -292,8 +360,56 @@ class LogSuppressFilter(logging.Filter):
|
||||
return True
|
||||
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
|
||||
|
||||
task_manager.default_model_to_load = get_initial_model_to_load()
|
||||
task_manager.start_render_thread()
|
||||
config = getConfig()
|
||||
|
||||
async def check_status(): # Task to Validate user config shortly after startup.
|
||||
# Check that the loaded config.json yielded a server in a known valid state.
|
||||
# When issues are found, try to fix them when possible and warn the user.
|
||||
device_count = 0
|
||||
# Wait for devices to register and/or change names.
|
||||
THREAD_START_DELAY = 5 # seconds - Give time for devices/threads to start.
|
||||
for i in range(10): # Maximum number of retry.
|
||||
await asyncio.sleep(THREAD_START_DELAY)
|
||||
new_count = task_manager.is_alive()
|
||||
# Stops retry once no more devices show up.
|
||||
if new_count > 0 and device_count == new_count: break
|
||||
device_count = new_count
|
||||
|
||||
if 'render_devices' in config and task_manager.is_alive() <= 0: # No running devices, probably invalid user config. Try to apply defaults.
|
||||
print('WARNING: No active render devices after loading config. Validate "render_devices" in config.json')
|
||||
task_manager.start_render_thread('auto') # Detect best device for renders
|
||||
task_manager.start_render_thread('cpu') # Allow CPU to be used for renders
|
||||
await asyncio.sleep(THREAD_START_DELAY) # delay message after thread start.
|
||||
print('Default render devices loaded to replace missing render_devices', config['render_devices'])
|
||||
|
||||
display_warning = False
|
||||
if not 'render_devices' in config and task_manager.is_alive(0) <= 0: # No config set, is on auto mode and without cuda:0
|
||||
task_manager.start_render_thread('cuda') # An other cuda device is better and cuda:0 is missing, start it...
|
||||
display_warning = True # And warn user to update settings...
|
||||
await asyncio.sleep(THREAD_START_DELAY) # delay message after thread start.
|
||||
|
||||
if display_warning or task_manager.is_alive(0) <= 0:
|
||||
print('WARNING: GFPGANer only works on GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.')
|
||||
print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer')
|
||||
print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat')
|
||||
print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh')
|
||||
|
||||
# Start the task_manager
|
||||
task_manager.default_model_to_load = resolve_model_to_use()
|
||||
if 'render_devices' in config: # Start a new thread for each device.
|
||||
if isinstance(config['render_devices'], str):
|
||||
config['render_devices'] = config['render_devices'].split(',')
|
||||
if not isinstance(config['render_devices'], list):
|
||||
raise Exception('Invalid render_devices value in config.')
|
||||
for device in config['render_devices']:
|
||||
task_manager.start_render_thread(device)
|
||||
else:
|
||||
# Select best GPU device using free memory, if more than one device.
|
||||
task_manager.start_render_thread('auto') # Detect best device for renders
|
||||
task_manager.start_render_thread('cpu') # Allow CPU to be used for renders
|
||||
|
||||
# Task to Validate user config shortly after startup.
|
||||
LOOP.create_task(check_status())
|
||||
|
||||
# start the browser ui
|
||||
import webbrowser; webbrowser.open('http://localhost:9000')
|
Loading…
Reference in New Issue
Block a user