Remove the need to use yield in the core loop for streaming results. This removes the need to patch the Stable Diffusion code, which can be fragile

This commit is contained in:
cmdr2 2022-11-29 13:03:57 +05:30
parent cb02b5ba18
commit e37be0f954
4 changed files with 77 additions and 96 deletions

View File

@ -42,13 +42,9 @@ if NOT DEFINED test_sd2 set test_sd2=N
if "%test_sd2%" == "N" (
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
)
if "%test_sd2%" == "Y" (
@call git -c advice.detachedHead=false checkout 6e2f82187f8ecc4ea59ac37dc239cfcc78038f6d
@call git apply ..\ui\sd_internal\ddim_callback_sd2.patch
@call git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9
)
@cd ..
@ -66,8 +62,6 @@ if NOT DEFINED test_sd2 set test_sd2=N
@cd stable-diffusion
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
@cd ..
)

View File

@ -37,12 +37,8 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
if [ "$test_sd2" == "N" ]; then
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
elif [ "$test_sd2" == "Y" ]; then
git -c advice.detachedHead=false checkout 992f111312afa9ec1a01beaa9733cb9728f5acd3
git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed"
git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9
fi
cd ..
@ -58,8 +54,6 @@ else
cd stable-diffusion
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
cd ..
fi

View File

@ -7,6 +7,7 @@ Notes:
import json
import os, re
import traceback
import queue
import torch
import numpy as np
from gc import collect as gc_collect
@ -392,9 +393,34 @@ def apply_filters(filter_name, image_data, model_path=None):
return image_data
def mk_img(req: Request):
def is_model_reload_necessary(req: Request):
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
return needs_model_reload
def reload_model():
unload_models()
unload_filters()
load_model_ckpt()
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
try:
yield from do_mk_img(req)
return do_mk_img(req, data_queue, task_temp_images, step_callback)
except Exception as e:
print(traceback.format_exc())
@ -405,12 +431,13 @@ def mk_img(req: Request):
thread_data.model.model2.to("cpu")
gc() # Release from memory.
yield json.dumps({
data_queue.put(json.dumps({
"status": 'failed',
"detail": str(e)
})
}))
raise e
def update_temp_img(req, x_samples):
def update_temp_img(req, x_samples, task_temp_images: list):
partial_images = []
for i in range(req.num_outputs):
if thread_data.test_sd2:
@ -421,19 +448,18 @@ def update_temp_img(req, x_samples):
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample)
buf = BytesIO()
img.save(buf, format='JPEG')
buf.seek(0)
buf = img_to_buffer(img, output_format='JPEG')
del img, x_sample, x_sample_ddim
# don't delete x_samples, it is used in the code that called this callback
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
task_temp_images[i] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
return partial_images
# Build and return the apropriate generator for do_mk_img
def get_image_progress_generator(req, extra_props=None):
def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
if not req.stream_progress_updates:
def empty_callback(x_samples, i): return x_samples
return empty_callback
@ -452,15 +478,17 @@ def get_image_progress_generator(req, extra_props=None):
progress.update(extra_props)
if req.stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(req, x_samples)
progress['output'] = update_temp_img(req, x_samples, task_temp_images)
yield json.dumps(progress)
data_queue.put(json.dumps(progress))
step_callback()
if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return img_callback
def do_mk_img(req: Request):
def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
thread_data.stop_processing = False
res = Response()
@ -469,28 +497,6 @@ def do_mk_img(req: Request):
thread_data.temp_images.clear()
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
if needs_model_reload:
unload_models()
unload_filters()
load_model_ckpt()
if thread_data.turbo != req.turbo and not thread_data.test_sd2:
thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo
@ -606,7 +612,7 @@ def do_mk_img(req: Request):
thread_data.modelFS.to(thread_data.device)
n_steps = req.num_inference_steps if req.init_image is None else t_enc
img_callback = get_image_progress_generator(req, {"total_steps": n_steps})
img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps})
# run the handler
try:
@ -615,13 +621,6 @@ def do_mk_img(req: Request):
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
else:
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f)
if req.stream_progress_updates:
yield from x_samples
if hasattr(thread_data, 'partial_x_samples'):
if thread_data.partial_x_samples is not None:
x_samples = thread_data.partial_x_samples
del thread_data.partial_x_samples
except UserInitiatedStop:
if not hasattr(thread_data, 'partial_x_samples'):
continue
@ -666,9 +665,11 @@ def do_mk_img(req: Request):
save_metadata(meta_out_path, req, prompts[0], opt_seed)
if return_orig_img:
img_str = img_to_base64_str(img, req.output_format)
img_buffer = img_to_buffer(img, req.output_format)
img_str = buffer_to_base64_str(img_buffer, req.output_format)
res_image_orig = ResponseImage(data=img_str, seed=opt_seed)
res.images.append(res_image_orig)
task_temp_images[i] = img_buffer
if req.save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path
@ -684,9 +685,11 @@ def do_mk_img(req: Request):
filters_applied.append(req.use_upscale)
if (len(filters_applied) > 0):
filtered_image = Image.fromarray(img_data[i])
filtered_img_data = img_to_base64_str(filtered_image, req.output_format)
filtered_buffer = img_to_buffer(filtered_image, req.output_format)
filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format)
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(response_image)
task_temp_images[i] = filtered_buffer
if req.save_to_disk_path is not None:
filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied))
save_image(filtered_image, filtered_img_out_path)
@ -705,7 +708,10 @@ def do_mk_img(req: Request):
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb')
print('Task completed')
yield json.dumps(res.json())
res = res.json()
data_queue.put(json.dumps(res))
return res
def save_image(img, img_out_path):
try:
@ -771,7 +777,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = sampler.sample(
samples_ddim, intermediates = sampler.sample(
S=opt_ddim_steps,
conditioning=c,
batch_size=opt_n_samples,
@ -790,7 +796,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = thread_data.model.sample(
samples_ddim, intermediates = thread_data.model.sample(
S=opt_ddim_steps,
conditioning=c,
seed=opt_seed,
@ -804,7 +810,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
mask=mask,
sampler = sampler_name,
)
yield from samples_ddim
return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1):
# encode (scaled latent)
@ -831,7 +837,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
)
# decode it
samples_ddim = thread_data.model.sample(
samples_ddim, intermediates = thread_data.model.sample(
t_enc,
c,
z_enc,
@ -842,7 +848,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
x_T=x_T,
sampler = 'ddim'
)
yield from samples_ddim
return samples_ddim
def gc():
gc_collect()
@ -910,8 +916,16 @@ def load_mask(mask_str, h0, w0, newH, newW, invert=False):
# https://stackoverflow.com/a/61114178
def img_to_base64_str(img, output_format="PNG"):
buffered = img_to_buffer(img, output_format)
return buffer_to_base64_str(buffered, output_format)
def img_to_buffer(img, output_format="PNG"):
buffered = BytesIO()
img.save(buffered, format=output_format)
buffered.seek(0)
return buffered
def buffer_to_base64_str(buffered, output_format="PNG"):
buffered.seek(0)
img_byte = buffered.getvalue()
mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg"

View File

@ -283,45 +283,24 @@ def thread_render(device):
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try:
if runtime.thread_data.device == 'cpu' and is_alive() > 1:
# CPU is not the only device. Keep track of active time to unload resources later.
runtime.thread_data.lastActive = time.time()
# Open data generator.
res = runtime.mk_img(task.request)
if current_model_path == task.request.use_stable_diffusion_model:
current_state = ServerStates.Rendering
else:
if runtime.is_model_reload_necessary(task.request):
current_state = ServerStates.LoadingModel
# Start reading from generator.
dataQueue = None
if task.request.stream_progress_updates:
dataQueue = task.buffer_queue
for result in res:
if current_state == ServerStates.LoadingModel:
current_state = ServerStates.Rendering
current_model_path = task.request.use_stable_diffusion_model
current_vae_path = task.request.use_vae_model
runtime.reload_model()
current_model_path = task.request.use_stable_diffusion_model
current_vae_path = task.request.use_vae_model
def step_callback():
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.thread_data.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result)
if isinstance(result, str):
result = json.loads(result)
task.response = result
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
buf = runtime.base64_str_to_buffer(out_obj['data'])
task.temp_images[result['output'].index(out_obj)] = buf
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(task.request.session_id, TASK_TTL)
task_cache.keep(task.request.session_id, TASK_TTL)
current_state = ServerStates.Rendering
task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback)
except Exception as e:
task.error = e
print(traceback.format_exc())