First draft for Multi-GPU support

This commit is contained in:
Marc-Andre Ferland 2022-10-16 21:41:39 -04:00
parent 2edc06c662
commit 7c72608e1c
3 changed files with 584 additions and 376 deletions

View File

@ -35,63 +35,144 @@ import base64
from io import BytesIO
#from colorama import Fore
# local
stop_processing = False
temp_images = {}
from threading import local as LocalThreadVars
thread_data = LocalThreadVars()
ckpt_file = None
gfpgan_file = None
real_esrgan_file = None
model = None
modelCS = None
modelFS = None
model_gfpgan = None
model_real_esrgan = None
model_is_half = False
model_fs_is_half = False
device = None
unet_bs = 1
precision = 'autocast'
sampler_plms = None
sampler_ddim = None
has_valid_gpu = False
force_full_precision = False
try:
gpu = torch.cuda.current_device()
gpu_name = torch.cuda.get_device_name(gpu)
print('GPU detected: ', gpu_name)
force_full_precision = ('nvidia' in gpu_name.lower() or 'geforce' in gpu_name.lower()) and (' 1660' in gpu_name or ' 1650' in gpu_name) # otherwise these NVIDIA cards create green images
if force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
mem_free, mem_total = torch.cuda.mem_get_info(gpu)
def device_would_fail(device):
if device == 'cpu': return None
# Returns None when no issues found, otherwise returns the detected error str.
# Memory check
try:
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9)
if mem_total < 3.0:
print("GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion")
raise Exception()
return 'GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion'
except RuntimeError as e:
return str(e) # Return cuda errors from mem_get_info as strings
return None
has_valid_gpu = True
except:
def device_select(device):
if device == 'cpu': return True
if not torch.cuda.is_available(): return False
failure_msg = device_would_fail(device)
if failure_msg:
if 'invalid device' in failure_msg:
raise NameError(f'GPU "{device}" could not be found. Remove this device from config.render_devices or use one of "auto" or "cuda".')
print(failure_msg)
return False
device_name = torch.cuda.get_device_name(device)
# otherwise these NVIDIA cards create green images
thread_data.force_full_precision = ('nvidia' in device_name.lower() or 'geforce' in device_name.lower()) and (' 1660' in device_name or ' 1650' in device_name)
if thread_data.force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
thread_data.device = device
thread_data.has_valid_gpu = True
return True
def device_init(device_selection=None):
# Thread bound properties
thread_data.stop_processing = False
thread_data.temp_images = {}
thread_data.ckpt_file = None
thread_data.gfpgan_file = None
thread_data.real_esrgan_file = None
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
thread_data.model_gfpgan = None
thread_data.model_real_esrgan = None
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
thread_data.device = None
thread_data.unet_bs = 1
thread_data.precision = 'autocast'
thread_data.sampler_plms = None
thread_data.sampler_ddim = None
thread_data.turbo = False
thread_data.has_valid_gpu = False
thread_data.force_full_precision = False
if device_selection.lower() == 'cpu':
print('CPU requested, skipping gpu init.')
thread_data.device = 'cpu'
return
if not torch.cuda.is_available():
print('WARNING: torch.cuda is not available. Using the CPU, but this will be very slow!')
return
device_count = torch.cuda.device_count()
if device_count <= 1 and device_selection == 'auto':
device_selection = 'current' # Use 'auto' only when there is more than one compatible device found.
if device_selection == 'auto':
print('Autoselecting GPU. Using most free memory.')
max_mem_free = 0
best_device = None
for device in range(device_count):
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_free /= float(10**9)
mem_total /= float(10**9)
device_name = torch.cuda.get_device_name(device)
print(f'GPU:{device} detected: {device_name} - Memory: {round(mem_total - mem_free, 2)}Go / {round(mem_total, 2)}Go')
if max_mem_free < mem_free:
max_mem_free = mem_free
best_device = device
if best_device and device_select(device):
print(f'Setting GPU:{device} as active')
torch.cuda.device(device)
return
if isinstance(device_selection, str):
device_selection = device_selection.lower()
if device_selection.startswith('gpu:'):
device_selection = int(device_selection[4:])
if device_selection != 'cuda' and device_selection != 'current':
if device_select(device_selection):
if isinstance(device_selection, int):
print(f'Setting GPU:{device_selection} as active')
else:
print(f'Setting {device_selection} as active')
torch.cuda.device(device_selection)
return
# By default use current device.
print('Checking current GPU...')
device = torch.cuda.current_device()
device_name = torch.cuda.get_device_name(device)
print(f'GPU:{device} detected: {device_name}')
if device_select(device):
return
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
pass
thread_data.device = 'cpu'
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast'):
global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half
def is_first_cuda_device(device):
if thread_data.device == 0 or thread_data.device == '0':
return True
if thread_data.device == 'cuda' or thread_data.device == 'cuda:0':
return True
if thread_data.device == torch.device(0):
return True
return False
device = device_to_use if has_valid_gpu else 'cpu'
precision = precision_to_use if not force_full_precision else 'full'
unet_bs = unet_bs_to_use
def load_model_ckpt():
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
if not thread_data.precision:
thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast'
if not thread_data.unet_bs:
thread_data.unet_bs = 1
unload_model()
if device == 'cpu':
precision = 'full'
if thread_data.device == 'cpu':
thread_data.precision = 'full'
sd = load_model_from_config(f"{ckpt_to_use}.ckpt")
print('loading', thread_data.ckpt_file, 'to', thread_data.device, 'using precision', thread_data.precision)
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
li, lo = [], []
for key, value in sd.items():
sp = key.split(".")
@ -114,88 +195,84 @@ def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_u
model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()
model.cdevice = device
model.unet_bs = unet_bs
model.turbo = turbo
model.cdevice = torch.device(thread_data.device)
model.unet_bs = thread_data.unet_bs
model.turbo = thread_data.turbo
if thread_data.device != 'cpu':
model.to(thread_data.device)
thread_data.model = model
modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()
modelCS.cond_stage_model.device = device
modelCS.cond_stage_model.device = torch.device(thread_data.device)
if thread_data.device != 'cpu':
modelCS.to(thread_data.device)
thread_data.modelCS = modelCS
modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
if thread_data.device != 'cpu':
modelFS.to(thread_data.device)
thread_data.modelFS = modelFS
del sd
if device != "cpu" and precision == "autocast":
model.half()
modelCS.half()
modelFS.half()
model_is_half = True
model_fs_is_half = True
if thread_data.device != "cpu" and thread_data.precision == "autocast":
thread_data.model.half()
thread_data.modelCS.half()
thread_data.modelFS.half()
thread_data.model_is_half = True
thread_data.model_fs_is_half = True
else:
model_is_half = False
model_fs_is_half = False
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
ckpt_file = ckpt_to_use
print('loaded ', ckpt_file, 'to', device, 'precision', precision)
print('loaded', thread_data.ckpt_file, 'as', model.device, '->', modelCS.cond_stage_model.device, '->', thread_data.modelFS.device, 'using precision', thread_data.precision)
def unload_model():
global model, modelCS, modelFS
if thread_data.model is not None:
print('Unloading models...')
del thread_data.model
del thread_data.modelCS
del thread_data.modelFS
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
if model is not None:
del model
del modelCS
del modelFS
model = None
modelCS = None
modelFS = None
def load_model_gfpgan(gfpgan_to_use):
global gfpgan_file, model_gfpgan
if gfpgan_to_use is None:
def load_model_gfpgan():
if thread_data.gfpgan_file is None:
print('load_model_gfpgan called without setting gfpgan_file')
return
if thread_data.device != 'cpu' and not is_first_cuda_device(thread_data.device):
#TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
raise Exception(f'Current device {torch.device(thread_data.device)} is not {torch.device(0)}.')
model_path = thread_data.gfpgan_file + ".pth"
thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
gfpgan_file = gfpgan_to_use
model_path = gfpgan_to_use + ".pth"
if device == 'cpu':
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
else:
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cuda'))
print('loaded ', gfpgan_to_use, 'to', device, 'precision', precision)
def load_model_real_esrgan(real_esrgan_to_use):
global real_esrgan_file, model_real_esrgan
if real_esrgan_to_use is None:
def load_model_real_esrgan():
if thread_data.real_esrgan_file is None:
print('load_model_real_esrgan called without setting real_esrgan_file')
return
real_esrgan_file = real_esrgan_to_use
model_path = real_esrgan_to_use + ".pth"
model_path = thread_data.real_esrgan_file + ".pth"
RealESRGAN_models = {
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
}
model_to_use = RealESRGAN_models[real_esrgan_to_use]
model_to_use = RealESRGAN_models[thread_data.real_esrgan_file]
if device == 'cpu':
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
model_real_esrgan.device = torch.device('cpu')
model_real_esrgan.model.to('cpu')
if thread_data.device == 'cpu':
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
#thread_data.model_real_esrgan.device = torch.device(thread_data.device)
thread_data.model_real_esrgan.model.to('cpu')
else:
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=model_is_half)
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half)
model_real_esrgan.model.name = real_esrgan_to_use
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file
print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
if disk_path is None: return None
@ -214,14 +291,22 @@ def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
def apply_filters(filter_name, image_data):
print(f'Applying filter {filter_name}...')
if isinstance(image_data, torch.Tensor):
print(image_data)
image_data.to(thread_data.device)
gc()
if filter_name == 'gfpgan':
_, _, output = model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
image_data = output[:,:,::-1]
if filter_name == 'real_esrgan':
output, _ = model_real_esrgan.enhance(image_data[:,:,::-1])
if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1])
image_data = output[:,:,::-1]
return image_data
@ -234,12 +319,12 @@ def mk_img(req: Request):
gc()
if device != "cpu":
modelFS.to("cpu")
modelCS.to("cpu")
if thread_data.device != "cpu":
thread_data.modelFS.to("cpu")
thread_data.modelCS.to("cpu")
model.model1.to("cpu")
model.model2.to("cpu")
thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu")
gc()
@ -249,66 +334,55 @@ def mk_img(req: Request):
})
def do_mk_img(req: Request):
global ckpt_file
global model, modelCS, modelFS, device
global model_gfpgan, model_real_esrgan
global stop_processing
stop_processing = False
thread_data.stop_processing = False
res = Response()
res.request = req
res.images = []
temp_images.clear()
thread_data.temp_images.clear()
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
ckpt_to_use = ckpt_file
if ckpt_to_use != req.use_stable_diffusion_model:
ckpt_to_use = req.use_stable_diffusion_model
if thread_data.ckpt_file != req.use_stable_diffusion_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
needs_model_reload = True
model.turbo = req.turbo
if req.use_cpu:
if device != 'cpu':
device = 'cpu'
if model_is_half:
load_model_ckpt(ckpt_to_use, device)
if thread_data.device != 'cpu':
thread_data.device = 'cpu'
if thread_data.model_is_half:
load_model_ckpt()
needs_model_reload = False
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
load_model_gfpgan()
load_model_real_esrgan()
else:
if has_valid_gpu:
prev_device = device
device = 'cuda'
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
(precision == 'full' and not req.use_full_precision and not force_full_precision):
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'))
if thread_data.has_valid_gpu:
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
load_model_ckpt()
load_model_gfpgan()
load_model_real_esrgan()
needs_model_reload = False
if prev_device != device:
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
if needs_model_reload:
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, precision)
load_model_ckpt()
if req.use_face_correction != gfpgan_file:
load_model_gfpgan(req.use_face_correction)
if req.use_face_correction != thread_data.gfpgan_file:
thread_data.gfpgan_file = req.use_face_correction
load_model_gfpgan()
if req.use_upscale != thread_data.real_esrgan_file:
thread_data.real_esrgan_file = req.use_upscale
load_model_real_esrgan()
if req.use_upscale != real_esrgan_file:
load_model_real_esrgan(req.use_upscale)
model.cdevice = device
modelCS.cond_stage_model.device = device
if thread_data.turbo != req.turbo:
thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo
opt_prompt = req.prompt
opt_seed = req.seed
@ -318,9 +392,8 @@ def do_mk_img(req: Request):
opt_ddim_eta = 0.0
opt_init_img = req.init_image
print(req.to_string(), '\n device', device)
print('\n\n Using precision:', precision)
print(req.to_string(), '\n device', thread_data.device)
print('\n\n Using precision:', thread_data.precision)
seed_everything(opt_seed)
@ -329,7 +402,7 @@ def do_mk_img(req: Request):
assert prompt is not None
data = [batch_size * [prompt]]
if precision == "autocast" and device != "cpu":
if thread_data.precision == "autocast" and thread_data.device != "cpu":
precision_scope = autocast
else:
precision_scope = nullcontext
@ -345,22 +418,22 @@ def do_mk_img(req: Request):
handler = _img2img
init_image = load_img(req.init_image, req.width, req.height)
init_image = init_image.to(device)
init_image = init_image.to(thread_data.device)
if device != "cpu" and precision == "autocast":
if thread_data.device != "cpu" and thread_data.precision == "autocast":
init_image = init_image.half()
modelFS.to(device)
thread_data.modelFS.to(thread_data.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None:
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(device)
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
if device != "cpu" and precision == "autocast":
if thread_data.device != "cpu" and thread_data.precision == "autocast":
mask = mask.half()
move_fs_to_cpu()
@ -381,10 +454,10 @@ def do_mk_img(req: Request):
for prompts in tqdm(data, desc="data"):
with precision_scope("cuda"):
modelCS.to(device)
thread_data.modelCS.to(thread_data.device)
uc = None
if req.guidance_scale != 1.0:
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
@ -397,11 +470,11 @@ def do_mk_img(req: Request):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = modelCS.get_learned_conditioning(prompts)
c = thread_data.modelCS.get_learned_conditioning(prompts)
modelFS.to(device)
thread_data.modelFS.to(thread_data.device)
partial_x_samples = None
def img_callback(x_samples, i):
@ -417,7 +490,7 @@ def do_mk_img(req: Request):
partial_images = []
for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
@ -429,18 +502,19 @@ def do_mk_img(req: Request):
del img, x_sample, x_samples_ddim
# don't delete x_samples, it is used in the code that called this callback
temp_images[str(req.session_id) + '/' + str(i)] = buf
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
progress['output'] = partial_images
yield json.dumps(progress)
if stop_processing:
if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
# run the handler
try:
print('Running handler...')
if handler == _txt2img:
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
else:
@ -458,7 +532,7 @@ def do_mk_img(req: Request):
print("saving images")
for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
@ -469,7 +543,7 @@ def do_mk_img(req: Request):
return_orig_img = not has_filters or not req.show_only_filtered_image
if stop_processing:
if thread_data.stop_processing:
return_orig_img = True
if req.save_to_disk_path is not None:
@ -489,7 +563,7 @@ def do_mk_img(req: Request):
del img
if has_filters and not stop_processing:
if has_filters and not thread_data.stop_processing:
filters_applied = []
if req.use_face_correction:
x_sample = apply_filters('gfpgan', x_sample)
@ -514,7 +588,7 @@ def do_mk_img(req: Request):
move_fs_to_cpu()
gc()
del x_samples, x_samples_ddim, x_sample
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mo')
print('Task completed')
@ -527,7 +601,7 @@ def save_image(img, img_out_path):
print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, req, prompt, opt_seed):
metadata = f"""{prompt}
metadata = f'''{prompt}
Width: {req.width}
Height: {req.height}
Seed: {opt_seed}
@ -538,8 +612,8 @@ Use Face Correction: {req.use_face_correction}
Use Upscaling: {req.use_upscale}
Sampler: {req.sampler}
Negative Prompt: {req.negative_prompt}
Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
"""
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
'''
try:
with open(meta_out_path, 'w') as f:
f.write(metadata)
@ -549,16 +623,19 @@ Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelCS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
if thread_data.device != "cpu":
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
print('Device:', thread_data.device, 'CS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
thread_data.modelCS.to("cpu")
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
print('Device:', thread_data.device, 'Waiting Memory transfer. Memory Used:', round(mem, 2), 'Mo')
time.sleep(1)
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
if sampler_name == 'ddim':
model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = model.sample(
samples_ddim = thread_data.model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
@ -572,14 +649,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
mask=mask,
sampler = sampler_name,
)
yield from samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
# encode (scaled latent)
z_enc = model.stochastic_encode(
z_enc = thread_data.model.stochastic_encode(
init_latent,
torch.tensor([t_enc] * batch_size).to(device),
torch.tensor([t_enc] * batch_size).to(thread_data.device),
opt_seed,
opt_ddim_eta,
opt_ddim_steps,
@ -587,7 +663,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
x_T = None if mask is None else init_latent
# decode it
samples_ddim = model.sample(
samples_ddim = thread_data.model.sample(
t_enc,
c,
z_enc,
@ -602,16 +678,19 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
yield from samples_ddim
def move_fs_to_cpu():
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
if thread_data.device != "cpu":
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
print('Device:', thread_data.device, 'FS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
thread_data.modelFS.to("cpu")
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
print('Device:', thread_data.device, 'Waiting for Memory transfer. Memory Used:', round(mem, 2), 'Mo')
time.sleep(1)
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
def gc():
if device == 'cpu':
#gc.collect()
if thread_data.device == 'cpu':
return
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
@ -621,7 +700,6 @@ def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")

View File

@ -9,6 +9,10 @@ from typing import Any, Generator, Hashable, Optional, Union
from pydantic import BaseModel
from sd_internal import Request, Response
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
class SymbolClass(type): # Print nicely formatted Symbol names.
def __repr__(self): return self.__qualname__
def __str__(self): return self.__name__
@ -66,17 +70,30 @@ class ImageRequest(BaseModel):
stream_progress_updates: bool = False
stream_image_progress: bool = False
class FilterRequest(BaseModel):
session_id: str = "session"
model: str = None
name: str = ""
init_image: str = None # base64
width: int = 512
height: int = 512
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False
use_full_precision: bool = False
output_format: str = "jpeg" # or "png"
# Temporary cache to allow to query tasks results for a short time after they are completed.
class TaskCache():
def __init__(self):
self._base = dict()
self._lock: threading.Lock = threading.RLock()
self._lock: threading.Lock = threading.Lock()
def _get_ttl_time(self, ttl: int) -> int:
return int(time.time()) + ttl
def _is_expired(self, timestamp: int) -> bool:
return int(time.time()) >= timestamp
def clean(self) -> None:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clean failed to acquire lock within timeout.')
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clean' + ERR_LOCK_FAILED)
try:
# Create a list of expired keys to delete
to_delete = []
@ -91,11 +108,11 @@ class TaskCache():
finally:
self._lock.release()
def clear(self) -> None:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clear failed to acquire lock within timeout.')
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clear' + ERR_LOCK_FAILED)
try: self._base.clear()
finally: self._lock.release()
def delete(self, key: Hashable) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.delete failed to acquire lock within timeout.')
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.delete' + ERR_LOCK_FAILED)
try:
if key not in self._base:
return False
@ -104,7 +121,7 @@ class TaskCache():
finally:
self._lock.release()
def keep(self, key: Hashable, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.keep failed to acquire lock within timeout.')
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.keep' + ERR_LOCK_FAILED)
try:
if key in self._base:
_, value = self._base.get(key)
@ -114,7 +131,7 @@ class TaskCache():
finally:
self._lock.release()
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.put failed to acquire lock within timeout.')
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.put' + ERR_LOCK_FAILED)
try:
self._base[key] = (
self._get_ttl_time(ttl), value
@ -128,21 +145,23 @@ class TaskCache():
finally:
self._lock.release()
def tryGet(self, key: Hashable) -> Any:
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.tryGet failed to acquire lock within timeout.')
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.tryGet' + ERR_LOCK_FAILED)
try:
ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl):
print(f'Session {key} expired. Discarding data.')
self.delete(key)
del self._base[key]
return None
return value
finally:
self._lock.release()
manager_lock = threading.Lock()
render_threads = []
current_state = ServerStates.Init
current_state_error:Exception = None
current_model_path = None
tasks_queue = queue.Queue()
tasks_queue = []
task_cache = TaskCache()
default_model_to_load = None
@ -155,7 +174,8 @@ def preload_model(file_path=None):
current_state = ServerStates.LoadingModel
try:
from . import runtime
runtime.load_model_ckpt(ckpt_to_use=file_path)
runtime.thread_data.ckpt_file = file_path
runtime.load_model_ckpt()
current_model_path = file_path
current_state_error = None
current_state = ServerStates.Online
@ -165,43 +185,62 @@ def preload_model(file_path=None):
current_state = ServerStates.Unavailable
print(traceback.format_exc())
def thread_render():
def thread_render(device):
global current_state, current_state_error, current_model_path
from . import runtime
current_state = ServerStates.Online
try:
runtime.device_init(device)
except:
print(traceback.format_exc())
return
preload_model()
current_state = ServerStates.Online
while True:
task_cache.clean()
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
return
task = None
try:
task = tasks_queue.get(timeout=1)
except queue.Empty as e:
if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable
return
else: continue
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.')
time.sleep(1)
continue
if len(tasks_queue) <= 0:
manager_lock.release()
time.sleep(1)
continue
try: # Select a render task.
for queued_task in tasks_queue:
if queued_task.request.use_cpu and runtime.thread_data.device != 'cpu':
continue # Cuda Tasks
if not queued_task.request.use_cpu and runtime.thread_data.device == 'cpu':
continue # CPU Tasks
if queued_task.request.use_face_correction and not runtime.is_first_cuda_device(runtime.thread_data.device):
continue #TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
task = queued_task
break
if task is not None:
del tasks_queue[tasks_queue.index(task)]
finally:
manager_lock.release()
if task is None:
time.sleep(1)
continue
#if current_model_path != task.request.use_stable_diffusion_model:
# preload_model(task.request.use_stable_diffusion_model)
if current_state_error:
task.error = current_state_error
continue
print(f'Session {task.request.session_id} starting task {id(task)}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try:
task.lock.acquire(blocking=False)
# Open data generator.
res = runtime.mk_img(task.request)
if current_model_path == task.request.use_stable_diffusion_model:
current_state = ServerStates.Rendering
else:
current_state = ServerStates.LoadingModel
except Exception as e:
task.error = e
task.lock.release()
tasks_queue.task_done()
print(traceback.format_exc())
continue
# Start reading from generator.
dataQueue = None
if task.request.stream_progress_updates:
dataQueue = task.buffer_queue
@ -224,13 +263,18 @@ def thread_render():
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]]
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(task.request.session_id, TASK_TTL)
except Exception as e:
task.error = e
print(traceback.format_exc())
continue
finally:
# Task completed
task.lock.release()
tasks_queue.task_done()
task_cache.keep(task.request.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration):
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
@ -240,19 +284,37 @@ def thread_render():
print(f'Session {task.request.session_id} task {id(task)} completed.')
current_state = ServerStates.Online
render_thread = threading.Thread(target=thread_render)
def is_alive(name=None):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED)
nbr_alive = 0
try:
for rthread in render_threads:
if name and not rthread.name.endswith(name):
continue
if rthread.is_alive():
nbr_alive += 1
return nbr_alive
finally:
manager_lock.release()
def start_render_thread():
# Start Rendering Thread
render_thread.daemon = True
render_thread.start()
def start_render_thread(device='auto'):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_threads' + ERR_LOCK_FAILED)
print('Start new Rendering Thread on device', device)
try:
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
rthread.daemon = True
rthread.name = 'Runner/' + device
rthread.start()
render_threads.append(rthread)
finally:
manager_lock.release()
def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error
current_state_error = SystemExit('Application shutting down.')
def render(req : ImageRequest):
if not render_thread.is_alive(): # Render thread is dead
if not is_alive(): # Render thread is dead
raise ChildProcessError('Rendering thread has died.')
# Alive, check if task in cache
task = task_cache.tryGet(req.session_id)
@ -293,6 +355,12 @@ def render(req : ImageRequest):
new_task = RenderTask(r)
if task_cache.put(r.session_id, new_task, TASK_TTL):
tasks_queue.put(new_task, block=True, timeout=30)
# Use twice the normal timeout for adding user requests.
# Tries to force task_cache.put to fail before tasks_queue.put would.
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
try:
tasks_queue.append(new_task)
return new_task
finally:
manager_lock.release()
raise RuntimeError('Failed to add task to cache.')

View File

@ -15,14 +15,24 @@ MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = {
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
'render_devices': ['auto'], # ['cuda'] or ['CPU', 'GPU:0', 'GPU:1', ...] or ['cpu']
'update_branch': 'main',
}
APP_CONFIG_DEFAULT_MODELS = [
# needed to support the legacy installations
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
import logging
import queue, threading, time
from typing import Any, Generator, Hashable, Optional, Union
#import queue, threading, time
from typing import Any, Generator, Hashable, List, Optional, Union
from sd_internal import Request, Response, task_manager
@ -37,52 +47,173 @@ ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
class SetAppConfigRequest(BaseModel):
update_branch: str = "main"
config_cached = None
config_last_mod_time = 0
def getConfig(default_val=APP_CONFIG_DEFAULTS):
global config_cached, config_last_mod_time
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return default_val
if config_last_mod_time > 0 and config_cached is not None:
# Don't read if file was not modified
mtime = os.path.getmtime(config_json_path)
if mtime <= config_last_mod_time:
return config_cached
with open(config_json_path, 'r') as f:
config_cached = json.load(f)
config_last_mod_time = os.path.getmtime(config_json_path)
return config_cached
except Exception as e:
print(str(e))
print(traceback.format_exc())
return default_val
# needs to support the legacy installations
def get_initial_model_to_load():
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
def setConfig(config):
try: # config.json
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w') as f:
return json.dump(config, f)
except:
print(traceback.format_exc())
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use)
if 'render_devices' in config:
gpu_devices = filter(lambda dev: dev.startswith('GPU:'), config['render_devices'])
else:
gpu_devices = []
try: # config.bat
config_bat = [
f"@set update_branch={config['update_branch']}"
]
if len(gpu_devices) > 0:
config_sh.append(f"@set CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}")
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
with open(config_bat_path, 'w') as f:
f.write(f.write('\r\n'.join(config_bat)))
except Exception as e:
print(traceback.format_exc())
try: # config.sh
config_sh = [
'#!/bin/bash'
f"export update_branch={config['update_branch']}"
]
if len(gpu_devices) > 0:
config_sh.append(f"CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}")
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
with open(config_sh_path, 'w') as f:
f.write('\n'.join(config_sh))
except Exception as e:
print(traceback.format_exc())
def resolve_model_to_use(model_name:str=None):
if not model_name: # When None try user configured model.
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion']
model_path = resolve_model_to_use(model_name)
if model_name:
if os.path.exists(model_name + '.ckpt'):
# Direct Path to file
return model_name
# Check models directory
models_dir_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
if os.path.exists(models_dir_path + '.ckpt'):
return models_dir_path
# Default locations
if model_name in APP_CONFIG_DEFAULT_MODELS:
default_model_path = os.path.join(SD_DIR, model_name)
if os.path.exists(default_model_path + '.ckpt'):
return default_model_path
# Can't find requested model, check the default paths.
for default_model in APP_CONFIG_DEFAULT_MODELS:
default_model_path = os.path.join(SD_DIR, default_model + '.ckpt')
if os.path.exists(default_model_path):
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', default_model_path + '.ckpt')
return default_model_path
raise Exception('No valid models found.')
if os.path.exists(model_path + '.ckpt'):
ckpt_to_use = model_path
class SetAppConfigRequest(BaseModel):
update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
config = getConfig()
if req.update_branch:
config['update_branch'] = req.update_branch
if req.render_devices and hasattr(req.render_devices, "__len__"): # strings, array of strings or numbers.
render_devices = []
if isinstance(req.render_devices, str):
req.render_devices = req.render_devices.split(',')
if isinstance(req.render_devices, list):
for gpu in req.render_devices:
if isinstance(req.render_devices, int):
render_devices.append('GPU:' + gpu)
else:
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
return ckpt_to_use
render_devices.append(gpu)
if isinstance(req.render_devices, int):
render_devices.append('GPU:' + req.render_devices)
if len(render_devices) > 0:
config['render_devices'] = render_devices
try:
setConfig(config)
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
def resolve_model_to_use(model_name):
if model_name in ('sd-v1-4', 'custom-model'):
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
},
}
legacy_model_path = os.path.join(SD_DIR, model_name)
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'):
model_path = legacy_model_path
# custom models
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
for file in os.listdir(sd_models_dir):
if file.endswith('.ckpt'):
model_name = os.path.splitext(file)[0]
models['options']['stable-diffusion'].append(model_name)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['active']['stable-diffusion'] = 'custom-model'
models['options']['stable-diffusion'].append('custom-model')
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
return models
@app.get('/get/{key:path}')
def read_web_data(key:str=None):
if not key: # /get without parameters, stable-diffusion easter egg.
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == 'app_config':
config = getConfig(default_val=None)
if config is None:
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
return JSONResponse(config, headers=NOCACHE_HEADERS)
elif key == 'models':
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
else:
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
return model_path
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
@app.get('/')
def read_root():
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
@app.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None):
if not task_manager.render_thread.is_alive(): # Render thread is dead.
if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(current_state_error))
if task_manager.is_alive() <= 0: # Check that render threads are alive.
if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(task_manager.current_state_error))
return HTTPException(status_code=500, detail='Render thread is dead.')
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error))
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive
response = {'status': str(task_manager.current_state)}
if session_id:
@ -119,7 +250,7 @@ def render(req : task_manager.ImageRequest):
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),
'queue': task_manager.tasks_queue.qsize(),
'queue': len(task_manager.tasks_queue),
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
'task': id(new_task)
}
@ -172,100 +303,13 @@ def get_image(session_id, img_id):
except KeyError as e:
return HTTPException(status_code=500, detail=str(e))
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
try:
config = {
'update_branch': req.update_branch
}
@app.get('/')
def read_root():
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
config_json_str = json.dumps(config)
config_bat_str = f'@set update_branch={req.update_branch}'
config_sh_str = f'export update_branch={req.update_branch}'
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
with open(config_json_path, 'w') as f:
f.write(config_json_str)
with open(config_bat_path, 'w') as f:
f.write(config_bat_str)
with open(config_sh_path, 'w') as f:
f.write(config_sh_str)
return {'OK'}
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
def getConfig(default_val={}):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return default_val
with open(config_json_path, 'r') as f:
return json.load(f)
except Exception as e:
print(str(e))
print(traceback.format_exc())
return default_val
def setConfig(config):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w') as f:
return json.dump(config, f)
except:
print(str(e))
print(traceback.format_exc())
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
},
}
# custom models
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
for file in os.listdir(sd_models_dir):
if file.endswith('.ckpt'):
model_name = os.path.splitext(file)[0]
models['options']['stable-diffusion'].append(model_name)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['active']['stable-diffusion'] = 'custom-model'
models['options']['stable-diffusion'].append('custom-model')
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
return models
@app.get('/get/{key:path}')
def read_web_data(key:str=None):
if not key: # /get without parameters, stable-diffusion easter egg.
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == 'app_config':
config = getConfig(default_val=None)
if config is None:
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
return JSONResponse(config, headers=NOCACHE_HEADERS)
elif key == 'models':
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
else:
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
# don't log certain requests
class LogSuppressFilter(logging.Filter):
@ -277,8 +321,26 @@ class LogSuppressFilter(logging.Filter):
return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
task_manager.default_model_to_load = get_initial_model_to_load()
task_manager.start_render_thread()
config = getConfig()
# Start the task_manager
task_manager.default_model_to_load = resolve_model_to_use()
if 'render_devices' in config: # Start a new thread for each device.
if isinstance(config['render_devices'], str):
config['render_devices'] = config['render_devices'].split(',')
if not isinstance(config['render_devices'], list):
raise Exception('Invalid render_devices value in config.')
for device in config['render_devices']:
task_manager.start_render_thread(device)
allow_cpu = False
if task_manager.is_alive() <= 0: # No running devices, apply defaults.
# Select best device GPU device using free memory if more than one device.
task_manager.start_render_thread('auto')
allow_cpu = True
# Allow CPU to be used for renders if not already enabled in current config.
if task_manager.is_alive('cpu') <= 0 and allow_cpu:
task_manager.start_render_thread('cpu')
# start the browser ui
import webbrowser; webbrowser.open('http://localhost:9000')