Remove the need to use yield in the core loop for streaming results. This removes the need to patch the Stable Diffusion code, which can be fragile

This commit is contained in:
cmdr2 2022-11-29 13:03:57 +05:30
parent cb02b5ba18
commit e37be0f954
4 changed files with 77 additions and 96 deletions

View File

@ -42,13 +42,9 @@ if NOT DEFINED test_sd2 set test_sd2=N
if "%test_sd2%" == "N" ( if "%test_sd2%" == "N" (
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
) )
if "%test_sd2%" == "Y" ( if "%test_sd2%" == "Y" (
@call git -c advice.detachedHead=false checkout 6e2f82187f8ecc4ea59ac37dc239cfcc78038f6d @call git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9
@call git apply ..\ui\sd_internal\ddim_callback_sd2.patch
) )
@cd .. @cd ..
@ -66,8 +62,6 @@ if NOT DEFINED test_sd2 set test_sd2=N
@cd stable-diffusion @cd stable-diffusion
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
@cd .. @cd ..
) )

View File

@ -37,12 +37,8 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
if [ "$test_sd2" == "N" ]; then if [ "$test_sd2" == "N" ]; then
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
elif [ "$test_sd2" == "Y" ]; then elif [ "$test_sd2" == "Y" ]; then
git -c advice.detachedHead=false checkout 992f111312afa9ec1a01beaa9733cb9728f5acd3 git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9
git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed"
fi fi
cd .. cd ..
@ -58,8 +54,6 @@ else
cd stable-diffusion cd stable-diffusion
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
cd .. cd ..
fi fi

View File

@ -7,6 +7,7 @@ Notes:
import json import json
import os, re import os, re
import traceback import traceback
import queue
import torch import torch
import numpy as np import numpy as np
from gc import collect as gc_collect from gc import collect as gc_collect
@ -392,9 +393,34 @@ def apply_filters(filter_name, image_data, model_path=None):
return image_data return image_data
def mk_img(req: Request): def is_model_reload_necessary(req: Request):
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
return needs_model_reload
def reload_model():
unload_models()
unload_filters()
load_model_ckpt()
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
try: try:
yield from do_mk_img(req) return do_mk_img(req, data_queue, task_temp_images, step_callback)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
@ -405,12 +431,13 @@ def mk_img(req: Request):
thread_data.model.model2.to("cpu") thread_data.model.model2.to("cpu")
gc() # Release from memory. gc() # Release from memory.
yield json.dumps({ data_queue.put(json.dumps({
"status": 'failed', "status": 'failed',
"detail": str(e) "detail": str(e)
}) }))
raise e
def update_temp_img(req, x_samples): def update_temp_img(req, x_samples, task_temp_images: list):
partial_images = [] partial_images = []
for i in range(req.num_outputs): for i in range(req.num_outputs):
if thread_data.test_sd2: if thread_data.test_sd2:
@ -421,19 +448,18 @@ def update_temp_img(req, x_samples):
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
img = Image.fromarray(x_sample) img = Image.fromarray(x_sample)
buf = BytesIO() buf = img_to_buffer(img, output_format='JPEG')
img.save(buf, format='JPEG')
buf.seek(0)
del img, x_sample, x_sample_ddim del img, x_sample, x_sample_ddim
# don't delete x_samples, it is used in the code that called this callback # don't delete x_samples, it is used in the code that called this callback
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
task_temp_images[i] = buf
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'}) partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
return partial_images return partial_images
# Build and return the apropriate generator for do_mk_img # Build and return the apropriate generator for do_mk_img
def get_image_progress_generator(req, extra_props=None): def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
if not req.stream_progress_updates: if not req.stream_progress_updates:
def empty_callback(x_samples, i): return x_samples def empty_callback(x_samples, i): return x_samples
return empty_callback return empty_callback
@ -452,15 +478,17 @@ def get_image_progress_generator(req, extra_props=None):
progress.update(extra_props) progress.update(extra_props)
if req.stream_image_progress and i % 5 == 0: if req.stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(req, x_samples) progress['output'] = update_temp_img(req, x_samples, task_temp_images)
yield json.dumps(progress) data_queue.put(json.dumps(progress))
step_callback()
if thread_data.stop_processing: if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing") raise UserInitiatedStop("User requested that we stop processing")
return img_callback return img_callback
def do_mk_img(req: Request): def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
thread_data.stop_processing = False thread_data.stop_processing = False
res = Response() res = Response()
@ -469,28 +497,6 @@ def do_mk_img(req: Request):
thread_data.temp_images.clear() thread_data.temp_images.clear()
# custom model support:
# the req.use_stable_diffusion_model needs to be a valid path
# to the ckpt file (without the extension).
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
needs_model_reload = False
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
if needs_model_reload:
unload_models()
unload_filters()
load_model_ckpt()
if thread_data.turbo != req.turbo and not thread_data.test_sd2: if thread_data.turbo != req.turbo and not thread_data.test_sd2:
thread_data.turbo = req.turbo thread_data.turbo = req.turbo
thread_data.model.turbo = req.turbo thread_data.model.turbo = req.turbo
@ -606,7 +612,7 @@ def do_mk_img(req: Request):
thread_data.modelFS.to(thread_data.device) thread_data.modelFS.to(thread_data.device)
n_steps = req.num_inference_steps if req.init_image is None else t_enc n_steps = req.num_inference_steps if req.init_image is None else t_enc
img_callback = get_image_progress_generator(req, {"total_steps": n_steps}) img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps})
# run the handler # run the handler
try: try:
@ -615,13 +621,6 @@ def do_mk_img(req: Request):
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler) x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
else: else:
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f) x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f)
if req.stream_progress_updates:
yield from x_samples
if hasattr(thread_data, 'partial_x_samples'):
if thread_data.partial_x_samples is not None:
x_samples = thread_data.partial_x_samples
del thread_data.partial_x_samples
except UserInitiatedStop: except UserInitiatedStop:
if not hasattr(thread_data, 'partial_x_samples'): if not hasattr(thread_data, 'partial_x_samples'):
continue continue
@ -666,9 +665,11 @@ def do_mk_img(req: Request):
save_metadata(meta_out_path, req, prompts[0], opt_seed) save_metadata(meta_out_path, req, prompts[0], opt_seed)
if return_orig_img: if return_orig_img:
img_str = img_to_base64_str(img, req.output_format) img_buffer = img_to_buffer(img, req.output_format)
img_str = buffer_to_base64_str(img_buffer, req.output_format)
res_image_orig = ResponseImage(data=img_str, seed=opt_seed) res_image_orig = ResponseImage(data=img_str, seed=opt_seed)
res.images.append(res_image_orig) res.images.append(res_image_orig)
task_temp_images[i] = img_buffer
if req.save_to_disk_path is not None: if req.save_to_disk_path is not None:
res_image_orig.path_abs = img_out_path res_image_orig.path_abs = img_out_path
@ -684,9 +685,11 @@ def do_mk_img(req: Request):
filters_applied.append(req.use_upscale) filters_applied.append(req.use_upscale)
if (len(filters_applied) > 0): if (len(filters_applied) > 0):
filtered_image = Image.fromarray(img_data[i]) filtered_image = Image.fromarray(img_data[i])
filtered_img_data = img_to_base64_str(filtered_image, req.output_format) filtered_buffer = img_to_buffer(filtered_image, req.output_format)
filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format)
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed) response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
res.images.append(response_image) res.images.append(response_image)
task_temp_images[i] = filtered_buffer
if req.save_to_disk_path is not None: if req.save_to_disk_path is not None:
filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied)) filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied))
save_image(filtered_image, filtered_img_out_path) save_image(filtered_image, filtered_img_out_path)
@ -705,7 +708,10 @@ def do_mk_img(req: Request):
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb') print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb')
print('Task completed') print('Task completed')
yield json.dumps(res.json()) res = res.json()
data_queue.put(json.dumps(res))
return res
def save_image(img, img_out_path): def save_image(img, img_out_path):
try: try:
@ -771,7 +777,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = sampler.sample( samples_ddim, intermediates = sampler.sample(
S=opt_ddim_steps, S=opt_ddim_steps,
conditioning=c, conditioning=c,
batch_size=opt_n_samples, batch_size=opt_n_samples,
@ -790,7 +796,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
if sampler_name == 'ddim': if sampler_name == 'ddim':
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
samples_ddim = thread_data.model.sample( samples_ddim, intermediates = thread_data.model.sample(
S=opt_ddim_steps, S=opt_ddim_steps,
conditioning=c, conditioning=c,
seed=opt_seed, seed=opt_seed,
@ -804,7 +810,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
mask=mask, mask=mask,
sampler = sampler_name, sampler = sampler_name,
) )
yield from samples_ddim return samples_ddim
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1): def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1):
# encode (scaled latent) # encode (scaled latent)
@ -831,7 +837,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
) )
# decode it # decode it
samples_ddim = thread_data.model.sample( samples_ddim, intermediates = thread_data.model.sample(
t_enc, t_enc,
c, c,
z_enc, z_enc,
@ -842,7 +848,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
x_T=x_T, x_T=x_T,
sampler = 'ddim' sampler = 'ddim'
) )
yield from samples_ddim return samples_ddim
def gc(): def gc():
gc_collect() gc_collect()
@ -910,8 +916,16 @@ def load_mask(mask_str, h0, w0, newH, newW, invert=False):
# https://stackoverflow.com/a/61114178 # https://stackoverflow.com/a/61114178
def img_to_base64_str(img, output_format="PNG"): def img_to_base64_str(img, output_format="PNG"):
buffered = img_to_buffer(img, output_format)
return buffer_to_base64_str(buffered, output_format)
def img_to_buffer(img, output_format="PNG"):
buffered = BytesIO() buffered = BytesIO()
img.save(buffered, format=output_format) img.save(buffered, format=output_format)
buffered.seek(0)
return buffered
def buffer_to_base64_str(buffered, output_format="PNG"):
buffered.seek(0) buffered.seek(0)
img_byte = buffered.getvalue() img_byte = buffered.getvalue()
mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg" mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg"

View File

@ -283,45 +283,24 @@ def thread_render(device):
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}') print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try: try:
if runtime.thread_data.device == 'cpu' and is_alive() > 1: if runtime.is_model_reload_necessary(task.request):
# CPU is not the only device. Keep track of active time to unload resources later.
runtime.thread_data.lastActive = time.time()
# Open data generator.
res = runtime.mk_img(task.request)
if current_model_path == task.request.use_stable_diffusion_model:
current_state = ServerStates.Rendering
else:
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
# Start reading from generator. runtime.reload_model()
dataQueue = None
if task.request.stream_progress_updates:
dataQueue = task.buffer_queue
for result in res:
if current_state == ServerStates.LoadingModel:
current_state = ServerStates.Rendering
current_model_path = task.request.use_stable_diffusion_model current_model_path = task.request.use_stable_diffusion_model
current_vae_path = task.request.use_vae_model current_vae_path = task.request.use_vae_model
def step_callback():
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.thread_data.stop_processing = True runtime.thread_data.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration): if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error task.error = current_state_error
current_state_error = None current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
if dataQueue:
dataQueue.put(result)
if isinstance(result, str):
result = json.loads(result)
task.response = result
if 'output' in result:
for out_obj in result['output']:
if 'path' in out_obj:
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
elif 'data' in out_obj:
buf = runtime.base64_str_to_buffer(out_obj['data'])
task.temp_images[result['output'].index(out_obj)] = buf
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(task.request.session_id, TASK_TTL) task_cache.keep(task.request.session_id, TASK_TTL)
current_state = ServerStates.Rendering
task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback)
except Exception as e: except Exception as e:
task.error = e task.error = e
print(traceback.format_exc()) print(traceback.format_exc())