Backend changes to support stopping a task mid-way. Uses a custom patch for the stable-diffusion codebase, to make it call a callback for DDIM

This commit is contained in:
cmdr2 2022-09-13 19:59:41 +05:30
parent 9ec2010ac2
commit e59c66ae26
5 changed files with 88 additions and 7 deletions

View File

@ -17,6 +17,8 @@
@call git pull
@call git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9
@call git apply ..\ui\sd_internal\ddim_callback.patch
@cd ..
) else (
@echo. & echo "Downloading Stable Diffusion.." & echo.
@ -31,6 +33,9 @@
@cd stable-diffusion
@call git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9
@call git apply ..\ui\sd_internal\ddim_callback.patch
@cd ..
)

View File

@ -18,6 +18,8 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
git pull
git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9
git apply ../ui/sd_internal/ddim_callback.patch
cd ..
else
printf "\n\nDownloading Stable Diffusion..\n\n"
@ -32,6 +34,9 @@ else
cd stable-diffusion
git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9
git apply ../ui/sd_internal/ddim_callback.patch
cd ..
fi

View File

@ -0,0 +1,34 @@
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
index dcf7901..1f99adc 100644
--- a/optimizedSD/ddpm.py
+++ b/optimizedSD/ddpm.py
@@ -528,7 +528,8 @@ class UNet(DDPM):
elif sampler == "ddim":
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
- mask = mask,init_latent=x_T,use_original_steps=False)
+ mask = mask,init_latent=x_T,use_original_steps=False,
+ callback=callback, img_callback=img_callback)
# elif sampler == "euler":
# cvd = CompVisDenoiser(self.alphas_cumprod)
@@ -687,7 +688,8 @@ class UNet(DDPM):
@torch.no_grad()
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
- mask = None,init_latent=None,use_original_steps=False):
+ mask = None,init_latent=None,use_original_steps=False,
+ callback=None, img_callback=None):
timesteps = self.ddim_timesteps
timesteps = timesteps[:t_start]
@@ -710,6 +712,9 @@ class UNet(DDPM):
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
+
+ if callback: callback(i)
+ if img_callback: img_callback(x_dec, i)
if mask is not None:
return x0 * mask + (1. - mask) * x_dec

View File

@ -35,6 +35,7 @@ from io import BytesIO
# local
session_id = str(uuid.uuid4())[-8:]
stop_processing = False
ckpt_file = None
gfpgan_file = None
@ -185,6 +186,9 @@ def load_model_real_esrgan(real_esrgan_to_use):
def mk_img(req: Request):
global modelFS, device
global model_gfpgan, model_real_esrgan
global stop_processing
stop_processing = False
res = Response()
res.images = []
@ -320,11 +324,26 @@ def mk_img(req: Request):
else:
c = modelCS.get_learned_conditioning(prompts)
partial_x_samples = None
def img_callback(x_samples, i):
nonlocal partial_x_samples
partial_x_samples = x_samples
if stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
# run the handler
if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed)
try:
if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback)
except UserInitiatedStop:
if partial_x_samples is None:
continue
x_samples = partial_x_samples
modelFS.to(device)
@ -411,7 +430,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except:
print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed):
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu":
@ -430,12 +449,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
unconditional_conditioning=uc,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
sampler = 'plms',
)
return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed):
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback):
# encode (scaled latent)
z_enc = model.stochastic_encode(
init_latent,
@ -451,6 +471,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
z_enc,
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
img_callback=img_callback,
sampler = 'ddim'
)
@ -479,6 +500,8 @@ def load_model_from_config(ckpt, verbose=False):
return sd
# utils
class UserInitiatedStop(Exception):
pass
def load_img(img_str):
image = base64_str_to_img(img_str).convert("RGB")

View File

@ -83,7 +83,7 @@ async def ping():
return HTTPException(status_code=500, detail=str(e))
@app.post('/image')
async def image(req : ImageRequest):
def image(req : ImageRequest):
from sd_internal import runtime
r = Request()
@ -114,6 +114,20 @@ async def image(req : ImageRequest):
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
@app.get('/image/stop')
def stop():
try:
if model_is_loading:
return {'ERROR'}
from sd_internal import runtime
runtime.stop_processing = True
return {'OK'}
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
try: