mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-08-10 16:17:45 +02:00
New samplers for txt2img: "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
This commit is contained in:
@ -275,6 +275,7 @@ def do_mk_img(req: Request):
|
||||
opt_use_upscale = req.use_upscale
|
||||
opt_show_only_filtered = req.show_only_filtered_image
|
||||
opt_format = 'png'
|
||||
opt_sampler_name = req.sampler
|
||||
|
||||
print(req.to_string(), '\n device', device)
|
||||
|
||||
@ -399,12 +400,11 @@ def do_mk_img(req: Request):
|
||||
# run the handler
|
||||
try:
|
||||
if handler == _txt2img:
|
||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, req.stream_progress_updates, mask)
|
||||
x_samples = _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, opt_sampler_name)
|
||||
else:
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, req.stream_progress_updates, mask)
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask)
|
||||
|
||||
if req.stream_progress_updates:
|
||||
yield from x_samples
|
||||
yield from x_samples
|
||||
|
||||
x_samples = partial_x_samples
|
||||
except UserInitiatedStop:
|
||||
@ -443,7 +443,7 @@ def do_mk_img(req: Request):
|
||||
if return_orig_img:
|
||||
save_image(img, img_out_path)
|
||||
|
||||
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale)
|
||||
save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_strength, opt_use_face_correction, opt_use_upscale, opt_sampler_name)
|
||||
|
||||
if return_orig_img:
|
||||
img_data = img_to_base64_str(img)
|
||||
@ -496,10 +496,7 @@ def do_mk_img(req: Request):
|
||||
|
||||
print('Task completed')
|
||||
|
||||
if req.stream_progress_updates:
|
||||
yield json.dumps(res.json())
|
||||
else:
|
||||
return res
|
||||
yield json.dumps(res.json())
|
||||
|
||||
def save_image(img, img_out_path):
|
||||
try:
|
||||
@ -507,8 +504,8 @@ def save_image(img, img_out_path):
|
||||
except:
|
||||
print('could not save the file', traceback.format_exc())
|
||||
|
||||
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale):
|
||||
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}"
|
||||
def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps, opt_scale, opt_prompt_strength, opt_correct_face, opt_upscale, sampler_name):
|
||||
metadata = f"{prompts[0]}\nWidth: {opt_W}\nHeight: {opt_H}\nSeed: {opt_seed}\nSteps: {opt_ddim_steps}\nGuidance Scale: {opt_scale}\nPrompt Strength: {opt_prompt_strength}\nUse Face Correction: {opt_correct_face}\nUse Upscaling: {opt_upscale}\nSampler: {sampler_name}"
|
||||
|
||||
try:
|
||||
with open(meta_out_path, 'w') as f:
|
||||
@ -516,7 +513,7 @@ def save_metadata(meta_out_path, prompts, opt_seed, opt_W, opt_H, opt_ddim_steps
|
||||
except:
|
||||
print('could not save the file', traceback.format_exc())
|
||||
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, streaming_callbacks, mask):
|
||||
def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, sampler_name):
|
||||
shape = [opt_n_samples, opt_C, opt_H // opt_f, opt_W // opt_f]
|
||||
|
||||
if device != "cpu":
|
||||
@ -536,17 +533,13 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
eta=opt_ddim_eta,
|
||||
x_T=start_code,
|
||||
img_callback=img_callback,
|
||||
streaming_callbacks=streaming_callbacks,
|
||||
mask=mask,
|
||||
sampler = 'plms',
|
||||
sampler = sampler_name,
|
||||
)
|
||||
|
||||
if streaming_callbacks:
|
||||
yield from samples_ddim
|
||||
else:
|
||||
return samples_ddim
|
||||
yield from samples_ddim
|
||||
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, streaming_callbacks, mask):
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask):
|
||||
# encode (scaled latent)
|
||||
z_enc = model.stochastic_encode(
|
||||
init_latent,
|
||||
@ -565,16 +558,12 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
unconditional_guidance_scale=opt_scale,
|
||||
unconditional_conditioning=uc,
|
||||
img_callback=img_callback,
|
||||
streaming_callbacks=streaming_callbacks,
|
||||
mask=mask,
|
||||
x_T=x_T,
|
||||
sampler = 'ddim'
|
||||
)
|
||||
|
||||
if streaming_callbacks:
|
||||
yield from samples_ddim
|
||||
else:
|
||||
return samples_ddim
|
||||
yield from samples_ddim
|
||||
|
||||
def move_fs_to_cpu():
|
||||
if device != "cpu":
|
||||
|
Reference in New Issue
Block a user