forked from extern/easydiffusion
Backend changes to support stopping a task mid-way. Uses a custom patch for the stable-diffusion codebase, to make it call a callback for DDIM
This commit is contained in:
parent
9ec2010ac2
commit
e59c66ae26
@ -17,6 +17,8 @@
|
||||
@call git pull
|
||||
@call git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9
|
||||
|
||||
@call git apply ..\ui\sd_internal\ddim_callback.patch
|
||||
|
||||
@cd ..
|
||||
) else (
|
||||
@echo. & echo "Downloading Stable Diffusion.." & echo.
|
||||
@ -31,6 +33,9 @@
|
||||
|
||||
@cd stable-diffusion
|
||||
@call git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9
|
||||
|
||||
@call git apply ..\ui\sd_internal\ddim_callback.patch
|
||||
|
||||
@cd ..
|
||||
)
|
||||
|
||||
|
@ -18,6 +18,8 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
|
||||
git pull
|
||||
git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9
|
||||
|
||||
git apply ../ui/sd_internal/ddim_callback.patch
|
||||
|
||||
cd ..
|
||||
else
|
||||
printf "\n\nDownloading Stable Diffusion..\n\n"
|
||||
@ -32,6 +34,9 @@ else
|
||||
|
||||
cd stable-diffusion
|
||||
git checkout d154155d4c0b43e13ec1f00eb72b7ff9d522fcf9
|
||||
|
||||
git apply ../ui/sd_internal/ddim_callback.patch
|
||||
|
||||
cd ..
|
||||
fi
|
||||
|
||||
|
34
ui/sd_internal/ddim_callback.patch
Normal file
34
ui/sd_internal/ddim_callback.patch
Normal file
@ -0,0 +1,34 @@
|
||||
diff --git a/optimizedSD/ddpm.py b/optimizedSD/ddpm.py
|
||||
index dcf7901..1f99adc 100644
|
||||
--- a/optimizedSD/ddpm.py
|
||||
+++ b/optimizedSD/ddpm.py
|
||||
@@ -528,7 +528,8 @@ class UNet(DDPM):
|
||||
elif sampler == "ddim":
|
||||
samples = self.ddim_sampling(x_latent, conditioning, S, unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
- mask = mask,init_latent=x_T,use_original_steps=False)
|
||||
+ mask = mask,init_latent=x_T,use_original_steps=False,
|
||||
+ callback=callback, img_callback=img_callback)
|
||||
|
||||
# elif sampler == "euler":
|
||||
# cvd = CompVisDenoiser(self.alphas_cumprod)
|
||||
@@ -687,7 +688,8 @@ class UNet(DDPM):
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
|
||||
- mask = None,init_latent=None,use_original_steps=False):
|
||||
+ mask = None,init_latent=None,use_original_steps=False,
|
||||
+ callback=None, img_callback=None):
|
||||
|
||||
timesteps = self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
@@ -710,6 +712,9 @@ class UNet(DDPM):
|
||||
x_dec = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
+
|
||||
+ if callback: callback(i)
|
||||
+ if img_callback: img_callback(x_dec, i)
|
||||
|
||||
if mask is not None:
|
||||
return x0 * mask + (1. - mask) * x_dec
|
@ -35,6 +35,7 @@ from io import BytesIO
|
||||
|
||||
# local
|
||||
session_id = str(uuid.uuid4())[-8:]
|
||||
stop_processing = False
|
||||
|
||||
ckpt_file = None
|
||||
gfpgan_file = None
|
||||
@ -185,6 +186,9 @@ def load_model_real_esrgan(real_esrgan_to_use):
|
||||
def mk_img(req: Request):
|
||||
global modelFS, device
|
||||
global model_gfpgan, model_real_esrgan
|
||||
global stop_processing
|
||||
|
||||
stop_processing = False
|
||||
|
||||
res = Response()
|
||||
res.images = []
|
||||
@ -320,11 +324,26 @@ def mk_img(req: Request):
|
||||
else:
|
||||
c = modelCS.get_learned_conditioning(prompts)
|
||||
|
||||
partial_x_samples = None
|
||||
def img_callback(x_samples, i):
|
||||
nonlocal partial_x_samples
|
||||
|
||||
partial_x_samples = x_samples
|
||||
|
||||
if stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
|
||||
# run the handler
|
||||
if handler == _txt2img:
|
||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed)
|
||||
else:
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed)
|
||||
try:
|
||||
if handler == _txt2img:
|
||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback)
|
||||
else:
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback)
|
||||
except UserInitiatedStop:
|
||||
if partial_x_samples is None:
|
||||
continue
|
||||
|
||||
x_samples = partial_x_samples
|
||||
|
||||
modelFS.to(device)
|
||||
|
||||
@ -411,7 +430,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
|
||||
except:
|
||||
print('could not save the file', traceback.format_exc())
|
||||
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed):
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback):
|
||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
if device != "cpu":
|
||||
@ -430,12 +449,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
unconditional_conditioning=uc,
|
||||
eta=opt_ddim_eta,
|
||||
x_T=start_code,
|
||||
img_callback=img_callback,
|
||||
sampler = 'plms',
|
||||
)
|
||||
|
||||
return samples_ddim
|
||||
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed):
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback):
|
||||
# encode (scaled latent)
|
||||
z_enc = model.stochastic_encode(
|
||||
init_latent,
|
||||
@ -451,6 +471,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
z_enc,
|
||||
unconditional_guidance_scale=opt_scale,
|
||||
unconditional_conditioning=uc,
|
||||
img_callback=img_callback,
|
||||
sampler = 'ddim'
|
||||
)
|
||||
|
||||
@ -479,6 +500,8 @@ def load_model_from_config(ckpt, verbose=False):
|
||||
return sd
|
||||
|
||||
# utils
|
||||
class UserInitiatedStop(Exception):
|
||||
pass
|
||||
|
||||
def load_img(img_str):
|
||||
image = base64_str_to_img(img_str).convert("RGB")
|
||||
|
16
ui/server.py
16
ui/server.py
@ -83,7 +83,7 @@ async def ping():
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post('/image')
|
||||
async def image(req : ImageRequest):
|
||||
def image(req : ImageRequest):
|
||||
from sd_internal import runtime
|
||||
|
||||
r = Request()
|
||||
@ -114,6 +114,20 @@ async def image(req : ImageRequest):
|
||||
print(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get('/image/stop')
|
||||
def stop():
|
||||
try:
|
||||
if model_is_loading:
|
||||
return {'ERROR'}
|
||||
|
||||
from sd_internal import runtime
|
||||
runtime.stop_processing = True
|
||||
|
||||
return {'OK'}
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post('/app_config')
|
||||
async def setAppConfig(req : SetAppConfigRequest):
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user