mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-18 15:57:11 +02:00
First draft for Multi-GPU support
This commit is contained in:
parent
2edc06c662
commit
7c72608e1c
@ -35,63 +35,144 @@ import base64
|
||||
from io import BytesIO
|
||||
#from colorama import Fore
|
||||
|
||||
# local
|
||||
stop_processing = False
|
||||
temp_images = {}
|
||||
from threading import local as LocalThreadVars
|
||||
thread_data = LocalThreadVars()
|
||||
|
||||
ckpt_file = None
|
||||
gfpgan_file = None
|
||||
real_esrgan_file = None
|
||||
|
||||
model = None
|
||||
modelCS = None
|
||||
modelFS = None
|
||||
model_gfpgan = None
|
||||
model_real_esrgan = None
|
||||
|
||||
model_is_half = False
|
||||
model_fs_is_half = False
|
||||
device = None
|
||||
unet_bs = 1
|
||||
precision = 'autocast'
|
||||
sampler_plms = None
|
||||
sampler_ddim = None
|
||||
|
||||
has_valid_gpu = False
|
||||
force_full_precision = False
|
||||
def device_would_fail(device):
|
||||
if device == 'cpu': return None
|
||||
# Returns None when no issues found, otherwise returns the detected error str.
|
||||
# Memory check
|
||||
try:
|
||||
gpu = torch.cuda.current_device()
|
||||
gpu_name = torch.cuda.get_device_name(gpu)
|
||||
print('GPU detected: ', gpu_name)
|
||||
|
||||
force_full_precision = ('nvidia' in gpu_name.lower() or 'geforce' in gpu_name.lower()) and (' 1660' in gpu_name or ' 1650' in gpu_name) # otherwise these NVIDIA cards create green images
|
||||
if force_full_precision:
|
||||
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
|
||||
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(gpu)
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_total /= float(10**9)
|
||||
if mem_total < 3.0:
|
||||
print("GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion")
|
||||
raise Exception()
|
||||
return 'GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion'
|
||||
except RuntimeError as e:
|
||||
return str(e) # Return cuda errors from mem_get_info as strings
|
||||
return None
|
||||
|
||||
has_valid_gpu = True
|
||||
except:
|
||||
def device_select(device):
|
||||
if device == 'cpu': return True
|
||||
if not torch.cuda.is_available(): return False
|
||||
failure_msg = device_would_fail(device)
|
||||
if failure_msg:
|
||||
if 'invalid device' in failure_msg:
|
||||
raise NameError(f'GPU "{device}" could not be found. Remove this device from config.render_devices or use one of "auto" or "cuda".')
|
||||
print(failure_msg)
|
||||
return False
|
||||
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
|
||||
# otherwise these NVIDIA cards create green images
|
||||
thread_data.force_full_precision = ('nvidia' in device_name.lower() or 'geforce' in device_name.lower()) and (' 1660' in device_name or ' 1650' in device_name)
|
||||
if thread_data.force_full_precision:
|
||||
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
|
||||
|
||||
thread_data.device = device
|
||||
thread_data.has_valid_gpu = True
|
||||
return True
|
||||
|
||||
def device_init(device_selection=None):
|
||||
# Thread bound properties
|
||||
thread_data.stop_processing = False
|
||||
thread_data.temp_images = {}
|
||||
|
||||
thread_data.ckpt_file = None
|
||||
thread_data.gfpgan_file = None
|
||||
thread_data.real_esrgan_file = None
|
||||
|
||||
thread_data.model = None
|
||||
thread_data.modelCS = None
|
||||
thread_data.modelFS = None
|
||||
thread_data.model_gfpgan = None
|
||||
thread_data.model_real_esrgan = None
|
||||
|
||||
thread_data.model_is_half = False
|
||||
thread_data.model_fs_is_half = False
|
||||
thread_data.device = None
|
||||
thread_data.unet_bs = 1
|
||||
thread_data.precision = 'autocast'
|
||||
thread_data.sampler_plms = None
|
||||
thread_data.sampler_ddim = None
|
||||
|
||||
thread_data.turbo = False
|
||||
thread_data.has_valid_gpu = False
|
||||
thread_data.force_full_precision = False
|
||||
|
||||
if device_selection.lower() == 'cpu':
|
||||
print('CPU requested, skipping gpu init.')
|
||||
thread_data.device = 'cpu'
|
||||
return
|
||||
if not torch.cuda.is_available():
|
||||
print('WARNING: torch.cuda is not available. Using the CPU, but this will be very slow!')
|
||||
return
|
||||
device_count = torch.cuda.device_count()
|
||||
if device_count <= 1 and device_selection == 'auto':
|
||||
device_selection = 'current' # Use 'auto' only when there is more than one compatible device found.
|
||||
if device_selection == 'auto':
|
||||
print('Autoselecting GPU. Using most free memory.')
|
||||
max_mem_free = 0
|
||||
best_device = None
|
||||
for device in range(device_count):
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_free /= float(10**9)
|
||||
mem_total /= float(10**9)
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
print(f'GPU:{device} detected: {device_name} - Memory: {round(mem_total - mem_free, 2)}Go / {round(mem_total, 2)}Go')
|
||||
if max_mem_free < mem_free:
|
||||
max_mem_free = mem_free
|
||||
best_device = device
|
||||
if best_device and device_select(device):
|
||||
print(f'Setting GPU:{device} as active')
|
||||
torch.cuda.device(device)
|
||||
return
|
||||
if isinstance(device_selection, str):
|
||||
device_selection = device_selection.lower()
|
||||
if device_selection.startswith('gpu:'):
|
||||
device_selection = int(device_selection[4:])
|
||||
if device_selection != 'cuda' and device_selection != 'current':
|
||||
if device_select(device_selection):
|
||||
if isinstance(device_selection, int):
|
||||
print(f'Setting GPU:{device_selection} as active')
|
||||
else:
|
||||
print(f'Setting {device_selection} as active')
|
||||
torch.cuda.device(device_selection)
|
||||
return
|
||||
# By default use current device.
|
||||
print('Checking current GPU...')
|
||||
device = torch.cuda.current_device()
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
print(f'GPU:{device} detected: {device_name}')
|
||||
if device_select(device):
|
||||
return
|
||||
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
|
||||
pass
|
||||
thread_data.device = 'cpu'
|
||||
|
||||
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast'):
|
||||
global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half
|
||||
def is_first_cuda_device(device):
|
||||
if thread_data.device == 0 or thread_data.device == '0':
|
||||
return True
|
||||
if thread_data.device == 'cuda' or thread_data.device == 'cuda:0':
|
||||
return True
|
||||
if thread_data.device == torch.device(0):
|
||||
return True
|
||||
return False
|
||||
|
||||
device = device_to_use if has_valid_gpu else 'cpu'
|
||||
precision = precision_to_use if not force_full_precision else 'full'
|
||||
unet_bs = unet_bs_to_use
|
||||
def load_model_ckpt():
|
||||
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
|
||||
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
|
||||
|
||||
if not thread_data.precision:
|
||||
thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast'
|
||||
if not thread_data.unet_bs:
|
||||
thread_data.unet_bs = 1
|
||||
|
||||
unload_model()
|
||||
|
||||
if device == 'cpu':
|
||||
precision = 'full'
|
||||
if thread_data.device == 'cpu':
|
||||
thread_data.precision = 'full'
|
||||
|
||||
sd = load_model_from_config(f"{ckpt_to_use}.ckpt")
|
||||
print('loading', thread_data.ckpt_file, 'to', thread_data.device, 'using precision', thread_data.precision)
|
||||
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
|
||||
li, lo = [], []
|
||||
for key, value in sd.items():
|
||||
sp = key.split(".")
|
||||
@ -114,88 +195,84 @@ def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_u
|
||||
model = instantiate_from_config(config.modelUNet)
|
||||
_, _ = model.load_state_dict(sd, strict=False)
|
||||
model.eval()
|
||||
model.cdevice = device
|
||||
model.unet_bs = unet_bs
|
||||
model.turbo = turbo
|
||||
model.cdevice = torch.device(thread_data.device)
|
||||
model.unet_bs = thread_data.unet_bs
|
||||
model.turbo = thread_data.turbo
|
||||
if thread_data.device != 'cpu':
|
||||
model.to(thread_data.device)
|
||||
thread_data.model = model
|
||||
|
||||
modelCS = instantiate_from_config(config.modelCondStage)
|
||||
_, _ = modelCS.load_state_dict(sd, strict=False)
|
||||
modelCS.eval()
|
||||
modelCS.cond_stage_model.device = device
|
||||
modelCS.cond_stage_model.device = torch.device(thread_data.device)
|
||||
if thread_data.device != 'cpu':
|
||||
modelCS.to(thread_data.device)
|
||||
thread_data.modelCS = modelCS
|
||||
|
||||
modelFS = instantiate_from_config(config.modelFirstStage)
|
||||
_, _ = modelFS.load_state_dict(sd, strict=False)
|
||||
modelFS.eval()
|
||||
if thread_data.device != 'cpu':
|
||||
modelFS.to(thread_data.device)
|
||||
thread_data.modelFS = modelFS
|
||||
del sd
|
||||
|
||||
if device != "cpu" and precision == "autocast":
|
||||
model.half()
|
||||
modelCS.half()
|
||||
modelFS.half()
|
||||
model_is_half = True
|
||||
model_fs_is_half = True
|
||||
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||
thread_data.model.half()
|
||||
thread_data.modelCS.half()
|
||||
thread_data.modelFS.half()
|
||||
thread_data.model_is_half = True
|
||||
thread_data.model_fs_is_half = True
|
||||
else:
|
||||
model_is_half = False
|
||||
model_fs_is_half = False
|
||||
thread_data.model_is_half = False
|
||||
thread_data.model_fs_is_half = False
|
||||
|
||||
ckpt_file = ckpt_to_use
|
||||
|
||||
print('loaded ', ckpt_file, 'to', device, 'precision', precision)
|
||||
print('loaded', thread_data.ckpt_file, 'as', model.device, '->', modelCS.cond_stage_model.device, '->', thread_data.modelFS.device, 'using precision', thread_data.precision)
|
||||
|
||||
def unload_model():
|
||||
global model, modelCS, modelFS
|
||||
if thread_data.model is not None:
|
||||
print('Unloading models...')
|
||||
del thread_data.model
|
||||
del thread_data.modelCS
|
||||
del thread_data.modelFS
|
||||
thread_data.model = None
|
||||
thread_data.modelCS = None
|
||||
thread_data.modelFS = None
|
||||
|
||||
if model is not None:
|
||||
del model
|
||||
del modelCS
|
||||
del modelFS
|
||||
|
||||
model = None
|
||||
modelCS = None
|
||||
modelFS = None
|
||||
|
||||
def load_model_gfpgan(gfpgan_to_use):
|
||||
global gfpgan_file, model_gfpgan
|
||||
|
||||
if gfpgan_to_use is None:
|
||||
def load_model_gfpgan():
|
||||
if thread_data.gfpgan_file is None:
|
||||
print('load_model_gfpgan called without setting gfpgan_file')
|
||||
return
|
||||
if thread_data.device != 'cpu' and not is_first_cuda_device(thread_data.device):
|
||||
#TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
|
||||
raise Exception(f'Current device {torch.device(thread_data.device)} is not {torch.device(0)}.')
|
||||
model_path = thread_data.gfpgan_file + ".pth"
|
||||
thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
|
||||
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
|
||||
|
||||
gfpgan_file = gfpgan_to_use
|
||||
model_path = gfpgan_to_use + ".pth"
|
||||
|
||||
if device == 'cpu':
|
||||
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
|
||||
else:
|
||||
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cuda'))
|
||||
|
||||
print('loaded ', gfpgan_to_use, 'to', device, 'precision', precision)
|
||||
|
||||
def load_model_real_esrgan(real_esrgan_to_use):
|
||||
global real_esrgan_file, model_real_esrgan
|
||||
|
||||
if real_esrgan_to_use is None:
|
||||
def load_model_real_esrgan():
|
||||
if thread_data.real_esrgan_file is None:
|
||||
print('load_model_real_esrgan called without setting real_esrgan_file')
|
||||
return
|
||||
|
||||
real_esrgan_file = real_esrgan_to_use
|
||||
model_path = real_esrgan_to_use + ".pth"
|
||||
model_path = thread_data.real_esrgan_file + ".pth"
|
||||
|
||||
RealESRGAN_models = {
|
||||
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
|
||||
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
|
||||
}
|
||||
|
||||
model_to_use = RealESRGAN_models[real_esrgan_to_use]
|
||||
model_to_use = RealESRGAN_models[thread_data.real_esrgan_file]
|
||||
|
||||
if device == 'cpu':
|
||||
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
|
||||
model_real_esrgan.device = torch.device('cpu')
|
||||
model_real_esrgan.model.to('cpu')
|
||||
if thread_data.device == 'cpu':
|
||||
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
|
||||
#thread_data.model_real_esrgan.device = torch.device(thread_data.device)
|
||||
thread_data.model_real_esrgan.model.to('cpu')
|
||||
else:
|
||||
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=model_is_half)
|
||||
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half)
|
||||
|
||||
model_real_esrgan.model.name = real_esrgan_to_use
|
||||
|
||||
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
|
||||
thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file
|
||||
print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
|
||||
|
||||
def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
|
||||
if disk_path is None: return None
|
||||
@ -214,14 +291,22 @@ def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
|
||||
|
||||
def apply_filters(filter_name, image_data):
|
||||
print(f'Applying filter {filter_name}...')
|
||||
if isinstance(image_data, torch.Tensor):
|
||||
print(image_data)
|
||||
image_data.to(thread_data.device)
|
||||
|
||||
gc()
|
||||
|
||||
if filter_name == 'gfpgan':
|
||||
_, _, output = model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
|
||||
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
|
||||
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
|
||||
image_data = output[:,:,::-1]
|
||||
|
||||
if filter_name == 'real_esrgan':
|
||||
output, _ = model_real_esrgan.enhance(image_data[:,:,::-1])
|
||||
if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.')
|
||||
print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
|
||||
output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1])
|
||||
image_data = output[:,:,::-1]
|
||||
|
||||
return image_data
|
||||
@ -234,12 +319,12 @@ def mk_img(req: Request):
|
||||
|
||||
gc()
|
||||
|
||||
if device != "cpu":
|
||||
modelFS.to("cpu")
|
||||
modelCS.to("cpu")
|
||||
if thread_data.device != "cpu":
|
||||
thread_data.modelFS.to("cpu")
|
||||
thread_data.modelCS.to("cpu")
|
||||
|
||||
model.model1.to("cpu")
|
||||
model.model2.to("cpu")
|
||||
thread_data.model.model1.to("cpu")
|
||||
thread_data.model.model2.to("cpu")
|
||||
|
||||
gc()
|
||||
|
||||
@ -249,66 +334,55 @@ def mk_img(req: Request):
|
||||
})
|
||||
|
||||
def do_mk_img(req: Request):
|
||||
global ckpt_file
|
||||
global model, modelCS, modelFS, device
|
||||
global model_gfpgan, model_real_esrgan
|
||||
global stop_processing
|
||||
|
||||
stop_processing = False
|
||||
thread_data.stop_processing = False
|
||||
|
||||
res = Response()
|
||||
res.request = req
|
||||
res.images = []
|
||||
|
||||
temp_images.clear()
|
||||
thread_data.temp_images.clear()
|
||||
|
||||
# custom model support:
|
||||
# the req.use_stable_diffusion_model needs to be a valid path
|
||||
# to the ckpt file (without the extension).
|
||||
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
|
||||
|
||||
needs_model_reload = False
|
||||
ckpt_to_use = ckpt_file
|
||||
if ckpt_to_use != req.use_stable_diffusion_model:
|
||||
ckpt_to_use = req.use_stable_diffusion_model
|
||||
if thread_data.ckpt_file != req.use_stable_diffusion_model:
|
||||
thread_data.ckpt_file = req.use_stable_diffusion_model
|
||||
needs_model_reload = True
|
||||
|
||||
model.turbo = req.turbo
|
||||
if req.use_cpu:
|
||||
if device != 'cpu':
|
||||
device = 'cpu'
|
||||
|
||||
if model_is_half:
|
||||
load_model_ckpt(ckpt_to_use, device)
|
||||
if thread_data.device != 'cpu':
|
||||
thread_data.device = 'cpu'
|
||||
if thread_data.model_is_half:
|
||||
load_model_ckpt()
|
||||
needs_model_reload = False
|
||||
|
||||
load_model_gfpgan(gfpgan_file)
|
||||
load_model_real_esrgan(real_esrgan_file)
|
||||
load_model_gfpgan()
|
||||
load_model_real_esrgan()
|
||||
else:
|
||||
if has_valid_gpu:
|
||||
prev_device = device
|
||||
device = 'cuda'
|
||||
|
||||
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
|
||||
(precision == 'full' and not req.use_full_precision and not force_full_precision):
|
||||
|
||||
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'))
|
||||
if thread_data.has_valid_gpu:
|
||||
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
|
||||
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
|
||||
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
||||
load_model_ckpt()
|
||||
load_model_gfpgan()
|
||||
load_model_real_esrgan()
|
||||
needs_model_reload = False
|
||||
|
||||
if prev_device != device:
|
||||
load_model_gfpgan(gfpgan_file)
|
||||
load_model_real_esrgan(real_esrgan_file)
|
||||
|
||||
if needs_model_reload:
|
||||
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, precision)
|
||||
load_model_ckpt()
|
||||
|
||||
if req.use_face_correction != gfpgan_file:
|
||||
load_model_gfpgan(req.use_face_correction)
|
||||
if req.use_face_correction != thread_data.gfpgan_file:
|
||||
thread_data.gfpgan_file = req.use_face_correction
|
||||
load_model_gfpgan()
|
||||
if req.use_upscale != thread_data.real_esrgan_file:
|
||||
thread_data.real_esrgan_file = req.use_upscale
|
||||
load_model_real_esrgan()
|
||||
|
||||
if req.use_upscale != real_esrgan_file:
|
||||
load_model_real_esrgan(req.use_upscale)
|
||||
|
||||
model.cdevice = device
|
||||
modelCS.cond_stage_model.device = device
|
||||
if thread_data.turbo != req.turbo:
|
||||
thread_data.turbo = req.turbo
|
||||
thread_data.model.turbo = req.turbo
|
||||
|
||||
opt_prompt = req.prompt
|
||||
opt_seed = req.seed
|
||||
@ -318,9 +392,8 @@ def do_mk_img(req: Request):
|
||||
opt_ddim_eta = 0.0
|
||||
opt_init_img = req.init_image
|
||||
|
||||
print(req.to_string(), '\n device', device)
|
||||
|
||||
print('\n\n Using precision:', precision)
|
||||
print(req.to_string(), '\n device', thread_data.device)
|
||||
print('\n\n Using precision:', thread_data.precision)
|
||||
|
||||
seed_everything(opt_seed)
|
||||
|
||||
@ -329,7 +402,7 @@ def do_mk_img(req: Request):
|
||||
assert prompt is not None
|
||||
data = [batch_size * [prompt]]
|
||||
|
||||
if precision == "autocast" and device != "cpu":
|
||||
if thread_data.precision == "autocast" and thread_data.device != "cpu":
|
||||
precision_scope = autocast
|
||||
else:
|
||||
precision_scope = nullcontext
|
||||
@ -345,22 +418,22 @@ def do_mk_img(req: Request):
|
||||
handler = _img2img
|
||||
|
||||
init_image = load_img(req.init_image, req.width, req.height)
|
||||
init_image = init_image.to(device)
|
||||
init_image = init_image.to(thread_data.device)
|
||||
|
||||
if device != "cpu" and precision == "autocast":
|
||||
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||
init_image = init_image.half()
|
||||
|
||||
modelFS.to(device)
|
||||
thread_data.modelFS.to(thread_data.device)
|
||||
|
||||
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
|
||||
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
|
||||
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
|
||||
|
||||
if req.mask is not None:
|
||||
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(device)
|
||||
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
|
||||
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
|
||||
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
|
||||
|
||||
if device != "cpu" and precision == "autocast":
|
||||
if thread_data.device != "cpu" and thread_data.precision == "autocast":
|
||||
mask = mask.half()
|
||||
|
||||
move_fs_to_cpu()
|
||||
@ -381,10 +454,10 @@ def do_mk_img(req: Request):
|
||||
for prompts in tqdm(data, desc="data"):
|
||||
|
||||
with precision_scope("cuda"):
|
||||
modelCS.to(device)
|
||||
thread_data.modelCS.to(thread_data.device)
|
||||
uc = None
|
||||
if req.guidance_scale != 1.0:
|
||||
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
||||
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
|
||||
if isinstance(prompts, tuple):
|
||||
prompts = list(prompts)
|
||||
|
||||
@ -397,11 +470,11 @@ def do_mk_img(req: Request):
|
||||
weight = weights[i]
|
||||
# if not skip_normalize:
|
||||
weight = weight / totalWeight
|
||||
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
|
||||
else:
|
||||
c = modelCS.get_learned_conditioning(prompts)
|
||||
c = thread_data.modelCS.get_learned_conditioning(prompts)
|
||||
|
||||
modelFS.to(device)
|
||||
thread_data.modelFS.to(thread_data.device)
|
||||
|
||||
partial_x_samples = None
|
||||
def img_callback(x_samples, i):
|
||||
@ -417,7 +490,7 @@ def do_mk_img(req: Request):
|
||||
partial_images = []
|
||||
|
||||
for i in range(batch_size):
|
||||
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
@ -429,18 +502,19 @@ def do_mk_img(req: Request):
|
||||
del img, x_sample, x_samples_ddim
|
||||
# don't delete x_samples, it is used in the code that called this callback
|
||||
|
||||
temp_images[str(req.session_id) + '/' + str(i)] = buf
|
||||
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
|
||||
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
||||
|
||||
progress['output'] = partial_images
|
||||
|
||||
yield json.dumps(progress)
|
||||
|
||||
if stop_processing:
|
||||
if thread_data.stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
|
||||
# run the handler
|
||||
try:
|
||||
print('Running handler...')
|
||||
if handler == _txt2img:
|
||||
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
|
||||
else:
|
||||
@ -458,7 +532,7 @@ def do_mk_img(req: Request):
|
||||
print("saving images")
|
||||
for i in range(batch_size):
|
||||
|
||||
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
|
||||
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
@ -469,7 +543,7 @@ def do_mk_img(req: Request):
|
||||
|
||||
return_orig_img = not has_filters or not req.show_only_filtered_image
|
||||
|
||||
if stop_processing:
|
||||
if thread_data.stop_processing:
|
||||
return_orig_img = True
|
||||
|
||||
if req.save_to_disk_path is not None:
|
||||
@ -489,7 +563,7 @@ def do_mk_img(req: Request):
|
||||
|
||||
del img
|
||||
|
||||
if has_filters and not stop_processing:
|
||||
if has_filters and not thread_data.stop_processing:
|
||||
filters_applied = []
|
||||
if req.use_face_correction:
|
||||
x_sample = apply_filters('gfpgan', x_sample)
|
||||
@ -514,7 +588,7 @@ def do_mk_img(req: Request):
|
||||
move_fs_to_cpu()
|
||||
gc()
|
||||
del x_samples, x_samples_ddim, x_sample
|
||||
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
|
||||
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mo')
|
||||
|
||||
print('Task completed')
|
||||
|
||||
@ -527,7 +601,7 @@ def save_image(img, img_out_path):
|
||||
print('could not save the file', traceback.format_exc())
|
||||
|
||||
def save_metadata(meta_out_path, req, prompt, opt_seed):
|
||||
metadata = f"""{prompt}
|
||||
metadata = f'''{prompt}
|
||||
Width: {req.width}
|
||||
Height: {req.height}
|
||||
Seed: {opt_seed}
|
||||
@ -538,8 +612,8 @@ Use Face Correction: {req.use_face_correction}
|
||||
Use Upscaling: {req.use_upscale}
|
||||
Sampler: {req.sampler}
|
||||
Negative Prompt: {req.negative_prompt}
|
||||
Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
|
||||
"""
|
||||
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
|
||||
'''
|
||||
try:
|
||||
with open(meta_out_path, 'w') as f:
|
||||
f.write(metadata)
|
||||
@ -549,16 +623,19 @@ Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
|
||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
if device != "cpu":
|
||||
mem = torch.cuda.memory_allocated() / 1e6
|
||||
modelCS.to("cpu")
|
||||
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
||||
if thread_data.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
||||
print('Device:', thread_data.device, 'CS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
|
||||
thread_data.modelCS.to("cpu")
|
||||
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
|
||||
print('Device:', thread_data.device, 'Waiting Memory transfer. Memory Used:', round(mem, 2), 'Mo')
|
||||
time.sleep(1)
|
||||
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
|
||||
|
||||
if sampler_name == 'ddim':
|
||||
model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||
|
||||
samples_ddim = model.sample(
|
||||
samples_ddim = thread_data.model.sample(
|
||||
S=opt_ddim_steps,
|
||||
conditioning=c,
|
||||
seed=opt_seed,
|
||||
@ -572,14 +649,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
mask=mask,
|
||||
sampler = sampler_name,
|
||||
)
|
||||
|
||||
yield from samples_ddim
|
||||
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
|
||||
# encode (scaled latent)
|
||||
z_enc = model.stochastic_encode(
|
||||
z_enc = thread_data.model.stochastic_encode(
|
||||
init_latent,
|
||||
torch.tensor([t_enc] * batch_size).to(device),
|
||||
torch.tensor([t_enc] * batch_size).to(thread_data.device),
|
||||
opt_seed,
|
||||
opt_ddim_eta,
|
||||
opt_ddim_steps,
|
||||
@ -587,7 +663,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
x_T = None if mask is None else init_latent
|
||||
|
||||
# decode it
|
||||
samples_ddim = model.sample(
|
||||
samples_ddim = thread_data.model.sample(
|
||||
t_enc,
|
||||
c,
|
||||
z_enc,
|
||||
@ -602,16 +678,19 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
yield from samples_ddim
|
||||
|
||||
def move_fs_to_cpu():
|
||||
if device != "cpu":
|
||||
mem = torch.cuda.memory_allocated() / 1e6
|
||||
modelFS.to("cpu")
|
||||
while torch.cuda.memory_allocated() / 1e6 >= mem:
|
||||
if thread_data.device != "cpu":
|
||||
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
|
||||
print('Device:', thread_data.device, 'FS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
|
||||
thread_data.modelFS.to("cpu")
|
||||
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
|
||||
print('Device:', thread_data.device, 'Waiting for Memory transfer. Memory Used:', round(mem, 2), 'Mo')
|
||||
time.sleep(1)
|
||||
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
|
||||
|
||||
def gc():
|
||||
if device == 'cpu':
|
||||
#gc.collect()
|
||||
if thread_data.device == 'cpu':
|
||||
return
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
@ -621,7 +700,6 @@ def chunk(it, size):
|
||||
it = iter(it)
|
||||
return iter(lambda: tuple(islice(it, size)), ())
|
||||
|
||||
|
||||
def load_model_from_config(ckpt, verbose=False):
|
||||
print(f"Loading model from {ckpt}")
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
|
@ -9,6 +9,10 @@ from typing import Any, Generator, Hashable, Optional, Union
|
||||
from pydantic import BaseModel
|
||||
from sd_internal import Request, Response
|
||||
|
||||
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
||||
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
||||
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
|
||||
|
||||
class SymbolClass(type): # Print nicely formatted Symbol names.
|
||||
def __repr__(self): return self.__qualname__
|
||||
def __str__(self): return self.__name__
|
||||
@ -66,17 +70,30 @@ class ImageRequest(BaseModel):
|
||||
stream_progress_updates: bool = False
|
||||
stream_image_progress: bool = False
|
||||
|
||||
class FilterRequest(BaseModel):
|
||||
session_id: str = "session"
|
||||
model: str = None
|
||||
name: str = ""
|
||||
init_image: str = None # base64
|
||||
width: int = 512
|
||||
height: int = 512
|
||||
save_to_disk_path: str = None
|
||||
turbo: bool = True
|
||||
use_cpu: bool = False
|
||||
use_full_precision: bool = False
|
||||
output_format: str = "jpeg" # or "png"
|
||||
|
||||
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||
class TaskCache():
|
||||
def __init__(self):
|
||||
self._base = dict()
|
||||
self._lock: threading.Lock = threading.RLock()
|
||||
self._lock: threading.Lock = threading.Lock()
|
||||
def _get_ttl_time(self, ttl: int) -> int:
|
||||
return int(time.time()) + ttl
|
||||
def _is_expired(self, timestamp: int) -> bool:
|
||||
return int(time.time()) >= timestamp
|
||||
def clean(self) -> None:
|
||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clean failed to acquire lock within timeout.')
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clean' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
# Create a list of expired keys to delete
|
||||
to_delete = []
|
||||
@ -91,11 +108,11 @@ class TaskCache():
|
||||
finally:
|
||||
self._lock.release()
|
||||
def clear(self) -> None:
|
||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.clear failed to acquire lock within timeout.')
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.clear' + ERR_LOCK_FAILED)
|
||||
try: self._base.clear()
|
||||
finally: self._lock.release()
|
||||
def delete(self, key: Hashable) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.delete failed to acquire lock within timeout.')
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.delete' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
if key not in self._base:
|
||||
return False
|
||||
@ -104,7 +121,7 @@ class TaskCache():
|
||||
finally:
|
||||
self._lock.release()
|
||||
def keep(self, key: Hashable, ttl: int) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.keep failed to acquire lock within timeout.')
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.keep' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
if key in self._base:
|
||||
_, value = self._base.get(key)
|
||||
@ -114,7 +131,7 @@ class TaskCache():
|
||||
finally:
|
||||
self._lock.release()
|
||||
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.put failed to acquire lock within timeout.')
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.put' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
self._base[key] = (
|
||||
self._get_ttl_time(ttl), value
|
||||
@ -128,21 +145,23 @@ class TaskCache():
|
||||
finally:
|
||||
self._lock.release()
|
||||
def tryGet(self, key: Hashable) -> Any:
|
||||
if not self._lock.acquire(blocking=True, timeout=10): raise Exception('TaskCache.tryGet failed to acquire lock within timeout.')
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('TaskCache.tryGet' + ERR_LOCK_FAILED)
|
||||
try:
|
||||
ttl, value = self._base.get(key, (None, None))
|
||||
if ttl is not None and self._is_expired(ttl):
|
||||
print(f'Session {key} expired. Discarding data.')
|
||||
self.delete(key)
|
||||
del self._base[key]
|
||||
return None
|
||||
return value
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
manager_lock = threading.Lock()
|
||||
render_threads = []
|
||||
current_state = ServerStates.Init
|
||||
current_state_error:Exception = None
|
||||
current_model_path = None
|
||||
tasks_queue = queue.Queue()
|
||||
tasks_queue = []
|
||||
task_cache = TaskCache()
|
||||
default_model_to_load = None
|
||||
|
||||
@ -155,7 +174,8 @@ def preload_model(file_path=None):
|
||||
current_state = ServerStates.LoadingModel
|
||||
try:
|
||||
from . import runtime
|
||||
runtime.load_model_ckpt(ckpt_to_use=file_path)
|
||||
runtime.thread_data.ckpt_file = file_path
|
||||
runtime.load_model_ckpt()
|
||||
current_model_path = file_path
|
||||
current_state_error = None
|
||||
current_state = ServerStates.Online
|
||||
@ -165,43 +185,62 @@ def preload_model(file_path=None):
|
||||
current_state = ServerStates.Unavailable
|
||||
print(traceback.format_exc())
|
||||
|
||||
def thread_render():
|
||||
def thread_render(device):
|
||||
global current_state, current_state_error, current_model_path
|
||||
from . import runtime
|
||||
current_state = ServerStates.Online
|
||||
try:
|
||||
runtime.device_init(device)
|
||||
except:
|
||||
print(traceback.format_exc())
|
||||
return
|
||||
preload_model()
|
||||
current_state = ServerStates.Online
|
||||
while True:
|
||||
task_cache.clean()
|
||||
if isinstance(current_state_error, SystemExit):
|
||||
current_state = ServerStates.Unavailable
|
||||
return
|
||||
task = None
|
||||
try:
|
||||
task = tasks_queue.get(timeout=1)
|
||||
except queue.Empty as e:
|
||||
if isinstance(current_state_error, SystemExit):
|
||||
current_state = ServerStates.Unavailable
|
||||
return
|
||||
else: continue
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.')
|
||||
time.sleep(1)
|
||||
continue
|
||||
if len(tasks_queue) <= 0:
|
||||
manager_lock.release()
|
||||
time.sleep(1)
|
||||
continue
|
||||
try: # Select a render task.
|
||||
for queued_task in tasks_queue:
|
||||
if queued_task.request.use_cpu and runtime.thread_data.device != 'cpu':
|
||||
continue # Cuda Tasks
|
||||
if not queued_task.request.use_cpu and runtime.thread_data.device == 'cpu':
|
||||
continue # CPU Tasks
|
||||
if queued_task.request.use_face_correction and not runtime.is_first_cuda_device(runtime.thread_data.device):
|
||||
continue #TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
|
||||
task = queued_task
|
||||
break
|
||||
if task is not None:
|
||||
del tasks_queue[tasks_queue.index(task)]
|
||||
finally:
|
||||
manager_lock.release()
|
||||
if task is None:
|
||||
time.sleep(1)
|
||||
continue
|
||||
#if current_model_path != task.request.use_stable_diffusion_model:
|
||||
# preload_model(task.request.use_stable_diffusion_model)
|
||||
if current_state_error:
|
||||
task.error = current_state_error
|
||||
continue
|
||||
print(f'Session {task.request.session_id} starting task {id(task)}')
|
||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||
try:
|
||||
task.lock.acquire(blocking=False)
|
||||
# Open data generator.
|
||||
res = runtime.mk_img(task.request)
|
||||
if current_model_path == task.request.use_stable_diffusion_model:
|
||||
current_state = ServerStates.Rendering
|
||||
else:
|
||||
current_state = ServerStates.LoadingModel
|
||||
except Exception as e:
|
||||
task.error = e
|
||||
task.lock.release()
|
||||
tasks_queue.task_done()
|
||||
print(traceback.format_exc())
|
||||
continue
|
||||
# Start reading from generator.
|
||||
dataQueue = None
|
||||
if task.request.stream_progress_updates:
|
||||
dataQueue = task.buffer_queue
|
||||
@ -224,13 +263,18 @@ def thread_render():
|
||||
for out_obj in result['output']:
|
||||
if 'path' in out_obj:
|
||||
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
|
||||
task.temp_images[int(img_id)] = runtime.temp_images[out_obj['path'][11:]]
|
||||
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
|
||||
elif 'data' in out_obj:
|
||||
task.temp_images[result['output'].index(out_obj)] = out_obj['data']
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||
except Exception as e:
|
||||
task.error = e
|
||||
print(traceback.format_exc())
|
||||
continue
|
||||
finally:
|
||||
# Task completed
|
||||
task.lock.release()
|
||||
tasks_queue.task_done()
|
||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||
if isinstance(task.error, StopAsyncIteration):
|
||||
print(f'Session {task.request.session_id} task {id(task)} cancelled!')
|
||||
@ -240,19 +284,37 @@ def thread_render():
|
||||
print(f'Session {task.request.session_id} task {id(task)} completed.')
|
||||
current_state = ServerStates.Online
|
||||
|
||||
render_thread = threading.Thread(target=thread_render)
|
||||
def is_alive(name=None):
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED)
|
||||
nbr_alive = 0
|
||||
try:
|
||||
for rthread in render_threads:
|
||||
if name and not rthread.name.endswith(name):
|
||||
continue
|
||||
if rthread.is_alive():
|
||||
nbr_alive += 1
|
||||
return nbr_alive
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
def start_render_thread():
|
||||
# Start Rendering Thread
|
||||
render_thread.daemon = True
|
||||
render_thread.start()
|
||||
def start_render_thread(device='auto'):
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_threads' + ERR_LOCK_FAILED)
|
||||
print('Start new Rendering Thread on device', device)
|
||||
try:
|
||||
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
|
||||
rthread.daemon = True
|
||||
rthread.name = 'Runner/' + device
|
||||
rthread.start()
|
||||
render_threads.append(rthread)
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
global current_state_error
|
||||
current_state_error = SystemExit('Application shutting down.')
|
||||
|
||||
def render(req : ImageRequest):
|
||||
if not render_thread.is_alive(): # Render thread is dead
|
||||
if not is_alive(): # Render thread is dead
|
||||
raise ChildProcessError('Rendering thread has died.')
|
||||
# Alive, check if task in cache
|
||||
task = task_cache.tryGet(req.session_id)
|
||||
@ -293,6 +355,12 @@ def render(req : ImageRequest):
|
||||
|
||||
new_task = RenderTask(r)
|
||||
if task_cache.put(r.session_id, new_task, TASK_TTL):
|
||||
tasks_queue.put(new_task, block=True, timeout=30)
|
||||
# Use twice the normal timeout for adding user requests.
|
||||
# Tries to force task_cache.put to fail before tasks_queue.put would.
|
||||
if manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT * 2):
|
||||
try:
|
||||
tasks_queue.append(new_task)
|
||||
return new_task
|
||||
finally:
|
||||
manager_lock.release()
|
||||
raise RuntimeError('Failed to add task to cache.')
|
||||
|
320
ui/server.py
320
ui/server.py
@ -15,14 +15,24 @@ MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
|
||||
|
||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||
APP_CONFIG_DEFAULTS = {
|
||||
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
|
||||
'render_devices': ['auto'], # ['cuda'] or ['CPU', 'GPU:0', 'GPU:1', ...] or ['cpu']
|
||||
'update_branch': 'main',
|
||||
}
|
||||
APP_CONFIG_DEFAULT_MODELS = [
|
||||
# needed to support the legacy installations
|
||||
'custom-model', # Check if user has a custom model, use it first.
|
||||
'sd-v1-4', # Default fallback.
|
||||
]
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
import queue, threading, time
|
||||
from typing import Any, Generator, Hashable, Optional, Union
|
||||
#import queue, threading, time
|
||||
from typing import Any, Generator, Hashable, List, Optional, Union
|
||||
|
||||
from sd_internal import Request, Response, task_manager
|
||||
|
||||
@ -37,52 +47,173 @@ ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
|
||||
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
|
||||
|
||||
class SetAppConfigRequest(BaseModel):
|
||||
update_branch: str = "main"
|
||||
config_cached = None
|
||||
config_last_mod_time = 0
|
||||
def getConfig(default_val=APP_CONFIG_DEFAULTS):
|
||||
global config_cached, config_last_mod_time
|
||||
try:
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
if not os.path.exists(config_json_path):
|
||||
return default_val
|
||||
if config_last_mod_time > 0 and config_cached is not None:
|
||||
# Don't read if file was not modified
|
||||
mtime = os.path.getmtime(config_json_path)
|
||||
if mtime <= config_last_mod_time:
|
||||
return config_cached
|
||||
with open(config_json_path, 'r') as f:
|
||||
config_cached = json.load(f)
|
||||
config_last_mod_time = os.path.getmtime(config_json_path)
|
||||
return config_cached
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
print(traceback.format_exc())
|
||||
return default_val
|
||||
|
||||
# needs to support the legacy installations
|
||||
def get_initial_model_to_load():
|
||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
||||
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
|
||||
def setConfig(config):
|
||||
try: # config.json
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
with open(config_json_path, 'w') as f:
|
||||
return json.dump(config, f)
|
||||
except:
|
||||
print(traceback.format_exc())
|
||||
|
||||
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use)
|
||||
if 'render_devices' in config:
|
||||
gpu_devices = filter(lambda dev: dev.startswith('GPU:'), config['render_devices'])
|
||||
else:
|
||||
gpu_devices = []
|
||||
|
||||
try: # config.bat
|
||||
config_bat = [
|
||||
f"@set update_branch={config['update_branch']}"
|
||||
]
|
||||
if len(gpu_devices) > 0:
|
||||
config_sh.append(f"@set CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}")
|
||||
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
|
||||
with open(config_bat_path, 'w') as f:
|
||||
f.write(f.write('\r\n'.join(config_bat)))
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
|
||||
try: # config.sh
|
||||
config_sh = [
|
||||
'#!/bin/bash'
|
||||
f"export update_branch={config['update_branch']}"
|
||||
]
|
||||
if len(gpu_devices) > 0:
|
||||
config_sh.append(f"CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}")
|
||||
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
|
||||
with open(config_sh_path, 'w') as f:
|
||||
f.write('\n'.join(config_sh))
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
|
||||
def resolve_model_to_use(model_name:str=None):
|
||||
if not model_name: # When None try user configured model.
|
||||
config = getConfig()
|
||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
||||
model_name = config['model']['stable-diffusion']
|
||||
model_path = resolve_model_to_use(model_name)
|
||||
if model_name:
|
||||
if os.path.exists(model_name + '.ckpt'):
|
||||
# Direct Path to file
|
||||
return model_name
|
||||
# Check models directory
|
||||
models_dir_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
||||
if os.path.exists(models_dir_path + '.ckpt'):
|
||||
return models_dir_path
|
||||
# Default locations
|
||||
if model_name in APP_CONFIG_DEFAULT_MODELS:
|
||||
default_model_path = os.path.join(SD_DIR, model_name)
|
||||
if os.path.exists(default_model_path + '.ckpt'):
|
||||
return default_model_path
|
||||
# Can't find requested model, check the default paths.
|
||||
for default_model in APP_CONFIG_DEFAULT_MODELS:
|
||||
default_model_path = os.path.join(SD_DIR, default_model + '.ckpt')
|
||||
if os.path.exists(default_model_path):
|
||||
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', default_model_path + '.ckpt')
|
||||
return default_model_path
|
||||
raise Exception('No valid models found.')
|
||||
|
||||
if os.path.exists(model_path + '.ckpt'):
|
||||
ckpt_to_use = model_path
|
||||
class SetAppConfigRequest(BaseModel):
|
||||
update_branch: str = None
|
||||
render_devices: Union[List[str], List[int], str, int] = None
|
||||
|
||||
@app.post('/app_config')
|
||||
async def setAppConfig(req : SetAppConfigRequest):
|
||||
config = getConfig()
|
||||
if req.update_branch:
|
||||
config['update_branch'] = req.update_branch
|
||||
if req.render_devices and hasattr(req.render_devices, "__len__"): # strings, array of strings or numbers.
|
||||
render_devices = []
|
||||
if isinstance(req.render_devices, str):
|
||||
req.render_devices = req.render_devices.split(',')
|
||||
if isinstance(req.render_devices, list):
|
||||
for gpu in req.render_devices:
|
||||
if isinstance(req.render_devices, int):
|
||||
render_devices.append('GPU:' + gpu)
|
||||
else:
|
||||
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
|
||||
return ckpt_to_use
|
||||
render_devices.append(gpu)
|
||||
if isinstance(req.render_devices, int):
|
||||
render_devices.append('GPU:' + req.render_devices)
|
||||
if len(render_devices) > 0:
|
||||
config['render_devices'] = render_devices
|
||||
try:
|
||||
setConfig(config)
|
||||
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def resolve_model_to_use(model_name):
|
||||
if model_name in ('sd-v1-4', 'custom-model'):
|
||||
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
||||
def getModels():
|
||||
models = {
|
||||
'active': {
|
||||
'stable-diffusion': 'sd-v1-4',
|
||||
},
|
||||
'options': {
|
||||
'stable-diffusion': ['sd-v1-4'],
|
||||
},
|
||||
}
|
||||
|
||||
legacy_model_path = os.path.join(SD_DIR, model_name)
|
||||
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'):
|
||||
model_path = legacy_model_path
|
||||
# custom models
|
||||
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
|
||||
for file in os.listdir(sd_models_dir):
|
||||
if file.endswith('.ckpt'):
|
||||
model_name = os.path.splitext(file)[0]
|
||||
models['options']['stable-diffusion'].append(model_name)
|
||||
|
||||
# legacy
|
||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
||||
if os.path.exists(custom_weight_path):
|
||||
models['active']['stable-diffusion'] = 'custom-model'
|
||||
models['options']['stable-diffusion'].append('custom-model')
|
||||
|
||||
config = getConfig()
|
||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
||||
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
|
||||
|
||||
return models
|
||||
|
||||
@app.get('/get/{key:path}')
|
||||
def read_web_data(key:str=None):
|
||||
if not key: # /get without parameters, stable-diffusion easter egg.
|
||||
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
||||
elif key == 'app_config':
|
||||
config = getConfig(default_val=None)
|
||||
if config is None:
|
||||
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
|
||||
return JSONResponse(config, headers=NOCACHE_HEADERS)
|
||||
elif key == 'models':
|
||||
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
|
||||
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
|
||||
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
|
||||
else:
|
||||
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
||||
return model_path
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
task_manager.current_state_error = SystemExit('Application shutting down.')
|
||||
|
||||
@app.get('/')
|
||||
def read_root():
|
||||
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
|
||||
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
|
||||
|
||||
@app.get('/ping') # Get server and optionally session status.
|
||||
def ping(session_id:str=None):
|
||||
if not task_manager.render_thread.is_alive(): # Render thread is dead.
|
||||
if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(current_state_error))
|
||||
if task_manager.is_alive() <= 0: # Check that render threads are alive.
|
||||
if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
return HTTPException(status_code=500, detail='Render thread is dead.')
|
||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error))
|
||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
# Alive
|
||||
response = {'status': str(task_manager.current_state)}
|
||||
if session_id:
|
||||
@ -119,7 +250,7 @@ def render(req : task_manager.ImageRequest):
|
||||
new_task = task_manager.render(req)
|
||||
response = {
|
||||
'status': str(task_manager.current_state),
|
||||
'queue': task_manager.tasks_queue.qsize(),
|
||||
'queue': len(task_manager.tasks_queue),
|
||||
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
|
||||
'task': id(new_task)
|
||||
}
|
||||
@ -172,100 +303,13 @@ def get_image(session_id, img_id):
|
||||
except KeyError as e:
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post('/app_config')
|
||||
async def setAppConfig(req : SetAppConfigRequest):
|
||||
try:
|
||||
config = {
|
||||
'update_branch': req.update_branch
|
||||
}
|
||||
@app.get('/')
|
||||
def read_root():
|
||||
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
|
||||
|
||||
config_json_str = json.dumps(config)
|
||||
config_bat_str = f'@set update_branch={req.update_branch}'
|
||||
config_sh_str = f'export update_branch={req.update_branch}'
|
||||
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
|
||||
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
|
||||
|
||||
with open(config_json_path, 'w') as f:
|
||||
f.write(config_json_str)
|
||||
|
||||
with open(config_bat_path, 'w') as f:
|
||||
f.write(config_bat_str)
|
||||
|
||||
with open(config_sh_path, 'w') as f:
|
||||
f.write(config_sh_str)
|
||||
|
||||
return {'OK'}
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def getConfig(default_val={}):
|
||||
try:
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
if not os.path.exists(config_json_path):
|
||||
return default_val
|
||||
with open(config_json_path, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
print(traceback.format_exc())
|
||||
return default_val
|
||||
|
||||
def setConfig(config):
|
||||
try:
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
with open(config_json_path, 'w') as f:
|
||||
return json.dump(config, f)
|
||||
except:
|
||||
print(str(e))
|
||||
print(traceback.format_exc())
|
||||
|
||||
def getModels():
|
||||
models = {
|
||||
'active': {
|
||||
'stable-diffusion': 'sd-v1-4',
|
||||
},
|
||||
'options': {
|
||||
'stable-diffusion': ['sd-v1-4'],
|
||||
},
|
||||
}
|
||||
|
||||
# custom models
|
||||
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
|
||||
for file in os.listdir(sd_models_dir):
|
||||
if file.endswith('.ckpt'):
|
||||
model_name = os.path.splitext(file)[0]
|
||||
models['options']['stable-diffusion'].append(model_name)
|
||||
|
||||
# legacy
|
||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
||||
if os.path.exists(custom_weight_path):
|
||||
models['active']['stable-diffusion'] = 'custom-model'
|
||||
models['options']['stable-diffusion'].append('custom-model')
|
||||
|
||||
config = getConfig()
|
||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
||||
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
|
||||
|
||||
return models
|
||||
|
||||
@app.get('/get/{key:path}')
|
||||
def read_web_data(key:str=None):
|
||||
if not key: # /get without parameters, stable-diffusion easter egg.
|
||||
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
||||
elif key == 'app_config':
|
||||
config = getConfig(default_val=None)
|
||||
if config is None:
|
||||
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
|
||||
return JSONResponse(config, headers=NOCACHE_HEADERS)
|
||||
elif key == 'models':
|
||||
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
|
||||
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
|
||||
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
|
||||
else:
|
||||
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
task_manager.current_state_error = SystemExit('Application shutting down.')
|
||||
|
||||
# don't log certain requests
|
||||
class LogSuppressFilter(logging.Filter):
|
||||
@ -277,8 +321,26 @@ class LogSuppressFilter(logging.Filter):
|
||||
return True
|
||||
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
|
||||
|
||||
task_manager.default_model_to_load = get_initial_model_to_load()
|
||||
task_manager.start_render_thread()
|
||||
config = getConfig()
|
||||
# Start the task_manager
|
||||
task_manager.default_model_to_load = resolve_model_to_use()
|
||||
if 'render_devices' in config: # Start a new thread for each device.
|
||||
if isinstance(config['render_devices'], str):
|
||||
config['render_devices'] = config['render_devices'].split(',')
|
||||
if not isinstance(config['render_devices'], list):
|
||||
raise Exception('Invalid render_devices value in config.')
|
||||
for device in config['render_devices']:
|
||||
task_manager.start_render_thread(device)
|
||||
|
||||
allow_cpu = False
|
||||
if task_manager.is_alive() <= 0: # No running devices, apply defaults.
|
||||
# Select best device GPU device using free memory if more than one device.
|
||||
task_manager.start_render_thread('auto')
|
||||
allow_cpu = True
|
||||
|
||||
# Allow CPU to be used for renders if not already enabled in current config.
|
||||
if task_manager.is_alive('cpu') <= 0 and allow_cpu:
|
||||
task_manager.start_render_thread('cpu')
|
||||
|
||||
# start the browser ui
|
||||
import webbrowser; webbrowser.open('http://localhost:9000')
|
Loading…
x
Reference in New Issue
Block a user