mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-19 08:17:49 +02:00
First draft for Multi-GPU support
This commit is contained in:
parent
2edc06c662
commit
7c72608e1c
@ -35,63 +35,144 @@ import base64
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
#from colorama import Fore
|
#from colorama import Fore
|
||||||
|
|
||||||
# local
|
from threading import local as LocalThreadVars
|
||||||
stop_processing = False
|
thread_data = LocalThreadVars()
|
||||||
temp_images = {}
|
|
||||||
|
|
||||||
ckpt_file = None
|
def device_would_fail(device):
|
||||||
gfpgan_file = None
|
if device == 'cpu': return None
|
||||||
real_esrgan_file = None
|
# Returns None when no issues found, otherwise returns the detected error str.
|
||||||
|
# Memory check
|
||||||
model = None
|
try:
|
||||||
modelCS = None
|
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||||
modelFS = None
|
|
||||||
model_gfpgan = None
|
|
||||||
model_real_esrgan = None
|
|
||||||
|
|
||||||
model_is_half = False
|
|
||||||
model_fs_is_half = False
|
|
||||||
device = None
|
|
||||||
unet_bs = 1
|
|
||||||
precision = 'autocast'
|
|
||||||
sampler_plms = None
|
|
||||||
sampler_ddim = None
|
|
||||||
|
|
||||||
has_valid_gpu = False
|
|
||||||
force_full_precision = False
|
|
||||||
try:
|
|
||||||
gpu = torch.cuda.current_device()
|
|
||||||
gpu_name = torch.cuda.get_device_name(gpu)
|
|
||||||
print('GPU detected: ', gpu_name)
|
|
||||||
|
|
||||||
force_full_precision = ('nvidia' in gpu_name.lower() or 'geforce' in gpu_name.lower()) and (' 1660' in gpu_name or ' 1650' in gpu_name) # otherwise these NVIDIA cards create green images
|
|
||||||
if force_full_precision:
|
|
||||||
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
|
|
||||||
|
|
||||||
mem_free, mem_total = torch.cuda.mem_get_info(gpu)
|
|
||||||
mem_total /= float(10**9)
|
mem_total /= float(10**9)
|
||||||
if mem_total < 3.0:
|
if mem_total < 3.0:
|
||||||
print("GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion")
|
return 'GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion'
|
||||||
raise Exception()
|
except RuntimeError as e:
|
||||||
|
return str(e) # Return cuda errors from mem_get_info as strings
|
||||||
|
return None
|
||||||
|
|
||||||
has_valid_gpu = True
|
def device_select(device):
|
||||||
except:
|
if device == 'cpu': return True
|
||||||
|
if not torch.cuda.is_available(): return False
|
||||||
|
failure_msg = device_would_fail(device)
|
||||||
|
if failure_msg:
|
||||||
|
if 'invalid device' in failure_msg:
|
||||||
|
raise NameError(f'GPU "{device}" could not be found. Remove this device from config.render_devices or use one of "auto" or "cuda".')
|
||||||
|
print(failure_msg)
|
||||||
|
return False
|
||||||
|
|
||||||
|
device_name = torch.cuda.get_device_name(device)
|
||||||
|
|
||||||
|
# otherwise these NVIDIA cards create green images
|
||||||
|
thread_data.force_full_precision = ('nvidia' in device_name.lower() or 'geforce' in device_name.lower()) and (' 1660' in device_name or ' 1650' in device_name)
|
||||||
|
if thread_data.force_full_precision:
|
||||||
|
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
|
||||||
|
|
||||||
|
thread_data.device = device
|
||||||
|
thread_data.has_valid_gpu = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
def device_init(device_selection=None):
|
||||||
|
# Thread bound properties
|
||||||
|
thread_data.stop_processing = False
|
||||||
|
thread_data.temp_images = {}
|
||||||
|
|
||||||
|
thread_data.ckpt_file = None
|
||||||
|
thread_data.gfpgan_file = None
|
||||||
|
thread_data.real_esrgan_file = None
|
||||||
|
|
||||||
|
thread_data.model = None
|
||||||
|
thread_data.modelCS = None
|
||||||
|
thread_data.modelFS = None
|
||||||
|
thread_data.model_gfpgan = None
|
||||||
|
thread_data.model_real_esrgan = None
|
||||||
|
|
||||||
|
thread_data.model_is_half = False
|
||||||
|
thread_data.model_fs_is_half = False
|
||||||
|
thread_data.device = None
|
||||||
|
thread_data.unet_bs = 1
|
||||||
|
thread_data.precision = 'autocast'
|
||||||
|
thread_data.sampler_plms = None
|
||||||
|
thread_data.sampler_ddim = None
|
||||||
|
|
||||||
|
thread_data.turbo = False
|
||||||
|
thread_data.has_valid_gpu = False
|
||||||
|
thread_data.force_full_precision = False
|
||||||
|
|
||||||
|
if device_selection.lower() == 'cpu':
|
||||||
|
print('CPU requested, skipping gpu init.')
|
||||||
|
thread_data.device = 'cpu'
|
||||||
|
return
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print('WARNING: torch.cuda is not available. Using the CPU, but this will be very slow!')
|
||||||
|
return
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
if device_count <= 1 and device_selection == 'auto':
|
||||||
|
device_selection = 'current' # Use 'auto' only when there is more than one compatible device found.
|
||||||
|
if device_selection == 'auto':
|
||||||
|
print('Autoselecting GPU. Using most free memory.')
|
||||||
|
max_mem_free = 0
|
||||||
|
best_device = None
|
||||||
|
for device in range(device_count):
|
||||||
|
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||||
|
mem_free /= float(10**9)
|
||||||
|
mem_total /= float(10**9)
|
||||||
|
device_name = torch.cuda.get_device_name(device)
|
||||||
|
print(f'GPU:{device} detected: {device_name} - Memory: {round(mem_total - mem_free, 2)}Go / {round(mem_total, 2)}Go')
|
||||||
|
if max_mem_free < mem_free:
|
||||||
|
max_mem_free = mem_free
|
||||||
|
best_device = device
|
||||||
|
if best_device and device_select(device):
|
||||||
|
print(f'Setting GPU:{device} as active')
|
||||||
|
torch.cuda.device(device)
|
||||||
|
return
|
||||||
|
if isinstance(device_selection, str):
|
||||||
|
device_selection = device_selection.lower()
|
||||||
|
if device_selection.startswith('gpu:'):
|
||||||
|
device_selection = int(device_selection[4:])
|
||||||
|
if device_selection != 'cuda' and device_selection != 'current':
|
||||||
|
if device_select(device_selection):
|
||||||
|
if isinstance(device_selection, int):
|
||||||
|
print(f'Setting GPU:{device_selection} as active')
|
||||||
|
else:
|
||||||
|
print(f'Setting {device_selection} as active')
|
||||||
|
torch.cuda.device(device_selection)
|
||||||
|
return
|
||||||
|
# By default use current device.
|
||||||
|
print('Checking current GPU...')
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
device_name = torch.cuda.get_device_name(device)
|
||||||
|
print(f'GPU:{device} detected: {device_name}')
|
||||||
|
if device_select(device):
|
||||||
|
return
|
||||||
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
|
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
|
||||||
pass
|
thread_data.device = 'cpu'
|
||||||
|
|
||||||
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast'):
|
def is_first_cuda_device(device):
|
||||||
global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half
|
if thread_data.device == 0 or thread_data.device == '0':
|
||||||
|
return True
|
||||||
|
if thread_data.device == 'cuda' or thread_data.device == 'cuda:0':
|
||||||
|
return True
|
||||||
|
if thread_data.device == torch.device(0):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
device = device_to_use if has_valid_gpu else 'cpu'
|
def load_model_ckpt():
|
||||||
precision = precision_to_use if not force_full_precision else 'full'
|
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
|
||||||
unet_bs = unet_bs_to_use
|
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
|
||||||
|
|
||||||
|
if not thread_data.precision:
|
||||||
|
thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast'
|
||||||
|
if not thread_data.unet_bs:
|
||||||
|
thread_data.unet_bs = 1
|
||||||
|
|
||||||
unload_model()
|
unload_model()
|
||||||
|
|
||||||
if device == 'cpu':
|
if thread_data.device == 'cpu':
|
||||||
precision = 'full'
|
thread_data.precision = 'full'
|
||||||
|
|
||||||
sd = load_model_from_config(f"{ckpt_to_use}.ckpt")
|
print('loading', thread_data.ckpt_file, 'to', thread_data.device, 'using precision', thread_data.precision)
|
||||||
|
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
||||||
li, lo = [], []
|
li, lo = [], []
|
||||||
for key, value in sd.items():
|
for key, value in sd.items():
|
||||||
sp = key.split(".")
|
sp = key.split(".")
|
||||||
@ -114,88 +195,84 @@ def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_u
|
|||||||
model = instantiate_from_config(config.modelUNet)
|
model = instantiate_from_config(config.modelUNet)
|
||||||
_, _ = model.load_state_dict(sd, strict=False)
|
_, _ = model.load_state_dict(sd, strict=False)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.cdevice = device
|
model.cdevice = torch.device(thread_data.device)
|
||||||
model.unet_bs = unet_bs
|
model.unet_bs = thread_data.unet_bs
|
||||||
model.turbo = turbo
|
model.turbo = thread_data.turbo
|
||||||
|
if thread_data.device != 'cpu':
|
||||||
|
model.to(thread_data.device)
|
||||||
|
thread_data.model = model
|
||||||
|
|
||||||
modelCS = instantiate_from_config(config.modelCondStage)
|
modelCS = instantiate_from_config(config.modelCondStage)
|
||||||
_, _ = modelCS.load_state_dict(sd, strict=False)
|
_, _ = modelCS.load_state_dict(sd, strict=False)
|
||||||
modelCS.eval()
|
modelCS.eval()
|
||||||
modelCS.cond_stage_model.device = device
|
modelCS.cond_stage_model.device = torch.device(thread_data.device)
|
||||||
|
if thread_data.device != 'cpu':
|
||||||
|
modelCS.to(thread_data.device)
|
||||||
|
thread_data.modelCS = modelCS
|
||||||
|
|
||||||
modelFS = instantiate_from_config(config.modelFirstStage)
|
modelFS = instantiate_from_config(config.modelFirstStage)
|
||||||
_, _ = modelFS.load_state_dict(sd, strict=False)
|
_, _ = modelFS.load_state_dict(sd, strict=False)
|
||||||
modelFS.eval()
|
modelFS.eval()
|
||||||
|
if thread_data.device != 'cpu':
|
||||||
|
modelFS.to(thread_data.device)
|
||||||
|
thread_data.modelFS = modelFS
|
||||||
del sd
|
del sd
|
||||||
|
|
||||||
if device != "cpu" and precision == "autocast":
|
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||||
model.half()
|
thread_data.model.half()
|
||||||
modelCS.half()
|
thread_data.modelCS.half()
|
||||||
modelFS.half()
|
thread_data.modelFS.half()
|
||||||
model_is_half = True
|
thread_data.model_is_half = True
|
||||||
model_fs_is_half = True
|
thread_data.model_fs_is_half = True
|
||||||
else:
|
else:
|
||||||
model_is_half = False
|
thread_data.model_is_half = False
|
||||||
model_fs_is_half = False
|
thread_data.model_fs_is_half = False
|
||||||
|
|
||||||
ckpt_file = ckpt_to_use
|
print('loaded', thread_data.ckpt_file, 'as', model.device, '->', modelCS.cond_stage_model.device, '->', thread_data.modelFS.device, 'using precision', thread_data.precision)
|
||||||
|
|
||||||
print('loaded ', ckpt_file, 'to', device, 'precision', precision)
|
|
||||||
|
|
||||||
def unload_model():
|
def unload_model():
|
||||||
global model, modelCS, modelFS
|
if thread_data.model is not None:
|
||||||
|
print('Unloading models...')
|
||||||
|
del thread_data.model
|
||||||
|
del thread_data.modelCS
|
||||||
|
del thread_data.modelFS
|
||||||
|
thread_data.model = None
|
||||||
|
thread_data.modelCS = None
|
||||||
|
thread_data.modelFS = None
|
||||||
|
|
||||||
if model is not None:
|
def load_model_gfpgan():
|
||||||
del model
|
if thread_data.gfpgan_file is None:
|
||||||
del modelCS
|
print('load_model_gfpgan called without setting gfpgan_file')
|
||||||
del modelFS
|
|
||||||
|
|
||||||
model = None
|
|
||||||
modelCS = None
|
|
||||||
modelFS = None
|
|
||||||
|
|
||||||
def load_model_gfpgan(gfpgan_to_use):
|
|
||||||
global gfpgan_file, model_gfpgan
|
|
||||||
|
|
||||||
if gfpgan_to_use is None:
|
|
||||||
return
|
return
|
||||||
|
if thread_data.device != 'cpu' and not is_first_cuda_device(thread_data.device):
|
||||||
|
#TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
|
||||||
|
raise Exception(f'Current device {torch.device(thread_data.device)} is not {torch.device(0)}.')
|
||||||
|
model_path = thread_data.gfpgan_file + ".pth"
|
||||||
|
thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||||
|
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
|
||||||
|
|
||||||
gfpgan_file = gfpgan_to_use
|
def load_model_real_esrgan():
|
||||||
model_path = gfpgan_to_use + ".pth"
|
if thread_data.real_esrgan_file is None:
|
||||||
|
print('load_model_real_esrgan called without setting real_esrgan_file')
|
||||||
if device == 'cpu':
|
|
||||||
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
|
|
||||||
else:
|
|
||||||
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cuda'))
|
|
||||||
|
|
||||||
print('loaded ', gfpgan_to_use, 'to', device, 'precision', precision)
|
|
||||||
|
|
||||||
def load_model_real_esrgan(real_esrgan_to_use):
|
|
||||||
global real_esrgan_file, model_real_esrgan
|
|
||||||
|
|
||||||
if real_esrgan_to_use is None:
|
|
||||||
return
|
return
|
||||||
|
model_path = thread_data.real_esrgan_file + ".pth"
|
||||||
real_esrgan_file = real_esrgan_to_use
|
|
||||||
model_path = real_esrgan_to_use + ".pth"
|
|
||||||
|
|
||||||
RealESRGAN_models = {
|
RealESRGAN_models = {
|
||||||
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
|
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
|
||||||
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||||
}
|
}
|
||||||
|
|
||||||
model_to_use = RealESRGAN_models[real_esrgan_to_use]
|
model_to_use = RealESRGAN_models[thread_data.real_esrgan_file]
|
||||||
|
|
||||||
if device == 'cpu':
|
if thread_data.device == 'cpu':
|
||||||
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
|
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
|
||||||
model_real_esrgan.device = torch.device('cpu')
|
#thread_data.model_real_esrgan.device = torch.device(thread_data.device)
|
||||||
model_real_esrgan.model.to('cpu')
|
thread_data.model_real_esrgan.model.to('cpu')
|
||||||
else:
|
else:
|
||||||
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=model_is_half)
|
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half)
|
||||||
|
|
||||||
model_real_esrgan.model.name = real_esrgan_to_use
|
thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file
|
||||||
|
print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
|
||||||
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
|
|
||||||
|
|
||||||
def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
|
def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
|
||||||
if disk_path is None: return None
|
if disk_path is None: return None
|
||||||
@ -214,14 +291,22 @@ def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
|
|||||||
|
|
||||||
def apply_filters(filter_name, image_data):
|
def apply_filters(filter_name, image_data):
|
||||||
print(f'Applying filter {filter_name}...')
|
print(f'Applying filter {filter_name}...')
|
||||||
|
if isinstance(image_data, torch.Tensor):
|
||||||
|
print(image_data)
|
||||||
|
image_data.to(thread_data.device)
|
||||||
|
|
||||||
gc()
|
gc()
|
||||||
|
|
||||||
if filter_name == 'gfpgan':
|
if filter_name == 'gfpgan':
|
||||||
_, _, output = model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
|
||||||
|
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
|
||||||
|
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||||
image_data = output[:,:,::-1]
|
image_data = output[:,:,::-1]
|
||||||
|
|
||||||
if filter_name == 'real_esrgan':
|
if filter_name == 'real_esrgan':
|
||||||
output, _ = model_real_esrgan.enhance(image_data[:,:,::-1])
|
if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.')
|
||||||
|
print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
|
||||||
|
output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1])
|
||||||
image_data = output[:,:,::-1]
|
image_data = output[:,:,::-1]
|
||||||
|
|
||||||
return image_data
|
return image_data
|
||||||
@ -234,12 +319,12 @@ def mk_img(req: Request):
|
|||||||
|
|
||||||
gc()
|
gc()
|
||||||
|
|
||||||
if device != "cpu":
|
if thread_data.device != "cpu":
|
||||||
modelFS.to("cpu")
|
thread_data.modelFS.to("cpu")
|
||||||
modelCS.to("cpu")
|
thread_data.modelCS.to("cpu")
|
||||||
|
|
||||||
model.model1.to("cpu")
|
thread_data.model.model1.to("cpu")
|
||||||
model.model2.to("cpu")
|
thread_data.model.model2.to("cpu")
|
||||||
|
|
||||||
gc()
|
gc()
|
||||||
|
|
||||||
@ -249,66 +334,55 @@ def mk_img(req: Request):
|
|||||||
})
|
})
|
||||||
|
|
||||||
def do_mk_img(req: Request):
|
def do_mk_img(req: Request):
|
||||||
global ckpt_file
|
thread_data.stop_processing = False
|
||||||
global model, modelCS, modelFS, device
|
|
||||||
global model_gfpgan, model_real_esrgan
|
|
||||||
global stop_processing
|
|
||||||
|
|
||||||
stop_processing = False
|
|
||||||
|
|
||||||
res = Response()
|
res = Response()
|
||||||
res.request = req
|
res.request = req
|
||||||
res.images = []
|
res.images = []
|
||||||
|
|
||||||
temp_images.clear()
|
thread_data.temp_images.clear()
|
||||||
|
|
||||||
# custom model support:
|
# custom model support:
|
||||||
# the req.use_stable_diffusion_model needs to be a valid path
|
# the req.use_stable_diffusion_model needs to be a valid path
|
||||||
# to the ckpt file (without the extension).
|
# to the ckpt file (without the extension).
|
||||||
|
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
|
||||||
|
|
||||||
needs_model_reload = False
|
needs_model_reload = False
|
||||||
ckpt_to_use = ckpt_file
|
if thread_data.ckpt_file != req.use_stable_diffusion_model:
|
||||||
if ckpt_to_use != req.use_stable_diffusion_model:
|
thread_data.ckpt_file = req.use_stable_diffusion_model
|
||||||
ckpt_to_use = req.use_stable_diffusion_model
|
|
||||||
needs_model_reload = True
|
needs_model_reload = True
|
||||||
|
|
||||||
model.turbo = req.turbo
|
|
||||||
if req.use_cpu:
|
if req.use_cpu:
|
||||||
if device != 'cpu':
|
if thread_data.device != 'cpu':
|
||||||
device = 'cpu'
|
thread_data.device = 'cpu'
|
||||||
|
if thread_data.model_is_half:
|
||||||
if model_is_half:
|
load_model_ckpt()
|
||||||
load_model_ckpt(ckpt_to_use, device)
|
|
||||||
needs_model_reload = False
|
needs_model_reload = False
|
||||||
|
load_model_gfpgan()
|
||||||
load_model_gfpgan(gfpgan_file)
|
load_model_real_esrgan()
|
||||||
load_model_real_esrgan(real_esrgan_file)
|
|
||||||
else:
|
else:
|
||||||
if has_valid_gpu:
|
if thread_data.has_valid_gpu:
|
||||||
prev_device = device
|
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
|
||||||
device = 'cuda'
|
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
|
||||||
|
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
||||||
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
|
load_model_ckpt()
|
||||||
(precision == 'full' and not req.use_full_precision and not force_full_precision):
|
load_model_gfpgan()
|
||||||
|
load_model_real_esrgan()
|
||||||
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'))
|
|
||||||
needs_model_reload = False
|
needs_model_reload = False
|
||||||
|
|
||||||
if prev_device != device:
|
|
||||||
load_model_gfpgan(gfpgan_file)
|
|
||||||
load_model_real_esrgan(real_esrgan_file)
|
|
||||||
|
|
||||||
if needs_model_reload:
|
if needs_model_reload:
|
||||||
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, precision)
|
load_model_ckpt()
|
||||||
|
|
||||||
if req.use_face_correction != gfpgan_file:
|
if req.use_face_correction != thread_data.gfpgan_file:
|
||||||
load_model_gfpgan(req.use_face_correction)
|
thread_data.gfpgan_file = req.use_face_correction
|
||||||
|
load_model_gfpgan()
|
||||||
|
if req.use_upscale != thread_data.real_esrgan_file:
|
||||||
|
thread_data.real_esrgan_file = req.use_upscale
|
||||||
|
load_model_real_esrgan()
|
||||||
|
|
||||||
if req.use_upscale != real_esrgan_file:
|
if thread_data.turbo != req.turbo:
|
||||||
load_model_real_esrgan(req.use_upscale)
|
thread_data.turbo = req.turbo
|
||||||
|
thread_data.model.turbo = req.turbo
|
||||||
model.cdevice = device
|
|
||||||
modelCS.cond_stage_model.device = device
|
|
||||||
|
|
||||||
opt_prompt = req.prompt
|
opt_prompt = req.prompt
|
||||||
opt_seed = req.seed
|
opt_seed = req.seed
|
||||||
@ -318,9 +392,8 @@ def do_mk_img(req: Request):
|
|||||||
opt_ddim_eta = 0.0
|
opt_ddim_eta = 0.0
|
||||||
opt_init_img = req.init_image
|
opt_init_img = req.init_image
|
||||||
|
|
||||||
print(req.to_string(), '\n device', device)
|
print(req.to_string(), '\n device', thread_data.device)
|
||||||
|
print('\n\n Using precision:', thread_data.precision)
|
||||||
print('\n\n Using precision:', precision)
|
|
||||||
|
|
||||||
seed_everything(opt_seed)
|
seed_everything(opt_seed)
|
||||||
|
|
||||||
@ -329,7 +402,7 @@ def do_mk_img(req: Request):
|
|||||||
assert prompt is not None
|
assert prompt is not None
|
||||||
data = [batch_size * [prompt]]
|
data = [batch_size * [prompt]]
|
||||||
|
|
||||||
if precision == "autocast" and device != "cpu":
|
if thread_data.precision == "autocast" and thread_data.device != "cpu":
|
||||||
precision_scope = autocast
|
precision_scope = autocast
|
||||||
else:
|
else:
|
||||||
precision_scope = nullcontext
|
precision_scope = nullcontext
|
||||||
@ -345,22 +418,22 @@ def do_mk_img(req: Request):
|
|||||||
handler = _img2img
|
handler = _img2img
|
||||||
|
|
||||||
init_image = load_img(req.init_image, req.width, req.height)
|
init_image = load_img(req.init_image, req.width, req.height)
|
||||||
init_image = init_image.to(device)
|
init_image = init_image.to(thread_data.device)
|
||||||
|
|
||||||
if device != "cpu" and precision == "autocast":
|
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||||
init_image = init_image.half()
|
init_image = init_image.half()
|
||||||
|
|
||||||
modelFS.to(device)
|
thread_data.modelFS.to(thread_data.device)
|
||||||
|
|
||||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||||
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
|
||||||
|
|
||||||
if req.mask is not None:
|
if req.mask is not None:
|
||||||
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(device)
|
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
|
||||||
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
|
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
|
||||||
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
||||||
|
|
||||||
if device != "cpu" and precision == "autocast":
|
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||||
mask = mask.half()
|
mask = mask.half()
|
||||||
|
|
||||||
move_fs_to_cpu()
|
move_fs_to_cpu()
|
||||||
@ -381,10 +454,10 @@ def do_mk_img(req: Request):
|
|||||||
for prompts in tqdm(data, desc="data"):
|
for prompts in tqdm(data, desc="data"):
|
||||||
|
|
||||||
with precision_scope("cuda"):
|
with precision_scope("cuda"):
|
||||||
modelCS.to(device)
|
thread_data.modelCS.to(thread_data.device)
|
||||||
uc = None
|
uc = None
|
||||||
if req.guidance_scale != 1.0:
|
if req.guidance_scale != 1.0:
|
||||||
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
||||||
if isinstance(prompts, tuple):
|
if isinstance(prompts, tuple):
|
||||||
prompts = list(prompts)
|
prompts = list(prompts)
|
||||||
|
|
||||||
@ -397,11 +470,11 @@ def do_mk_img(req: Request):
|
|||||||
weight = weights[i]
|
weight = weights[i]
|
||||||
# if not skip_normalize:
|
# if not skip_normalize:
|
||||||
weight = weight / totalWeight
|
weight = weight / totalWeight
|
||||||
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||||
else:
|
else:
|
||||||
c = modelCS.get_learned_conditioning(prompts)
|
c = thread_data.modelCS.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
modelFS.to(device)
|
thread_data.modelFS.to(thread_data.device)
|
||||||
|
|
||||||
partial_x_samples = None
|
partial_x_samples = None
|
||||||
def img_callback(x_samples, i):
|
def img_callback(x_samples, i):
|
||||||
@ -417,7 +490,7 @@ def do_mk_img(req: Request):
|
|||||||
partial_images = []
|
partial_images = []
|
||||||
|
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
@ -429,18 +502,19 @@ def do_mk_img(req: Request):
|
|||||||
del img, x_sample, x_samples_ddim
|
del img, x_sample, x_samples_ddim
|
||||||
# don't delete x_samples, it is used in the code that called this callback
|
# don't delete x_samples, it is used in the code that called this callback
|
||||||
|
|
||||||
temp_images[str(req.session_id) + '/' + str(i)] = buf
|
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
|
||||||
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
||||||
|
|
||||||
progress['output'] = partial_images
|
progress['output'] = partial_images
|
||||||
|
|
||||||
yield json.dumps(progress)
|
yield json.dumps(progress)
|
||||||
|
|
||||||
if stop_processing:
|
if thread_data.stop_processing:
|
||||||
raise UserInitiatedStop("User requested that we stop processing")
|
raise UserInitiatedStop("User requested that we stop processing")
|
||||||
|
|
||||||
# run the handler
|
# run the handler
|
||||||
try:
|
try:
|
||||||
|
print('Running handler...')
|
||||||
if handler == _txt2img:
|
if handler == _txt2img:
|
||||||
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
|
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
|
||||||
else:
|
else:
|
||||||
@ -458,7 +532,7 @@ def do_mk_img(req: Request):
|
|||||||
print("saving images")
|
print("saving images")
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
|
|
||||||
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||||
x_sample = x_sample.astype(np.uint8)
|
x_sample = x_sample.astype(np.uint8)
|
||||||
@ -469,7 +543,7 @@ def do_mk_img(req: Request):
|
|||||||
|
|
||||||
return_orig_img = not has_filters or not req.show_only_filtered_image
|
return_orig_img = not has_filters or not req.show_only_filtered_image
|
||||||
|
|
||||||
if stop_processing:
|
if thread_data.stop_processing:
|
||||||
return_orig_img = True
|
return_orig_img = True
|
||||||
|
|
||||||
if req.save_to_disk_path is not None:
|
if req.save_to_disk_path is not None:
|
||||||
@ -489,7 +563,7 @@ def do_mk_img(req: Request):
|
|||||||
|
|
||||||
del img
|
del img
|
||||||
|
|
||||||
if has_filters and not stop_processing:
|
if has_filters and not thread_data.stop_processing:
|
||||||
filters_applied = []
|
filters_applied = []
|
||||||
if req.use_face_correction:
|
if req.use_face_correction:
|
||||||
x_sample = apply_filters('gfpgan', x_sample)
|
x_sample = apply_filters('gfpgan', x_sample)
|
||||||
@ -514,7 +588,7 @@ def do_mk_img(req: Request):
|
|||||||
move_fs_to_cpu()
|
move_fs_to_cpu()
|
||||||
gc()
|
gc()
|
||||||
del x_samples, x_samples_ddim, x_sample
|
del x_samples, x_samples_ddim, x_sample
|
||||||
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
|
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mo')
|
||||||
|
|
||||||
print('Task completed')
|
print('Task completed')
|
||||||
|
|
||||||
@ -527,7 +601,7 @@ def save_image(img, img_out_path):
|
|||||||
print('could not save the file', traceback.format_exc())
|
print('could not save the file', traceback.format_exc())
|
||||||
|
|
||||||
def save_metadata(meta_out_path, req, prompt, opt_seed):
|
def save_metadata(meta_out_path, req, prompt, opt_seed):
|
||||||
metadata = f"""{prompt}
|
metadata = f'''{prompt}
|
||||||
Width: {req.width}
|
Width: {req.width}
|
||||||
Height: {req.height}
|
Height: {req.height}
|
||||||
Seed: {opt_seed}
|
Seed: {opt_seed}
|
||||||
@ -538,8 +612,8 @@ Use Face Correction: {req.use_face_correction}
|
|||||||
Use Upscaling: {req.use_upscale}
|
Use Upscaling: {req.use_upscale}
|
||||||
Sampler: {req.sampler}
|
Sampler: {req.sampler}
|
||||||
Negative Prompt: {req.negative_prompt}
|
Negative Prompt: {req.negative_prompt}
|
||||||
Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
|
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
|
||||||
"""
|
'''
|
||||||
try:
|
try:
|
||||||
with open(meta_out_path, 'w') as f:
|
with open(meta_out_path, 'w') as f:
|
||||||
f.write(metadata)
|
f.write(metadata)
|
||||||
@ -549,16 +623,19 @@ Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
|
|||||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
|
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
|
||||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||||
|
|
||||||
if device != "cpu":
|
if thread_data.device != "cpu":
|
||||||
mem = torch.cuda.memory_allocated() / 1e6
|
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
||||||
modelCS.to("cpu")
|
print('Device:', thread_data.device, 'CS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
|
||||||
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
thread_data.modelCS.to("cpu")
|
||||||
|
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
|
||||||
|
print('Device:', thread_data.device, 'Waiting Memory transfer. Memory Used:', round(mem, 2), 'Mo')
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
|
||||||
|
|
||||||
if sampler_name == 'ddim':
|
if sampler_name == 'ddim':
|
||||||
model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||||
|
|
||||||
samples_ddim = model.sample(
|
samples_ddim = thread_data.model.sample(
|
||||||
S=opt_ddim_steps,
|
S=opt_ddim_steps,
|
||||||
conditioning=c,
|
conditioning=c,
|
||||||
seed=opt_seed,
|
seed=opt_seed,
|
||||||
@ -572,14 +649,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
|||||||
mask=mask,
|
mask=mask,
|
||||||
sampler = sampler_name,
|
sampler = sampler_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield from samples_ddim
|
yield from samples_ddim
|
||||||
|
|
||||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
|
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
|
||||||
# encode (scaled latent)
|
# encode (scaled latent)
|
||||||
z_enc = model.stochastic_encode(
|
z_enc = thread_data.model.stochastic_encode(
|
||||||
init_latent,
|
init_latent,
|
||||||
torch.tensor([t_enc] * batch_size).to(device),
|
torch.tensor([t_enc] * batch_size).to(thread_data.device),
|
||||||
opt_seed,
|
opt_seed,
|
||||||
opt_ddim_eta,
|
opt_ddim_eta,
|
||||||
opt_ddim_steps,
|
opt_ddim_steps,
|
||||||
@ -587,7 +663,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
|||||||
x_T = None if mask is None else init_latent
|
x_T = None if mask is None else init_latent
|
||||||
|
|
||||||
# decode it
|
# decode it
|
||||||
samples_ddim = model.sample(
|
samples_ddim = thread_data.model.sample(
|
||||||
t_enc,
|
t_enc,
|
||||||
c,
|
c,
|
||||||
z_enc,
|
z_enc,
|
||||||
@ -602,16 +678,19 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
|||||||
yield from samples_ddim
|
yield from samples_ddim
|
||||||
|
|
||||||
def move_fs_to_cpu():
|
def move_fs_to_cpu():
|
||||||
if device != "cpu":
|
if thread_data.device != "cpu":
|
||||||
mem = torch.cuda.memory_allocated() / 1e6
|
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
||||||
modelFS.to("cpu")
|
print('Device:', thread_data.device, 'FS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
|
||||||
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
thread_data.modelFS.to("cpu")
|
||||||
|
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
|
||||||
|
print('Device:', thread_data.device, 'Waiting for Memory transfer. Memory Used:', round(mem, 2), 'Mo')
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
|
||||||
|
|
||||||
def gc():
|
def gc():
|
||||||
if device == 'cpu':
|
#gc.collect()
|
||||||
|
if thread_data.device == 'cpu':
|
||||||
return
|
return
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
torch.cuda.ipc_collect()
|
torch.cuda.ipc_collect()
|
||||||
|
|
||||||
@ -621,7 +700,6 @@ def chunk(it, size):
|
|||||||
it = iter(it)
|
it = iter(it)
|
||||||
return iter(lambda: tuple(islice(it, size)), ())
|
return iter(lambda: tuple(islice(it, size)), ())
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(ckpt, verbose=False):
|
def load_model_from_config(ckpt, verbose=False):
|
||||||
print(f"Loading model from {ckpt}")
|
print(f"Loading model from {ckpt}")
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
|
@ -9,6 +9,10 @@ from typing import Any, Generator, Hashable, Optional, Union
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sd_internal import Request, Response
|
from sd_internal import Request, Response
|
||||||
|
|
||||||
|
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
||||||
|
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
||||||
|
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
|
||||||
|
|
||||||
class SymbolClass(type): # Print nicely formatted Symbol names.
|
class SymbolClass(type): # Print nicely formatted Symbol names.
|
||||||
def __repr__(self): return self.__qualname__
|
def __repr__(self): return self.__qualname__
|
||||||
def __str__(self): return self.__name__
|
def __str__(self): return self.__name__
|
||||||
@ -66,17 +70,30 @@ class ImageRequest(BaseModel):
|
|||||||
stream_progress_updates: bool = False
|
stream_progress_updates: bool = False
|
||||||
stream_image_progress: bool = False
|
stream_image_progress: bool = False
|
||||||
|
|
||||||
|
class FilterRequest(BaseModel):
|
||||||
|
session_id: str = "session"
|
||||||
|
model: str = None
|
||||||
|
name: str = ""
|
||||||
|
init_image: str = None # base64
|
||||||
|
width: int = 512
|
||||||
|
height: int = 512
|
||||||
|
save_to_disk_path: str = None
|
||||||
|
turbo: bool = True
|
||||||
|
use_cpu: bool = False
|
||||||
|
use_full_precision: bool = False
|
||||||
|
output_format: str = "jpeg" # or "png"
|
||||||
|
|
||||||
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||||
class TaskCache():
|
class TaskCache():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._base = dict()
|
self._base = dict()
|
||||||
self._lock: threading.Lock = threading.RLock()
|
self._lock: threading.Lock = threading.Lock()
|
||||||
def _get_ttl_time(self, ttl: int) -> int:
|
def _get_ttl_time(self, ttl: int) -> int:
|
||||||
return int(time.time()) + ttl
|
return int(time.time()) + ttl
|
||||||
def _is_expired(self, timestamp: int) -> bool:
|
def _is_expired(self, timestamp: int) -> bool:
|
||||||
return int(time.time()) >= timestamp
|
return int(time.time()) >= timestamp
|
||||||
def clean(self) -> None:
|
def clean(self) -> None:
|
||||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clean failed to acquire lock within timeout.')
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clean' + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
# Create a list of expired keys to delete
|
# Create a list of expired keys to delete
|
||||||
to_delete = []
|
to_delete = []
|
||||||
@ -91,11 +108,11 @@ class TaskCache():
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clear failed to acquire lock within timeout.')
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clear' + ERR_LOCK_FAILED)
|
||||||
try: self._base.clear()
|
try: self._base.clear()
|
||||||
finally: self._lock.release()
|
finally: self._lock.release()
|
||||||
def delete(self, key: Hashable) -> bool:
|
def delete(self, key: Hashable) -> bool:
|
||||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.delete failed to acquire lock within timeout.')
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.delete' + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
if key not in self._base:
|
if key not in self._base:
|
||||||
return False
|
return False
|
||||||
@ -104,7 +121,7 @@ class TaskCache():
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
def keep(self, key: Hashable, ttl: int) -> bool:
|
def keep(self, key: Hashable, ttl: int) -> bool:
|
||||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.keep failed to acquire lock within timeout.')
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.keep' + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
if key in self._base:
|
if key in self._base:
|
||||||
_, value = self._base.get(key)
|
_, value = self._base.get(key)
|
||||||
@ -114,7 +131,7 @@ class TaskCache():
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
|
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
|
||||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.put failed to acquire lock within timeout.')
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.put' + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
self._base[key] = (
|
self._base[key] = (
|
||||||
self._get_ttl_time(ttl), value
|
self._get_ttl_time(ttl), value
|
||||||
@ -128,21 +145,23 @@ class TaskCache():
|
|||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
def tryGet(self, key: Hashable) -> Any:
|
def tryGet(self, key: Hashable) -> Any:
|
||||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.tryGet failed to acquire lock within timeout.')
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.tryGet' + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
ttl, value = self._base.get(key, (None, None))
|
ttl, value = self._base.get(key, (None, None))
|
||||||
if ttl is not None and self._is_expired(ttl):
|
if ttl is not None and self._is_expired(ttl):
|
||||||
print(f'Session {key} expired. Discarding data.')
|
print(f'Session {key} expired. Discarding data.')
|
||||||
self.delete(key)
|
del self._base[key]
|
||||||
return None
|
return None
|
||||||
return value
|
return value
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
|
manager_lock = threading.Lock()
|
||||||
|
render_threads = []
|
||||||
current_state = ServerStates.Init
|
current_state = ServerStates.Init
|
||||||
current_state_error:Exception = None
|
current_state_error:Exception = None
|
||||||
current_model_path = None
|
current_model_path = None
|
||||||
tasks_queue = queue.Queue()
|
tasks_queue = []
|
||||||
task_cache = TaskCache()
|
task_cache = TaskCache()
|
||||||
default_model_to_load = None
|
default_model_to_load = None
|
||||||
|
|
||||||
@ -155,7 +174,8 @@ def preload_model(file_path=None):
|
|||||||
current_state = ServerStates.LoadingModel
|
current_state = ServerStates.LoadingModel
|
||||||
try:
|
try:
|
||||||
from . import runtime
|
from . import runtime
|
||||||
runtime.load_model_ckpt(ckpt_to_use=file_path)
|
runtime.thread_data.ckpt_file = file_path
|
||||||
|
runtime.load_model_ckpt()
|
||||||
current_model_path = file_path
|
current_model_path = file_path
|
||||||
current_state_error = None
|
current_state_error = None
|
||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
@ -165,43 +185,62 @@ def preload_model(file_path=None):
|
|||||||
current_state = ServerStates.Unavailable
|
current_state = ServerStates.Unavailable
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
def thread_render():
|
def thread_render(device):
|
||||||
global current_state, current_state_error, current_model_path
|
global current_state, current_state_error, current_model_path
|
||||||
from . import runtime
|
from . import runtime
|
||||||
current_state = ServerStates.Online
|
try:
|
||||||
|
runtime.device_init(device)
|
||||||
|
except:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return
|
||||||
preload_model()
|
preload_model()
|
||||||
|
current_state = ServerStates.Online
|
||||||
while True:
|
while True:
|
||||||
task_cache.clean()
|
task_cache.clean()
|
||||||
if isinstance(current_state_error, SystemExit):
|
if isinstance(current_state_error, SystemExit):
|
||||||
current_state = ServerStates.Unavailable
|
current_state = ServerStates.Unavailable
|
||||||
return
|
return
|
||||||
task = None
|
task = None
|
||||||
try:
|
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
task = tasks_queue.get(timeout=1)
|
print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.')
|
||||||
except queue.Empty as e:
|
time.sleep(1)
|
||||||
if isinstance(current_state_error, SystemExit):
|
continue
|
||||||
current_state = ServerStates.Unavailable
|
if len(tasks_queue) <= 0:
|
||||||
return
|
manager_lock.release()
|
||||||
else: continue
|
time.sleep(1)
|
||||||
|
continue
|
||||||
|
try: # Select a render task.
|
||||||
|
for queued_task in tasks_queue:
|
||||||
|
if queued_task.request.use_cpu and runtime.thread_data.device != 'cpu':
|
||||||
|
continue # Cuda Tasks
|
||||||
|
if not queued_task.request.use_cpu and runtime.thread_data.device == 'cpu':
|
||||||
|
continue # CPU Tasks
|
||||||
|
if queued_task.request.use_face_correction and not runtime.is_first_cuda_device(runtime.thread_data.device):
|
||||||
|
continue #TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
|
||||||
|
task = queued_task
|
||||||
|
break
|
||||||
|
if task is not None:
|
||||||
|
del tasks_queue[tasks_queue.index(task)]
|
||||||
|
finally:
|
||||||
|
manager_lock.release()
|
||||||
|
if task is None:
|
||||||
|
time.sleep(1)
|
||||||
|
continue
|
||||||
#if current_model_path != task.request.use_stable_diffusion_model:
|
#if current_model_path != task.request.use_stable_diffusion_model:
|
||||||
# preload_model(task.request.use_stable_diffusion_model)
|
# preload_model(task.request.use_stable_diffusion_model)
|
||||||
if current_state_error:
|
if current_state_error:
|
||||||
task.error = current_state_error
|
task.error = current_state_error
|
||||||
continue
|
continue
|
||||||
print(f'Session {task.request.session_id} starting task {id(task)}')
|
print(f'Session {task.request.session_id} starting task {id(task)}')
|
||||||
|
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||||
try:
|
try:
|
||||||
task.lock.acquire(blocking=False)
|
# Open data generator.
|
||||||
res = runtime.mk_img(task.request)
|
res = runtime.mk_img(task.request)
|
||||||
if current_model_path == task.request.use_stable_diffusion_model:
|
if current_model_path == task.request.use_stable_diffusion_model:
|
||||||
current_state = ServerStates.Rendering
|
current_state = ServerStates.Rendering
|
||||||
else:
|
else:
|
||||||
current_state = ServerStates.LoadingModel
|
current_state = ServerStates.LoadingModel
|
||||||
except Exception as e:
|
# Start reading from generator.
|
||||||
task.error = e
|
|
||||||
task.lock.release()
|
|
||||||
tasks_queue.task_done()
|
|
||||||
print(traceback.format_exc())
|
|
||||||
continue
|
|
||||||
dataQueue = None
|
dataQueue = None
|
||||||
if task.request.stream_progress_updates:
|
if task.request.stream_progress_updates:
|
||||||
dataQueue = task.buffer_queue
|
dataQueue = task.buffer_queue
|
||||||
@ -224,13 +263,18 @@ def thread_render():
|
|||||||
for out_obj in result['output']:
|
for out_obj in result['output']:
|
||||||
if 'path' in out_obj:
|
if 'path' in out_obj:
|
||||||
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
|
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
|
||||||
task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]]
|
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
|
||||||
elif 'data' in out_obj:
|
elif 'data' in out_obj:
|
||||||
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
|
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
|
||||||
|
# Before looping back to the generator, mark cache as still alive.
|
||||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||||
|
except Exception as e:
|
||||||
|
task.error = e
|
||||||
|
print(traceback.format_exc())
|
||||||
|
continue
|
||||||
|
finally:
|
||||||
# Task completed
|
# Task completed
|
||||||
task.lock.release()
|
task.lock.release()
|
||||||
tasks_queue.task_done()
|
|
||||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||||
if isinstance(task.error, StopAsyncIteration):
|
if isinstance(task.error, StopAsyncIteration):
|
||||||
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
|
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
|
||||||
@ -240,19 +284,37 @@ def thread_render():
|
|||||||
print(f'Session {task.request.session_id} task {id(task)} completed.')
|
print(f'Session {task.request.session_id} task {id(task)} completed.')
|
||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
|
|
||||||
render_thread = threading.Thread(target=thread_render)
|
def is_alive(name=None):
|
||||||
|
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED)
|
||||||
|
nbr_alive = 0
|
||||||
|
try:
|
||||||
|
for rthread in render_threads:
|
||||||
|
if name and not rthread.name.endswith(name):
|
||||||
|
continue
|
||||||
|
if rthread.is_alive():
|
||||||
|
nbr_alive += 1
|
||||||
|
return nbr_alive
|
||||||
|
finally:
|
||||||
|
manager_lock.release()
|
||||||
|
|
||||||
def start_render_thread():
|
def start_render_thread(device='auto'):
|
||||||
# Start Rendering Thread
|
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_threads' + ERR_LOCK_FAILED)
|
||||||
render_thread.daemon = True
|
print('Start new Rendering Thread on device', device)
|
||||||
render_thread.start()
|
try:
|
||||||
|
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
|
||||||
|
rthread.daemon = True
|
||||||
|
rthread.name = 'Runner/' + device
|
||||||
|
rthread.start()
|
||||||
|
render_threads.append(rthread)
|
||||||
|
finally:
|
||||||
|
manager_lock.release()
|
||||||
|
|
||||||
def shutdown_event(): # Signal render thread to close on shutdown
|
def shutdown_event(): # Signal render thread to close on shutdown
|
||||||
global current_state_error
|
global current_state_error
|
||||||
current_state_error = SystemExit('Application shutting down.')
|
current_state_error = SystemExit('Application shutting down.')
|
||||||
|
|
||||||
def render(req : ImageRequest):
|
def render(req : ImageRequest):
|
||||||
if not render_thread.is_alive(): # Render thread is dead
|
if not is_alive(): # Render thread is dead
|
||||||
raise ChildProcessError('Rendering thread has died.')
|
raise ChildProcessError('Rendering thread has died.')
|
||||||
# Alive, check if task in cache
|
# Alive, check if task in cache
|
||||||
task = task_cache.tryGet(req.session_id)
|
task = task_cache.tryGet(req.session_id)
|
||||||
@ -293,6 +355,12 @@ def render(req : ImageRequest):
|
|||||||
|
|
||||||
new_task = RenderTask(r)
|
new_task = RenderTask(r)
|
||||||
if task_cache.put(r.session_id, new_task, TASK_TTL):
|
if task_cache.put(r.session_id, new_task, TASK_TTL):
|
||||||
tasks_queue.put(new_task, block=True, timeout=30)
|
# Use twice the normal timeout for adding user requests.
|
||||||
|
# Tries to force task_cache.put to fail before tasks_queue.put would.
|
||||||
|
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
|
||||||
|
try:
|
||||||
|
tasks_queue.append(new_task)
|
||||||
return new_task
|
return new_task
|
||||||
|
finally:
|
||||||
|
manager_lock.release()
|
||||||
raise RuntimeError('Failed to add task to cache.')
|
raise RuntimeError('Failed to add task to cache.')
|
||||||
|
320
ui/server.py
320
ui/server.py
@ -15,14 +15,24 @@ MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
|
|||||||
|
|
||||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||||
|
APP_CONFIG_DEFAULTS = {
|
||||||
|
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
|
||||||
|
'render_devices': ['auto'], # ['cuda'] or ['CPU', 'GPU:0', 'GPU:1', ...] or ['cpu']
|
||||||
|
'update_branch': 'main',
|
||||||
|
}
|
||||||
|
APP_CONFIG_DEFAULT_MODELS = [
|
||||||
|
# needed to support the legacy installations
|
||||||
|
'custom-model', # Check if user has a custom model, use it first.
|
||||||
|
'sd-v1-4', # Default fallback.
|
||||||
|
]
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import logging
|
import logging
|
||||||
import queue, threading, time
|
#import queue, threading, time
|
||||||
from typing import Any, Generator, Hashable, Optional, Union
|
from typing import Any, Generator, Hashable, List, Optional, Union
|
||||||
|
|
||||||
from sd_internal import Request, Response, task_manager
|
from sd_internal import Request, Response, task_manager
|
||||||
|
|
||||||
@ -37,52 +47,173 @@ ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
|
|||||||
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||||
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
|
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
|
||||||
|
|
||||||
class SetAppConfigRequest(BaseModel):
|
config_cached = None
|
||||||
update_branch: str = "main"
|
config_last_mod_time = 0
|
||||||
|
def getConfig(default_val=APP_CONFIG_DEFAULTS):
|
||||||
|
global config_cached, config_last_mod_time
|
||||||
|
try:
|
||||||
|
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||||
|
if not os.path.exists(config_json_path):
|
||||||
|
return default_val
|
||||||
|
if config_last_mod_time > 0 and config_cached is not None:
|
||||||
|
# Don't read if file was not modified
|
||||||
|
mtime = os.path.getmtime(config_json_path)
|
||||||
|
if mtime <= config_last_mod_time:
|
||||||
|
return config_cached
|
||||||
|
with open(config_json_path, 'r') as f:
|
||||||
|
config_cached = json.load(f)
|
||||||
|
config_last_mod_time = os.path.getmtime(config_json_path)
|
||||||
|
return config_cached
|
||||||
|
except Exception as e:
|
||||||
|
print(str(e))
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return default_val
|
||||||
|
|
||||||
# needs to support the legacy installations
|
def setConfig(config):
|
||||||
def get_initial_model_to_load():
|
try: # config.json
|
||||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||||
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
|
with open(config_json_path, 'w') as f:
|
||||||
|
return json.dump(config, f)
|
||||||
|
except:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use)
|
if 'render_devices' in config:
|
||||||
|
gpu_devices = filter(lambda dev: dev.startswith('GPU:'), config['render_devices'])
|
||||||
|
else:
|
||||||
|
gpu_devices = []
|
||||||
|
|
||||||
|
try: # config.bat
|
||||||
|
config_bat = [
|
||||||
|
f"@set update_branch={config['update_branch']}"
|
||||||
|
]
|
||||||
|
if len(gpu_devices) > 0:
|
||||||
|
config_sh.append(f"@set CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}")
|
||||||
|
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
|
||||||
|
with open(config_bat_path, 'w') as f:
|
||||||
|
f.write(f.write('\r\n'.join(config_bat)))
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
try: # config.sh
|
||||||
|
config_sh = [
|
||||||
|
'#!/bin/bash'
|
||||||
|
f"export update_branch={config['update_branch']}"
|
||||||
|
]
|
||||||
|
if len(gpu_devices) > 0:
|
||||||
|
config_sh.append(f"CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}")
|
||||||
|
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
|
||||||
|
with open(config_sh_path, 'w') as f:
|
||||||
|
f.write('\n'.join(config_sh))
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
def resolve_model_to_use(model_name:str=None):
|
||||||
|
if not model_name: # When None try user configured model.
|
||||||
config = getConfig()
|
config = getConfig()
|
||||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
if 'model' in config and 'stable-diffusion' in config['model']:
|
||||||
model_name = config['model']['stable-diffusion']
|
model_name = config['model']['stable-diffusion']
|
||||||
model_path = resolve_model_to_use(model_name)
|
if model_name:
|
||||||
|
if os.path.exists(model_name + '.ckpt'):
|
||||||
|
# Direct Path to file
|
||||||
|
return model_name
|
||||||
|
# Check models directory
|
||||||
|
models_dir_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
||||||
|
if os.path.exists(models_dir_path + '.ckpt'):
|
||||||
|
return models_dir_path
|
||||||
|
# Default locations
|
||||||
|
if model_name in APP_CONFIG_DEFAULT_MODELS:
|
||||||
|
default_model_path = os.path.join(SD_DIR, model_name)
|
||||||
|
if os.path.exists(default_model_path + '.ckpt'):
|
||||||
|
return default_model_path
|
||||||
|
# Can't find requested model, check the default paths.
|
||||||
|
for default_model in APP_CONFIG_DEFAULT_MODELS:
|
||||||
|
default_model_path = os.path.join(SD_DIR, default_model + '.ckpt')
|
||||||
|
if os.path.exists(default_model_path):
|
||||||
|
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', default_model_path + '.ckpt')
|
||||||
|
return default_model_path
|
||||||
|
raise Exception('No valid models found.')
|
||||||
|
|
||||||
if os.path.exists(model_path + '.ckpt'):
|
class SetAppConfigRequest(BaseModel):
|
||||||
ckpt_to_use = model_path
|
update_branch: str = None
|
||||||
|
render_devices: Union[List[str], List[int], str, int] = None
|
||||||
|
|
||||||
|
@app.post('/app_config')
|
||||||
|
async def setAppConfig(req : SetAppConfigRequest):
|
||||||
|
config = getConfig()
|
||||||
|
if req.update_branch:
|
||||||
|
config['update_branch'] = req.update_branch
|
||||||
|
if req.render_devices and hasattr(req.render_devices, "__len__"): # strings, array of strings or numbers.
|
||||||
|
render_devices = []
|
||||||
|
if isinstance(req.render_devices, str):
|
||||||
|
req.render_devices = req.render_devices.split(',')
|
||||||
|
if isinstance(req.render_devices, list):
|
||||||
|
for gpu in req.render_devices:
|
||||||
|
if isinstance(req.render_devices, int):
|
||||||
|
render_devices.append('GPU:' + gpu)
|
||||||
else:
|
else:
|
||||||
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
|
render_devices.append(gpu)
|
||||||
return ckpt_to_use
|
if isinstance(req.render_devices, int):
|
||||||
|
render_devices.append('GPU:' + req.render_devices)
|
||||||
|
if len(render_devices) > 0:
|
||||||
|
config['render_devices'] = render_devices
|
||||||
|
try:
|
||||||
|
setConfig(config)
|
||||||
|
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
return HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
def resolve_model_to_use(model_name):
|
def getModels():
|
||||||
if model_name in ('sd-v1-4', 'custom-model'):
|
models = {
|
||||||
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
'active': {
|
||||||
|
'stable-diffusion': 'sd-v1-4',
|
||||||
|
},
|
||||||
|
'options': {
|
||||||
|
'stable-diffusion': ['sd-v1-4'],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
legacy_model_path = os.path.join(SD_DIR, model_name)
|
# custom models
|
||||||
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'):
|
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
|
||||||
model_path = legacy_model_path
|
for file in os.listdir(sd_models_dir):
|
||||||
|
if file.endswith('.ckpt'):
|
||||||
|
model_name = os.path.splitext(file)[0]
|
||||||
|
models['options']['stable-diffusion'].append(model_name)
|
||||||
|
|
||||||
|
# legacy
|
||||||
|
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
||||||
|
if os.path.exists(custom_weight_path):
|
||||||
|
models['active']['stable-diffusion'] = 'custom-model'
|
||||||
|
models['options']['stable-diffusion'].append('custom-model')
|
||||||
|
|
||||||
|
config = getConfig()
|
||||||
|
if 'model' in config and 'stable-diffusion' in config['model']:
|
||||||
|
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
|
||||||
|
|
||||||
|
return models
|
||||||
|
|
||||||
|
@app.get('/get/{key:path}')
|
||||||
|
def read_web_data(key:str=None):
|
||||||
|
if not key: # /get without parameters, stable-diffusion easter egg.
|
||||||
|
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
||||||
|
elif key == 'app_config':
|
||||||
|
config = getConfig(default_val=None)
|
||||||
|
if config is None:
|
||||||
|
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
|
||||||
|
return JSONResponse(config, headers=NOCACHE_HEADERS)
|
||||||
|
elif key == 'models':
|
||||||
|
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
|
||||||
|
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
|
||||||
|
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
|
||||||
else:
|
else:
|
||||||
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
|
||||||
return model_path
|
|
||||||
|
|
||||||
@app.on_event("shutdown")
|
|
||||||
def shutdown_event(): # Signal render thread to close on shutdown
|
|
||||||
task_manager.current_state_error = SystemExit('Application shutting down.')
|
|
||||||
|
|
||||||
@app.get('/')
|
|
||||||
def read_root():
|
|
||||||
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
|
|
||||||
|
|
||||||
@app.get('/ping') # Get server and optionally session status.
|
@app.get('/ping') # Get server and optionally session status.
|
||||||
def ping(session_id:str=None):
|
def ping(session_id:str=None):
|
||||||
if not task_manager.render_thread.is_alive(): # Render thread is dead.
|
if task_manager.is_alive() <= 0: # Check that render threads are alive.
|
||||||
if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(current_state_error))
|
if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||||
return HTTPException(status_code=500, detail='Render thread is dead.')
|
return HTTPException(status_code=500, detail='Render thread is dead.')
|
||||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error))
|
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||||
# Alive
|
# Alive
|
||||||
response = {'status': str(task_manager.current_state)}
|
response = {'status': str(task_manager.current_state)}
|
||||||
if session_id:
|
if session_id:
|
||||||
@ -119,7 +250,7 @@ def render(req : task_manager.ImageRequest):
|
|||||||
new_task = task_manager.render(req)
|
new_task = task_manager.render(req)
|
||||||
response = {
|
response = {
|
||||||
'status': str(task_manager.current_state),
|
'status': str(task_manager.current_state),
|
||||||
'queue': task_manager.tasks_queue.qsize(),
|
'queue': len(task_manager.tasks_queue),
|
||||||
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
|
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
|
||||||
'task': id(new_task)
|
'task': id(new_task)
|
||||||
}
|
}
|
||||||
@ -172,100 +303,13 @@ def get_image(session_id, img_id):
|
|||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
return HTTPException(status_code=500, detail=str(e))
|
return HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@app.post('/app_config')
|
@app.get('/')
|
||||||
async def setAppConfig(req : SetAppConfigRequest):
|
def read_root():
|
||||||
try:
|
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
|
||||||
config = {
|
|
||||||
'update_branch': req.update_branch
|
|
||||||
}
|
|
||||||
|
|
||||||
config_json_str = json.dumps(config)
|
@app.on_event("shutdown")
|
||||||
config_bat_str = f'@set update_branch={req.update_branch}'
|
def shutdown_event(): # Signal render thread to close on shutdown
|
||||||
config_sh_str = f'export update_branch={req.update_branch}'
|
task_manager.current_state_error = SystemExit('Application shutting down.')
|
||||||
|
|
||||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
|
||||||
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
|
|
||||||
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
|
|
||||||
|
|
||||||
with open(config_json_path, 'w') as f:
|
|
||||||
f.write(config_json_str)
|
|
||||||
|
|
||||||
with open(config_bat_path, 'w') as f:
|
|
||||||
f.write(config_bat_str)
|
|
||||||
|
|
||||||
with open(config_sh_path, 'w') as f:
|
|
||||||
f.write(config_sh_str)
|
|
||||||
|
|
||||||
return {'OK'}
|
|
||||||
except Exception as e:
|
|
||||||
print(traceback.format_exc())
|
|
||||||
return HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
def getConfig(default_val={}):
|
|
||||||
try:
|
|
||||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
|
||||||
if not os.path.exists(config_json_path):
|
|
||||||
return default_val
|
|
||||||
with open(config_json_path, 'r') as f:
|
|
||||||
return json.load(f)
|
|
||||||
except Exception as e:
|
|
||||||
print(str(e))
|
|
||||||
print(traceback.format_exc())
|
|
||||||
return default_val
|
|
||||||
|
|
||||||
def setConfig(config):
|
|
||||||
try:
|
|
||||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
|
||||||
with open(config_json_path, 'w') as f:
|
|
||||||
return json.dump(config, f)
|
|
||||||
except:
|
|
||||||
print(str(e))
|
|
||||||
print(traceback.format_exc())
|
|
||||||
|
|
||||||
def getModels():
|
|
||||||
models = {
|
|
||||||
'active': {
|
|
||||||
'stable-diffusion': 'sd-v1-4',
|
|
||||||
},
|
|
||||||
'options': {
|
|
||||||
'stable-diffusion': ['sd-v1-4'],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# custom models
|
|
||||||
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
|
|
||||||
for file in os.listdir(sd_models_dir):
|
|
||||||
if file.endswith('.ckpt'):
|
|
||||||
model_name = os.path.splitext(file)[0]
|
|
||||||
models['options']['stable-diffusion'].append(model_name)
|
|
||||||
|
|
||||||
# legacy
|
|
||||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
|
||||||
if os.path.exists(custom_weight_path):
|
|
||||||
models['active']['stable-diffusion'] = 'custom-model'
|
|
||||||
models['options']['stable-diffusion'].append('custom-model')
|
|
||||||
|
|
||||||
config = getConfig()
|
|
||||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
|
||||||
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
|
|
||||||
|
|
||||||
return models
|
|
||||||
|
|
||||||
@app.get('/get/{key:path}')
|
|
||||||
def read_web_data(key:str=None):
|
|
||||||
if not key: # /get without parameters, stable-diffusion easter egg.
|
|
||||||
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
|
||||||
elif key == 'app_config':
|
|
||||||
config = getConfig(default_val=None)
|
|
||||||
if config is None:
|
|
||||||
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
|
|
||||||
return JSONResponse(config, headers=NOCACHE_HEADERS)
|
|
||||||
elif key == 'models':
|
|
||||||
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
|
|
||||||
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
|
|
||||||
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
|
|
||||||
else:
|
|
||||||
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
|
|
||||||
|
|
||||||
# don't log certain requests
|
# don't log certain requests
|
||||||
class LogSuppressFilter(logging.Filter):
|
class LogSuppressFilter(logging.Filter):
|
||||||
@ -277,8 +321,26 @@ class LogSuppressFilter(logging.Filter):
|
|||||||
return True
|
return True
|
||||||
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
|
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
|
||||||
|
|
||||||
task_manager.default_model_to_load = get_initial_model_to_load()
|
config = getConfig()
|
||||||
task_manager.start_render_thread()
|
# Start the task_manager
|
||||||
|
task_manager.default_model_to_load = resolve_model_to_use()
|
||||||
|
if 'render_devices' in config: # Start a new thread for each device.
|
||||||
|
if isinstance(config['render_devices'], str):
|
||||||
|
config['render_devices'] = config['render_devices'].split(',')
|
||||||
|
if not isinstance(config['render_devices'], list):
|
||||||
|
raise Exception('Invalid render_devices value in config.')
|
||||||
|
for device in config['render_devices']:
|
||||||
|
task_manager.start_render_thread(device)
|
||||||
|
|
||||||
|
allow_cpu = False
|
||||||
|
if task_manager.is_alive() <= 0: # No running devices, apply defaults.
|
||||||
|
# Select best device GPU device using free memory if more than one device.
|
||||||
|
task_manager.start_render_thread('auto')
|
||||||
|
allow_cpu = True
|
||||||
|
|
||||||
|
# Allow CPU to be used for renders if not already enabled in current config.
|
||||||
|
if task_manager.is_alive('cpu') <= 0 and allow_cpu:
|
||||||
|
task_manager.start_render_thread('cpu')
|
||||||
|
|
||||||
# start the browser ui
|
# start the browser ui
|
||||||
import webbrowser; webbrowser.open('http://localhost:9000')
|
import webbrowser; webbrowser.open('http://localhost:9000')
|
Loading…
x
Reference in New Issue
Block a user