diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 526793fe..d8967e1f 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -15,7 +15,7 @@ @call git reset --hard @call git pull - @call git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9 + @call git checkout f6cfebffa752ee11a7b07497b8529d5971de916c @call git apply ..\ui\sd_internal\ddim_callback.patch @@ -32,7 +32,7 @@ ) @cd stable-diffusion - @call git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9 + @call git checkout f6cfebffa752ee11a7b07497b8529d5971de916c @call git apply ..\ui\sd_internal\ddim_callback.patch diff --git a/ui/index.html b/ui/index.html index d7e08dae..32133ea6 100644 --- a/ui/index.html +++ b/ui/index.html @@ -308,7 +308,7 @@
' + e.stack + '', res) setStatus('request', 'error', 'error') progressBar.style.display = 'none' + res = undefined } if (!res) { diff --git a/ui/sd_internal/ddim_callback.patch b/ui/sd_internal/ddim_callback.patch index 1fbd4380..7e40d614 100644 --- a/ui/sd_internal/ddim_callback.patch +++ b/ui/sd_internal/ddim_callback.patch @@ -1,7 +1,16 @@ diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py -index dcf7901..4028a70 100644 +index b967b55..75ddd8b 100644 --- a/optimizedSD/ddpm.py +++ b/optimizedSD/ddpm.py +@@ -22,7 +22,7 @@ from ldm.util import exists, default, instantiate_from_config + from ldm.modules.diffusionmodules.util import make_beta_schedule + from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +-from samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff ++from .samplers import CompVisDenoiser, get_ancestral_step, to_d, append_dims,linear_multistep_coeff + + def disabled_train(self): + """Overwrite model.train with this function to make sure train/eval mode @@ -485,6 +485,7 @@ class UNet(DDPM): log_every_t=100, unconditional_guidance_scale=1., @@ -25,11 +34,11 @@ index dcf7901..4028a70 100644 + callback=callback, img_callback=img_callback, + streaming_callbacks=streaming_callbacks) - # elif sampler == "euler": - # cvd = CompVisDenoiser(self.alphas_cumprod) -@@ -536,11 +540,15 @@ class UNet(DDPM): - # samples = self.heun_sampling(noise, sig, conditioning, unconditional_conditioning=unconditional_conditioning, - # unconditional_guidance_scale=unconditional_guidance_scale) + elif sampler == "euler": + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=False) +@@ -555,11 +559,15 @@ class UNet(DDPM): + samples = self.lms_sampling(self.alphas_cumprod,x_latent, S, conditioning, unconditional_conditioning=unconditional_conditioning, + unconditional_guidance_scale=unconditional_guidance_scale) + if streaming_callbacks: # this line needs to be right after the sampling() call + yield from samples @@ -44,7 +53,7 @@ index dcf7901..4028a70 100644 @torch.no_grad() def plms_sampling(self, cond,b, img, -@@ -548,7 +556,8 @@ class UNet(DDPM): +@@ -567,7 +575,8 @@ class UNet(DDPM): callback=None, quantize_denoised=False, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, @@ -54,13 +63,13 @@ index dcf7901..4028a70 100644 device = self.betas.device timesteps = self.ddim_timesteps -@@ -580,10 +589,22 @@ class UNet(DDPM): +@@ -599,10 +608,21 @@ class UNet(DDPM): old_eps.append(e_t) if len(old_eps) >= 4: old_eps.pop(0) - if callback: callback(i) - if img_callback: img_callback(pred_x0, i) - +- - return img + if callback: + if streaming_callbacks: @@ -80,7 +89,7 @@ index dcf7901..4028a70 100644 @torch.no_grad() def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, -@@ -687,7 +708,9 @@ class UNet(DDPM): +@@ -706,7 +726,9 @@ class UNet(DDPM): @torch.no_grad() def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, @@ -91,11 +100,10 @@ index dcf7901..4028a70 100644 timesteps = self.ddim_timesteps timesteps = timesteps[:t_start] -@@ -710,11 +733,25 @@ class UNet(DDPM): - x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, +@@ -730,10 +752,24 @@ class UNet(DDPM): unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) -+ + + if callback: + if streaming_callbacks: + yield from callback(i) @@ -106,7 +114,7 @@ index dcf7901..4028a70 100644 + yield from img_callback(x_dec, i) + else: + img_callback(x_dec, i) - ++ if mask is not None: - return x0 * mask + (1. - mask) * x_dec + x_dec = x0 * mask + (1. - mask) * x_dec @@ -119,3 +127,16 @@ index dcf7901..4028a70 100644 @torch.no_grad() +diff --git a/optimizedSD/openaimodelSplit.py b/optimizedSD/openaimodelSplit.py +index abc3098..7a32ffe 100644 +--- a/optimizedSD/openaimodelSplit.py ++++ b/optimizedSD/openaimodelSplit.py +@@ -13,7 +13,7 @@ from ldm.modules.diffusionmodules.util import ( + normalization, + timestep_embedding, + ) +-from splitAttention import SpatialTransformer ++from .splitAttention import SpatialTransformer + + + class AttentionPool2d(nn.Module): diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index ed6e7c80..9434414e 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -189,8 +189,23 @@ def mk_img(req: Request): try: yield from do_mk_img(req) except Exception as e: + print(traceback.format_exc()) + gc() - raise e + + if device != "cpu": + modelFS.to("cpu") + modelCS.to("cpu") + + model.model1.to("cpu") + model.model2.to("cpu") + + gc() + + yield json.dumps({ + "status": 'failed', + "detail": str(e) + }) def do_mk_img(req: Request): global model, modelCS, modelFS, device @@ -306,11 +321,7 @@ def do_mk_img(req: Request): if device != "cpu" and precision == "autocast": mask = mask.half() - if device != "cpu": - mem = torch.cuda.memory_allocated() / 1e6 - modelFS.to("cpu") - while torch.cuda.memory_allocated() / 1e6 >= mem: - time.sleep(1) + move_fs_to_cpu() assert 0. <= opt_strength <= 1., 'can only work with strength in [0.0, 1.0]' t_enc = int(opt_strength * opt_ddim_steps) @@ -359,7 +370,7 @@ def do_mk_img(req: Request): if req.stream_progress_updates: progress = {"step": i, "total_steps": opt_ddim_steps} - if req.stream_image_progress: + if req.stream_image_progress and i % 5 == 0: partial_images = [] for i in range(batch_size): @@ -478,12 +489,8 @@ def do_mk_img(req: Request): seeds += str(opt_seed) + "," opt_seed += 1 + move_fs_to_cpu() gc() - if device != "cpu": - mem = torch.cuda.memory_allocated() / 1e6 - modelFS.to("cpu") - while torch.cuda.memory_allocated() / 1e6 >= mem: - time.sleep(1) del x_samples, x_samples_ddim, x_sample print("memory_final = ", torch.cuda.memory_allocated() / 1e6) @@ -569,6 +576,13 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o else: return samples_ddim +def move_fs_to_cpu(): + if device != "cpu": + mem = torch.cuda.memory_allocated() / 1e6 + modelFS.to("cpu") + while torch.cuda.memory_allocated() / 1e6 >= mem: + time.sleep(1) + def gc(): if device == 'cpu': return diff --git a/ui/server.py b/ui/server.py index 3d7bec53..260e9b31 100644 --- a/ui/server.py +++ b/ui/server.py @@ -139,16 +139,24 @@ def image(req : ImageRequest): r.use_face_correction = req.use_face_correction r.show_only_filtered_image = req.show_only_filtered_image - r.stream_progress_updates = req.stream_progress_updates + r.stream_progress_updates = True # the underlying implementation only supports streaming r.stream_image_progress = req.stream_image_progress try: + if not req.stream_progress_updates: + r.stream_image_progress = False + res = runtime.mk_img(r) - if r.stream_progress_updates: + if req.stream_progress_updates: return StreamingResponse(res, media_type='application/json') - else: - return res.json() + else: # compatibility mode: buffer the streaming responses, and return the last one + last_result = None + + for result in res: + last_result = result + + return json.loads(last_result) except Exception as e: print(traceback.format_exc()) return HTTPException(status_code=500, detail=str(e))