diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 9a0357b4..0361c51b 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -17,6 +17,8 @@ @call git pull @call git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9 + @call git apply ..\ui\sd_internal\ddim_callback.patch + @cd .. ) else ( @echo. & echo "Downloading Stable Diffusion.." & echo. @@ -31,6 +33,9 @@ @cd stable-diffusion @call git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9 + + @call git apply ..\ui\sd_internal\ddim_callback.patch + @cd .. ) diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index cff3ddcb..b999bc63 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -18,6 +18,8 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta git pull git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9 + git apply ../ui/sd_internal/ddim_callback.patch + cd .. else printf "\n\nDownloading Stable Diffusion..\n\n" @@ -32,6 +34,9 @@ else cd stable-diffusion git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9 + + git apply ../ui/sd_internal/ddim_callback.patch + cd .. fi diff --git a/ui/sd_internal/ddim_callback.patch b/ui/sd_internal/ddim_callback.patch new file mode 100644 index 00000000..685c2ce5 --- /dev/null +++ b/ui/sd_internal/ddim_callback.patch @@ -0,0 +1,34 @@ +diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py +index dcf7901..1f99adc 100644 +--- a/optimizedSD/ddpm.py ++++ b/optimizedSD/ddpm.py +@@ -528,7 +528,8 @@ class UNet(DDPM): + elif sampler == "ddim": + samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, +- mask = mask,init_latent=x_T,use_original_steps=False) ++ mask = mask,init_latent=x_T,use_original_steps=False, ++ callback=callback, img_callback=img_callback) + + # elif sampler == "euler": + # cvd = CompVisDenoiser(self.alphas_cumprod) +@@ -687,7 +688,8 @@ class UNet(DDPM): + + @torch.no_grad() + def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, +- mask = None,init_latent=None,use_original_steps=False): ++ mask = None,init_latent=None,use_original_steps=False, ++ callback=None, img_callback=None): + + timesteps = self.ddim_timesteps + timesteps = timesteps[:t_start] +@@ -710,6 +712,9 @@ class UNet(DDPM): + x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) ++ ++ if callback: callback(i) ++ if img_callback: img_callback(x_dec, i) + + if mask is not None: + return x0 * mask + (1. - mask) * x_dec diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 90661961..5fe16ab5 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -35,6 +35,7 @@ from io import BytesIO # local session_id = str(uuid.uuid4())[-8:] +stop_processing = False ckpt_file = None gfpgan_file = None @@ -185,6 +186,9 @@ def load_model_real_esrgan(real_esrgan_to_use): def mk_img(req: Request): global modelFS, device global model_gfpgan, model_real_esrgan + global stop_processing + + stop_processing = False res = Response() res.images = [] @@ -320,11 +324,26 @@ def mk_img(req: Request): else: c = modelCS.get_learned_conditioning(prompts) + partial_x_samples = None + def img_callback(x_samples, i): + nonlocal partial_x_samples + + partial_x_samples = x_samples + + if stop_processing: + raise UserInitiatedStop("User requested that we stop processing") + # run the handler - if handler == _txt2img: - x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed) - else: - x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed) + try: + if handler == _txt2img: + x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback) + else: + x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback) + except UserInitiatedStop: + if partial_x_samples is None: + continue + + x_samples = partial_x_samples modelFS.to(device) @@ -411,7 +430,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps except: print('could not save the file', traceback.format_exc()) -def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed): +def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback): shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f] if device != "cpu": @@ -430,12 +449,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, unconditional_conditioning=uc, eta=opt_ddim_eta, x_T=start_code, + img_callback=img_callback, sampler = 'plms', ) return samples_ddim -def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed): +def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback): # encode (scaled latent) z_enc = model.stochastic_encode( init_latent, @@ -451,6 +471,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o z_enc, unconditional_guidance_scale=opt_scale, unconditional_conditioning=uc, + img_callback=img_callback, sampler = 'ddim' ) @@ -479,6 +500,8 @@ def load_model_from_config(ckpt, verbose=False): return sd # utils +class UserInitiatedStop(Exception): + pass def load_img(img_str): image = base64_str_to_img(img_str).convert("RGB") diff --git a/ui/server.py b/ui/server.py index 3fff7a9c..f55c1091 100644 --- a/ui/server.py +++ b/ui/server.py @@ -83,7 +83,7 @@ async def ping(): return HTTPException(status_code=500, detail=str(e)) @app.post('/image') -async def image(req : ImageRequest): +def image(req : ImageRequest): from sd_internal import runtime r = Request() @@ -114,6 +114,20 @@ async def image(req : ImageRequest): print(traceback.format_exc()) return HTTPException(status_code=500, detail=str(e)) +@app.get('/image/stop') +def stop(): + try: + if model_is_loading: + return {'ERROR'} + + from sd_internal import runtime + runtime.stop_processing = True + + return {'OK'} + except Exception as e: + print(traceback.format_exc()) + return HTTPException(status_code=500, detail=str(e)) + @app.post('/app_config') async def setAppConfig(req : SetAppConfigRequest): try: