mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-08-13 09:47:16 +02:00
Backend changes to support stopping a task mid-way. Uses a custom patch for the stable-diffusion codebase, to make it call a callback for DDIM
This commit is contained in:
@ -35,6 +35,7 @@ from io import BytesIO
|
||||
|
||||
# local
|
||||
session_id = str(uuid.uuid4())[-8:]
|
||||
stop_processing = False
|
||||
|
||||
ckpt_file = None
|
||||
gfpgan_file = None
|
||||
@ -185,6 +186,9 @@ def load_model_real_esrgan(real_esrgan_to_use):
|
||||
def mk_img(req: Request):
|
||||
global modelFS, device
|
||||
global model_gfpgan, model_real_esrgan
|
||||
global stop_processing
|
||||
|
||||
stop_processing = False
|
||||
|
||||
res = Response()
|
||||
res.images = []
|
||||
@ -320,11 +324,26 @@ def mk_img(req: Request):
|
||||
else:
|
||||
c = modelCS.get_learned_conditioning(prompts)
|
||||
|
||||
partial_x_samples = None
|
||||
def img_callback(x_samples, i):
|
||||
nonlocal partial_x_samples
|
||||
|
||||
partial_x_samples = x_samples
|
||||
|
||||
if stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
|
||||
# run the handler
|
||||
if handler == _txt2img:
|
||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed)
|
||||
else:
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed)
|
||||
try:
|
||||
if handler == _txt2img:
|
||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback)
|
||||
else:
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback)
|
||||
except UserInitiatedStop:
|
||||
if partial_x_samples is None:
|
||||
continue
|
||||
|
||||
x_samples = partial_x_samples
|
||||
|
||||
modelFS.to(device)
|
||||
|
||||
@ -411,7 +430,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
|
||||
except:
|
||||
print('could not save the file', traceback.format_exc())
|
||||
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed):
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback):
|
||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
if device != "cpu":
|
||||
@ -430,12 +449,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
unconditional_conditioning=uc,
|
||||
eta=opt_ddim_eta,
|
||||
x_T=start_code,
|
||||
img_callback=img_callback,
|
||||
sampler = 'plms',
|
||||
)
|
||||
|
||||
return samples_ddim
|
||||
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed):
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback):
|
||||
# encode (scaled latent)
|
||||
z_enc = model.stochastic_encode(
|
||||
init_latent,
|
||||
@ -451,6 +471,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
z_enc,
|
||||
unconditional_guidance_scale=opt_scale,
|
||||
unconditional_conditioning=uc,
|
||||
img_callback=img_callback,
|
||||
sampler = 'ddim'
|
||||
)
|
||||
|
||||
@ -479,6 +500,8 @@ def load_model_from_config(ckpt, verbose=False):
|
||||
return sd
|
||||
|
||||
# utils
|
||||
class UserInitiatedStop(Exception):
|
||||
pass
|
||||
|
||||
def load_img(img_str):
|
||||
image = base64_str_to_img(img_str).convert("RGB")
|
||||
|
Reference in New Issue
Block a user