Merge pull request #404 from cmdr2/multi-gpu

Support for multiple GPUs, and improvements to RAM and VRAM usage
This commit is contained in:
cmdr2 2022-10-27 23:52:11 +05:30 committed by GitHub
commit 5d8bda1178
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 883 additions and 502 deletions

View File

@ -18,7 +18,7 @@
<div id="container">
<div id="top-nav">
<div id="logo">
<h1>Stable Diffusion UI <small>v2.3.5 <span id="updateBranchLabel"></span></small></h1>
<h1>Stable Diffusion UI <small>v2.3.6 <span id="updateBranchLabel"></span></small></h1>
</div>
<ul id="top-nav-items">
<li class="dropdown">

View File

@ -1,8 +1,15 @@
"""runtime.py: torch device owned by a thread.
Notes:
Avoid device switching, transfering all models will get too complex.
To use a diffrent device signal the current render device to exit
And then start a new clean thread for the new device.
"""
import json
import os, re
import traceback
import torch
import numpy as np
from gc import collect as gc_collect
from omegaconf import OmegaConf
from PIL import Image, ImageOps
from tqdm import tqdm, trange
@ -35,63 +42,145 @@ import base64
from io import BytesIO
#from colorama import Fore
# local
stop_processing = False
temp_images = {}
from threading import local as LocalThreadVars
thread_data = LocalThreadVars()
ckpt_file = None
gfpgan_file = None
real_esrgan_file = None
def device_would_fail(device):
if device == 'cpu': return None
# Returns None when no issues found, otherwise returns the detected error str.
# Memory check
try:
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9)
if mem_total < 3.0:
return 'GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion'
except RuntimeError as e:
return str(e) # Return cuda errors from mem_get_info as strings
return None
model = None
modelCS = None
modelFS = None
model_gfpgan = None
model_real_esrgan = None
def device_select(device):
if device == 'cpu': return True
if not torch.cuda.is_available(): return False
failure_msg = device_would_fail(device)
if failure_msg:
if 'invalid device' in failure_msg:
raise NameError(f'GPU "{device}" could not be found. Remove this device from config.render_devices or use one of "auto" or "cuda".')
print(failure_msg)
return False
model_is_half = False
model_fs_is_half = False
device = None
unet_bs = 1
precision = 'autocast'
sampler_plms = None
sampler_ddim = None
device_name = torch.cuda.get_device_name(device)
has_valid_gpu = False
force_full_precision = False
try:
gpu = torch.cuda.current_device()
gpu_name = torch.cuda.get_device_name(gpu)
print('GPU detected: ', gpu_name)
# otherwise these NVIDIA cards create green images
thread_data.force_full_precision = ('nvidia' in device_name.lower() or 'geforce' in device_name.lower()) and (' 1660' in device_name or ' 1650' in device_name)
if thread_data.force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', device_name)
force_full_precision = ('nvidia' in gpu_name.lower() or 'geforce' in gpu_name.lower()) and (' 1660' in gpu_name or ' 1650' in gpu_name) # otherwise these NVIDIA cards create green images
if force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
thread_data.device = device
thread_data.has_valid_gpu = True
return True
mem_free, mem_total = torch.cuda.mem_get_info(gpu)
mem_total /= float(10**9)
if mem_total < 3.0:
print("GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion")
raise Exception()
def device_init(device_selection=None):
# Thread bound properties
thread_data.stop_processing = False
thread_data.temp_images = {}
has_valid_gpu = True
except:
thread_data.ckpt_file = None
thread_data.gfpgan_file = None
thread_data.real_esrgan_file = None
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
thread_data.model_gfpgan = None
thread_data.model_real_esrgan = None
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
thread_data.device = None
thread_data.unet_bs = 1
thread_data.precision = 'autocast'
thread_data.sampler_plms = None
thread_data.sampler_ddim = None
thread_data.turbo = False
thread_data.has_valid_gpu = False
thread_data.force_full_precision = False
thread_data.reduced_memory = True
if device_selection.lower() == 'cpu':
print('CPU requested, skipping gpu init.')
thread_data.device = 'cpu'
return
if not torch.cuda.is_available():
print('WARNING: torch.cuda is not available. Using the CPU, but this will be very slow!')
thread_data.device = 'cpu'
return
device_count = torch.cuda.device_count()
if device_count <= 1 and device_selection == 'auto':
device_selection = 'current' # Use 'auto' only when there is more than one compatible device found.
if device_selection == 'auto':
print('Autoselecting GPU. Using most free memory.')
max_mem_free = 0
best_device = None
for device in range(device_count):
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_free /= float(10**9)
mem_total /= float(10**9)
device_name = torch.cuda.get_device_name(device)
print(f'GPU:{device} detected: {device_name} - Memory: {round(mem_total - mem_free, 2)}Go / {round(mem_total, 2)}Go')
if max_mem_free < mem_free:
max_mem_free = mem_free
best_device = device
if best_device and device_select(device):
print(f'Setting GPU:{device} as active')
torch.cuda.device(device)
return
if isinstance(device_selection, str):
device_selection = device_selection.lower()
if device_selection.startswith('gpu:'):
device_selection = int(device_selection[4:])
if device_selection != 'cuda' and device_selection != 'current' and device_selection != 'gpu':
if device_select(device_selection):
if isinstance(device_selection, int):
print(f'Setting GPU:{device_selection} as active')
else:
print(f'Setting {device_selection} as active')
torch.cuda.device(device_selection)
return
# By default use current device.
print('Checking current GPU...')
device = torch.cuda.current_device()
device_name = torch.cuda.get_device_name(device)
print(f'GPU:{device} detected: {device_name}')
if device_select(device):
return
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
pass
thread_data.device = 'cpu'
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast'):
global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half
def is_first_cuda_device(device):
if device is None: return False
if device == 0 or device == '0': return True
if device == 'cuda' or device == 'cuda:0': return True
if device == 'gpu' or device == 'gpu:0': return True
if device == 'current': return True
if device == torch.device(0): return True
return False
device = device_to_use if has_valid_gpu else 'cpu'
precision = precision_to_use if not force_full_precision else 'full'
unet_bs = unet_bs_to_use
def load_model_ckpt():
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
unload_model()
if not thread_data.precision:
thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast'
if device == 'cpu':
precision = 'full'
if not thread_data.unet_bs:
thread_data.unet_bs = 1
sd = load_model_from_config(f"{ckpt_to_use}.ckpt")
if thread_data.device == 'cpu':
thread_data.precision = 'full'
print('loading', thread_data.ckpt_file, 'to', thread_data.device, 'using precision', thread_data.precision)
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
li, lo = [], []
for key, value in sd.items():
sp = key.split(".")
@ -114,88 +203,127 @@ def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_u
model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()
model.cdevice = device
model.unet_bs = unet_bs
model.turbo = turbo
model.cdevice = torch.device(thread_data.device)
model.unet_bs = thread_data.unet_bs
model.turbo = thread_data.turbo
if thread_data.device != 'cpu':
model.to(thread_data.device)
#if thread_data.reduced_memory:
#model.model1.to("cpu")
#model.model2.to("cpu")
thread_data.model = model
modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()
modelCS.cond_stage_model.device = device
modelCS.cond_stage_model.device = torch.device(thread_data.device)
if thread_data.device != 'cpu':
if thread_data.reduced_memory:
modelCS.to('cpu')
else:
modelCS.to(thread_data.device) # Preload on device if not already there.
thread_data.modelCS = modelCS
modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
if thread_data.device != 'cpu':
if thread_data.reduced_memory:
modelFS.to('cpu')
else:
modelFS.to(thread_data.device) # Preload on device if not already there.
thread_data.modelFS = modelFS
del sd
if device != "cpu" and precision == "autocast":
model.half()
modelCS.half()
modelFS.half()
model_is_half = True
model_fs_is_half = True
if thread_data.device != "cpu" and thread_data.precision == "autocast":
thread_data.model.half()
thread_data.modelCS.half()
thread_data.modelFS.half()
thread_data.model_is_half = True
thread_data.model_fs_is_half = True
else:
model_is_half = False
model_fs_is_half = False
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
ckpt_file = ckpt_to_use
print('loaded', thread_data.ckpt_file, 'as', model.device, '->', modelCS.cond_stage_model.device, '->', thread_data.modelFS.device, 'using precision', thread_data.precision)
print('loaded ', ckpt_file, 'to', device, 'precision', precision)
def unload_filters():
if thread_data.model_gfpgan is not None:
del thread_data.model_gfpgan
thread_data.model_gfpgan = None
def unload_model():
global model, modelCS, modelFS
if thread_data.model_real_esrgan is not None:
del thread_data.model_real_esrgan
thread_data.model_real_esrgan = None
if model is not None:
del model
del modelCS
del modelFS
def unload_models():
if thread_data.model is not None:
print('Unloading models...')
del thread_data.model
del thread_data.modelCS
del thread_data.modelFS
model = None
modelCS = None
modelFS = None
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
def load_model_gfpgan(gfpgan_to_use):
global gfpgan_file, model_gfpgan
def wait_model_move_to(model, target_device): # Send to target_device and wait until complete.
if thread_data.device == target_device: return
start_mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
if start_mem <= 0: return
model_name = model.__class__.__name__
print(f'Device:{thread_data.device} - Sending model {model_name} to {target_device} | Memory transfer starting. Memory Used: {round(start_mem)}Mo')
start_time = time.time()
model.to(target_device)
time_step = start_time
WARNING_TIMEOUT = 1.5 # seconds - Show activity in console after timeout.
last_mem = start_mem
is_transfering = True
while is_transfering:
time.sleep(0.5) # 500ms
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
is_transfering = bool(mem > 0 and mem < last_mem) # still stuff loaded, but less than last time.
last_mem = mem
if not is_transfering:
break;
if time.time() - time_step > WARNING_TIMEOUT: # Long delay, print to console to show activity.
print(f'Device:{thread_data.device} - Waiting for Memory transfer. Memory Used: {round(mem)}Mo, Transfered: {round(start_mem - mem)}Mo')
time_step = time.time()
print(f'Device:{thread_data.device} - {model_name} Moved: {round(start_mem - last_mem)}Mo in {round(time.time() - start_time, 3)} seconds to {target_device}')
if gfpgan_to_use is None:
return
def load_model_gfpgan():
if thread_data.gfpgan_file is None: raise ValueError(f'Thread gfpgan_file is undefined.')
#print('load_model_gfpgan called without setting gfpgan_file')
#return
if not is_first_cuda_device(thread_data.device):
#TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
raise Exception(f'Current device {torch.device(thread_data.device)} is not {torch.device(0)}. Cannot run GFPGANer.')
model_path = thread_data.gfpgan_file + ".pth"
thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
gfpgan_file = gfpgan_to_use
model_path = gfpgan_to_use + ".pth"
if device == 'cpu':
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
else:
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cuda'))
print('loaded ', gfpgan_to_use, 'to', device, 'precision', precision)
def load_model_real_esrgan(real_esrgan_to_use):
global real_esrgan_file, model_real_esrgan
if real_esrgan_to_use is None:
return
real_esrgan_file = real_esrgan_to_use
model_path = real_esrgan_to_use + ".pth"
def load_model_real_esrgan():
if thread_data.real_esrgan_file is None: raise ValueError(f'Thread real_esrgan_file is undefined.')
#print('load_model_real_esrgan called without setting real_esrgan_file')
#return
model_path = thread_data.real_esrgan_file + ".pth"
RealESRGAN_models = {
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
}
model_to_use = RealESRGAN_models[real_esrgan_to_use]
model_to_use = RealESRGAN_models[thread_data.real_esrgan_file]
if device == 'cpu':
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
model_real_esrgan.device = torch.device('cpu')
model_real_esrgan.model.to('cpu')
if thread_data.device == 'cpu':
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
#thread_data.model_real_esrgan.device = torch.device(thread_data.device)
thread_data.model_real_esrgan.model.to('cpu')
else:
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=model_is_half)
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half)
model_real_esrgan.model.name = real_esrgan_to_use
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file
print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None):
if disk_path is None: return None
@ -206,22 +334,38 @@ def get_base_path(disk_path, session_id, prompt, img_id, ext, suffix=None):
os.makedirs(session_out_path, exist_ok=True)
prompt_flattened = filename_regex.sub('_', prompt)[:50]
if suffix is not None:
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}_{suffix}.{ext}")
return os.path.join(session_out_path, f"{prompt_flattened}_{img_id}.{ext}")
def apply_filters(filter_name, image_data):
def apply_filters(filter_name, image_data, model_path=None):
print(f'Applying filter {filter_name}...')
gc()
gc() # Free space before loading new data.
if isinstance(image_data, torch.Tensor):
print(image_data)
image_data.to(thread_data.device)
if filter_name == 'gfpgan':
_, _, output = model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
if model_path is not None and model_path != thread_data.gfpgan_file:
thread_data.gfpgan_file = model_path
load_model_gfpgan()
elif not thread_data.model_gfpgan:
load_model_gfpgan()
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
image_data = output[:,:,::-1]
if filter_name == 'real_esrgan':
output, _ = model_real_esrgan.enhance(image_data[:,:,::-1])
if model_path is not None and model_path != thread_data.real_esrgan_file:
thread_data.real_esrgan_file = model_path
load_model_real_esrgan()
elif not thread_data.model_real_esrgan:
load_model_real_esrgan()
if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1])
image_data = output[:,:,::-1]
return image_data
@ -232,83 +376,105 @@ def mk_img(req: Request):
except Exception as e:
print(traceback.format_exc())
gc()
if device != "cpu":
modelFS.to("cpu")
modelCS.to("cpu")
model.model1.to("cpu")
model.model2.to("cpu")
gc()
if thread_data.reduced_memory:
thread_data.modelFS.to('cpu')
thread_data.modelCS.to('cpu')
thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu")
else:
# Model crashed, release all resources in unknown state.
unload_models()
unload_filters()
gc() # Release from memory.
yield json.dumps({
"status": 'failed',
"detail": str(e)
})
def do_mk_img(req: Request):
global ckpt_file
global model, modelCS, modelFS, device
global model_gfpgan, model_real_esrgan
global stop_processing
def update_temp_img(req, x_samples):
partial_images = []
for i in range(req.num_outputs):
x_sample_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_sample_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
stop_processing = False
del img, x_sample, x_sample_ddim
# don't delete x_samples, it is used in the code that called this callback
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
return partial_images
# Build and return the apropriate generator for do_mk_img
def get_image_progress_generator(req, extra_props=None):
if not req.stream_progress_updates:
def empty_callback(x_samples, i): return x_samples
return empty_callback
thread_data.partial_x_samples = None
last_callback_time = -1
def img_callback(x_samples, i):
nonlocal last_callback_time
thread_data.partial_x_samples = x_samples
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
progress = {"step": i, "step_time": step_time}
if extra_props is not None:
progress.update(extra_props)
if req.stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(req, x_samples)
yield json.dumps(progress)
if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return img_callback
def do_mk_img(req: Request):
thread_data.stop_processing = False
res = Response()
res.request = req
res.images = []
temp_images.clear()
thread_data.temp_images.clear()
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
ckpt_to_use = ckpt_file
if ckpt_to_use != req.use_stable_diffusion_model:
ckpt_to_use = req.use_stable_diffusion_model
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
needs_model_reload = True
model.turbo = req.turbo
if req.use_cpu:
if device != 'cpu':
device = 'cpu'
if model_is_half:
load_model_ckpt(ckpt_to_use, device)
needs_model_reload = False
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
else:
if has_valid_gpu:
prev_device = device
device = 'cuda'
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
(precision == 'full' and not req.use_full_precision and not force_full_precision):
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'))
needs_model_reload = False
if prev_device != device:
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
if thread_data.has_valid_gpu:
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
if needs_model_reload:
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, precision)
unload_models()
unload_filters()
load_model_ckpt()
if req.use_face_correction != gfpgan_file:
load_model_gfpgan(req.use_face_correction)
if thread_data.turbo != req.turbo:
thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo
if req.use_upscale != real_esrgan_file:
load_model_real_esrgan(req.use_upscale)
model.cdevice = device
modelCS.cond_stage_model.device = device
# Start by cleaning memory, loading and unloading things can leave memory allocated.
gc()
opt_prompt = req.prompt
opt_seed = req.seed
@ -316,11 +482,9 @@ def do_mk_img(req: Request):
opt_C = 4
opt_f = 8
opt_ddim_eta = 0.0
opt_init_img = req.init_image
print(req.to_string(), '\n device', device)
print('\n\n Using precision:', precision)
print(req.to_string(), '\n device', thread_data.device)
print('\n\n Using precision:', thread_data.precision)
seed_everything(opt_seed)
@ -329,7 +493,7 @@ def do_mk_img(req: Request):
assert prompt is not None
data = [batch_size * [prompt]]
if precision == "autocast" and device != "cpu":
if thread_data.precision == "autocast" and thread_data.device != "cpu":
precision_scope = autocast
else:
precision_scope = nullcontext
@ -345,25 +509,26 @@ def do_mk_img(req: Request):
handler = _img2img
init_image = load_img(req.init_image, req.width, req.height)
init_image = init_image.to(device)
init_image = init_image.to(thread_data.device)
if device != "cpu" and precision == "autocast":
if thread_data.device != "cpu" and thread_data.precision == "autocast":
init_image = init_image.half()
modelFS.to(device)
thread_data.modelFS.to(thread_data.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None:
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(device)
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
if device != "cpu" and precision == "autocast":
if thread_data.device != "cpu" and thread_data.precision == "autocast":
mask = mask.half()
move_fs_to_cpu()
# Send to CPU and wait until complete.
wait_model_move_to(thread_data.modelFS, 'cpu')
assert 0. <= req.prompt_strength <= 1., 'can only work with strength in [0.0, 1.0]'
t_enc = int(req.prompt_strength * req.num_inference_steps)
@ -375,16 +540,16 @@ def do_mk_img(req: Request):
else:
session_out_path = None
seeds = ""
with torch.no_grad():
for n in trange(opt_n_iter, desc="Sampling"):
for prompts in tqdm(data, desc="data"):
with precision_scope("cuda"):
modelCS.to(device)
if thread_data.reduced_memory:
thread_data.modelCS.to(thread_data.device)
uc = None
if req.guidance_scale != 1.0:
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
@ -397,85 +562,65 @@ def do_mk_img(req: Request):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = modelCS.get_learned_conditioning(prompts)
c = thread_data.modelCS.get_learned_conditioning(prompts)
modelFS.to(device)
if thread_data.reduced_memory:
thread_data.modelFS.to(thread_data.device)
partial_x_samples = None
last_callback_time = -1
def img_callback(x_samples, i):
nonlocal partial_x_samples, last_callback_time
partial_x_samples = x_samples
if req.stream_progress_updates:
n_steps = req.num_inference_steps if req.init_image is None else t_enc
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
progress = {"step": i, "total_steps": n_steps, "step_time": step_time}
if req.stream_image_progress and i % 5 == 0:
partial_images = []
for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
del img, x_sample, x_samples_ddim
# don't delete x_samples, it is used in the code that called this callback
temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
progress['output'] = partial_images
yield json.dumps(progress)
if stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
n_steps = req.num_inference_steps if req.init_image is None else t_enc
img_callback = get_image_progress_generator(req, {"total_steps": n_steps})
# run the handler
try:
print('Running handler...')
if handler == _txt2img:
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask)
yield from x_samples
x_samples = partial_x_samples
if req.stream_progress_updates:
yield from x_samples
if hasattr(thread_data, 'partial_x_samples'):
if thread_data.partial_x_samples is not None:
x_samples = thread_data.partial_x_samples
del thread_data.partial_x_samples
except UserInitiatedStop:
if partial_x_samples is None:
if not hasattr(thread_data, 'partial_x_samples'):
continue
if thread_data.partial_x_samples is None:
del thread_data.partial_x_samples
continue
x_samples = thread_data.partial_x_samples
del thread_data.partial_x_samples
x_samples = partial_x_samples
print("saving images")
print("decoding images")
img_data = [None] * batch_size
for i in range(batch_size):
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
img_data[i] = x_sample
del x_samples, x_samples_ddim, x_sample
if thread_data.reduced_memory:
# Send to CPU and wait until complete.
wait_model_move_to(thread_data.modelFS, 'cpu')
print("saving images")
for i in range(batch_size):
img = Image.fromarray(img_data[i])
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
has_filters = (req.use_face_correction is not None and req.use_face_correction.startswith('GFPGAN')) or \
(req.use_upscale is not None and req.use_upscale.startswith('RealESRGAN'))
return_orig_img = not has_filters or not req.show_only_filtered_image
if stop_processing:
if thread_data.stop_processing:
return_orig_img = True
if req.save_to_disk_path is not None:
@ -486,25 +631,24 @@ def do_mk_img(req: Request):
save_metadata(meta_out_path, req, prompts[0], opt_seed)
if return_orig_img:
img_data = img_to_base64_str(img, req.output_format)
res_image_orig = ResponseImage(data=img_data, seed=opt_seed)
img_str = img_to_base64_str(img, req.output_format)
res_image_orig = ResponseImage(data=img_str, seed=opt_seed)
res.images.append(res_image_orig)
if req.save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path
del img
if has_filters and not stop_processing:
if has_filters and not thread_data.stop_processing:
filters_applied = []
if req.use_face_correction:
x_sample = apply_filters('gfpgan', x_sample)
img_data[i] = apply_filters('gfpgan', img_data[i], req.use_face_correction)
filters_applied.append(req.use_face_correction)
if req.use_upscale:
x_sample = apply_filters('real_esrgan', x_sample)
img_data[i] = apply_filters('real_esrgan', img_data[i], req.use_upscale)
filters_applied.append(req.use_upscale)
if (len(filters_applied) > 0):
filtered_image = Image.fromarray(x_sample)
filtered_image = Image.fromarray(img_data[i])
filtered_img_data = img_to_base64_str(filtered_image, req.output_format)
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(response_image)
@ -513,17 +657,17 @@ def do_mk_img(req: Request):
save_image(filtered_image, filtered_img_out_path)
response_image.path_abs = filtered_img_out_path
del filtered_image
seeds += str(opt_seed) + ","
# Filter Applied, move to next seed
opt_seed += 1
move_fs_to_cpu()
if thread_data.reduced_memory:
unload_filters()
del img_data
gc()
del x_samples, x_samples_ddim, x_sample
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
if thread_data.device != 'cpu':
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mo')
print('Task completed')
yield json.dumps(res.json())
def save_image(img, img_out_path):
@ -533,7 +677,7 @@ def save_image(img, img_out_path):
print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, req, prompt, opt_seed):
metadata = f"""{prompt}
metadata = f'''{prompt}
Width: {req.width}
Height: {req.height}
Seed: {opt_seed}
@ -544,8 +688,8 @@ Use Face Correction: {req.use_face_correction}
Use Upscaling: {req.use_upscale}
Sampler: {req.sampler}
Negative Prompt: {req.negative_prompt}
Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
"""
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
'''
try:
with open(meta_out_path, 'w', encoding='utf-8') as f:
f.write(metadata)
@ -555,16 +699,13 @@ Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelCS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
# Send to CPU and wait until complete.
wait_model_move_to(thread_data.modelCS, 'cpu')
if sampler_name == 'ddim':
model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = model.sample(
samples_ddim = thread_data.model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
@ -578,14 +719,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
mask=mask,
sampler = sampler_name,
)
yield from samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
# encode (scaled latent)
z_enc = model.stochastic_encode(
z_enc = thread_data.model.stochastic_encode(
init_latent,
torch.tensor([t_enc] * batch_size).to(device),
torch.tensor([t_enc] * batch_size).to(thread_data.device),
opt_seed,
opt_ddim_eta,
opt_ddim_steps,
@ -593,7 +733,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
x_T = None if mask is None else init_latent
# decode it
samples_ddim = model.sample(
samples_ddim = thread_data.model.sample(
t_enc,
c,
z_enc,
@ -604,20 +744,12 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
x_T=x_T,
sampler = 'ddim'
)
yield from samples_ddim
def move_fs_to_cpu():
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
time.sleep(1)
def gc():
if device == 'cpu':
gc_collect()
if thread_data.device == 'cpu':
return
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
@ -627,7 +759,6 @@ def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")

View File

@ -1,14 +1,25 @@
"""task_manager.py: manage tasks dispatching and render threads.
Notes:
render_threads should be the only hard reference held by the manager to the threads.
Use weak_thread_data to store all other data using weak keys.
This will allow for garbage collection after the thread dies.
"""
import json
import traceback
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
import queue, threading, time
import queue, threading, time, weakref
from typing import Any, Generator, Hashable, Optional, Union
from pydantic import BaseModel
from sd_internal import Request, Response
THREAD_NAME_PREFIX = 'Runtime-Render/'
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
class SymbolClass(type): # Print nicely formatted Symbol names.
def __repr__(self): return self.__qualname__
def __str__(self): return self.__name__
@ -25,7 +36,7 @@ class RenderTask(): # Task with output queue and completion lock.
def __init__(self, req: Request):
self.request: Request = req # Initial Request
self.response: Any = None # Copy of the last reponse
self.temp_images:[] = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
@ -66,17 +77,30 @@ class ImageRequest(BaseModel):
stream_progress_updates: bool = False
stream_image_progress: bool = False
class FilterRequest(BaseModel):
session_id: str = "session"
model: str = None
name: str = ""
init_image: str = None # base64
width: int = 512
height: int = 512
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False
use_full_precision: bool = False
output_format: str = "jpeg" # or "png"
# Temporary cache to allow to query tasks results for a short time after they are completed.
class TaskCache():
def __init__(self):
self._base = dict()
self._lock: threading.Lock = threading.RLock()
self._lock: threading.Lock = threading.Lock()
def _get_ttl_time(self, ttl: int) -> int:
return int(time.time()) + ttl
def _is_expired(self, timestamp: int) -> bool:
return int(time.time()) >= timestamp
def clean(self) -> None:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clean failed to acquire lock within timeout.')
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clean' + ERR_LOCK_FAILED)
try:
# Create a list of expired keys to delete
to_delete = []
@ -91,11 +115,11 @@ class TaskCache():
finally:
self._lock.release()
def clear(self) -> None:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clear failed to acquire lock within timeout.')
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clear' + ERR_LOCK_FAILED)
try: self._base.clear()
finally: self._lock.release()
def delete(self, key: Hashable) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.delete failed to acquire lock within timeout.')
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.delete' + ERR_LOCK_FAILED)
try:
if key not in self._base:
return False
@ -104,7 +128,7 @@ class TaskCache():
finally:
self._lock.release()
def keep(self, key: Hashable, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.keep failed to acquire lock within timeout.')
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.keep' + ERR_LOCK_FAILED)
try:
if key in self._base:
_, value = self._base.get(key)
@ -114,7 +138,7 @@ class TaskCache():
finally:
self._lock.release()
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.put failed to acquire lock within timeout.')
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.put' + ERR_LOCK_FAILED)
try:
self._base[key] = (
self._get_ttl_time(ttl), value
@ -128,23 +152,26 @@ class TaskCache():
finally:
self._lock.release()
def tryGet(self, key: Hashable) -> Any:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.tryGet failed to acquire lock within timeout.')
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.tryGet' + ERR_LOCK_FAILED)
try:
ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl):
print(f'Session {key} expired. Discarding data.')
self.delete(key)
del self._base[key]
return None
return value
finally:
self._lock.release()
manager_lock = threading.RLock()
render_threads = []
current_state = ServerStates.Init
current_state_error:Exception = None
current_model_path = None
tasks_queue = queue.Queue()
tasks_queue = []
task_cache = TaskCache()
default_model_to_load = None
weak_thread_data = weakref.WeakKeyDictionary()
def preload_model(file_path=None):
global current_state, current_state_error, current_model_path
@ -155,7 +182,8 @@ def preload_model(file_path=None):
current_state = ServerStates.LoadingModel
try:
from . import runtime
runtime.load_model_ckpt(ckpt_to_use=file_path)
runtime.thread_data.ckpt_file = file_path
runtime.load_model_ckpt()
current_model_path = file_path
current_state_error = None
current_state = ServerStates.Online
@ -165,72 +193,129 @@ def preload_model(file_path=None):
current_state = ServerStates.Unavailable
print(traceback.format_exc())
def thread_render():
def thread_get_next_task():
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.')
return None
if len(tasks_queue) <= 0:
manager_lock.release()
return None
from . import runtime
task = None
try: # Select a render task.
for queued_task in tasks_queue:
if queued_task.request.use_face_correction: # TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
if is_alive(0) <= 0: # Allows GFPGANer only on cuda:0.
queued_task.error = Exception('cuda:0 is not available with the current config. Remove GFPGANer filter to run task.')
task = queued_task
break
if queued_task.request.use_cpu:
queued_task.error = Exception('Cpu cannot be used to run this task. Remove GFPGANer filter to run task.')
task = queued_task
break
if not runtime.is_first_cuda_device(runtime.thread_data.device):
continue # Wait for cuda:0
if queued_task.request.use_cpu and runtime.thread_data.device != 'cpu':
if is_alive('cpu') > 0:
continue # CPU Tasks, Skip GPU device
else:
queued_task.error = Exception('Cpu is not enabled in render_devices.')
task = queued_task
break
if not queued_task.request.use_cpu and runtime.thread_data.device == 'cpu':
if is_alive() > 1: # cpu is alive, so need more than one.
continue # GPU Tasks, don't run on CPU unless there is nothing else.
else:
queued_task.error = Exception('No active gpu found. Please check the error message in the command-line window at startup.')
task = queued_task
break
task = queued_task
break
if task is not None:
del tasks_queue[tasks_queue.index(task)]
return task
finally:
manager_lock.release()
def thread_render(device):
global current_state, current_state_error, current_model_path
from . import runtime
current_state = ServerStates.Online
weak_thread_data[threading.current_thread()] = {
'device': device
}
try:
runtime.device_init(device)
except:
print(traceback.format_exc())
return
weak_thread_data[threading.current_thread()] = {
'device': runtime.thread_data.device
}
preload_model()
current_state = ServerStates.Online
while True:
task_cache.clean()
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
return
task = None
try:
task = tasks_queue.get(timeout=1)
except queue.Empty as e:
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
return
else: continue
#if current_model_path != task.request.use_stable_diffusion_model:
# preload_model(task.request.use_stable_diffusion_model)
task = thread_get_next_task()
if task is None:
time.sleep(1)
continue
if task.error is not None:
print(task.error)
task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response))
continue
if current_state_error:
task.error = current_state_error
task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response))
continue
print(f'Session {task.request.session_id} starting task {id(task)}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try:
task.lock.acquire(blocking=False)
# Open data generator.
res = runtime.mk_img(task.request)
if current_model_path == task.request.use_stable_diffusion_model:
current_state = ServerStates.Rendering
else:
current_state = ServerStates.LoadingModel
# Start reading from generator.
dataQueue = None
if task.request.stream_progress_updates:
dataQueue = task.buffer_queue
for result in res:
if current_state == ServerStates.LoadingModel:
current_state = ServerStates.Rendering
current_model_path = task.request.use_stable_diffusion_model
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.thread_data.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result)
if isinstance(result, str):
result = json.loads(result)
task.response = result
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(task.request.session_id, TASK_TTL)
except Exception as e:
task.error = e
task.lock.release()
tasks_queue.task_done()
print(traceback.format_exc())
continue
dataQueue = None
if task.request.stream_progress_updates:
dataQueue = task.buffer_queue
for result in res:
if current_state == ServerStates.LoadingModel:
current_state = ServerStates.Rendering
current_model_path = task.request.use_stable_diffusion_model
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result)
if isinstance(result, str):
result = json.loads(result)
task.response = result
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
task_cache.keep(task.request.session_id, TASK_TTL)
# Task completed
task.lock.release()
tasks_queue.task_done()
finally:
# Task completed
task.lock.release()
task_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
@ -240,19 +325,62 @@ def thread_render():
print(f'Session {task.request.session_id} task {id(task)} completed.')
current_state = ServerStates.Online
render_thread = threading.Thread(target=thread_render)
def get_cached_task(session_id:str, update_ttl:bool=False):
# By calling keep before tryGet, wont discard if was expired.
if update_ttl and not task_cache.keep(session_id, TASK_TTL):
# Failed to keep task, already gone.
return None
return task_cache.tryGet(session_id)
def start_render_thread():
# Start Rendering Thread
render_thread.daemon = True
render_thread.start()
def is_first_cuda_device(device):
from . import runtime # When calling runtime from outside thread_render DO NOT USE thread specific attributes or functions.
return runtime.is_first_cuda_device(device)
def is_alive(name=None):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED)
nbr_alive = 0
try:
for rthread in render_threads:
if name is not None:
weak_data = weak_thread_data.get(rthread)
if weak_data is None or weak_data['device'] is None:
print('The thread', rthread.name, 'is registered but has no data store in the task manager.')
continue
thread_name = str(weak_data['device']).lower()
if is_first_cuda_device(name):
if not is_first_cuda_device(thread_name):
continue
elif thread_name != name:
continue
if rthread.is_alive():
nbr_alive += 1
return nbr_alive
finally:
manager_lock.release()
def start_render_thread(device='auto'):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_threads' + ERR_LOCK_FAILED)
print('Start new Rendering Thread on device', device)
try:
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
rthread.daemon = True
rthread.name = THREAD_NAME_PREFIX + device
rthread.start()
timeout = LOCK_TIMEOUT
while not rthread.is_alive():
if timeout <= 0: raise Exception('render_thread', rthread.name, 'failed to start before timeout or has crashed.')
timeout -= 1
time.sleep(1)
render_threads.append(rthread)
finally:
manager_lock.release()
def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error
current_state_error = SystemExit('Application shutting down.')
def render(req : ImageRequest):
if not render_thread.is_alive(): # Render thread is dead
if is_alive() <= 0: # Render thread is dead
raise ChildProcessError('Rendering thread has died.')
# Alive, check if task in cache
task = task_cache.tryGet(req.session_id)
@ -294,6 +422,12 @@ def render(req : ImageRequest):
new_task = RenderTask(r)
if task_cache.put(r.session_id, new_task, TASK_TTL):
tasks_queue.put(new_task, block=True, timeout=30)
return new_task
# Use twice the normal timeout for adding user requests.
# Tries to force task_cache.put to fail before tasks_queue.put would.
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
try:
tasks_queue.append(new_task)
return new_task
finally:
manager_lock.release()
raise RuntimeError('Failed to add task to cache.')

View File

@ -1,3 +1,7 @@
"""server.py: FastAPI SD-UI Web Host.
Notes:
async endpoints always run on the main thread. Without they run on the thread pool.
"""
import json
import traceback
@ -16,17 +20,29 @@ UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = {
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
'render_devices': ['auto'], # ['cuda'] or ['CPU', 'GPU:0', 'GPU:1', ...] or ['cpu']
'update_branch': 'main',
}
APP_CONFIG_DEFAULT_MODELS = [
# needed to support the legacy installations
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
import asyncio
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
import logging
import queue, threading, time
from typing import Any, Generator, Hashable, Optional, Union
#import queue, threading, time
from typing import Any, Generator, Hashable, List, Optional, Union
from sd_internal import Request, Response, task_manager
LOOP = asyncio.get_event_loop()
app = FastAPI()
modifiers_cache = None
@ -42,191 +58,137 @@ NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media')), name="media")
app.mount('/plugins', StaticFiles(directory=UI_PLUGINS_DIR), name="plugins")
class SetAppConfigRequest(BaseModel):
update_branch: str = "main"
# needs to support the legacy installations
def get_initial_model_to_load():
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use)
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion']
model_path = resolve_model_to_use(model_name)
if os.path.exists(model_path + '.ckpt'):
ckpt_to_use = model_path
else:
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
return ckpt_to_use
def resolve_model_to_use(model_name):
if model_name in ('sd-v1-4', 'custom-model'):
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
legacy_model_path = os.path.join(SD_DIR, model_name)
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'):
model_path = legacy_model_path
else:
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
return model_path
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
@app.get('/')
def read_root():
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
@app.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None):
if not task_manager.render_thread.is_alive(): # Render thread is dead.
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
raise HTTPException(status_code=500, detail='Render thread is dead.')
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive
response = {'status': str(task_manager.current_state)}
if session_id:
task = task_manager.task_cache.tryGet(session_id)
if task:
response['task'] = id(task)
if task.lock.locked():
response['session'] = 'running'
elif isinstance(task.error, StopAsyncIteration):
response['session'] = 'stopped'
elif task.error:
response['session'] = 'error'
elif not task.buffer_queue.empty():
response['session'] = 'buffer'
elif task.response:
response['session'] = 'completed'
else:
response['session'] = 'pending'
return JSONResponse(response, headers=NOCACHE_HEADERS)
def save_model_to_config(model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = model_name
setConfig(config)
@app.post('/render')
def render(req : task_manager.ImageRequest):
try:
save_model_to_config(req.use_stable_diffusion_model)
req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),
'queue': task_manager.tasks_queue.qsize(),
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
'task': id(new_task)
}
return JSONResponse(response, headers=NOCACHE_HEADERS)
except ChildProcessError as e: # Render thread is dead
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one.
raise HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get('/image/stream/{session_id:str}/{task_id:int}')
def stream(session_id:str, task_id:int):
#TODO Move to WebSockets ??
task = task_manager.task_cache.tryGet(session_id)
if not task: raise HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#print(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
@app.get('/image/stop')
def stop(session_id:str=None):
if not session_id:
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
task_manager.current_state_error = StopAsyncIteration('')
return {'OK'}
task = task_manager.task_cache.tryGet(session_id)
if not task: raise HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration('')
return {'OK'}
@app.get('/image/tmp/{session_id}/{img_id:int}')
def get_image(session_id, img_id):
task = task_manager.task_cache.tryGet(session_id)
if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
try:
img_data = task.temp_images[img_id]
if isinstance(img_data, str):
return img_data
img_data.seek(0)
return StreamingResponse(img_data, media_type='image/jpeg')
except KeyError as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
try:
config = {
'update_branch': req.update_branch
}
config_json_str = json.dumps(config)
config_bat_str = f'@set update_branch={req.update_branch}'
config_sh_str = f'export update_branch={req.update_branch}'
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
with open(config_json_path, 'w', encoding='utf-8') as f:
f.write(config_json_str)
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write(config_bat_str)
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write(config_sh_str)
return {'OK'}
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def getConfig(default_val={}):
config_cached = None
config_last_mod_time = 0
def getConfig(default_val=APP_CONFIG_DEFAULTS):
global config_cached, config_last_mod_time
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return default_val
if config_last_mod_time > 0 and config_cached is not None:
# Don't read if file was not modified
mtime = os.path.getmtime(config_json_path)
if mtime <= config_last_mod_time:
return config_cached
with open(config_json_path, 'r', encoding='utf-8') as f:
return json.load(f)
config_cached = json.load(f)
config_last_mod_time = os.path.getmtime(config_json_path)
return config_cached
except Exception as e:
print(str(e))
print(traceback.format_exc())
return default_val
def setConfig(config):
try:
try: # config.json
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w', encoding='utf-8') as f:
return json.dump(config, f)
except Exception as e:
print(str(e))
except:
print(traceback.format_exc())
if 'render_devices' in config:
gpu_devices = filter(lambda dev: dev.lower().startswith('gpu') or dev.lower().startswith('cuda'), config['render_devices'])
else:
gpu_devices = []
has_first_cuda_device = False
for device in gpu_devices:
if not task_manager.is_first_cuda_device(device): continue
has_first_cuda_device = True
break
if len(gpu_devices) > 0 and not has_first_cuda_device:
print('WARNING: GFPGANer only works on GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.')
print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer')
try: # config.bat
config_bat = [
f"@set update_branch={config['update_branch']}"
]
if len(gpu_devices) > 0 and not has_first_cuda_device:
config_bat.append('::Set the devices visible inside SD-UI here')
config_bat.append(f"::@set CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}") # Needs better detection for edge cases, add as a comment for now.
print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat')
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write(f.write('\r\n'.join(config_bat)))
except Exception as e:
print(traceback.format_exc())
try: # config.sh
config_sh = [
'#!/bin/bash'
f"export update_branch={config['update_branch']}"
]
if len(gpu_devices) > 0 and not has_first_cuda_device:
config_sh.append('#Set the devices visible inside SD-UI here')
config_sh.append(f"#CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}") # Needs better detection for edge cases, add as a comment for now.
print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh')
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh))
except Exception as e:
print(traceback.format_exc())
def resolve_model_to_use(model_name:str=None):
if not model_name: # When None try user configured model.
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion']
if model_name:
# Check models directory
models_dir_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
if os.path.exists(models_dir_path + '.ckpt'):
return models_dir_path
if os.path.exists(model_name + '.ckpt'):
# Direct Path to file
model_name = os.path.abspath(model_name)
return model_name
# Default locations
if model_name in APP_CONFIG_DEFAULT_MODELS:
default_model_path = os.path.join(SD_DIR, model_name)
if os.path.exists(default_model_path + '.ckpt'):
return default_model_path
# Can't find requested model, check the default paths.
for default_model in APP_CONFIG_DEFAULT_MODELS:
default_model_path = os.path.join(SD_DIR, default_model)
if os.path.exists(default_model_path + '.ckpt'):
if model_name is not None:
print(f'Could not find the configured custom model {model_name}.ckpt. Using the default one: {default_model_path}.ckpt')
return default_model_path
raise Exception('No valid models found.')
class SetAppConfigRequest(BaseModel):
update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
config = getConfig()
if req.update_branch:
config['update_branch'] = req.update_branch
if req.render_devices and hasattr(req.render_devices, "__len__"): # strings, array of strings or numbers.
render_devices = []
if isinstance(req.render_devices, str):
req.render_devices = req.render_devices.split(',')
if isinstance(req.render_devices, list):
for gpu in req.render_devices:
if isinstance(req.render_devices, int):
render_devices.append('GPU:' + gpu)
else:
render_devices.append(gpu)
if isinstance(req.render_devices, int):
render_devices.append('GPU:' + req.render_devices)
if len(render_devices) > 0:
config['render_devices'] = render_devices
try:
setConfig(config)
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
except Exception as e:
print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e))
def getModels():
models = {
'active': {
@ -282,6 +244,112 @@ def read_web_data(key:str=None):
else:
raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
@app.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None):
if task_manager.is_alive() <= 0: # Check that render threads are alive.
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
raise HTTPException(status_code=500, detail='Render thread is dead.')
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive
response = {'status': str(task_manager.current_state)}
if session_id:
task = task_manager.get_cached_task(session_id, update_ttl=True)
if task:
response['task'] = id(task)
if task.lock.locked():
response['session'] = 'running'
elif isinstance(task.error, StopAsyncIteration):
response['session'] = 'stopped'
elif task.error:
response['session'] = 'error'
elif not task.buffer_queue.empty():
response['session'] = 'buffer'
elif task.response:
response['session'] = 'completed'
else:
response['session'] = 'pending'
return JSONResponse(response, headers=NOCACHE_HEADERS)
def save_model_to_config(model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = model_name
setConfig(config)
@app.post('/render')
def render(req : task_manager.ImageRequest):
if req.use_cpu and task_manager.is_alive('cpu') <= 0: raise HTTPException(status_code=403, detail=f'CPU rendering is not enabled in config.json or the thread has died...') # HTTP403 Forbidden
if req.use_face_correction and task_manager.is_alive(0) <= 0: #TODO Remove when GFPGANer is fixed upstream.
raise HTTPException(status_code=412, detail=f'GFPGANer only works GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.') # HTTP412 Precondition Failed
try:
save_model_to_config(req.use_stable_diffusion_model)
req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),
'queue': len(task_manager.tasks_queue),
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
'task': id(new_task)
}
return JSONResponse(response, headers=NOCACHE_HEADERS)
except ChildProcessError as e: # Render thread is dead
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending, deny queueing more than one.
raise HTTPException(status_code=503, detail=f'Session {req.session_id} has an already pending task.') # HTTP503 Service Unavailable
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get('/image/stream/{session_id:str}/{task_id:int}')
def stream(session_id:str, task_id:int):
#TODO Move to WebSockets ??
task = task_manager.get_cached_task(session_id, update_ttl=True)
if not task: raise HTTPException(status_code=410, detail='No request received.') # HTTP410 Gone
if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked():
if task.response:
#print(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
@app.get('/image/stop')
def stop(session_id:str=None):
if not session_id:
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
task_manager.current_state_error = StopAsyncIteration('')
return {'OK'}
task = task_manager.get_cached_task(session_id, update_ttl=False)
if not task: raise HTTPException(status_code=404, detail=f'Session {session_id} has no active task.') # HTTP404 Not Found
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Session {session_id} task is already stopped.') # HTTP409 Conflict
task.error = StopAsyncIteration('')
return {'OK'}
@app.get('/image/tmp/{session_id}/{img_id:int}')
def get_image(session_id, img_id):
task = task_manager.get_cached_task(session_id, update_ttl=True)
if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
try:
img_data = task.temp_images[img_id]
if isinstance(img_data, str):
return img_data
img_data.seek(0)
return StreamingResponse(img_data, media_type='image/jpeg')
except KeyError as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get('/')
def read_root():
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
# don't log certain requests
class LogSuppressFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
@ -292,8 +360,56 @@ class LogSuppressFilter(logging.Filter):
return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
task_manager.default_model_to_load = get_initial_model_to_load()
task_manager.start_render_thread()
config = getConfig()
async def check_status(): # Task to Validate user config shortly after startup.
# Check that the loaded config.json yielded a server in a known valid state.
# When issues are found, try to fix them when possible and warn the user.
device_count = 0
# Wait for devices to register and/or change names.
THREAD_START_DELAY = 5 # seconds - Give time for devices/threads to start.
for i in range(10): # Maximum number of retry.
await asyncio.sleep(THREAD_START_DELAY)
new_count = task_manager.is_alive()
# Stops retry once no more devices show up.
if new_count > 0 and device_count == new_count: break
device_count = new_count
if 'render_devices' in config and task_manager.is_alive() <= 0: # No running devices, probably invalid user config. Try to apply defaults.
print('WARNING: No active render devices after loading config. Validate "render_devices" in config.json')
task_manager.start_render_thread('auto') # Detect best device for renders
task_manager.start_render_thread('cpu') # Allow CPU to be used for renders
await asyncio.sleep(THREAD_START_DELAY) # delay message after thread start.
print('Default render devices loaded to replace missing render_devices', config['render_devices'])
display_warning = False
if not 'render_devices' in config and task_manager.is_alive(0) <= 0: # No config set, is on auto mode and without cuda:0
task_manager.start_render_thread('cuda') # An other cuda device is better and cuda:0 is missing, start it...
display_warning = True # And warn user to update settings...
await asyncio.sleep(THREAD_START_DELAY) # delay message after thread start.
if display_warning or task_manager.is_alive(0) <= 0:
print('WARNING: GFPGANer only works on GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.')
print('Using CUDA_VISIBLE_DEVICES will remap the selected devices starting at GPU:0 fixing GFPGANer')
print('Add the line "@set CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.bat')
print('Add the line "CUDA_VISIBLE_DEVICES=N" where N is the GPUs to use to config.sh')
# Start the task_manager
task_manager.default_model_to_load = resolve_model_to_use()
if 'render_devices' in config: # Start a new thread for each device.
if isinstance(config['render_devices'], str):
config['render_devices'] = config['render_devices'].split(',')
if not isinstance(config['render_devices'], list):
raise Exception('Invalid render_devices value in config.')
for device in config['render_devices']:
task_manager.start_render_thread(device)
else:
# Select best GPU device using free memory, if more than one device.
task_manager.start_render_thread('auto') # Detect best device for renders
task_manager.start_render_thread('cpu') # Allow CPU to be used for renders
# Task to Validate user config shortly after startup.
LOOP.create_task(check_status())
# start the browser ui
import webbrowser; webbrowser.open('http://localhost:9000')