First draft for Multi-GPU support

This commit is contained in:
Marc-Andre Ferland 2022-10-16 21:41:39 -04:00
parent 2edc06c662
commit 7c72608e1c
3 changed files with 584 additions and 376 deletions

View File

@ -35,63 +35,144 @@ import base64
from io import BytesIO from io import BytesIO
#from colorama import Fore #from colorama import Fore
# local from threading import local as LocalThreadVars
stop_processing = False thread_data = LocalThreadVars()
temp_images = {}
ckpt_file = None def device_would_fail(device):
gfpgan_file = None if device == 'cpu': return None
real_esrgan_file = None # Returns None when no issues found, otherwise returns the detected error str.
# Memory check
model = None try:
modelCS = None mem_free, mem_total = torch.cuda.mem_get_info(device)
modelFS = None
model_gfpgan = None
model_real_esrgan = None
model_is_half = False
model_fs_is_half = False
device = None
unet_bs = 1
precision = 'autocast'
sampler_plms = None
sampler_ddim = None
has_valid_gpu = False
force_full_precision = False
try:
gpu = torch.cuda.current_device()
gpu_name = torch.cuda.get_device_name(gpu)
print('GPU detected: ', gpu_name)
force_full_precision = ('nvidia' in gpu_name.lower() or 'geforce' in gpu_name.lower()) and (' 1660' in gpu_name or ' 1650' in gpu_name) # otherwise these NVIDIA cards create green images
if force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
mem_free, mem_total = torch.cuda.mem_get_info(gpu)
mem_total /= float(10**9) mem_total /= float(10**9)
if mem_total < 3.0: if mem_total < 3.0:
print("GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion") return 'GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion'
raise Exception() except RuntimeError as e:
return str(e) # Return cuda errors from mem_get_info as strings
return None
has_valid_gpu = True def device_select(device):
except: if device == 'cpu': return True
if not torch.cuda.is_available(): return False
failure_msg = device_would_fail(device)
if failure_msg:
if 'invalid device' in failure_msg:
raise NameError(f'GPU "{device}" could not be found. Remove this device from config.render_devices or use one of "auto" or "cuda".')
print(failure_msg)
return False
device_name = torch.cuda.get_device_name(device)
# otherwise these NVIDIA cards create green images
thread_data.force_full_precision = ('nvidia' in device_name.lower() or 'geforce' in device_name.lower()) and (' 1660' in device_name or ' 1650' in device_name)
if thread_data.force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
thread_data.device = device
thread_data.has_valid_gpu = True
return True
def device_init(device_selection=None):
# Thread bound properties
thread_data.stop_processing = False
thread_data.temp_images = {}
thread_data.ckpt_file = None
thread_data.gfpgan_file = None
thread_data.real_esrgan_file = None
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
thread_data.model_gfpgan = None
thread_data.model_real_esrgan = None
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
thread_data.device = None
thread_data.unet_bs = 1
thread_data.precision = 'autocast'
thread_data.sampler_plms = None
thread_data.sampler_ddim = None
thread_data.turbo = False
thread_data.has_valid_gpu = False
thread_data.force_full_precision = False
if device_selection.lower() == 'cpu':
print('CPU requested, skipping gpu init.')
thread_data.device = 'cpu'
return
if not torch.cuda.is_available():
print('WARNING: torch.cuda is not available. Using the CPU, but this will be very slow!')
return
device_count = torch.cuda.device_count()
if device_count <= 1 and device_selection == 'auto':
device_selection = 'current' # Use 'auto' only when there is more than one compatible device found.
if device_selection == 'auto':
print('Autoselecting GPU. Using most free memory.')
max_mem_free = 0
best_device = None
for device in range(device_count):
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_free /= float(10**9)
mem_total /= float(10**9)
device_name = torch.cuda.get_device_name(device)
print(f'GPU:{device} detected: {device_name} - Memory: {round(mem_total - mem_free, 2)}Go / {round(mem_total, 2)}Go')
if max_mem_free < mem_free:
max_mem_free = mem_free
best_device = device
if best_device and device_select(device):
print(f'Setting GPU:{device} as active')
torch.cuda.device(device)
return
if isinstance(device_selection, str):
device_selection = device_selection.lower()
if device_selection.startswith('gpu:'):
device_selection = int(device_selection[4:])
if device_selection != 'cuda' and device_selection != 'current':
if device_select(device_selection):
if isinstance(device_selection, int):
print(f'Setting GPU:{device_selection} as active')
else:
print(f'Setting {device_selection} as active')
torch.cuda.device(device_selection)
return
# By default use current device.
print('Checking current GPU...')
device = torch.cuda.current_device()
device_name = torch.cuda.get_device_name(device)
print(f'GPU:{device} detected: {device_name}')
if device_select(device):
return
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!') print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
pass thread_data.device = 'cpu'
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast'): def is_first_cuda_device(device):
global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half if thread_data.device == 0 or thread_data.device == '0':
return True
if thread_data.device == 'cuda' or thread_data.device == 'cuda:0':
return True
if thread_data.device == torch.device(0):
return True
return False
device = device_to_use if has_valid_gpu else 'cpu' def load_model_ckpt():
precision = precision_to_use if not force_full_precision else 'full' if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
unet_bs = unet_bs_to_use if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
if not thread_data.precision:
thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast'
if not thread_data.unet_bs:
thread_data.unet_bs = 1
unload_model() unload_model()
if device == 'cpu': if thread_data.device == 'cpu':
precision = 'full' thread_data.precision = 'full'
sd = load_model_from_config(f"{ckpt_to_use}.ckpt") print('loading', thread_data.ckpt_file, 'to', thread_data.device, 'using precision', thread_data.precision)
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
li, lo = [], [] li, lo = [], []
for key, value in sd.items(): for key, value in sd.items():
sp = key.split(".") sp = key.split(".")
@ -114,88 +195,84 @@ def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_u
model = instantiate_from_config(config.modelUNet) model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False) _, _ = model.load_state_dict(sd, strict=False)
model.eval() model.eval()
model.cdevice = device model.cdevice = torch.device(thread_data.device)
model.unet_bs = unet_bs model.unet_bs = thread_data.unet_bs
model.turbo = turbo model.turbo = thread_data.turbo
if thread_data.device != 'cpu':
model.to(thread_data.device)
thread_data.model = model
modelCS = instantiate_from_config(config.modelCondStage) modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False) _, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval() modelCS.eval()
modelCS.cond_stage_model.device = device modelCS.cond_stage_model.device = torch.device(thread_data.device)
if thread_data.device != 'cpu':
modelCS.to(thread_data.device)
thread_data.modelCS = modelCS
modelFS = instantiate_from_config(config.modelFirstStage) modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False) _, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval() modelFS.eval()
if thread_data.device != 'cpu':
modelFS.to(thread_data.device)
thread_data.modelFS = modelFS
del sd del sd
if device != "cpu" and precision == "autocast": if thread_data.device != "cpu" and thread_data.precision == "autocast":
model.half() thread_data.model.half()
modelCS.half() thread_data.modelCS.half()
modelFS.half() thread_data.modelFS.half()
model_is_half = True thread_data.model_is_half = True
model_fs_is_half = True thread_data.model_fs_is_half = True
else: else:
model_is_half = False thread_data.model_is_half = False
model_fs_is_half = False thread_data.model_fs_is_half = False
ckpt_file = ckpt_to_use print('loaded', thread_data.ckpt_file, 'as', model.device, '->', modelCS.cond_stage_model.device, '->', thread_data.modelFS.device, 'using precision', thread_data.precision)
print('loaded ', ckpt_file, 'to', device, 'precision', precision)
def unload_model(): def unload_model():
global model, modelCS, modelFS if thread_data.model is not None:
print('Unloading models...')
del thread_data.model
del thread_data.modelCS
del thread_data.modelFS
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
if model is not None: def load_model_gfpgan():
del model if thread_data.gfpgan_file is None:
del modelCS print('load_model_gfpgan called without setting gfpgan_file')
del modelFS
model = None
modelCS = None
modelFS = None
def load_model_gfpgan(gfpgan_to_use):
global gfpgan_file, model_gfpgan
if gfpgan_to_use is None:
return return
if thread_data.device != 'cpu' and not is_first_cuda_device(thread_data.device):
#TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
raise Exception(f'Current device {torch.device(thread_data.device)} is not {torch.device(0)}.')
model_path = thread_data.gfpgan_file + ".pth"
thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
gfpgan_file = gfpgan_to_use def load_model_real_esrgan():
model_path = gfpgan_to_use + ".pth" if thread_data.real_esrgan_file is None:
print('load_model_real_esrgan called without setting real_esrgan_file')
if device == 'cpu':
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
else:
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cuda'))
print('loaded ', gfpgan_to_use, 'to', device, 'precision', precision)
def load_model_real_esrgan(real_esrgan_to_use):
global real_esrgan_file, model_real_esrgan
if real_esrgan_to_use is None:
return return
model_path = thread_data.real_esrgan_file + ".pth"
real_esrgan_file = real_esrgan_to_use
model_path = real_esrgan_to_use + ".pth"
RealESRGAN_models = { RealESRGAN_models = {
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4), 'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) 'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
} }
model_to_use = RealESRGAN_models[real_esrgan_to_use] model_to_use = RealESRGAN_models[thread_data.real_esrgan_file]
if device == 'cpu': if thread_data.device == 'cpu':
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
model_real_esrgan.device = torch.device('cpu') #thread_data.model_real_esrgan.device = torch.device(thread_data.device)
model_real_esrgan.model.to('cpu') thread_data.model_real_esrgan.model.to('cpu')
else: else:
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=model_is_half) thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half)
model_real_esrgan.model.name = real_esrgan_to_use thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file
print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
def get_base_path(disk_path, session_id, prompt, ext, suffix=None): def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
if disk_path is None: return None if disk_path is None: return None
@ -214,14 +291,22 @@ def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
def apply_filters(filter_name, image_data): def apply_filters(filter_name, image_data):
print(f'Applying filter {filter_name}...') print(f'Applying filter {filter_name}...')
if isinstance(image_data, torch.Tensor):
print(image_data)
image_data.to(thread_data.device)
gc() gc()
if filter_name == 'gfpgan': if filter_name == 'gfpgan':
_, _, output = model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True) if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
image_data = output[:,:,::-1] image_data = output[:,:,::-1]
if filter_name == 'real_esrgan': if filter_name == 'real_esrgan':
output, _ = model_real_esrgan.enhance(image_data[:,:,::-1]) if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1])
image_data = output[:,:,::-1] image_data = output[:,:,::-1]
return image_data return image_data
@ -234,12 +319,12 @@ def mk_img(req: Request):
gc() gc()
if device != "cpu": if thread_data.device != "cpu":
modelFS.to("cpu") thread_data.modelFS.to("cpu")
modelCS.to("cpu") thread_data.modelCS.to("cpu")
model.model1.to("cpu") thread_data.model.model1.to("cpu")
model.model2.to("cpu") thread_data.model.model2.to("cpu")
gc() gc()
@ -249,66 +334,55 @@ def mk_img(req: Request):
}) })
def do_mk_img(req: Request): def do_mk_img(req: Request):
global ckpt_file thread_data.stop_processing = False
global model, modelCS, modelFS, device
global model_gfpgan, model_real_esrgan
global stop_processing
stop_processing = False
res = Response() res = Response()
res.request = req res.request = req
res.images = [] res.images = []
temp_images.clear() thread_data.temp_images.clear()
# custom model support: # custom model support:
# the req.use_stable_diffusion_model needs to be a valid path # the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension). # to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False needs_model_reload = False
ckpt_to_use = ckpt_file if thread_data.ckpt_file != req.use_stable_diffusion_model:
if ckpt_to_use != req.use_stable_diffusion_model: thread_data.ckpt_file = req.use_stable_diffusion_model
ckpt_to_use = req.use_stable_diffusion_model
needs_model_reload = True needs_model_reload = True
model.turbo = req.turbo
if req.use_cpu: if req.use_cpu:
if device != 'cpu': if thread_data.device != 'cpu':
device = 'cpu' thread_data.device = 'cpu'
if thread_data.model_is_half:
if model_is_half: load_model_ckpt()
load_model_ckpt(ckpt_to_use, device)
needs_model_reload = False needs_model_reload = False
load_model_gfpgan()
load_model_gfpgan(gfpgan_file) load_model_real_esrgan()
load_model_real_esrgan(real_esrgan_file)
else: else:
if has_valid_gpu: if thread_data.has_valid_gpu:
prev_device = device if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
device = 'cuda' (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \ load_model_ckpt()
(precision == 'full' and not req.use_full_precision and not force_full_precision): load_model_gfpgan()
load_model_real_esrgan()
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'))
needs_model_reload = False needs_model_reload = False
if prev_device != device:
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
if needs_model_reload: if needs_model_reload:
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, precision) load_model_ckpt()
if req.use_face_correction != gfpgan_file: if req.use_face_correction != thread_data.gfpgan_file:
load_model_gfpgan(req.use_face_correction) thread_data.gfpgan_file = req.use_face_correction
load_model_gfpgan()
if req.use_upscale != thread_data.real_esrgan_file:
thread_data.real_esrgan_file = req.use_upscale
load_model_real_esrgan()
if req.use_upscale != real_esrgan_file: if thread_data.turbo != req.turbo:
load_model_real_esrgan(req.use_upscale) thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo
model.cdevice = device
modelCS.cond_stage_model.device = device
opt_prompt = req.prompt opt_prompt = req.prompt
opt_seed = req.seed opt_seed = req.seed
@ -318,9 +392,8 @@ def do_mk_img(req: Request):
opt_ddim_eta = 0.0 opt_ddim_eta = 0.0
opt_init_img = req.init_image opt_init_img = req.init_image
print(req.to_string(), '\n device', device) print(req.to_string(), '\n device', thread_data.device)
print('\n\n Using precision:', thread_data.precision)
print('\n\n Using precision:', precision)
seed_everything(opt_seed) seed_everything(opt_seed)
@ -329,7 +402,7 @@ def do_mk_img(req: Request):
assert prompt is not None assert prompt is not None
data = [batch_size * [prompt]] data = [batch_size * [prompt]]
if precision == "autocast" and device != "cpu": if thread_data.precision == "autocast" and thread_data.device != "cpu":
precision_scope = autocast precision_scope = autocast
else: else:
precision_scope = nullcontext precision_scope = nullcontext
@ -345,22 +418,22 @@ def do_mk_img(req: Request):
handler = _img2img handler = _img2img
init_image = load_img(req.init_image, req.width, req.height) init_image = load_img(req.init_image, req.width, req.height)
init_image = init_image.to(device) init_image = init_image.to(thread_data.device)
if device != "cpu" and precision == "autocast": if thread_data.device != "cpu" and thread_data.precision == "autocast":
init_image = init_image.half() init_image = init_image.half()
modelFS.to(device) thread_data.modelFS.to(thread_data.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None: if req.mask is not None:
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(device) mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0) mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size) mask = repeat(mask, '1 ... -> b ...', b=batch_size)
if device != "cpu" and precision == "autocast": if thread_data.device != "cpu" and thread_data.precision == "autocast":
mask = mask.half() mask = mask.half()
move_fs_to_cpu() move_fs_to_cpu()
@ -381,10 +454,10 @@ def do_mk_img(req: Request):
for prompts in tqdm(data, desc="data"): for prompts in tqdm(data, desc="data"):
with precision_scope("cuda"): with precision_scope("cuda"):
modelCS.to(device) thread_data.modelCS.to(thread_data.device)
uc = None uc = None
if req.guidance_scale != 1.0: if req.guidance_scale != 1.0:
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt]) uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple): if isinstance(prompts, tuple):
prompts = list(prompts) prompts = list(prompts)
@ -397,11 +470,11 @@ def do_mk_img(req: Request):
weight = weights[i] weight = weights[i]
# if not skip_normalize: # if not skip_normalize:
weight = weight / totalWeight weight = weight / totalWeight
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight) c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else: else:
c = modelCS.get_learned_conditioning(prompts) c = thread_data.modelCS.get_learned_conditioning(prompts)
modelFS.to(device) thread_data.modelFS.to(thread_data.device)
partial_x_samples = None partial_x_samples = None
def img_callback(x_samples, i): def img_callback(x_samples, i):
@ -417,7 +490,7 @@ def do_mk_img(req: Request):
partial_images = [] partial_images = []
for i in range(batch_size): for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
@ -429,18 +502,19 @@ def do_mk_img(req: Request):
del img, x_sample, x_samples_ddim del img, x_sample, x_samples_ddim
# don't delete x_samples, it is used in the code that called this callback # don't delete x_samples, it is used in the code that called this callback
temp_images[str(req.session_id) + '/' + str(i)] = buf thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'}) partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
progress['output'] = partial_images progress['output'] = partial_images
yield json.dumps(progress) yield json.dumps(progress)
if stop_processing: if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing") raise UserInitiatedStop("User requested that we stop processing")
# run the handler # run the handler
try: try:
print('Running handler...')
if handler == _txt2img: if handler == _txt2img:
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler) x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
else: else:
@ -458,7 +532,7 @@ def do_mk_img(req: Request):
print("saving images") print("saving images")
for i in range(batch_size): for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0)) x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
@ -469,7 +543,7 @@ def do_mk_img(req: Request):
return_orig_img = not has_filters or not req.show_only_filtered_image return_orig_img = not has_filters or not req.show_only_filtered_image
if stop_processing: if thread_data.stop_processing:
return_orig_img = True return_orig_img = True
if req.save_to_disk_path is not None: if req.save_to_disk_path is not None:
@ -489,7 +563,7 @@ def do_mk_img(req: Request):
del img del img
if has_filters and not stop_processing: if has_filters and not thread_data.stop_processing:
filters_applied = [] filters_applied = []
if req.use_face_correction: if req.use_face_correction:
x_sample = apply_filters('gfpgan', x_sample) x_sample = apply_filters('gfpgan', x_sample)
@ -514,7 +588,7 @@ def do_mk_img(req: Request):
move_fs_to_cpu() move_fs_to_cpu()
gc() gc()
del x_samples, x_samples_ddim, x_sample del x_samples, x_samples_ddim, x_sample
print("memory_final = ", torch.cuda.memory_allocated() / 1e6) print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mo')
print('Task completed') print('Task completed')
@ -527,7 +601,7 @@ def save_image(img, img_out_path):
print('could not save the file', traceback.format_exc()) print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, req, prompt, opt_seed): def save_metadata(meta_out_path, req, prompt, opt_seed):
metadata = f"""{prompt} metadata = f'''{prompt}
Width: {req.width} Width: {req.width}
Height: {req.height} Height: {req.height}
Seed: {opt_seed} Seed: {opt_seed}
@ -538,8 +612,8 @@ Use Face Correction: {req.use_face_correction}
Use Upscaling: {req.use_upscale} Use Upscaling: {req.use_upscale}
Sampler: {req.sampler} Sampler: {req.sampler}
Negative Prompt: {req.negative_prompt} Negative Prompt: {req.negative_prompt}
Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'} Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
""" '''
try: try:
with open(meta_out_path, 'w') as f: with open(meta_out_path, 'w') as f:
f.write(metadata) f.write(metadata)
@ -549,16 +623,19 @@ Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name): def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu": if thread_data.device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6 mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
modelCS.to("cpu") print('Device:', thread_data.device, 'CS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
while torch.cuda.memory_allocated() / 1e6 >= mem: thread_data.modelCS.to("cpu")
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
print('Device:', thread_data.device, 'Waiting Memory transfer. Memory Used:', round(mem, 2), 'Mo')
time.sleep(1) time.sleep(1)
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
if sampler_name == 'ddim': if sampler_name == 'ddim':
model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = model.sample( samples_ddim = thread_data.model.sample(
S=opt_ddim_steps, S=opt_ddim_steps,
conditioning=c, conditioning=c,
seed=opt_seed, seed=opt_seed,
@ -572,14 +649,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
mask=mask, mask=mask,
sampler = sampler_name, sampler = sampler_name,
) )
yield from samples_ddim yield from samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask): def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
# encode (scaled latent) # encode (scaled latent)
z_enc = model.stochastic_encode( z_enc = thread_data.model.stochastic_encode(
init_latent, init_latent,
torch.tensor([t_enc] * batch_size).to(device), torch.tensor([t_enc] * batch_size).to(thread_data.device),
opt_seed, opt_seed,
opt_ddim_eta, opt_ddim_eta,
opt_ddim_steps, opt_ddim_steps,
@ -587,7 +663,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
x_T = None if mask is None else init_latent x_T = None if mask is None else init_latent
# decode it # decode it
samples_ddim = model.sample( samples_ddim = thread_data.model.sample(
t_enc, t_enc,
c, c,
z_enc, z_enc,
@ -602,16 +678,19 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
yield from samples_ddim yield from samples_ddim
def move_fs_to_cpu(): def move_fs_to_cpu():
if device != "cpu": if thread_data.device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6 mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
modelFS.to("cpu") print('Device:', thread_data.device, 'FS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
while torch.cuda.memory_allocated() / 1e6 >= mem: thread_data.modelFS.to("cpu")
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
print('Device:', thread_data.device, 'Waiting for Memory transfer. Memory Used:', round(mem, 2), 'Mo')
time.sleep(1) time.sleep(1)
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
def gc(): def gc():
if device == 'cpu': #gc.collect()
if thread_data.device == 'cpu':
return return
torch.cuda.empty_cache() torch.cuda.empty_cache()
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
@ -621,7 +700,6 @@ def chunk(it, size):
it = iter(it) it = iter(it)
return iter(lambda: tuple(islice(it, size)), ()) return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(ckpt, verbose=False): def load_model_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}") print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")

View File

@ -9,6 +9,10 @@ from typing import Any, Generator, Hashable, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
from sd_internal import Request, Response from sd_internal import Request, Response
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
class SymbolClass(type): # Print nicely formatted Symbol names. class SymbolClass(type): # Print nicely formatted Symbol names.
def __repr__(self): return self.__qualname__ def __repr__(self): return self.__qualname__
def __str__(self): return self.__name__ def __str__(self): return self.__name__
@ -66,17 +70,30 @@ class ImageRequest(BaseModel):
stream_progress_updates: bool = False stream_progress_updates: bool = False
stream_image_progress: bool = False stream_image_progress: bool = False
class FilterRequest(BaseModel):
session_id: str = "session"
model: str = None
name: str = ""
init_image: str = None # base64
width: int = 512
height: int = 512
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False
use_full_precision: bool = False
output_format: str = "jpeg" # or "png"
# Temporary cache to allow to query tasks results for a short time after they are completed. # Temporary cache to allow to query tasks results for a short time after they are completed.
class TaskCache(): class TaskCache():
def __init__(self): def __init__(self):
self._base = dict() self._base = dict()
self._lock: threading.Lock = threading.RLock() self._lock: threading.Lock = threading.Lock()
def _get_ttl_time(self, ttl: int) -> int: def _get_ttl_time(self, ttl: int) -> int:
return int(time.time()) + ttl return int(time.time()) + ttl
def _is_expired(self, timestamp: int) -> bool: def _is_expired(self, timestamp: int) -> bool:
return int(time.time()) >= timestamp return int(time.time()) >= timestamp
def clean(self) -> None: def clean(self) -> None:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clean failed to acquire lock within timeout.') if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clean' + ERR_LOCK_FAILED)
try: try:
# Create a list of expired keys to delete # Create a list of expired keys to delete
to_delete = [] to_delete = []
@ -91,11 +108,11 @@ class TaskCache():
finally: finally:
self._lock.release() self._lock.release()
def clear(self) -> None: def clear(self) -> None:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clear failed to acquire lock within timeout.') if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clear' + ERR_LOCK_FAILED)
try: self._base.clear() try: self._base.clear()
finally: self._lock.release() finally: self._lock.release()
def delete(self, key: Hashable) -> bool: def delete(self, key: Hashable) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.delete failed to acquire lock within timeout.') if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.delete' + ERR_LOCK_FAILED)
try: try:
if key not in self._base: if key not in self._base:
return False return False
@ -104,7 +121,7 @@ class TaskCache():
finally: finally:
self._lock.release() self._lock.release()
def keep(self, key: Hashable, ttl: int) -> bool: def keep(self, key: Hashable, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.keep failed to acquire lock within timeout.') if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.keep' + ERR_LOCK_FAILED)
try: try:
if key in self._base: if key in self._base:
_, value = self._base.get(key) _, value = self._base.get(key)
@ -114,7 +131,7 @@ class TaskCache():
finally: finally:
self._lock.release() self._lock.release()
def put(self, key: Hashable, value: Any, ttl: int) -> bool: def put(self, key: Hashable, value: Any, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.put failed to acquire lock within timeout.') if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.put' + ERR_LOCK_FAILED)
try: try:
self._base[key] = ( self._base[key] = (
self._get_ttl_time(ttl), value self._get_ttl_time(ttl), value
@ -128,21 +145,23 @@ class TaskCache():
finally: finally:
self._lock.release() self._lock.release()
def tryGet(self, key: Hashable) -> Any: def tryGet(self, key: Hashable) -> Any:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.tryGet failed to acquire lock within timeout.') if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.tryGet' + ERR_LOCK_FAILED)
try: try:
ttl, value = self._base.get(key, (None, None)) ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl): if ttl is not None and self._is_expired(ttl):
print(f'Session {key} expired. Discarding data.') print(f'Session {key} expired. Discarding data.')
self.delete(key) del self._base[key]
return None return None
return value return value
finally: finally:
self._lock.release() self._lock.release()
manager_lock = threading.Lock()
render_threads = []
current_state = ServerStates.Init current_state = ServerStates.Init
current_state_error:Exception = None current_state_error:Exception = None
current_model_path = None current_model_path = None
tasks_queue = queue.Queue() tasks_queue = []
task_cache = TaskCache() task_cache = TaskCache()
default_model_to_load = None default_model_to_load = None
@ -155,7 +174,8 @@ def preload_model(file_path=None):
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
try: try:
from . import runtime from . import runtime
runtime.load_model_ckpt(ckpt_to_use=file_path) runtime.thread_data.ckpt_file = file_path
runtime.load_model_ckpt()
current_model_path = file_path current_model_path = file_path
current_state_error = None current_state_error = None
current_state = ServerStates.Online current_state = ServerStates.Online
@ -165,43 +185,62 @@ def preload_model(file_path=None):
current_state = ServerStates.Unavailable current_state = ServerStates.Unavailable
print(traceback.format_exc()) print(traceback.format_exc())
def thread_render(): def thread_render(device):
global current_state, current_state_error, current_model_path global current_state, current_state_error, current_model_path
from . import runtime from . import runtime
current_state = ServerStates.Online try:
runtime.device_init(device)
except:
print(traceback.format_exc())
return
preload_model() preload_model()
current_state = ServerStates.Online
while True: while True:
task_cache.clean() task_cache.clean()
if isinstance(current_state_error, SystemExit): if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable current_state = ServerStates.Unavailable
return return
task = None task = None
try: if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
task = tasks_queue.get(timeout=1) print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.')
except queue.Empty as e: time.sleep(1)
if isinstance(current_state_error, SystemExit): continue
current_state = ServerStates.Unavailable if len(tasks_queue) <= 0:
return manager_lock.release()
else: continue time.sleep(1)
continue
try: # Select a render task.
for queued_task in tasks_queue:
if queued_task.request.use_cpu and runtime.thread_data.device != 'cpu':
continue # Cuda Tasks
if not queued_task.request.use_cpu and runtime.thread_data.device == 'cpu':
continue # CPU Tasks
if queued_task.request.use_face_correction and not runtime.is_first_cuda_device(runtime.thread_data.device):
continue #TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
task = queued_task
break
if task is not None:
del tasks_queue[tasks_queue.index(task)]
finally:
manager_lock.release()
if task is None:
time.sleep(1)
continue
#if current_model_path != task.request.use_stable_diffusion_model: #if current_model_path != task.request.use_stable_diffusion_model:
# preload_model(task.request.use_stable_diffusion_model) # preload_model(task.request.use_stable_diffusion_model)
if current_state_error: if current_state_error:
task.error = current_state_error task.error = current_state_error
continue continue
print(f'Session {task.request.session_id} starting task {id(task)}') print(f'Session {task.request.session_id} starting task {id(task)}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try: try:
task.lock.acquire(blocking=False) # Open data generator.
res = runtime.mk_img(task.request) res = runtime.mk_img(task.request)
if current_model_path == task.request.use_stable_diffusion_model: if current_model_path == task.request.use_stable_diffusion_model:
current_state = ServerStates.Rendering current_state = ServerStates.Rendering
else: else:
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
except Exception as e: # Start reading from generator.
task.error = e
task.lock.release()
tasks_queue.task_done()
print(traceback.format_exc())
continue
dataQueue = None dataQueue = None
if task.request.stream_progress_updates: if task.request.stream_progress_updates:
dataQueue = task.buffer_queue dataQueue = task.buffer_queue
@ -224,13 +263,18 @@ def thread_render():
for out_obj in result['output']: for out_obj in result['output']:
if 'path' in out_obj: if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:] img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]] task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj: elif 'data' in out_obj:
task.temp_images[result['output'].index(out_obj)] = out_obj['data'] task.temp_images[result['output'].index(out_obj)] = out_obj['data']
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(task.request.session_id, TASK_TTL) task_cache.keep(task.request.session_id, TASK_TTL)
except Exception as e:
task.error = e
print(traceback.format_exc())
continue
finally:
# Task completed # Task completed
task.lock.release() task.lock.release()
tasks_queue.task_done()
task_cache.keep(task.request.session_id, TASK_TTL) task_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration): if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!') print(f'Session {task.request.session_id} task {id(task)} cancelled!')
@ -240,19 +284,37 @@ def thread_render():
print(f'Session {task.request.session_id} task {id(task)} completed.') print(f'Session {task.request.session_id} task {id(task)} completed.')
current_state = ServerStates.Online current_state = ServerStates.Online
render_thread = threading.Thread(target=thread_render) def is_alive(name=None):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED)
nbr_alive = 0
try:
for rthread in render_threads:
if name and not rthread.name.endswith(name):
continue
if rthread.is_alive():
nbr_alive += 1
return nbr_alive
finally:
manager_lock.release()
def start_render_thread(): def start_render_thread(device='auto'):
# Start Rendering Thread if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_threads' + ERR_LOCK_FAILED)
render_thread.daemon = True print('Start new Rendering Thread on device', device)
render_thread.start() try:
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
rthread.daemon = True
rthread.name = 'Runner/' + device
rthread.start()
render_threads.append(rthread)
finally:
manager_lock.release()
def shutdown_event(): # Signal render thread to close on shutdown def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error global current_state_error
current_state_error = SystemExit('Application shutting down.') current_state_error = SystemExit('Application shutting down.')
def render(req : ImageRequest): def render(req : ImageRequest):
if not render_thread.is_alive(): # Render thread is dead if not is_alive(): # Render thread is dead
raise ChildProcessError('Rendering thread has died.') raise ChildProcessError('Rendering thread has died.')
# Alive, check if task in cache # Alive, check if task in cache
task = task_cache.tryGet(req.session_id) task = task_cache.tryGet(req.session_id)
@ -293,6 +355,12 @@ def render(req : ImageRequest):
new_task = RenderTask(r) new_task = RenderTask(r)
if task_cache.put(r.session_id, new_task, TASK_TTL): if task_cache.put(r.session_id, new_task, TASK_TTL):
tasks_queue.put(new_task, block=True, timeout=30) # Use twice the normal timeout for adding user requests.
# Tries to force task_cache.put to fail before tasks_queue.put would.
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
try:
tasks_queue.append(new_task)
return new_task return new_task
finally:
manager_lock.release()
raise RuntimeError('Failed to add task to cache.') raise RuntimeError('Failed to add task to cache.')

View File

@ -15,14 +15,24 @@ MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = {
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
'render_devices': ['auto'], # ['cuda'] or ['CPU', 'GPU:0', 'GPU:1', ...] or ['cpu']
'update_branch': 'main',
}
APP_CONFIG_DEFAULT_MODELS = [
# needed to support the legacy installations
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
import logging import logging
import queue, threading, time #import queue, threading, time
from typing import Any, Generator, Hashable, Optional, Union from typing import Any, Generator, Hashable, List, Optional, Union
from sd_internal import Request, Response, task_manager from sd_internal import Request, Response, task_manager
@ -37,52 +47,173 @@ ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media") app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
class SetAppConfigRequest(BaseModel): config_cached = None
update_branch: str = "main" config_last_mod_time = 0
def getConfig(default_val=APP_CONFIG_DEFAULTS):
global config_cached, config_last_mod_time
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return default_val
if config_last_mod_time > 0 and config_cached is not None:
# Don't read if file was not modified
mtime = os.path.getmtime(config_json_path)
if mtime <= config_last_mod_time:
return config_cached
with open(config_json_path, 'r') as f:
config_cached = json.load(f)
config_last_mod_time = os.path.getmtime(config_json_path)
return config_cached
except Exception as e:
print(str(e))
print(traceback.format_exc())
return default_val
# needs to support the legacy installations def setConfig(config):
def get_initial_model_to_load(): try: # config.json
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt') config_json_path = os.path.join(CONFIG_DIR, 'config.json')
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model" with open(config_json_path, 'w') as f:
return json.dump(config, f)
except:
print(traceback.format_exc())
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use) if 'render_devices' in config:
gpu_devices = filter(lambda dev: dev.startswith('GPU:'), config['render_devices'])
else:
gpu_devices = []
try: # config.bat
config_bat = [
f"@set update_branch={config['update_branch']}"
]
if len(gpu_devices) > 0:
config_sh.append(f"@set CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}")
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
with open(config_bat_path, 'w') as f:
f.write(f.write('\r\n'.join(config_bat)))
except Exception as e:
print(traceback.format_exc())
try: # config.sh
config_sh = [
'#!/bin/bash'
f"export update_branch={config['update_branch']}"
]
if len(gpu_devices) > 0:
config_sh.append(f"CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}")
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
with open(config_sh_path, 'w') as f:
f.write('\n'.join(config_sh))
except Exception as e:
print(traceback.format_exc())
def resolve_model_to_use(model_name:str=None):
if not model_name: # When None try user configured model.
config = getConfig() config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']: if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion'] model_name = config['model']['stable-diffusion']
model_path = resolve_model_to_use(model_name) if model_name:
if os.path.exists(model_name + '.ckpt'):
# Direct Path to file
return model_name
# Check models directory
models_dir_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
if os.path.exists(models_dir_path + '.ckpt'):
return models_dir_path
# Default locations
if model_name in APP_CONFIG_DEFAULT_MODELS:
default_model_path = os.path.join(SD_DIR, model_name)
if os.path.exists(default_model_path + '.ckpt'):
return default_model_path
# Can't find requested model, check the default paths.
for default_model in APP_CONFIG_DEFAULT_MODELS:
default_model_path = os.path.join(SD_DIR, default_model + '.ckpt')
if os.path.exists(default_model_path):
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', default_model_path + '.ckpt')
return default_model_path
raise Exception('No valid models found.')
if os.path.exists(model_path + '.ckpt'): class SetAppConfigRequest(BaseModel):
ckpt_to_use = model_path update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
config = getConfig()
if req.update_branch:
config['update_branch'] = req.update_branch
if req.render_devices and hasattr(req.render_devices, "__len__"): # strings, array of strings or numbers.
render_devices = []
if isinstance(req.render_devices, str):
req.render_devices = req.render_devices.split(',')
if isinstance(req.render_devices, list):
for gpu in req.render_devices:
if isinstance(req.render_devices, int):
render_devices.append('GPU:' + gpu)
else: else:
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt') render_devices.append(gpu)
return ckpt_to_use if isinstance(req.render_devices, int):
render_devices.append('GPU:' + req.render_devices)
if len(render_devices) > 0:
config['render_devices'] = render_devices
try:
setConfig(config)
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
def resolve_model_to_use(model_name): def getModels():
if model_name in ('sd-v1-4', 'custom-model'): models = {
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name) 'active': {
'stable-diffusion': 'sd-v1-4',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
},
}
legacy_model_path = os.path.join(SD_DIR, model_name) # custom models
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'): sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
model_path = legacy_model_path for file in os.listdir(sd_models_dir):
if file.endswith('.ckpt'):
model_name = os.path.splitext(file)[0]
models['options']['stable-diffusion'].append(model_name)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['active']['stable-diffusion'] = 'custom-model'
models['options']['stable-diffusion'].append('custom-model')
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
return models
@app.get('/get/{key:path}')
def read_web_data(key:str=None):
if not key: # /get without parameters, stable-diffusion easter egg.
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == 'app_config':
config = getConfig(default_val=None)
if config is None:
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
return JSONResponse(config, headers=NOCACHE_HEADERS)
elif key == 'models':
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
else: else:
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name) return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
return model_path
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
@app.get('/')
def read_root():
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
@app.get('/ping') # Get server and optionally session status. @app.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None): def ping(session_id:str=None):
if not task_manager.render_thread.is_alive(): # Render thread is dead. if task_manager.is_alive() <= 0: # Check that render threads are alive.
if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(current_state_error)) if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(task_manager.current_state_error))
return HTTPException(status_code=500, detail='Render thread is dead.') return HTTPException(status_code=500, detail='Render thread is dead.')
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error)) if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive # Alive
response = {'status': str(task_manager.current_state)} response = {'status': str(task_manager.current_state)}
if session_id: if session_id:
@ -119,7 +250,7 @@ def render(req : task_manager.ImageRequest):
new_task = task_manager.render(req) new_task = task_manager.render(req)
response = { response = {
'status': str(task_manager.current_state), 'status': str(task_manager.current_state),
'queue': task_manager.tasks_queue.qsize(), 'queue': len(task_manager.tasks_queue),
'stream': f'/image/stream/{req.session_id}/{id(new_task)}', 'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
'task': id(new_task) 'task': id(new_task)
} }
@ -172,100 +303,13 @@ def get_image(session_id, img_id):
except KeyError as e: except KeyError as e:
return HTTPException(status_code=500, detail=str(e)) return HTTPException(status_code=500, detail=str(e))
@app.post('/app_config') @app.get('/')
async def setAppConfig(req : SetAppConfigRequest): def read_root():
try: return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
config = {
'update_branch': req.update_branch
}
config_json_str = json.dumps(config) @app.on_event("shutdown")
config_bat_str = f'@set update_branch={req.update_branch}' def shutdown_event(): # Signal render thread to close on shutdown
config_sh_str = f'export update_branch={req.update_branch}' task_manager.current_state_error = SystemExit('Application shutting down.')
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
with open(config_json_path, 'w') as f:
f.write(config_json_str)
with open(config_bat_path, 'w') as f:
f.write(config_bat_str)
with open(config_sh_path, 'w') as f:
f.write(config_sh_str)
return {'OK'}
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
def getConfig(default_val={}):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return default_val
with open(config_json_path, 'r') as f:
return json.load(f)
except Exception as e:
print(str(e))
print(traceback.format_exc())
return default_val
def setConfig(config):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w') as f:
return json.dump(config, f)
except:
print(str(e))
print(traceback.format_exc())
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
},
}
# custom models
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
for file in os.listdir(sd_models_dir):
if file.endswith('.ckpt'):
model_name = os.path.splitext(file)[0]
models['options']['stable-diffusion'].append(model_name)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['active']['stable-diffusion'] = 'custom-model'
models['options']['stable-diffusion'].append('custom-model')
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
return models
@app.get('/get/{key:path}')
def read_web_data(key:str=None):
if not key: # /get without parameters, stable-diffusion easter egg.
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == 'app_config':
config = getConfig(default_val=None)
if config is None:
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
return JSONResponse(config, headers=NOCACHE_HEADERS)
elif key == 'models':
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
else:
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
# don't log certain requests # don't log certain requests
class LogSuppressFilter(logging.Filter): class LogSuppressFilter(logging.Filter):
@ -277,8 +321,26 @@ class LogSuppressFilter(logging.Filter):
return True return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
task_manager.default_model_to_load = get_initial_model_to_load() config = getConfig()
task_manager.start_render_thread() # Start the task_manager
task_manager.default_model_to_load = resolve_model_to_use()
if 'render_devices' in config: # Start a new thread for each device.
if isinstance(config['render_devices'], str):
config['render_devices'] = config['render_devices'].split(',')
if not isinstance(config['render_devices'], list):
raise Exception('Invalid render_devices value in config.')
for device in config['render_devices']:
task_manager.start_render_thread(device)
allow_cpu = False
if task_manager.is_alive() <= 0: # No running devices, apply defaults.
# Select best device GPU device using free memory if more than one device.
task_manager.start_render_thread('auto')
allow_cpu = True
# Allow CPU to be used for renders if not already enabled in current config.
if task_manager.is_alive('cpu') <= 0 and allow_cpu:
task_manager.start_render_thread('cpu')
# start the browser ui # start the browser ui
import webbrowser; webbrowser.open('http://localhost:9000') import webbrowser; webbrowser.open('http://localhost:9000')