Merge pull request #348 from madrang/multi-gpu

Multi gpu
This commit is contained in:
cmdr2 2022-10-26 19:33:44 +05:30 committed by GitHub
commit ec14429238
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 872 additions and 501 deletions

View File

@ -1,8 +1,15 @@
"""runtime.py: torch device owned by a thread.
Notes:
Avoid device switching, transfering all models will get too complex.
To use a diffrent device signal the current render device to exit
And then start a new clean thread for the new device.
"""
import json import json
import os, re import os, re
import traceback import traceback
import torch import torch
import numpy as np import numpy as np
from gc import collect as gc_collect
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image, ImageOps from PIL import Image, ImageOps
from tqdm import tqdm, trange from tqdm import tqdm, trange
@ -35,63 +42,144 @@ import base64
from io import BytesIO from io import BytesIO
#from colorama import Fore #from colorama import Fore
# local from threading import local as LocalThreadVars
stop_processing = False thread_data = LocalThreadVars()
temp_images = {}
ckpt_file = None def device_would_fail(device):
gfpgan_file = None if device == 'cpu': return None
real_esrgan_file = None # Returns None when no issues found, otherwise returns the detected error str.
# Memory check
try:
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9)
if mem_total < 3.0:
return 'GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion'
except RuntimeError as e:
return str(e) # Return cuda errors from mem_get_info as strings
return None
model = None def device_select(device):
modelCS = None if device == 'cpu': return True
modelFS = None if not torch.cuda.is_available(): return False
model_gfpgan = None failure_msg = device_would_fail(device)
model_real_esrgan = None if failure_msg:
if 'invalid device' in failure_msg:
raise NameError(f'GPU "{device}" could not be found. Remove this device from config.render_devices or use one of "auto" or "cuda".')
print(failure_msg)
return False
model_is_half = False device_name = torch.cuda.get_device_name(device)
model_fs_is_half = False
device = None
unet_bs = 1
precision = 'autocast'
sampler_plms = None
sampler_ddim = None
has_valid_gpu = False # otherwise these NVIDIA cards create green images
force_full_precision = False thread_data.force_full_precision = ('nvidia' in device_name.lower() or 'geforce' in device_name.lower()) and (' 1660' in device_name or ' 1650' in device_name)
try: if thread_data.force_full_precision:
gpu = torch.cuda.current_device() print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', device_name)
gpu_name = torch.cuda.get_device_name(gpu)
print('GPU detected: ', gpu_name)
force_full_precision = ('nvidia' in gpu_name.lower() or 'geforce' in gpu_name.lower()) and (' 1660' in gpu_name or ' 1650' in gpu_name) # otherwise these NVIDIA cards create green images thread_data.device = device
if force_full_precision: thread_data.has_valid_gpu = True
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name) return True
mem_free, mem_total = torch.cuda.mem_get_info(gpu) def device_init(device_selection=None):
mem_total /= float(10**9) # Thread bound properties
if mem_total < 3.0: thread_data.stop_processing = False
print("GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion") thread_data.temp_images = {}
raise Exception()
has_valid_gpu = True thread_data.ckpt_file = None
except: thread_data.gfpgan_file = None
thread_data.real_esrgan_file = None
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
thread_data.model_gfpgan = None
thread_data.model_real_esrgan = None
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
thread_data.device = None
thread_data.unet_bs = 1
thread_data.precision = 'autocast'
thread_data.sampler_plms = None
thread_data.sampler_ddim = None
thread_data.turbo = False
thread_data.has_valid_gpu = False
thread_data.force_full_precision = False
thread_data.reduced_memory = True
if device_selection.lower() == 'cpu':
print('CPU requested, skipping gpu init.')
thread_data.device = 'cpu'
return
if not torch.cuda.is_available():
print('WARNING: torch.cuda is not available. Using the CPU, but this will be very slow!')
return
device_count = torch.cuda.device_count()
if device_count <= 1 and device_selection == 'auto':
device_selection = 'current' # Use 'auto' only when there is more than one compatible device found.
if device_selection == 'auto':
print('Autoselecting GPU. Using most free memory.')
max_mem_free = 0
best_device = None
for device in range(device_count):
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_free /= float(10**9)
mem_total /= float(10**9)
device_name = torch.cuda.get_device_name(device)
print(f'GPU:{device} detected: {device_name} - Memory: {round(mem_total - mem_free, 2)}Go / {round(mem_total, 2)}Go')
if max_mem_free < mem_free:
max_mem_free = mem_free
best_device = device
if best_device and device_select(device):
print(f'Setting GPU:{device} as active')
torch.cuda.device(device)
return
if isinstance(device_selection, str):
device_selection = device_selection.lower()
if device_selection.startswith('gpu:'):
device_selection = int(device_selection[4:])
if device_selection != 'cuda' and device_selection != 'current' and device_selection != 'gpu':
if device_select(device_selection):
if isinstance(device_selection, int):
print(f'Setting GPU:{device_selection} as active')
else:
print(f'Setting {device_selection} as active')
torch.cuda.device(device_selection)
return
# By default use current device.
print('Checking current GPU...')
device = torch.cuda.current_device()
device_name = torch.cuda.get_device_name(device)
print(f'GPU:{device} detected: {device_name}')
if device_select(device):
return
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!') print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
pass thread_data.device = 'cpu'
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast'): def is_first_cuda_device(device):
global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half if device is None: return False
if device == 0 or device == '0': return True
if device == 'cuda' or device == 'cuda:0': return True
if device == 'gpu' or device == 'gpu:0': return True
if device == 'current': return True
if device == torch.device(0): return True
return False
device = device_to_use if has_valid_gpu else 'cpu' def load_model_ckpt():
precision = precision_to_use if not force_full_precision else 'full' if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
unet_bs = unet_bs_to_use if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
unload_model() if not thread_data.precision:
thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast'
if device == 'cpu': if not thread_data.unet_bs:
precision = 'full' thread_data.unet_bs = 1
sd = load_model_from_config(f"{ckpt_to_use}.ckpt") if thread_data.device == 'cpu':
thread_data.precision = 'full'
print('loading', thread_data.ckpt_file, 'to', thread_data.device, 'using precision', thread_data.precision)
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
li, lo = [], [] li, lo = [], []
for key, value in sd.items(): for key, value in sd.items():
sp = key.split(".") sp = key.split(".")
@ -114,88 +202,127 @@ def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_u
model = instantiate_from_config(config.modelUNet) model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False) _, _ = model.load_state_dict(sd, strict=False)
model.eval() model.eval()
model.cdevice = device model.cdevice = torch.device(thread_data.device)
model.unet_bs = unet_bs model.unet_bs = thread_data.unet_bs
model.turbo = turbo model.turbo = thread_data.turbo
if thread_data.device != 'cpu':
model.to(thread_data.device)
#if thread_data.reduced_memory:
#model.model1.to("cpu")
#model.model2.to("cpu")
thread_data.model = model
modelCS = instantiate_from_config(config.modelCondStage) modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False) _, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval() modelCS.eval()
modelCS.cond_stage_model.device = device modelCS.cond_stage_model.device = torch.device(thread_data.device)
if thread_data.device != 'cpu':
if thread_data.reduced_memory:
modelCS.to('cpu')
else:
modelCS.to(thread_data.device) # Preload on device if not already there.
thread_data.modelCS = modelCS
modelFS = instantiate_from_config(config.modelFirstStage) modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False) _, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval() modelFS.eval()
if thread_data.device != 'cpu':
if thread_data.reduced_memory:
modelFS.to('cpu')
else:
modelFS.to(thread_data.device) # Preload on device if not already there.
thread_data.modelFS = modelFS
del sd del sd
if device != "cpu" and precision == "autocast": if thread_data.device != "cpu" and thread_data.precision == "autocast":
model.half() thread_data.model.half()
modelCS.half() thread_data.modelCS.half()
modelFS.half() thread_data.modelFS.half()
model_is_half = True thread_data.model_is_half = True
model_fs_is_half = True thread_data.model_fs_is_half = True
else: else:
model_is_half = False thread_data.model_is_half = False
model_fs_is_half = False thread_data.model_fs_is_half = False
ckpt_file = ckpt_to_use print('loaded', thread_data.ckpt_file, 'as', model.device, '->', modelCS.cond_stage_model.device, '->', thread_data.modelFS.device, 'using precision', thread_data.precision)
print('loaded ', ckpt_file, 'to', device, 'precision', precision) def unload_filters():
if thread_data.model_gfpgan is not None:
del thread_data.model_gfpgan
thread_data.model_gfpgan = None
def unload_model(): if thread_data.model_real_esrgan is not None:
global model, modelCS, modelFS del thread_data.model_real_esrgan
thread_data.model_real_esrgan = None
if model is not None: def unload_models():
del model if thread_data.model is not None:
del modelCS print('Unloading models...')
del modelFS del thread_data.model
del thread_data.modelCS
del thread_data.modelFS
model = None thread_data.model = None
modelCS = None thread_data.modelCS = None
modelFS = None thread_data.modelFS = None
def load_model_gfpgan(gfpgan_to_use): def wait_model_move_to(model, target_device): # Send to target_device and wait until complete.
global gfpgan_file, model_gfpgan if thread_data.device == target_device: return
start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
if start_mem <= 0: return
model_name = model.__class__.__name__
print(f'Device:{thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mo')
start_time = time.time()
model.to(target_device)
time_step = start_time
WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout.
last_mem = start_mem
is_transfering = True
while is_transfering:
time.sleep(0.5) # 500ms
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time.
last_mem = mem
if not is_transfering:
break;
if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity.
print(f'Device:{thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mo, Transfered: {round(start_mem - mem)}Mo')
time_step = time.time()
print(f'Device:{thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mo in {round(time.time() - start_time, 3)} seconds to {target_device}')
if gfpgan_to_use is None: def load_model_gfpgan():
return if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.')
#print('load_model_gfpgan called without setting gfpgan_file')
#return
if not is_first_cuda_device(thread_data.device):
#TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
raise Exception(f'Current device {torch.device(thread_data.device)} is not {torch.device(0)}. Cannot run GFPGANer.')
model_path = thread_data.gfpgan_file + ".pth"
thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
gfpgan_file = gfpgan_to_use def load_model_real_esrgan():
model_path = gfpgan_to_use + ".pth" if thread_data.real_esrgan_file is None: raise ValueError(f'Thread real_esrgan_file is undefined.')
#print('load_model_real_esrgan called without setting real_esrgan_file')
if device == 'cpu': #return
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu')) model_path = thread_data.real_esrgan_file + ".pth"
else:
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cuda'))
print('loaded ', gfpgan_to_use, 'to', device, 'precision', precision)
def load_model_real_esrgan(real_esrgan_to_use):
global real_esrgan_file, model_real_esrgan
if real_esrgan_to_use is None:
return
real_esrgan_file = real_esrgan_to_use
model_path = real_esrgan_to_use + ".pth"
RealESRGAN_models = { RealESRGAN_models = {
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
} }
model_to_use = RealESRGAN_models[real_esrgan_to_use] model_to_use = RealESRGAN_models[thread_data.real_esrgan_file]
if device == 'cpu': if thread_data.device == 'cpu':
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
model_real_esrgan.device = torch.device('cpu') #thread_data.model_real_esrgan.device = torch.device(thread_data.device)
model_real_esrgan.model.to('cpu') thread_data.model_real_esrgan.model.to('cpu')
else: else:
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=model_is_half) thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half)
model_real_esrgan.model.name = real_esrgan_to_use thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file
print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None): def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None):
if disk_path is None: return None if disk_path is None: return None
@ -207,21 +334,37 @@ def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None):
prompt_flattened = filename_regex.sub('_', prompt)[:50] prompt_flattened = filename_regex.sub('_', prompt)[:50]
if suffix is not None: if suffix is not None:
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}") return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}")
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}") return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}")
def apply_filters(filter_name, image_data): def apply_filters(filter_name, image_data, model_path=None):
print(f'Applying filter {filter_name}...') print(f'Applying filter {filter_name}...')
gc() gc() # Free space before loading new data.
if isinstance(image_data, torch.Tensor):
print(image_data)
image_data.to(thread_data.device)
if filter_name == 'gfpgan': if filter_name == 'gfpgan':
_, _, output = model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) if model_path is not None and model_path != thread_data.gfpgan_file:
thread_data.gfpgan_file = model_path
load_model_gfpgan()
elif not thread_data.model_gfpgan:
load_model_gfpgan()
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
image_data = output[:,:,::-1] image_data = output[:,:,::-1]
if filter_name == 'real_esrgan': if filter_name == 'real_esrgan':
output, _ = model_real_esrgan.enhance(image_data[:,:,::-1]) if model_path is not None and model_path != thread_data.real_esrgan_file:
thread_data.real_esrgan_file = model_path
load_model_real_esrgan()
elif not thread_data.model_real_esrgan:
load_model_real_esrgan()
if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1])
image_data = output[:,:,::-1] image_data = output[:,:,::-1]
return image_data return image_data
@ -231,84 +374,98 @@ def mk_img(req: Request):
yield from do_mk_img(req) yield from do_mk_img(req)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
# Model crashed, release all resources in unknown state.
gc() unload_models()
unload_filters()
if device != "cpu": gc() # Release from memory.
modelFS.to("cpu")
modelCS.to("cpu")
model.model1.to("cpu")
model.model2.to("cpu")
gc()
yield json.dumps({ yield json.dumps({
"status": 'failed', "status": 'failed',
"detail": str(e) "detail": str(e)
}) })
def do_mk_img(req: Request): def update_temp_img(req, x_samples):
global ckpt_file partial_images = []
global model, modelCS, modelFS, device for i in range(req.num_outputs):
global model_gfpgan, model_real_esrgan x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
global stop_processing x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
stop_processing = False del img, x_sample, x_sample_ddim
# don't delete x_samples, it is used in the code that called this callback
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
return partial_images
# Build and return the apropriate generator for do_mk_img
def get_image_progress_generator(req, extra_props=None):
if not req.stream_progress_updates:
def empty_callback(x_samples, i): return x_samples
return empty_callback
thread_data.partial_x_samples = None
last_callback_time = -1
def img_callback(x_samples, i):
nonlocal last_callback_time
thread_data.partial_x_samples = x_samples
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
progress = {"step": i, "step_time": step_time}
if extra_props is not None:
progress.update(extra_props)
if req.stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(req, x_samples)
yield json.dumps(progress)
if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return img_callback
def do_mk_img(req: Request):
thread_data.stop_processing = False
res = Response() res = Response()
res.request = req res.request = req
res.images = [] res.images = []
temp_images.clear() thread_data.temp_images.clear()
# custom model support: # custom model support:
# the req.use_stable_diffusion_model needs to be a valid path # the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension). # to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False needs_model_reload = False
ckpt_to_use = ckpt_file if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model:
if ckpt_to_use != req.use_stable_diffusion_model: thread_data.ckpt_file = req.use_stable_diffusion_model
ckpt_to_use = req.use_stable_diffusion_model
needs_model_reload = True needs_model_reload = True
model.turbo = req.turbo if thread_data.has_valid_gpu:
if req.use_cpu: if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
if device != 'cpu': (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
device = 'cpu' thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
if model_is_half:
load_model_ckpt(ckpt_to_use, device)
needs_model_reload = False
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
else:
if has_valid_gpu:
prev_device = device
device = 'cuda'
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
(precision == 'full' and not req.use_full_precision and not force_full_precision):
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'))
needs_model_reload = False
if prev_device != device:
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
if needs_model_reload: if needs_model_reload:
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, precision) unload_models()
unload_filters()
load_model_ckpt()
if req.use_face_correction != gfpgan_file: if thread_data.turbo != req.turbo:
load_model_gfpgan(req.use_face_correction) thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo
if req.use_upscale != real_esrgan_file: # Start by cleaning memory, loading and unloading things can leave memory allocated.
load_model_real_esrgan(req.use_upscale) gc()
model.cdevice = device
modelCS.cond_stage_model.device = device
opt_prompt = req.prompt opt_prompt = req.prompt
opt_seed = req.seed opt_seed = req.seed
@ -316,11 +473,9 @@ def do_mk_img(req: Request):
opt_C = 4 opt_C = 4
opt_f = 8 opt_f = 8
opt_ddim_eta = 0.0 opt_ddim_eta = 0.0
opt_init_img = req.init_image
print(req.to_string(), '\n device', device) print(req.to_string(), '\n device', thread_data.device)
print('\n\n Using precision:', thread_data.precision)
print('\n\n Using precision:', precision)
seed_everything(opt_seed) seed_everything(opt_seed)
@ -329,7 +484,7 @@ def do_mk_img(req: Request):
assert prompt is not None assert prompt is not None
data = [batch_size * [prompt]] data = [batch_size * [prompt]]
if precision == "autocast" and device != "cpu": if thread_data.precision == "autocast" and thread_data.device != "cpu":
precision_scope = autocast precision_scope = autocast
else: else:
precision_scope = nullcontext precision_scope = nullcontext
@ -345,25 +500,26 @@ def do_mk_img(req: Request):
handler = _img2img handler = _img2img
init_image = load_img(req.init_image, req.width, req.height) init_image = load_img(req.init_image, req.width, req.height)
init_image = init_image.to(device) init_image = init_image.to(thread_data.device)
if device != "cpu" and precision == "autocast": if thread_data.device != "cpu" and thread_data.precision == "autocast":
init_image = init_image.half() init_image = init_image.half()
modelFS.to(device) thread_data.modelFS.to(thread_data.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None: if req.mask is not None:
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(device) mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0) mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size) mask = repeat(mask, '1 ... -> b ...', b=batch_size)
if device != "cpu" and precision == "autocast": if thread_data.device != "cpu" and thread_data.precision == "autocast":
mask = mask.half() mask = mask.half()
move_fs_to_cpu() # Send to CPU and wait until complete.
wait_model_move_to(thread_data.modelFS, 'cpu')
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(req.prompt_strength * req.num_inference_steps) t_enc = int(req.prompt_strength * req.num_inference_steps)
@ -375,16 +531,16 @@ def do_mk_img(req: Request):
else: else:
session_out_path = None session_out_path = None
seeds = ""
with torch.no_grad(): with torch.no_grad():
for n in trange(opt_n_iter, desc="Sampling"): for n in trange(opt_n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"): for prompts in tqdm(data, desc="data"):
with precision_scope("cuda"): with precision_scope("cuda"):
modelCS.to(device) if thread_data.reduced_memory:
thread_data.modelCS.to(thread_data.device)
uc = None uc = None
if req.guidance_scale != 1.0: if req.guidance_scale != 1.0:
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
@ -397,85 +553,65 @@ def do_mk_img(req: Request):
weight = weights[i] weight = weights[i]
# if not skip_normalize: # if not skip_normalize:
weight = weight / totalWeight weight = weight / totalWeight
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else: else:
c = modelCS.get_learned_conditioning(prompts) c = thread_data.modelCS.get_learned_conditioning(prompts)
modelFS.to(device) if thread_data.reduced_memory:
thread_data.modelFS.to(thread_data.device)
partial_x_samples = None n_steps = req.num_inference_steps if req.init_image is None else t_enc
last_callback_time = -1 img_callback = get_image_progress_generator(req, {"total_steps": n_steps})
def img_callback(x_samples, i):
nonlocal partial_x_samples, last_callback_time
partial_x_samples = x_samples
if req.stream_progress_updates:
n_steps = req.num_inference_steps if req.init_image is None else t_enc
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
progress = {"step": i, "total_steps": n_steps, "step_time": step_time}
if req.stream_image_progress and i % 5 == 0:
partial_images = []
for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
del img, x_sample, x_samples_ddim
# don't delete x_samples, it is used in the code that called this callback
temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
progress['output'] = partial_images
yield json.dumps(progress)
if stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
# run the handler # run the handler
try: try:
print('Running handler...')
if handler == _txt2img: if handler == _txt2img:
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler) x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
else: else:
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask) x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask)
yield from x_samples if req.stream_progress_updates:
yield from x_samples
x_samples = partial_x_samples if hasattr(thread_data, 'partial_x_samples'):
if thread_data.partial_x_samples is not None:
x_samples = thread_data.partial_x_samples
del thread_data.partial_x_samples
except UserInitiatedStop: except UserInitiatedStop:
if partial_x_samples is None: if not hasattr(thread_data, 'partial_x_samples'):
continue continue
if thread_data.partial_x_samples is None:
del thread_data.partial_x_samples
continue
x_samples = thread_data.partial_x_samples
del thread_data.partial_x_samples
x_samples = partial_x_samples print("decoding images")
img_data = [None] * batch_size
print("saving images")
for i in range(batch_size): for i in range(batch_size):
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample) img_data[i] = x_sample
del x_samples, x_samples_ddim, x_sample
if thread_data.reduced_memory:
# Send to CPU and wait until complete.
wait_model_move_to(thread_data.modelFS, 'cpu')
print("saving images")
for i in range(batch_size):
img = Image.fromarray(img_data[i])
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \ has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
(req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN')) (req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
return_orig_img = not has_filters or not req.show_only_filtered_image return_orig_img = not has_filters or not req.show_only_filtered_image
if stop_processing: if thread_data.stop_processing:
return_orig_img = True return_orig_img = True
if req.save_to_disk_path is not None: if req.save_to_disk_path is not None:
@ -492,19 +628,18 @@ def do_mk_img(req: Request):
if req.save_to_disk_path is not None: if req.save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path res_image_orig.path_abs = img_out_path
del img del img
if has_filters and not stop_processing: if has_filters and not thread_data.stop_processing:
filters_applied = [] filters_applied = []
if req.use_face_correction: if req.use_face_correction:
x_sample = apply_filters('gfpgan', x_sample) img_data[i] = apply_filters('gfpgan', img_data[i], req.use_face_correction)
filters_applied.append(req.use_face_correction) filters_applied.append(req.use_face_correction)
if req.use_upscale: if req.use_upscale:
x_sample = apply_filters('real_esrgan', x_sample) img_data[i] = apply_filters('real_esrgan', img_data[i], req.use_upscale)
filters_applied.append(req.use_upscale) filters_applied.append(req.use_upscale)
if (len(filters_applied) > 0): if (len(filters_applied) > 0):
filtered_image = Image.fromarray(x_sample) filtered_image = Image.fromarray(img_data[i])
filtered_img_data = img_to_base64_str(filtered_image, req.output_format) filtered_img_data = img_to_base64_str(filtered_image, req.output_format)
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed) response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(response_image) res.images.append(response_image)
@ -513,17 +648,17 @@ def do_mk_img(req: Request):
save_image(filtered_image, filtered_img_out_path) save_image(filtered_image, filtered_img_out_path)
response_image.path_abs = filtered_img_out_path response_image.path_abs = filtered_img_out_path
del filtered_image del filtered_image
# Filter Applied, move to next seed
seeds += str(opt_seed) + ","
opt_seed += 1 opt_seed += 1
move_fs_to_cpu() if thread_data.reduced_memory:
unload_filters()
del img_data
gc() gc()
del x_samples, x_samples_ddim, x_sample if thread_data.device != 'cpu':
print("memory_final = ", torch.cuda.memory_allocated() / 1e6) print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mo')
print('Task completed') print('Task completed')
yield json.dumps(res.json()) yield json.dumps(res.json())
def save_image(img, img_out_path): def save_image(img, img_out_path):
@ -533,7 +668,7 @@ def save_image(img, img_out_path):
print('could not save the file', traceback.format_exc()) print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, req, prompt, opt_seed): def save_metadata(meta_out_path, req, prompt, opt_seed):
metadata = f"""{prompt} metadata = f'''{prompt}
Width: {req.width} Width: {req.width}
Height: {req.height} Height: {req.height}
Seed: {opt_seed} Seed: {opt_seed}
@ -544,8 +679,8 @@ Use Face Correction: {req.use_face_correction}
Use Upscaling: {req.use_upscale} Use Upscaling: {req.use_upscale}
Sampler: {req.sampler} Sampler: {req.sampler}
Negative Prompt: {req.negative_prompt} Negative Prompt: {req.negative_prompt}
Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'} Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
""" '''
try: try:
with open(meta_out_path, 'w', encoding='utf-8') as f: with open(meta_out_path, 'w', encoding='utf-8') as f:
f.write(metadata) f.write(metadata)
@ -555,16 +690,13 @@ Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name): def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu": # Send to CPU and wait until complete.
mem = torch.cuda.memory_allocated() / 1e6 wait_model_move_to(thread_data.modelCS, 'cpu')
modelCS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
if sampler_name == 'ddim': if sampler_name == 'ddim':
model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = model.sample( samples_ddim = thread_data.model.sample(
S=opt_ddim_steps, S=opt_ddim_steps,
conditioning=c, conditioning=c,
seed=opt_seed, seed=opt_seed,
@ -578,14 +710,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
mask=mask, mask=mask,
sampler = sampler_name, sampler = sampler_name,
) )
yield from samples_ddim yield from samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask): def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
# encode (scaled latent) # encode (scaled latent)
z_enc = model.stochastic_encode( z_enc = thread_data.model.stochastic_encode(
init_latent, init_latent,
torch.tensor([t_enc] * batch_size).to(device), torch.tensor([t_enc] * batch_size).to(thread_data.device),
opt_seed, opt_seed,
opt_ddim_eta, opt_ddim_eta,
opt_ddim_steps, opt_ddim_steps,
@ -593,7 +724,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
x_T = None if mask is None else init_latent x_T = None if mask is None else init_latent
# decode it # decode it
samples_ddim = model.sample( samples_ddim = thread_data.model.sample(
t_enc, t_enc,
c, c,
z_enc, z_enc,
@ -604,20 +735,12 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
x_T=x_T, x_T=x_T,
sampler = 'ddim' sampler = 'ddim'
) )
yield from samples_ddim yield from samples_ddim
def move_fs_to_cpu():
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
def gc(): def gc():
if device == 'cpu': gc_collect()
if thread_data.device == 'cpu':
return return
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
@ -627,7 +750,6 @@ def chunk(it, size):
it = iter(it) it = iter(it)
return iter(lambda: tuple(islice(it, size)), ()) return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(ckpt, verbose=False): def load_model_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}") print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")

View File

@ -1,14 +1,25 @@
"""task_manager.py: manage tasks dispatching and render threads.
Notes:
render_threads should be the only hard reference held by the manager to the threads.
Use weak_thread_data to store all other data using weak keys.
This will allow for garbage collection after the thread dies.
"""
import json import json
import traceback import traceback
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
import queue, threading, time import queue, threading, time, weakref
from typing import Any, Generator, Hashable, Optional, Union from typing import Any, Generator, Hashable, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
from sd_internal import Request, Response from sd_internal import Request, Response
THREAD_NAME_PREFIX = 'Runtime-Render/'
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
class SymbolClass(type): # Print nicely formatted Symbol names. class SymbolClass(type): # Print nicely formatted Symbol names.
def __repr__(self): return self.__qualname__ def __repr__(self): return self.__qualname__
def __str__(self): return self.__name__ def __str__(self): return self.__name__
@ -25,7 +36,7 @@ class RenderTask(): # Task with output queue and completion lock.
def __init__(self, req: Request): def __init__(self, req: Request):
self.request: Request = req # Initial Request self.request: Request = req # Initial Request
self.response: Any = None # Copy of the last reponse self.response: Any = None # Copy of the last reponse
self.temp_images:[] = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
self.error: Exception = None self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
@ -66,17 +77,30 @@ class ImageRequest(BaseModel):
stream_progress_updates: bool = False stream_progress_updates: bool = False
stream_image_progress: bool = False stream_image_progress: bool = False
class FilterRequest(BaseModel):
session_id: str = "session"
model: str = None
name: str = ""
init_image: str = None # base64
width: int = 512
height: int = 512
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False
use_full_precision: bool = False
output_format: str = "jpeg" # or "png"
# Temporary cache to allow to query tasks results for a short time after they are completed. # Temporary cache to allow to query tasks results for a short time after they are completed.
class TaskCache(): class TaskCache():
def __init__(self): def __init__(self):
self._base = dict() self._base = dict()
self._lock: threading.Lock = threading.RLock() self._lock: threading.Lock = threading.Lock()
def _get_ttl_time(self, ttl: int) -> int: def _get_ttl_time(self, ttl: int) -> int:
return int(time.time()) + ttl return int(time.time()) + ttl
def _is_expired(self, timestamp: int) -> bool: def _is_expired(self, timestamp: int) -> bool:
return int(time.time()) >= timestamp return int(time.time()) >= timestamp
def clean(self) -> None: def clean(self) -> None:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clean failed to acquire lock within timeout.') if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clean' + ERR_LOCK_FAILED)
try: try:
# Create a list of expired keys to delete # Create a list of expired keys to delete
to_delete = [] to_delete = []
@ -91,11 +115,11 @@ class TaskCache():
finally: finally:
self._lock.release() self._lock.release()
def clear(self) -> None: def clear(self) -> None:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clear failed to acquire lock within timeout.') if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clear' + ERR_LOCK_FAILED)
try: self._base.clear() try: self._base.clear()
finally: self._lock.release() finally: self._lock.release()
def delete(self, key: Hashable) -> bool: def delete(self, key: Hashable) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.delete failed to acquire lock within timeout.') if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.delete' + ERR_LOCK_FAILED)
try: try:
if key not in self._base: if key not in self._base:
return False return False
@ -104,7 +128,7 @@ class TaskCache():
finally: finally:
self._lock.release() self._lock.release()
def keep(self, key: Hashable, ttl: int) -> bool: def keep(self, key: Hashable, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.keep failed to acquire lock within timeout.') if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.keep' + ERR_LOCK_FAILED)
try: try:
if key in self._base: if key in self._base:
_, value = self._base.get(key) _, value = self._base.get(key)
@ -114,7 +138,7 @@ class TaskCache():
finally: finally:
self._lock.release() self._lock.release()
def put(self, key: Hashable, value: Any, ttl: int) -> bool: def put(self, key: Hashable, value: Any, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.put failed to acquire lock within timeout.') if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.put' + ERR_LOCK_FAILED)
try: try:
self._base[key] = ( self._base[key] = (
self._get_ttl_time(ttl), value self._get_ttl_time(ttl), value
@ -128,23 +152,26 @@ class TaskCache():
finally: finally:
self._lock.release() self._lock.release()
def tryGet(self, key: Hashable) -> Any: def tryGet(self, key: Hashable) -> Any:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.tryGet failed to acquire lock within timeout.') if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.tryGet' + ERR_LOCK_FAILED)
try: try:
ttl, value = self._base.get(key, (None, None)) ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl): if ttl is not None and self._is_expired(ttl):
print(f'Session {key} expired. Discarding data.') print(f'Session {key} expired. Discarding data.')
self.delete(key) del self._base[key]
return None return None
return value return value
finally: finally:
self._lock.release() self._lock.release()
manager_lock = threading.RLock()
render_threads = []
current_state = ServerStates.Init current_state = ServerStates.Init
current_state_error:Exception = None current_state_error:Exception = None
current_model_path = None current_model_path = None
tasks_queue = queue.Queue() tasks_queue = []
task_cache = TaskCache() task_cache = TaskCache()
default_model_to_load = None default_model_to_load = None
weak_thread_data = weakref.WeakKeyDictionary()
def preload_model(file_path=None): def preload_model(file_path=None):
global current_state, current_state_error, current_model_path global current_state, current_state_error, current_model_path
@ -155,7 +182,8 @@ def preload_model(file_path=None):
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
try: try:
from . import runtime from . import runtime
runtime.load_model_ckpt(ckpt_to_use=file_path) runtime.thread_data.ckpt_file = file_path
runtime.load_model_ckpt()
current_model_path = file_path current_model_path = file_path
current_state_error = None current_state_error = None
current_state = ServerStates.Online current_state = ServerStates.Online
@ -165,72 +193,130 @@ def preload_model(file_path=None):
current_state = ServerStates.Unavailable current_state = ServerStates.Unavailable
print(traceback.format_exc()) print(traceback.format_exc())
def thread_render(): def thread_get_next_task():
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.')
return None
if len(tasks_queue) <= 0:
manager_lock.release()
return None
from . import runtime
task = None
try: # Select a render task.
for queued_task in tasks_queue:
if queued_task.request.use_face_correction: # TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
if is_alive(0) <= 0: # Allows GFPGANer only on cuda:0.
queued_task.error = Exception('cuda:0 is not available with the current config. Remove GFPGANer filter to run task.')
task = queued_task
break
if queued_task.request.use_cpu:
queued_task.error = Exception('Cpu cannot be used to run this task. Remove GFPGANer filter to run task.')
task = queued_task
break
if not runtime.is_first_cuda_device(runtime.thread_data.device):
continue # Wait for cuda:0
if queued_task.request.use_cpu and runtime.thread_data.device != 'cpu':
if is_alive('cpu') > 0:
continue # CPU Tasks, Skip GPU device
else:
queued_task.error = Exception('Cpu is not enabled in render_devices.')
task = queued_task
break
if not queued_task.request.use_cpu and runtime.thread_data.device == 'cpu':
if is_alive() > 1: # cpu is alive, so need more than one.
continue # GPU Tasks, don't run on CPU unless there is nothing else.
else:
queued_task.error = Exception('No active gpu found. Please check the error message in the command-line window at startup.')
task = queued_task
break
task = queued_task
break
if task is not None:
del tasks_queue[tasks_queue.index(task)]
return task
finally:
manager_lock.release()
def thread_render(device):
global current_state, current_state_error, current_model_path global current_state, current_state_error, current_model_path
from . import runtime from . import runtime
weak_thread_data[threading.current_thread()] = {
'device': device
}
try:
runtime.device_init(device)
except:
print(traceback.format_exc())
return
weak_thread_data[threading.current_thread()] = {
'device': runtime.thread_data.device
}
if runtime.thread_data.device != 'cpu':
preload_model()
current_state = ServerStates.Online current_state = ServerStates.Online
preload_model()
while True: while True:
task_cache.clean() task_cache.clean()
if isinstance(current_state_error, SystemExit): if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable current_state = ServerStates.Unavailable
return return
task = None task = thread_get_next_task()
try: if task is None:
task = tasks_queue.get(timeout=1) time.sleep(1)
except queue.Empty as e: continue
if isinstance(current_state_error, SystemExit): if task.error is not None:
current_state = ServerStates.Unavailable print(task.error)
return task.response = {"status": 'failed', "detail": str(task.error)}
else: continue task.buffer_queue.put(json.dumps(task.response))
#if current_model_path != task.request.use_stable_diffusion_model: continue
# preload_model(task.request.use_stable_diffusion_model)
if current_state_error: if current_state_error:
task.error = current_state_error task.error = current_state_error
task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response))
continue continue
print(f'Session {task.request.session_id} starting task {id(task)}') print(f'Session {task.request.session_id} starting task {id(task)}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try: try:
task.lock.acquire(blocking=False) # Open data generator.
res = runtime.mk_img(task.request) res = runtime.mk_img(task.request)
if current_model_path == task.request.use_stable_diffusion_model: if current_model_path == task.request.use_stable_diffusion_model:
current_state = ServerStates.Rendering current_state = ServerStates.Rendering
else: else:
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
# Start reading from generator.
dataQueue = None
if task.request.stream_progress_updates:
dataQueue = task.buffer_queue
for result in res:
if current_state == ServerStates.LoadingModel:
current_state = ServerStates.Rendering
current_model_path = task.request.use_stable_diffusion_model
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.thread_data.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result)
if isinstance(result, str):
result = json.loads(result)
task.response = result
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(task.request.session_id, TASK_TTL)
except Exception as e: except Exception as e:
task.error = e task.error = e
task.lock.release()
tasks_queue.task_done()
print(traceback.format_exc()) print(traceback.format_exc())
continue continue
dataQueue = None finally:
if task.request.stream_progress_updates: # Task completed
dataQueue = task.buffer_queue task.lock.release()
for result in res:
if current_state == ServerStates.LoadingModel:
current_state = ServerStates.Rendering
current_model_path = task.request.use_stable_diffusion_model
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result)
if isinstance(result, str):
result = json.loads(result)
task.response = result
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
task_cache.keep(task.request.session_id, TASK_TTL)
# Task completed
task.lock.release()
tasks_queue.task_done()
task_cache.keep(task.request.session_id, TASK_TTL) task_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration): if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!') print(f'Session {task.request.session_id} task {id(task)} cancelled!')
@ -240,19 +326,62 @@ def thread_render():
print(f'Session {task.request.session_id} task {id(task)} completed.') print(f'Session {task.request.session_id} task {id(task)} completed.')
current_state = ServerStates.Online current_state = ServerStates.Online
render_thread = threading.Thread(target=thread_render) def get_cached_task(session_id:str, update_ttl:bool=False):
# By calling keep before tryGet, wont discard if was expired.
if update_ttl and not task_cache.keep(session_id, TASK_TTL):
# Failed to keep task, already gone.
return None
return task_cache.tryGet(session_id)
def start_render_thread(): def is_first_cuda_device(device):
# Start Rendering Thread from . import runtime # When calling runtime from outside thread_render DO NOT USE thread specific attributes or functions.
render_thread.daemon = True return runtime.is_first_cuda_device(device)
render_thread.start()
def is_alive(name=None):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED)
nbr_alive = 0
try:
for rthread in render_threads:
if name is not None:
weak_data = weak_thread_data.get(rthread)
if weak_data is None or weak_data['device'] is None:
print('The thread', rthread.name, 'is registered but has no data store in the task manager.')
continue
thread_name = str(weak_data['device']).lower()
if is_first_cuda_device(name):
if not is_first_cuda_device(thread_name):
continue
elif thread_name != name:
continue
if rthread.is_alive():
nbr_alive += 1
return nbr_alive
finally:
manager_lock.release()
def start_render_thread(device='auto'):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_threads' + ERR_LOCK_FAILED)
print('Start new Rendering Thread on device', device)
try:
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
rthread.daemon = True
rthread.name = THREAD_NAME_PREFIX + device
rthread.start()
timeout = LOCK_TIMEOUT
while not rthread.is_alive():
if timeout <= 0: raise Exception('render_thread', rthread.name, 'failed to start before timeout or has crashed.')
timeout -= 1
time.sleep(1)
render_threads.append(rthread)
finally:
manager_lock.release()
def shutdown_event(): # Signal render thread to close on shutdown def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error global current_state_error
current_state_error = SystemExit('Application shutting down.') current_state_error = SystemExit('Application shutting down.')
def render(req : ImageRequest): def render(req : ImageRequest):
if not render_thread.is_alive(): # Render thread is dead if is_alive() <= 0: # Render thread is dead
raise ChildProcessError('Rendering thread has died.') raise ChildProcessError('Rendering thread has died.')
# Alive, check if task in cache # Alive, check if task in cache
task = task_cache.tryGet(req.session_id) task = task_cache.tryGet(req.session_id)
@ -294,6 +423,12 @@ def render(req : ImageRequest):
new_task = RenderTask(r) new_task = RenderTask(r)
if task_cache.put(r.session_id, new_task, TASK_TTL): if task_cache.put(r.session_id, new_task, TASK_TTL):
tasks_queue.put(new_task, block=True, timeout=30) # Use twice the normal timeout for adding user requests.
return new_task # Tries to force task_cache.put to fail before tasks_queue.put would.
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
try:
tasks_queue.append(new_task)
return new_task
finally:
manager_lock.release()
raise RuntimeError('Failed to add task to cache.') raise RuntimeError('Failed to add task to cache.')

View File

@ -1,3 +1,7 @@
"""server.py: FastAPI SD-UI Web Host.
Notes:
async endpoints always run on the main thread. Without they run on the thread pool.
"""
import json import json
import traceback import traceback
@ -16,17 +20,29 @@ UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = {
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
'render_devices': ['auto'], # ['cuda'] or ['CPU', 'GPU:0', 'GPU:1', ...] or ['cpu']
'update_branch': 'main',
}
APP_CONFIG_DEFAULT_MODELS = [
# needed to support the legacy installations
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
import asyncio
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
import logging import logging
import queue, threading, time #import queue, threading, time
from typing import Any, Generator, Hashable, Optional, Union from typing import Any, Generator, Hashable, List, Optional, Union
from sd_internal import Request, Response, task_manager from sd_internal import Request, Response, task_manager
LOOP = asyncio.get_event_loop()
app = FastAPI() app = FastAPI()
modifiers_cache = None modifiers_cache = None
@ -42,191 +58,135 @@ NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media')), name="media") app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media')), name="media")
app.mount('/plugins', StaticFiles(directory=UI_PLUGINS_DIR), name="plugins") app.mount('/plugins', StaticFiles(directory=UI_PLUGINS_DIR), name="plugins")
class SetAppConfigRequest(BaseModel): config_cached = None
update_branch: str = "main" config_last_mod_time = 0
def getConfig(default_val=APP_CONFIG_DEFAULTS):
# needs to support the legacy installations global config_cached, config_last_mod_time
def get_initial_model_to_load():
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use)
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion']
model_path = resolve_model_to_use(model_name)
if os.path.exists(model_path + '.ckpt'):
ckpt_to_use = model_path
else:
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
return ckpt_to_use
def resolve_model_to_use(model_name):
if model_name in ('sd-v1-4', 'custom-model'):
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
legacy_model_path = os.path.join(SD_DIR, model_name)
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'):
model_path = legacy_model_path
else:
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
return model_path
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
@app.get('/')
def read_root():
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
@app.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None):
if not task_manager.render_thread.is_alive(): # Render thread is dead.
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
raise HTTPException(status_code=500, detail='Render thread is dead.')
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive
response = {'status': str(task_manager.current_state)}
if session_id:
task = task_manager.task_cache.tryGet(session_id)
if task:
response['task'] = id(task)
if task.lock.locked():
response['session'] = 'running'
elif isinstance(task.error, StopAsyncIteration):
response['session'] = 'stopped'
elif task.error:
response['session'] = 'error'
elif not task.buffer_queue.empty():
response['session'] = 'buffer'
elif task.response:
response['session'] = 'completed'
else:
response['session'] = 'pending'
return JSONResponse(response, headers=NOCACHE_HEADERS)
def save_model_to_config(model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = model_name
setConfig(config)
@app.post('/render')
def render(req : task_manager.ImageRequest):
try:
save_model_to_config(req.use_stable_diffusion_model)
req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),
'queue': task_manager.tasks_queue.qsize(),
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
'task': id(new_task)
}
return JSONResponse(response, headers=NOCACHE_HEADERS)
except ChildProcessError as e: # Render thread is dead
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one.
raise HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get('/image/stream/{session_id:str}/{task_id:int}')
def stream(session_id:str, task_id:int):
#TODO Move to WebSockets ??
task = task_manager.task_cache.tryGet(session_id)
if not task: raise HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#print(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
@app.get('/image/stop')
def stop(session_id:str=None):
if not session_id:
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
task_manager.current_state_error = StopAsyncIteration('')
return {'OK'}
task = task_manager.task_cache.tryGet(session_id)
if not task: raise HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration('')
return {'OK'}
@app.get('/image/tmp/{session_id}/{img_id:int}')
def get_image(session_id, img_id):
task = task_manager.task_cache.tryGet(session_id)
if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
try:
img_data = task.temp_images[img_id]
if isinstance(img_data, str):
return img_data
img_data.seek(0)
return StreamingResponse(img_data, media_type='image/jpeg')
except KeyError as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
try:
config = {
'update_branch': req.update_branch
}
config_json_str = json.dumps(config)
config_bat_str = f'@set update_branch={req.update_branch}'
config_sh_str = f'export update_branch={req.update_branch}'
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
with open(config_json_path, 'w', encoding='utf-8') as f:
f.write(config_json_str)
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write(config_bat_str)
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write(config_sh_str)
return {'OK'}
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def getConfig(default_val={}):
try: try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json') config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path): if not os.path.exists(config_json_path):
return default_val return default_val
if config_last_mod_time > 0 and config_cached is not None:
# Don't read if file was not modified
mtime = os.path.getmtime(config_json_path)
if mtime <= config_last_mod_time:
return config_cached
with open(config_json_path, 'r', encoding='utf-8') as f: with open(config_json_path, 'r', encoding='utf-8') as f:
return json.load(f) config_cached = json.load(f)
config_last_mod_time = os.path.getmtime(config_json_path)
return config_cached
except Exception as e: except Exception as e:
print(str(e)) print(str(e))
print(traceback.format_exc()) print(traceback.format_exc())
return default_val return default_val
def setConfig(config): def setConfig(config):
try: try: # config.json
config_json_path = os.path.join(CONFIG_DIR, 'config.json') config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w', encoding='utf-8') as f: with open(config_json_path, 'w', encoding='utf-8') as f:
return json.dump(config, f) return json.dump(config, f)
except Exception as e: except:
print(str(e))
print(traceback.format_exc()) print(traceback.format_exc())
if 'render_devices' in config:
gpu_devices = filter(lambda dev: dev.lower().startswith('gpu') or dev.lower().startswith('cuda'), config['render_devices'])
else:
gpu_devices = []
has_first_cuda_device = False
for device in gpu_devices:
if not task_manager.is_first_cuda_device(device): continue
has_first_cuda_device = True
break
if len(gpu_devices) > 0 and not has_first_cuda_device:
print('WARNING: GFPGANer only works on GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.')
print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer')
try: # config.bat
config_bat = [
f"@set update_branch={config['update_branch']}"
]
if len(gpu_devices) > 0 and not has_first_cuda_device:
config_bat.append('::Set the devices visible inside SD-UI here')
config_bat.append(f"::@set CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}") # Needs better detection for edge cases, add as a comment for now.
print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat')
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write(f.write('\r\n'.join(config_bat)))
except Exception as e:
print(traceback.format_exc())
try: # config.sh
config_sh = [
'#!/bin/bash'
f"export update_branch={config['update_branch']}"
]
if len(gpu_devices) > 0 and not has_first_cuda_device:
config_sh.append('#Set the devices visible inside SD-UI here')
config_sh.append(f"#CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}") # Needs better detection for edge cases, add as a comment for now.
print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh')
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh))
except Exception as e:
print(traceback.format_exc())
def resolve_model_to_use(model_name:str=None):
if not model_name: # When None try user configured model.
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion']
if model_name:
if os.path.exists(model_name + '.ckpt'):
# Direct Path to file
return model_name
# Check models directory
models_dir_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
if os.path.exists(models_dir_path + '.ckpt'):
return models_dir_path
# Default locations
if model_name in APP_CONFIG_DEFAULT_MODELS:
default_model_path = os.path.join(SD_DIR, model_name)
if os.path.exists(default_model_path + '.ckpt'):
return default_model_path
# Can't find requested model, check the default paths.
for default_model in APP_CONFIG_DEFAULT_MODELS:
default_model_path = os.path.join(SD_DIR, default_model + '.ckpt')
if os.path.exists(default_model_path):
print(f'Could not find the configured custom model {model_name}.ckpt. Using the default one: {default_model_path}.ckpt')
return default_model_path
raise Exception('No valid models found.')
class SetAppConfigRequest(BaseModel):
update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
config = getConfig()
if req.update_branch:
config['update_branch'] = req.update_branch
if req.render_devices and hasattr(req.render_devices, "__len__"): # strings, array of strings or numbers.
render_devices = []
if isinstance(req.render_devices, str):
req.render_devices = req.render_devices.split(',')
if isinstance(req.render_devices, list):
for gpu in req.render_devices:
if isinstance(req.render_devices, int):
render_devices.append('GPU:' + gpu)
else:
render_devices.append(gpu)
if isinstance(req.render_devices, int):
render_devices.append('GPU:' + req.render_devices)
if len(render_devices) > 0:
config['render_devices'] = render_devices
try:
setConfig(config)
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def getModels(): def getModels():
models = { models = {
'active': { 'active': {
@ -282,6 +242,112 @@ def read_web_data(key:str=None):
else: else:
raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
@app.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None):
if task_manager.is_alive() <= 0: # Check that render threads are alive.
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
raise HTTPException(status_code=500, detail='Render thread is dead.')
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive
response = {'status': str(task_manager.current_state)}
if session_id:
task = task_manager.get_cached_task(session_id, update_ttl=True)
if task:
response['task'] = id(task)
if task.lock.locked():
response['session'] = 'running'
elif isinstance(task.error, StopAsyncIteration):
response['session'] = 'stopped'
elif task.error:
response['session'] = 'error'
elif not task.buffer_queue.empty():
response['session'] = 'buffer'
elif task.response:
response['session'] = 'completed'
else:
response['session'] = 'pending'
return JSONResponse(response, headers=NOCACHE_HEADERS)
def save_model_to_config(model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = model_name
setConfig(config)
@app.post('/render')
def render(req : task_manager.ImageRequest):
if req.use_cpu and task_manager.is_alive('cpu') <= 0: raise HTTPException(status_code=403, detail=f'CPU rendering is not enabled in config.json or the thread has died...') # HTTP403 Forbidden
if req.use_face_correction and task_manager.is_alive(0) <= 0: #TODO Remove when GFPGANer is fixed upstream.
raise HTTPException(status_code=412, detail=f'GFPGANer only works GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.') # HTTP412 Precondition Failed
try:
save_model_to_config(req.use_stable_diffusion_model)
req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),
'queue': len(task_manager.tasks_queue),
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
'task': id(new_task)
}
return JSONResponse(response, headers=NOCACHE_HEADERS)
except ChildProcessError as e: # Render thread is dead
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one.
raise HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get('/image/stream/{session_id:str}/{task_id:int}')
def stream(session_id:str, task_id:int):
#TODO Move to WebSockets ??
task = task_manager.get_cached_task(session_id, update_ttl=True)
if not task: raise HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#print(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
@app.get('/image/stop')
def stop(session_id:str=None):
if not session_id:
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
task_manager.current_state_error = StopAsyncIteration('')
return {'OK'}
task = task_manager.get_cached_task(session_id, update_ttl=False)
if not task: raise HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration('')
return {'OK'}
@app.get('/image/tmp/{session_id}/{img_id:int}')
def get_image(session_id, img_id):
task = task_manager.get_cached_task(session_id, update_ttl=True)
if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
try:
img_data = task.temp_images[img_id]
if isinstance(img_data, str):
return img_data
img_data.seek(0)
return StreamingResponse(img_data, media_type='image/jpeg')
except KeyError as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get('/')
def read_root():
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
# don't log certain requests # don't log certain requests
class LogSuppressFilter(logging.Filter): class LogSuppressFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool: def filter(self, record: logging.LogRecord) -> bool:
@ -292,8 +358,56 @@ class LogSuppressFilter(logging.Filter):
return True return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
task_manager.default_model_to_load = get_initial_model_to_load() config = getConfig()
task_manager.start_render_thread()
async def check_status(): # Task to Validate user config shortly after startup.
# Check that the loaded config.json yielded a server in a known valid state.
# When issues are found, try to fix them when possible and warn the user.
device_count = 0
# Wait for devices to register and/or change names.
THREAD_START_DELAY = 5 # seconds - Give time for devices/threads to start.
for i in range(10): # Maximum number of retry.
await asyncio.sleep(THREAD_START_DELAY)
new_count = task_manager.is_alive()
# Stops retry once no more devices show up.
if new_count > 0 and device_count == new_count: break
device_count = new_count
if 'render_devices' in config and task_manager.is_alive() <= 0: # No running devices, probably invalid user config. Try to apply defaults.
print('WARNING: No active render devices after loading config. Validate "render_devices" in config.json')
task_manager.start_render_thread('auto') # Detect best device for renders
task_manager.start_render_thread('cpu') # Allow CPU to be used for renders
await asyncio.sleep(THREAD_START_DELAY) # delay message after thread start.
print('Default render devices loaded to replace missing render_devices', config['render_devices'])
display_warning = False
if not 'render_devices' in config and task_manager.is_alive(0) <= 0: # No config set, is on auto mode and without cuda:0
task_manager.start_render_thread('cuda') # An other cuda device is better and cuda:0 is missing, start it...
display_warning = True # And warn user to update settings...
await asyncio.sleep(THREAD_START_DELAY) # delay message after thread start.
if display_warning or task_manager.is_alive(0) <= 0:
print('WARNING: GFPGANer only works on GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.')
print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer')
print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat')
print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh')
# Start the task_manager
task_manager.default_model_to_load = resolve_model_to_use()
if 'render_devices' in config: # Start a new thread for each device.
if isinstance(config['render_devices'], str):
config['render_devices'] = config['render_devices'].split(',')
if not isinstance(config['render_devices'], list):
raise Exception('Invalid render_devices value in config.')
for device in config['render_devices']:
task_manager.start_render_thread(device)
else:
# Select best GPU device using free memory, if more than one device.
task_manager.start_render_thread('auto') # Detect best device for renders
task_manager.start_render_thread('cpu') # Allow CPU to be used for renders
# Task to Validate user config shortly after startup.
LOOP.create_task(check_status())
# start the browser ui # start the browser ui
import webbrowser; webbrowser.open('http://localhost:9000') import webbrowser; webbrowser.open('http://localhost:9000')