New samplers for txt2img: "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"

This commit is contained in:
cmdr2
2022-09-23 00:19:05 +05:30
parent b0c15bc430
commit 956b3d89db
5 changed files with 291 additions and 96 deletions

View File

@ -275,6 +275,7 @@ def do_mk_img(req: Request):
opt_use_upscale = req.use_upscale
opt_show_only_filtered = req.show_only_filtered_image
opt_format = 'png'
opt_sampler_name = req.sampler
print(req.to_string(), '\n device', device)
@ -399,12 +400,11 @@ def do_mk_img(req: Request):
# run the handler
try:
if handler == _txt2img:
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates, mask)
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, opt_sampler_name)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates, mask)
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask)
if req.stream_progress_updates:
yield from x_samples
yield from x_samples
x_samples = partial_x_samples
except UserInitiatedStop:
@ -443,7 +443,7 @@ def do_mk_img(req: Request):
if return_orig_img:
save_image(img, img_out_path)
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale)
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name)
if return_orig_img:
img_data = img_to_base64_str(img)
@ -496,10 +496,7 @@ def do_mk_img(req: Request):
print('Task completed')
if req.stream_progress_updates:
yield json.dumps(res.json())
else:
return res
yield json.dumps(res.json())
def save_image(img, img_out_path):
try:
@ -507,8 +504,8 @@ def save_image(img, img_out_path):
except:
print('could not save the file', traceback.format_exc())
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale):
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}"
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name):
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}"
try:
with open(meta_out_path, 'w') as f:
@ -516,7 +513,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
except:
print('could not save the file', traceback.format_exc())
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks, mask):
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
if device != "cpu":
@ -536,17 +533,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
eta=opt_ddim_eta,
x_T=start_code,
img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
mask=mask,
sampler = 'plms',
sampler = sampler_name,
)
if streaming_callbacks:
yield from samples_ddim
else:
return samples_ddim
yield from samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks, mask):
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
# encode (scaled latent)
z_enc = model.stochastic_encode(
init_latent,
@ -565,16 +558,12 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
unconditional_guidance_scale=opt_scale,
unconditional_conditioning=uc,
img_callback=img_callback,
streaming_callbacks=streaming_callbacks,
mask=mask,
x_T=x_T,
sampler = 'ddim'
)
if streaming_callbacks:
yield from samples_ddim
else:
return samples_ddim
yield from samples_ddim
def move_fs_to_cpu():
if device != "cpu":