forked from extern/easydiffusion
Remove the need to use yield in the core loop for streaming results. This removes the need to patch the Stable Diffusion code, which can be fragile
This commit is contained in:
parent
cb02b5ba18
commit
e37be0f954
@ -42,13 +42,9 @@ if NOT DEFINED test_sd2 set test_sd2=N
|
||||
|
||||
if "%test_sd2%" == "N" (
|
||||
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
||||
|
||||
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
|
||||
)
|
||||
if "%test_sd2%" == "Y" (
|
||||
@call git -c advice.detachedHead=false checkout 6e2f82187f8ecc4ea59ac37dc239cfcc78038f6d
|
||||
|
||||
@call git apply ..\ui\sd_internal\ddim_callback_sd2.patch
|
||||
@call git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9
|
||||
)
|
||||
|
||||
@cd ..
|
||||
@ -66,8 +62,6 @@ if NOT DEFINED test_sd2 set test_sd2=N
|
||||
@cd stable-diffusion
|
||||
@call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
||||
|
||||
@call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch
|
||||
|
||||
@cd ..
|
||||
)
|
||||
|
||||
|
@ -37,12 +37,8 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta
|
||||
|
||||
if [ "$test_sd2" == "N" ]; then
|
||||
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
||||
|
||||
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
|
||||
elif [ "$test_sd2" == "Y" ]; then
|
||||
git -c advice.detachedHead=false checkout 992f111312afa9ec1a01beaa9733cb9728f5acd3
|
||||
|
||||
git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed"
|
||||
git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9
|
||||
fi
|
||||
|
||||
cd ..
|
||||
@ -58,8 +54,6 @@ else
|
||||
cd stable-diffusion
|
||||
git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a
|
||||
|
||||
git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed"
|
||||
|
||||
cd ..
|
||||
fi
|
||||
|
||||
|
@ -7,6 +7,7 @@ Notes:
|
||||
import json
|
||||
import os, re
|
||||
import traceback
|
||||
import queue
|
||||
import torch
|
||||
import numpy as np
|
||||
from gc import collect as gc_collect
|
||||
@ -392,9 +393,34 @@ def apply_filters(filter_name, image_data, model_path=None):
|
||||
|
||||
return image_data
|
||||
|
||||
def mk_img(req: Request):
|
||||
def is_model_reload_necessary(req: Request):
|
||||
# custom model support:
|
||||
# the req.use_stable_diffusion_model needs to be a valid path
|
||||
# to the ckpt file (without the extension).
|
||||
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
|
||||
|
||||
needs_model_reload = False
|
||||
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
|
||||
thread_data.ckpt_file = req.use_stable_diffusion_model
|
||||
thread_data.vae_file = req.use_vae_model
|
||||
needs_model_reload = True
|
||||
|
||||
if thread_data.device != 'cpu':
|
||||
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
|
||||
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
|
||||
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
||||
needs_model_reload = True
|
||||
|
||||
return needs_model_reload
|
||||
|
||||
def reload_model():
|
||||
unload_models()
|
||||
unload_filters()
|
||||
load_model_ckpt()
|
||||
|
||||
def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
try:
|
||||
yield from do_mk_img(req)
|
||||
return do_mk_img(req, data_queue, task_temp_images, step_callback)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
|
||||
@ -405,12 +431,13 @@ def mk_img(req: Request):
|
||||
thread_data.model.model2.to("cpu")
|
||||
|
||||
gc() # Release from memory.
|
||||
yield json.dumps({
|
||||
data_queue.put(json.dumps({
|
||||
"status": 'failed',
|
||||
"detail": str(e)
|
||||
})
|
||||
}))
|
||||
raise e
|
||||
|
||||
def update_temp_img(req, x_samples):
|
||||
def update_temp_img(req, x_samples, task_temp_images: list):
|
||||
partial_images = []
|
||||
for i in range(req.num_outputs):
|
||||
if thread_data.test_sd2:
|
||||
@ -421,19 +448,18 @@ def update_temp_img(req, x_samples):
|
||||
x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c")
|
||||
x_sample = x_sample.astype(np.uint8)
|
||||
img = Image.fromarray(x_sample)
|
||||
buf = BytesIO()
|
||||
img.save(buf, format='JPEG')
|
||||
buf.seek(0)
|
||||
buf = img_to_buffer(img, output_format='JPEG')
|
||||
|
||||
del img, x_sample, x_sample_ddim
|
||||
# don't delete x_samples, it is used in the code that called this callback
|
||||
|
||||
thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf
|
||||
task_temp_images[i] = buf
|
||||
partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'})
|
||||
return partial_images
|
||||
|
||||
# Build and return the apropriate generator for do_mk_img
|
||||
def get_image_progress_generator(req, extra_props=None):
|
||||
def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None):
|
||||
if not req.stream_progress_updates:
|
||||
def empty_callback(x_samples, i): return x_samples
|
||||
return empty_callback
|
||||
@ -452,15 +478,17 @@ def get_image_progress_generator(req, extra_props=None):
|
||||
progress.update(extra_props)
|
||||
|
||||
if req.stream_image_progress and i % 5 == 0:
|
||||
progress['output'] = update_temp_img(req, x_samples)
|
||||
progress['output'] = update_temp_img(req, x_samples, task_temp_images)
|
||||
|
||||
yield json.dumps(progress)
|
||||
data_queue.put(json.dumps(progress))
|
||||
|
||||
step_callback()
|
||||
|
||||
if thread_data.stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
return img_callback
|
||||
|
||||
def do_mk_img(req: Request):
|
||||
def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
thread_data.stop_processing = False
|
||||
|
||||
res = Response()
|
||||
@ -469,28 +497,6 @@ def do_mk_img(req: Request):
|
||||
|
||||
thread_data.temp_images.clear()
|
||||
|
||||
# custom model support:
|
||||
# the req.use_stable_diffusion_model needs to be a valid path
|
||||
# to the ckpt file (without the extension).
|
||||
if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt')
|
||||
|
||||
needs_model_reload = False
|
||||
if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
|
||||
thread_data.ckpt_file = req.use_stable_diffusion_model
|
||||
thread_data.vae_file = req.use_vae_model
|
||||
needs_model_reload = True
|
||||
|
||||
if thread_data.device != 'cpu':
|
||||
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
|
||||
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
|
||||
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
||||
needs_model_reload = True
|
||||
|
||||
if needs_model_reload:
|
||||
unload_models()
|
||||
unload_filters()
|
||||
load_model_ckpt()
|
||||
|
||||
if thread_data.turbo != req.turbo and not thread_data.test_sd2:
|
||||
thread_data.turbo = req.turbo
|
||||
thread_data.model.turbo = req.turbo
|
||||
@ -606,7 +612,7 @@ def do_mk_img(req: Request):
|
||||
thread_data.modelFS.to(thread_data.device)
|
||||
|
||||
n_steps = req.num_inference_steps if req.init_image is None else t_enc
|
||||
img_callback = get_image_progress_generator(req, {"total_steps": n_steps})
|
||||
img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps})
|
||||
|
||||
# run the handler
|
||||
try:
|
||||
@ -615,13 +621,6 @@ def do_mk_img(req: Request):
|
||||
x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler)
|
||||
else:
|
||||
x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f)
|
||||
|
||||
if req.stream_progress_updates:
|
||||
yield from x_samples
|
||||
if hasattr(thread_data, 'partial_x_samples'):
|
||||
if thread_data.partial_x_samples is not None:
|
||||
x_samples = thread_data.partial_x_samples
|
||||
del thread_data.partial_x_samples
|
||||
except UserInitiatedStop:
|
||||
if not hasattr(thread_data, 'partial_x_samples'):
|
||||
continue
|
||||
@ -666,9 +665,11 @@ def do_mk_img(req: Request):
|
||||
save_metadata(meta_out_path, req, prompts[0], opt_seed)
|
||||
|
||||
if return_orig_img:
|
||||
img_str = img_to_base64_str(img, req.output_format)
|
||||
img_buffer = img_to_buffer(img, req.output_format)
|
||||
img_str = buffer_to_base64_str(img_buffer, req.output_format)
|
||||
res_image_orig = ResponseImage(data=img_str, seed=opt_seed)
|
||||
res.images.append(res_image_orig)
|
||||
task_temp_images[i] = img_buffer
|
||||
|
||||
if req.save_to_disk_path is not None:
|
||||
res_image_orig.path_abs = img_out_path
|
||||
@ -684,9 +685,11 @@ def do_mk_img(req: Request):
|
||||
filters_applied.append(req.use_upscale)
|
||||
if (len(filters_applied) > 0):
|
||||
filtered_image = Image.fromarray(img_data[i])
|
||||
filtered_img_data = img_to_base64_str(filtered_image, req.output_format)
|
||||
filtered_buffer = img_to_buffer(filtered_image, req.output_format)
|
||||
filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format)
|
||||
response_image = ResponseImage(data=filtered_img_data, seed=opt_seed)
|
||||
res.images.append(response_image)
|
||||
task_temp_images[i] = filtered_buffer
|
||||
if req.save_to_disk_path is not None:
|
||||
filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied))
|
||||
save_image(filtered_image, filtered_img_out_path)
|
||||
@ -705,7 +708,10 @@ def do_mk_img(req: Request):
|
||||
print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb')
|
||||
|
||||
print('Task completed')
|
||||
yield json.dumps(res.json())
|
||||
res = res.json()
|
||||
data_queue.put(json.dumps(res))
|
||||
|
||||
return res
|
||||
|
||||
def save_image(img, img_out_path):
|
||||
try:
|
||||
@ -771,7 +777,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||
|
||||
|
||||
samples_ddim = sampler.sample(
|
||||
samples_ddim, intermediates = sampler.sample(
|
||||
S=opt_ddim_steps,
|
||||
conditioning=c,
|
||||
batch_size=opt_n_samples,
|
||||
@ -790,7 +796,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
if sampler_name == 'ddim':
|
||||
thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False)
|
||||
|
||||
samples_ddim = thread_data.model.sample(
|
||||
samples_ddim, intermediates = thread_data.model.sample(
|
||||
S=opt_ddim_steps,
|
||||
conditioning=c,
|
||||
seed=opt_seed,
|
||||
@ -804,7 +810,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code,
|
||||
mask=mask,
|
||||
sampler = sampler_name,
|
||||
)
|
||||
yield from samples_ddim
|
||||
return samples_ddim
|
||||
|
||||
def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1):
|
||||
# encode (scaled latent)
|
||||
@ -831,7 +837,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
)
|
||||
|
||||
# decode it
|
||||
samples_ddim = thread_data.model.sample(
|
||||
samples_ddim, intermediates = thread_data.model.sample(
|
||||
t_enc,
|
||||
c,
|
||||
z_enc,
|
||||
@ -842,7 +848,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o
|
||||
x_T=x_T,
|
||||
sampler = 'ddim'
|
||||
)
|
||||
yield from samples_ddim
|
||||
return samples_ddim
|
||||
|
||||
def gc():
|
||||
gc_collect()
|
||||
@ -910,8 +916,16 @@ def load_mask(mask_str, h0, w0, newH, newW, invert=False):
|
||||
|
||||
# https://stackoverflow.com/a/61114178
|
||||
def img_to_base64_str(img, output_format="PNG"):
|
||||
buffered = img_to_buffer(img, output_format)
|
||||
return buffer_to_base64_str(buffered, output_format)
|
||||
|
||||
def img_to_buffer(img, output_format="PNG"):
|
||||
buffered = BytesIO()
|
||||
img.save(buffered, format=output_format)
|
||||
buffered.seek(0)
|
||||
return buffered
|
||||
|
||||
def buffer_to_base64_str(buffered, output_format="PNG"):
|
||||
buffered.seek(0)
|
||||
img_byte = buffered.getvalue()
|
||||
mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg"
|
||||
|
@ -283,45 +283,24 @@ def thread_render(device):
|
||||
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}')
|
||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||
try:
|
||||
if runtime.thread_data.device == 'cpu' and is_alive() > 1:
|
||||
# CPU is not the only device. Keep track of active time to unload resources later.
|
||||
runtime.thread_data.lastActive = time.time()
|
||||
# Open data generator.
|
||||
res = runtime.mk_img(task.request)
|
||||
if current_model_path == task.request.use_stable_diffusion_model:
|
||||
current_state = ServerStates.Rendering
|
||||
else:
|
||||
if runtime.is_model_reload_necessary(task.request):
|
||||
current_state = ServerStates.LoadingModel
|
||||
# Start reading from generator.
|
||||
dataQueue = None
|
||||
if task.request.stream_progress_updates:
|
||||
dataQueue = task.buffer_queue
|
||||
for result in res:
|
||||
if current_state == ServerStates.LoadingModel:
|
||||
current_state = ServerStates.Rendering
|
||||
current_model_path = task.request.use_stable_diffusion_model
|
||||
current_vae_path = task.request.use_vae_model
|
||||
runtime.reload_model()
|
||||
current_model_path = task.request.use_stable_diffusion_model
|
||||
current_vae_path = task.request.use_vae_model
|
||||
|
||||
def step_callback():
|
||||
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
||||
runtime.thread_data.stop_processing = True
|
||||
if isinstance(current_state_error, StopAsyncIteration):
|
||||
task.error = current_state_error
|
||||
current_state_error = None
|
||||
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
||||
if dataQueue:
|
||||
dataQueue.put(result)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
task.response = result
|
||||
if 'output' in result:
|
||||
for out_obj in result['output']:
|
||||
if 'path' in out_obj:
|
||||
img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:]
|
||||
task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]]
|
||||
elif 'data' in out_obj:
|
||||
buf = runtime.base64_str_to_buffer(out_obj['data'])
|
||||
task.temp_images[result['output'].index(out_obj)] = buf
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||
|
||||
task_cache.keep(task.request.session_id, TASK_TTL)
|
||||
|
||||
current_state = ServerStates.Rendering
|
||||
task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback)
|
||||
except Exception as e:
|
||||
task.error = e
|
||||
print(traceback.format_exc())
|
||||
|
Loading…
Reference in New Issue
Block a user