First draft for Multi-GPU support

This commit is contained in:
Marc-Andre Ferland
2022-10-16 21:41:39 -04:00
parent 2edc06c662
commit 7c72608e1c
3 changed files with 584 additions and 376 deletions

View File

@ -35,63 +35,144 @@ import base64
from io import BytesIO
#from colorama import Fore
# local
stop_processing = False
temp_images = {}
from threading import local as LocalThreadVars
thread_data = LocalThreadVars()
ckpt_file = None
gfpgan_file = None
real_esrgan_file = None
def device_would_fail(device):
if device == 'cpu': return None
# Returns None when no issues found, otherwise returns the detected error str.
# Memory check
try:
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9)
if mem_total < 3.0:
return 'GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion'
except RuntimeError as e:
return str(e) # Return cuda errors from mem_get_info as strings
return None
model = None
modelCS = None
modelFS = None
model_gfpgan = None
model_real_esrgan = None
def device_select(device):
if device == 'cpu': return True
if not torch.cuda.is_available(): return False
failure_msg = device_would_fail(device)
if failure_msg:
if 'invalid device' in failure_msg:
raise NameError(f'GPU "{device}" could not be found. Remove this device from config.render_devices or use one of "auto" or "cuda".')
print(failure_msg)
return False
model_is_half = False
model_fs_is_half = False
device = None
unet_bs = 1
precision = 'autocast'
sampler_plms = None
sampler_ddim = None
device_name = torch.cuda.get_device_name(device)
has_valid_gpu = False
force_full_precision = False
try:
gpu = torch.cuda.current_device()
gpu_name = torch.cuda.get_device_name(gpu)
print('GPU detected: ', gpu_name)
force_full_precision = ('nvidia' in gpu_name.lower() or 'geforce' in gpu_name.lower()) and (' 1660' in gpu_name or ' 1650' in gpu_name) # otherwise these NVIDIA cards create green images
if force_full_precision:
# otherwise these NVIDIA cards create green images
thread_data.force_full_precision = ('nvidia' in device_name.lower() or 'geforce' in device_name.lower()) and (' 1660' in device_name or ' 1650' in device_name)
if thread_data.force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', gpu_name)
mem_free, mem_total = torch.cuda.mem_get_info(gpu)
mem_total /= float(10**9)
if mem_total < 3.0:
print("GPUs with less than 3 GB of VRAM are not compatible with Stable Diffusion")
raise Exception()
thread_data.device = device
thread_data.has_valid_gpu = True
return True
has_valid_gpu = True
except:
def device_init(device_selection=None):
# Thread bound properties
thread_data.stop_processing = False
thread_data.temp_images = {}
thread_data.ckpt_file = None
thread_data.gfpgan_file = None
thread_data.real_esrgan_file = None
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
thread_data.model_gfpgan = None
thread_data.model_real_esrgan = None
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
thread_data.device = None
thread_data.unet_bs = 1
thread_data.precision = 'autocast'
thread_data.sampler_plms = None
thread_data.sampler_ddim = None
thread_data.turbo = False
thread_data.has_valid_gpu = False
thread_data.force_full_precision = False
if device_selection.lower() == 'cpu':
print('CPU requested, skipping gpu init.')
thread_data.device = 'cpu'
return
if not torch.cuda.is_available():
print('WARNING: torch.cuda is not available. Using the CPU, but this will be very slow!')
return
device_count = torch.cuda.device_count()
if device_count <= 1 and device_selection == 'auto':
device_selection = 'current' # Use 'auto' only when there is more than one compatible device found.
if device_selection == 'auto':
print('Autoselecting GPU. Using most free memory.')
max_mem_free = 0
best_device = None
for device in range(device_count):
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_free /= float(10**9)
mem_total /= float(10**9)
device_name = torch.cuda.get_device_name(device)
print(f'GPU:{device} detected: {device_name} - Memory: {round(mem_total - mem_free, 2)}Go / {round(mem_total, 2)}Go')
if max_mem_free < mem_free:
max_mem_free = mem_free
best_device = device
if best_device and device_select(device):
print(f'Setting GPU:{device} as active')
torch.cuda.device(device)
return
if isinstance(device_selection, str):
device_selection = device_selection.lower()
if device_selection.startswith('gpu:'):
device_selection = int(device_selection[4:])
if device_selection != 'cuda' and device_selection != 'current':
if device_select(device_selection):
if isinstance(device_selection, int):
print(f'Setting GPU:{device_selection} as active')
else:
print(f'Setting {device_selection} as active')
torch.cuda.device(device_selection)
return
# By default use current device.
print('Checking current GPU...')
device = torch.cuda.current_device()
device_name = torch.cuda.get_device_name(device)
print(f'GPU:{device} detected: {device_name}')
if device_select(device):
return
print('WARNING: No compatible GPU found. Using the CPU, but this will be very slow!')
pass
thread_data.device = 'cpu'
def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_use=1, precision_to_use='autocast'):
global ckpt_file, model, modelCS, modelFS, model_is_half, device, unet_bs, precision, model_fs_is_half
def is_first_cuda_device(device):
if thread_data.device == 0 or thread_data.device == '0':
return True
if thread_data.device == 'cuda' or thread_data.device == 'cuda:0':
return True
if thread_data.device == torch.device(0):
return True
return False
device = device_to_use if has_valid_gpu else 'cpu'
precision = precision_to_use if not force_full_precision else 'full'
unet_bs = unet_bs_to_use
def load_model_ckpt():
if not thread_data.ckpt_file: raise ValueError(f'Thread ckpt_file is undefined.')
if not os.path.exists(thread_data.ckpt_file + '.ckpt'): raise FileNotFoundError(f'Cannot find {thread_data.ckpt_file}.ckpt')
if not thread_data.precision:
thread_data.precision = 'full' if thread_data.force_full_precision else 'autocast'
if not thread_data.unet_bs:
thread_data.unet_bs = 1
unload_model()
if device == 'cpu':
precision = 'full'
if thread_data.device == 'cpu':
thread_data.precision = 'full'
sd = load_model_from_config(f"{ckpt_to_use}.ckpt")
print('loading', thread_data.ckpt_file, 'to', thread_data.device, 'using precision', thread_data.precision)
sd = load_model_from_config(thread_data.ckpt_file + '.ckpt')
li, lo = [], []
for key, value in sd.items():
sp = key.split(".")
@ -114,88 +195,84 @@ def load_model_ckpt(ckpt_to_use, device_to_use='cuda', turbo=False, unet_bs_to_u
model = instantiate_from_config(config.modelUNet)
_, _ = model.load_state_dict(sd, strict=False)
model.eval()
model.cdevice = device
model.unet_bs = unet_bs
model.turbo = turbo
model.cdevice = torch.device(thread_data.device)
model.unet_bs = thread_data.unet_bs
model.turbo = thread_data.turbo
if thread_data.device != 'cpu':
model.to(thread_data.device)
thread_data.model = model
modelCS = instantiate_from_config(config.modelCondStage)
_, _ = modelCS.load_state_dict(sd, strict=False)
modelCS.eval()
modelCS.cond_stage_model.device = device
modelCS.cond_stage_model.device = torch.device(thread_data.device)
if thread_data.device != 'cpu':
modelCS.to(thread_data.device)
thread_data.modelCS = modelCS
modelFS = instantiate_from_config(config.modelFirstStage)
_, _ = modelFS.load_state_dict(sd, strict=False)
modelFS.eval()
if thread_data.device != 'cpu':
modelFS.to(thread_data.device)
thread_data.modelFS = modelFS
del sd
if device != "cpu" and precision == "autocast":
model.half()
modelCS.half()
modelFS.half()
model_is_half = True
model_fs_is_half = True
if thread_data.device != "cpu" and thread_data.precision == "autocast":
thread_data.model.half()
thread_data.modelCS.half()
thread_data.modelFS.half()
thread_data.model_is_half = True
thread_data.model_fs_is_half = True
else:
model_is_half = False
model_fs_is_half = False
thread_data.model_is_half = False
thread_data.model_fs_is_half = False
ckpt_file = ckpt_to_use
print('loaded ', ckpt_file, 'to', device, 'precision', precision)
print('loaded', thread_data.ckpt_file, 'as', model.device, '->', modelCS.cond_stage_model.device, '->', thread_data.modelFS.device, 'using precision', thread_data.precision)
def unload_model():
global model, modelCS, modelFS
if thread_data.model is not None:
print('Unloading models...')
del thread_data.model
del thread_data.modelCS
del thread_data.modelFS
thread_data.model = None
thread_data.modelCS = None
thread_data.modelFS = None
if model is not None:
del model
del modelCS
del modelFS
model = None
modelCS = None
modelFS = None
def load_model_gfpgan(gfpgan_to_use):
global gfpgan_file, model_gfpgan
if gfpgan_to_use is None:
def load_model_gfpgan():
if thread_data.gfpgan_file is None:
print('load_model_gfpgan called without setting gfpgan_file')
return
if thread_data.device != 'cpu' and not is_first_cuda_device(thread_data.device):
#TODO Remove when fixed - A bug with GFPGANer and facexlib needs to be fixed before use on other devices.
raise Exception(f'Current device {torch.device(thread_data.device)} is not {torch.device(0)}.')
model_path = thread_data.gfpgan_file + ".pth"
thread_data.model_gfpgan = GFPGANer(device=torch.device(thread_data.device), model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
print('loaded', thread_data.gfpgan_file, 'to', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
gfpgan_file = gfpgan_to_use
model_path = gfpgan_to_use + ".pth"
if device == 'cpu':
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cpu'))
else:
model_gfpgan = GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None, device=torch.device('cuda'))
print('loaded ', gfpgan_to_use, 'to', device, 'precision', precision)
def load_model_real_esrgan(real_esrgan_to_use):
global real_esrgan_file, model_real_esrgan
if real_esrgan_to_use is None:
def load_model_real_esrgan():
if thread_data.real_esrgan_file is None:
print('load_model_real_esrgan called without setting real_esrgan_file')
return
real_esrgan_file = real_esrgan_to_use
model_path = real_esrgan_to_use + ".pth"
model_path = thread_data.real_esrgan_file + ".pth"
RealESRGAN_models = {
'RealESRGAN_x4plus': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4),
'RealESRGAN_x4plus_anime_6B': RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
}
model_to_use = RealESRGAN_models[real_esrgan_to_use]
model_to_use = RealESRGAN_models[thread_data.real_esrgan_file]
if device == 'cpu':
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
model_real_esrgan.device = torch.device('cpu')
model_real_esrgan.model.to('cpu')
if thread_data.device == 'cpu':
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=False) # cpu does not support half
#thread_data.model_real_esrgan.device = torch.device(thread_data.device)
thread_data.model_real_esrgan.model.to('cpu')
else:
model_real_esrgan = RealESRGANer(scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=model_is_half)
thread_data.model_real_esrgan = RealESRGANer(device=torch.device(thread_data.device), scale=2, model_path=model_path, model=model_to_use, pre_pad=0, half=thread_data.model_is_half)
model_real_esrgan.model.name = real_esrgan_to_use
print('loaded ', real_esrgan_to_use, 'to', device, 'precision', precision)
thread_data.model_real_esrgan.model.name = thread_data.real_esrgan_file
print('loaded ', thread_data.real_esrgan_file, 'to', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
if disk_path is None: return None
@ -214,14 +291,22 @@ def get_base_path(disk_path, session_id, prompt, ext, suffix=None):
def apply_filters(filter_name, image_data):
print(f'Applying filter {filter_name}...')
if isinstance(image_data, torch.Tensor):
print(image_data)
image_data.to(thread_data.device)
gc()
if filter_name == 'gfpgan':
_, _, output = model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
if thread_data.model_gfpgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.gfpgan_file, 'on', thread_data.model_gfpgan.device, 'precision', thread_data.precision)
_, _, output = thread_data.model_gfpgan.enhance(image_data[:,:,::-1], has_aligned=False, only_center_face=False, paste_back=True)
image_data = output[:,:,::-1]
if filter_name == 'real_esrgan':
output, _ = model_real_esrgan.enhance(image_data[:,:,::-1])
if thread_data.model_real_esrgan is None: raise Exception('Model "gfpgan" not loaded.')
print('enhance with', thread_data.real_esrgan_file, 'on', thread_data.model_real_esrgan.device, 'precision', thread_data.precision)
output, _ = thread_data.model_real_esrgan.enhance(image_data[:,:,::-1])
image_data = output[:,:,::-1]
return image_data
@ -234,12 +319,12 @@ def mk_img(req: Request):
gc()
if device != "cpu":
modelFS.to("cpu")
modelCS.to("cpu")
if thread_data.device != "cpu":
thread_data.modelFS.to("cpu")
thread_data.modelCS.to("cpu")
model.model1.to("cpu")
model.model2.to("cpu")
thread_data.model.model1.to("cpu")
thread_data.model.model2.to("cpu")
gc()
@ -249,66 +334,55 @@ def mk_img(req: Request):
})
def do_mk_img(req: Request):
global ckpt_file
global model, modelCS, modelFS, device
global model_gfpgan, model_real_esrgan
global stop_processing
stop_processing = False
thread_data.stop_processing = False
res = Response()
res.request = req
res.images = []
temp_images.clear()
thread_data.temp_images.clear()
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
ckpt_to_use = ckpt_file
if ckpt_to_use != req.use_stable_diffusion_model:
ckpt_to_use = req.use_stable_diffusion_model
if thread_data.ckpt_file != req.use_stable_diffusion_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
needs_model_reload = True
model.turbo = req.turbo
if req.use_cpu:
if device != 'cpu':
device = 'cpu'
if model_is_half:
load_model_ckpt(ckpt_to_use, device)
if thread_data.device != 'cpu':
thread_data.device = 'cpu'
if thread_data.model_is_half:
load_model_ckpt()
needs_model_reload = False
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
load_model_gfpgan()
load_model_real_esrgan()
else:
if has_valid_gpu:
prev_device = device
device = 'cuda'
if (precision == 'autocast' and (req.use_full_precision or not model_is_half)) or \
(precision == 'full' and not req.use_full_precision and not force_full_precision):
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, ('full' if req.use_full_precision else 'autocast'))
if thread_data.has_valid_gpu:
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
load_model_ckpt()
load_model_gfpgan()
load_model_real_esrgan()
needs_model_reload = False
if prev_device != device:
load_model_gfpgan(gfpgan_file)
load_model_real_esrgan(real_esrgan_file)
if needs_model_reload:
load_model_ckpt(ckpt_to_use, device, req.turbo, unet_bs, precision)
load_model_ckpt()
if req.use_face_correction != gfpgan_file:
load_model_gfpgan(req.use_face_correction)
if req.use_face_correction != thread_data.gfpgan_file:
thread_data.gfpgan_file = req.use_face_correction
load_model_gfpgan()
if req.use_upscale != thread_data.real_esrgan_file:
thread_data.real_esrgan_file = req.use_upscale
load_model_real_esrgan()
if req.use_upscale != real_esrgan_file:
load_model_real_esrgan(req.use_upscale)
model.cdevice = device
modelCS.cond_stage_model.device = device
if thread_data.turbo != req.turbo:
thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo
opt_prompt = req.prompt
opt_seed = req.seed
@ -318,9 +392,8 @@ def do_mk_img(req: Request):
opt_ddim_eta = 0.0
opt_init_img = req.init_image
print(req.to_string(), '\n device', device)
print('\n\n Using precision:', precision)
print(req.to_string(), '\n device', thread_data.device)
print('\n\n Using precision:', thread_data.precision)
seed_everything(opt_seed)
@ -329,7 +402,7 @@ def do_mk_img(req: Request):
assert prompt is not None
data = [batch_size * [prompt]]
if precision == "autocast" and device != "cpu":
if thread_data.precision == "autocast" and thread_data.device != "cpu":
precision_scope = autocast
else:
precision_scope = nullcontext
@ -345,22 +418,22 @@ def do_mk_img(req: Request):
handler = _img2img
init_image = load_img(req.init_image, req.width, req.height)
init_image = init_image.to(device)
init_image = init_image.to(thread_data.device)
if device != "cpu" and precision == "autocast":
if thread_data.device != "cpu" and thread_data.precision == "autocast":
init_image = init_image.half()
modelFS.to(device)
thread_data.modelFS.to(thread_data.device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_latent = modelFS.get_first_stage_encoding(modelFS.encode_first_stage(init_image)) # move to latent space
init_latent = thread_data.modelFS.get_first_stage_encoding(thread_data.modelFS.encode_first_stage(init_image)) # move to latent space
if req.mask is not None:
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(device)
mask = load_mask(req.mask, req.width, req.height, init_latent.shape[2], init_latent.shape[3], True).to(thread_data.device)
mask = mask[0][0].unsqueeze(0).repeat(4, 1, 1).unsqueeze(0)
mask = repeat(mask, '1 ... -> b ...', b=batch_size)
if device != "cpu" and precision == "autocast":
if thread_data.device != "cpu" and thread_data.precision == "autocast":
mask = mask.half()
move_fs_to_cpu()
@ -381,10 +454,10 @@ def do_mk_img(req: Request):
for prompts in tqdm(data, desc="data"):
with precision_scope("cuda"):
modelCS.to(device)
thread_data.modelCS.to(thread_data.device)
uc = None
if req.guidance_scale != 1.0:
uc = modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
uc = thread_data.modelCS.get_learned_conditioning(batch_size * [req.negative_prompt])
if isinstance(prompts, tuple):
prompts = list(prompts)
@ -397,11 +470,11 @@ def do_mk_img(req: Request):
weight = weights[i]
# if not skip_normalize:
weight = weight / totalWeight
c = torch.add(c, modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
c = torch.add(c, thread_data.modelCS.get_learned_conditioning(subprompts[i]), alpha=weight)
else:
c = modelCS.get_learned_conditioning(prompts)
c = thread_data.modelCS.get_learned_conditioning(prompts)
modelFS.to(device)
thread_data.modelFS.to(thread_data.device)
partial_x_samples = None
def img_callback(x_samples, i):
@ -417,7 +490,7 @@ def do_mk_img(req: Request):
partial_images = []
for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
@ -429,18 +502,19 @@ def do_mk_img(req: Request):
del img, x_sample, x_samples_ddim
# don't delete x_samples, it is used in the code that called this callback
temp_images[str(req.session_id) + '/' + str(i)] = buf
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
progress['output'] = partial_images
yield json.dumps(progress)
if stop_processing:
if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
# run the handler
try:
print('Running handler...')
if handler == _txt2img:
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
else:
@ -458,7 +532,7 @@ def do_mk_img(req: Request):
print("saving images")
for i in range(batch_size):
x_samples_ddim = modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_samples_ddim = thread_data.modelFS.decode_first_stage(x_samples[i].unsqueeze(0))
x_sample = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
@ -469,7 +543,7 @@ def do_mk_img(req: Request):
return_orig_img = not has_filters or not req.show_only_filtered_image
if stop_processing:
if thread_data.stop_processing:
return_orig_img = True
if req.save_to_disk_path is not None:
@ -489,7 +563,7 @@ def do_mk_img(req: Request):
del img
if has_filters and not stop_processing:
if has_filters and not thread_data.stop_processing:
filters_applied = []
if req.use_face_correction:
x_sample = apply_filters('gfpgan', x_sample)
@ -514,7 +588,7 @@ def do_mk_img(req: Request):
move_fs_to_cpu()
gc()
del x_samples, x_samples_ddim, x_sample
print("memory_final = ", torch.cuda.memory_allocated() / 1e6)
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mo')
print('Task completed')
@ -527,7 +601,7 @@ def save_image(img, img_out_path):
print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, req, prompt, opt_seed):
metadata = f"""{prompt}
metadata = f'''{prompt}
Width: {req.width}
Height: {req.height}
Seed: {opt_seed}
@ -538,8 +612,8 @@ Use Face Correction: {req.use_face_correction}
Use Upscaling: {req.use_upscale}
Sampler: {req.sampler}
Negative Prompt: {req.negative_prompt}
Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
"""
Stable Diffusion model: {req.use_stable_diffusion_model + '.ckpt'}
'''
try:
with open(meta_out_path, 'w') as f:
f.write(metadata)
@ -549,16 +623,19 @@ Stable Diffusion Model: {req.use_stable_diffusion_model + '.ckpt'}
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelCS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
if thread_data.device != "cpu":
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
print('Device:', thread_data.device, 'CS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
thread_data.modelCS.to("cpu")
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
print('Device:', thread_data.device, 'Waiting Memory transfer. Memory Used:', round(mem, 2), 'Mo')
time.sleep(1)
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
if sampler_name == 'ddim':
model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = model.sample(
samples_ddim = thread_data.model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
@ -572,14 +649,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
mask=mask,
sampler = sampler_name,
)
yield from samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
# encode (scaled latent)
z_enc = model.stochastic_encode(
z_enc = thread_data.model.stochastic_encode(
init_latent,
torch.tensor([t_enc] * batch_size).to(device),
torch.tensor([t_enc] * batch_size).to(thread_data.device),
opt_seed,
opt_ddim_eta,
opt_ddim_steps,
@ -587,7 +663,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
x_T = None if mask is None else init_latent
# decode it
samples_ddim = model.sample(
samples_ddim = thread_data.model.sample(
t_enc,
c,
z_enc,
@ -602,16 +678,19 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
yield from samples_ddim
def move_fs_to_cpu():
if device != "cpu":
mem = torch.cuda.memory_allocated() / 1e6
modelFS.to("cpu")
while torch.cuda.memory_allocated() / 1e6 >= mem:
if thread_data.device != "cpu":
mem = torch.cuda.memory_allocated(thread_data.device) / 1e6
print('Device:', thread_data.device, 'FS_Model, Memory transfer starting. Memory Used:', round(mem, 2), 'Mo')
thread_data.modelFS.to("cpu")
while torch.cuda.memory_allocated(thread_data.device) / 1e6 >= mem and mem > 0:
print('Device:', thread_data.device, 'Waiting for Memory transfer. Memory Used:', round(mem, 2), 'Mo')
time.sleep(1)
print('Transfered', round(mem - torch.cuda.memory_allocated(thread_data.device) / 1e6, 2), 'Mo')
def gc():
if device == 'cpu':
#gc.collect()
if thread_data.device == 'cpu':
return
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
@ -621,7 +700,6 @@ def chunk(it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def load_model_from_config(ckpt, verbose=False):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")