Backend changes to support stopping a task mid-way. Uses a custom patch for the stable-diffusion codebase, to make it call a callback for DDIM

This commit is contained in:
cmdr2
2022-09-13 19:59:41 +05:30
parent 9ec2010ac2
commit e59c66ae26
5 changed files with 88 additions and 7 deletions

View File

@ -35,6 +35,7 @@ from io import BytesIO
# local
session_id = str(uuid.uuid4())[-8:]
stop_processing = False
ckpt_file = None
gfpgan_file = None
@ -185,6 +186,9 @@ def load_model_real_esrgan(real_esrgan_to_use):
def mk_img(req: Request):
global modelFS, device
global model_gfpgan, model_real_esrgan
global stop_processing
stop_processing = False
res = Response()
res.images = []
@ -320,11 +324,26 @@ def mk_img(req: Request):
else:
c = modelCS.get_learned_conditioning(prompts)
partial_x_samples = None
def img_callback(x_samples, i):
nonlocal partial_x_samples
partial_x_samples = x_samples
if stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
# run the handler
if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed)
try:
if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback)
except UserInitiatedStop:
if partial_x_samples is None:
continue
x_samples = partial_x_samples
modelFS.to(device)
@ -411,7 +430,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except:
print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed):
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu":
@ -430,12 +449,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
sampler = 'plms',
)
return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed):
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback):
# encode (scaled latent)
z_enc = model.stochastic_encode(
init_latent,
@ -451,6 +471,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
z_enc,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
img_callback=img_callback,
sampler = 'ddim'
)
@ -479,6 +500,8 @@ def load_model_from_config(ckpt, verbose=False):
return sd
# utils
class UserInitiatedStop(Exception):
pass
def load_img(img_str):
image = base64_str_to_img(img_str).convert("RGB")