diff --git a/scripts/on_sd_start.bat b/scripts/on_sd_start.bat index 87b6ada0..e088ed57 100644 --- a/scripts/on_sd_start.bat +++ b/scripts/on_sd_start.bat @@ -42,13 +42,9 @@ if NOT DEFINED test_sd2 set test_sd2=N if "%test_sd2%" == "N" ( @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - - @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch ) if "%test_sd2%" == "Y" ( - @call git -c advice.detachedHead=false checkout 6e2f82187f8ecc4ea59ac37dc239cfcc78038f6d - - @call git apply ..\ui\sd_internal\ddim_callback_sd2.patch + @call git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9 ) @cd .. @@ -66,8 +62,6 @@ if NOT DEFINED test_sd2 set test_sd2=N @cd stable-diffusion @call git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - @call git apply --whitespace=warn ..\ui\sd_internal\ddim_callback.patch - @cd .. ) diff --git a/scripts/on_sd_start.sh b/scripts/on_sd_start.sh index 460d58a7..199a9ae8 100755 --- a/scripts/on_sd_start.sh +++ b/scripts/on_sd_start.sh @@ -37,12 +37,8 @@ if [ -e "scripts/install_status.txt" ] && [ `grep -c sd_git_cloned scripts/insta if [ "$test_sd2" == "N" ]; then git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - - git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed" elif [ "$test_sd2" == "Y" ]; then - git -c advice.detachedHead=false checkout 992f111312afa9ec1a01beaa9733cb9728f5acd3 - - git apply --whitespace=warn ../ui/sd_internal/ddim_callback_sd2.patch || fail "sd2 ddim patch failed" + git -c advice.detachedHead=false checkout 8878d67decd3deb3c98472c1e39d2a51dc5950f9 fi cd .. @@ -58,8 +54,6 @@ else cd stable-diffusion git -c advice.detachedHead=false checkout 7f32368ed1030a6e710537047bacd908adea183a - git apply --whitespace=warn ../ui/sd_internal/ddim_callback.patch || fail "ddim patch failed" - cd .. fi diff --git a/ui/sd_internal/runtime.py b/ui/sd_internal/runtime.py index 2a989e46..69031fba 100644 --- a/ui/sd_internal/runtime.py +++ b/ui/sd_internal/runtime.py @@ -7,6 +7,7 @@ Notes: import json import os, re import traceback +import queue import torch import numpy as np from gc import collect as gc_collect @@ -392,9 +393,34 @@ def apply_filters(filter_name, image_data, model_path=None): return image_data -def mk_img(req: Request): +def is_model_reload_necessary(req: Request): + # custom model support: + # the req.use_stable_diffusion_model needs to be a valid path + # to the ckpt file (without the extension). + if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt') + + needs_model_reload = False + if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: + thread_data.ckpt_file = req.use_stable_diffusion_model + thread_data.vae_file = req.use_vae_model + needs_model_reload = True + + if thread_data.device != 'cpu': + if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ + (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): + thread_data.precision = 'full' if req.use_full_precision else 'autocast' + needs_model_reload = True + + return needs_model_reload + +def reload_model(): + unload_models() + unload_filters() + load_model_ckpt() + +def mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): try: - yield from do_mk_img(req) + return do_mk_img(req, data_queue, task_temp_images, step_callback) except Exception as e: print(traceback.format_exc()) @@ -405,12 +431,13 @@ def mk_img(req: Request): thread_data.model.model2.to("cpu") gc() # Release from memory. - yield json.dumps({ + data_queue.put(json.dumps({ "status": 'failed', "detail": str(e) - }) + })) + raise e -def update_temp_img(req, x_samples): +def update_temp_img(req, x_samples, task_temp_images: list): partial_images = [] for i in range(req.num_outputs): if thread_data.test_sd2: @@ -421,19 +448,18 @@ def update_temp_img(req, x_samples): x_sample = 255.0 * rearrange(x_sample[0].cpu().numpy(), "c h w -> h w c") x_sample = x_sample.astype(np.uint8) img = Image.fromarray(x_sample) - buf = BytesIO() - img.save(buf, format='JPEG') - buf.seek(0) + buf = img_to_buffer(img, output_format='JPEG') del img, x_sample, x_sample_ddim # don't delete x_samples, it is used in the code that called this callback thread_data.temp_images[str(req.session_id) + '/' + str(i)] = buf + task_temp_images[i] = buf partial_images.append({'path': f'/image/tmp/{req.session_id}/{i}'}) return partial_images # Build and return the apropriate generator for do_mk_img -def get_image_progress_generator(req, extra_props=None): +def get_image_progress_generator(req, data_queue: queue.Queue, task_temp_images: list, step_callback, extra_props=None): if not req.stream_progress_updates: def empty_callback(x_samples, i): return x_samples return empty_callback @@ -452,15 +478,17 @@ def get_image_progress_generator(req, extra_props=None): progress.update(extra_props) if req.stream_image_progress and i % 5 == 0: - progress['output'] = update_temp_img(req, x_samples) + progress['output'] = update_temp_img(req, x_samples, task_temp_images) - yield json.dumps(progress) + data_queue.put(json.dumps(progress)) + + step_callback() if thread_data.stop_processing: raise UserInitiatedStop("User requested that we stop processing") return img_callback -def do_mk_img(req: Request): +def do_mk_img(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): thread_data.stop_processing = False res = Response() @@ -469,28 +497,6 @@ def do_mk_img(req: Request): thread_data.temp_images.clear() - # custom model support: - # the req.use_stable_diffusion_model needs to be a valid path - # to the ckpt file (without the extension). - if not os.path.exists(req.use_stable_diffusion_model + '.ckpt'): raise FileNotFoundError(f'Cannot find {req.use_stable_diffusion_model}.ckpt') - - needs_model_reload = False - if not thread_data.model or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: - thread_data.ckpt_file = req.use_stable_diffusion_model - thread_data.vae_file = req.use_vae_model - needs_model_reload = True - - if thread_data.device != 'cpu': - if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \ - (thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision): - thread_data.precision = 'full' if req.use_full_precision else 'autocast' - needs_model_reload = True - - if needs_model_reload: - unload_models() - unload_filters() - load_model_ckpt() - if thread_data.turbo != req.turbo and not thread_data.test_sd2: thread_data.turbo = req.turbo thread_data.model.turbo = req.turbo @@ -606,7 +612,7 @@ def do_mk_img(req: Request): thread_data.modelFS.to(thread_data.device) n_steps = req.num_inference_steps if req.init_image is None else t_enc - img_callback = get_image_progress_generator(req, {"total_steps": n_steps}) + img_callback = get_image_progress_generator(req, data_queue, task_temp_images, step_callback, {"total_steps": n_steps}) # run the handler try: @@ -615,13 +621,6 @@ def do_mk_img(req: Request): x_samples = _txt2img(req.width, req.height, req.num_outputs, req.num_inference_steps, req.guidance_scale, None, opt_C, opt_f, opt_ddim_eta, c, uc, opt_seed, img_callback, mask, req.sampler) else: x_samples = _img2img(init_latent, t_enc, batch_size, req.guidance_scale, c, uc, req.num_inference_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C, req.height, req.width, opt_f) - - if req.stream_progress_updates: - yield from x_samples - if hasattr(thread_data, 'partial_x_samples'): - if thread_data.partial_x_samples is not None: - x_samples = thread_data.partial_x_samples - del thread_data.partial_x_samples except UserInitiatedStop: if not hasattr(thread_data, 'partial_x_samples'): continue @@ -666,9 +665,11 @@ def do_mk_img(req: Request): save_metadata(meta_out_path, req, prompts[0], opt_seed) if return_orig_img: - img_str = img_to_base64_str(img, req.output_format) + img_buffer = img_to_buffer(img, req.output_format) + img_str = buffer_to_base64_str(img_buffer, req.output_format) res_image_orig = ResponseImage(data=img_str, seed=opt_seed) res.images.append(res_image_orig) + task_temp_images[i] = img_buffer if req.save_to_disk_path is not None: res_image_orig.path_abs = img_out_path @@ -684,9 +685,11 @@ def do_mk_img(req: Request): filters_applied.append(req.use_upscale) if (len(filters_applied) > 0): filtered_image = Image.fromarray(img_data[i]) - filtered_img_data = img_to_base64_str(filtered_image, req.output_format) + filtered_buffer = img_to_buffer(filtered_image, req.output_format) + filtered_img_data = buffer_to_base64_str(filtered_buffer, req.output_format) response_image = ResponseImage(data=filtered_img_data, seed=opt_seed) res.images.append(response_image) + task_temp_images[i] = filtered_buffer if req.save_to_disk_path is not None: filtered_img_out_path = get_base_path(req.save_to_disk_path, req.session_id, prompts[0], img_id, req.output_format, "_".join(filters_applied)) save_image(filtered_image, filtered_img_out_path) @@ -705,7 +708,10 @@ def do_mk_img(req: Request): print(f'memory_final = {round(torch.cuda.memory_allocated(thread_data.device) / 1e6, 2)}Mb') print('Task completed') - yield json.dumps(res.json()) + res = res.json() + data_queue.put(json.dumps(res)) + + return res def save_image(img, img_out_path): try: @@ -771,7 +777,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, sampler.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) - samples_ddim = sampler.sample( + samples_ddim, intermediates = sampler.sample( S=opt_ddim_steps, conditioning=c, batch_size=opt_n_samples, @@ -790,7 +796,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, if sampler_name == 'ddim': thread_data.model.make_schedule(ddim_num_steps=opt_ddim_steps, ddim_eta=opt_ddim_eta, verbose=False) - samples_ddim = thread_data.model.sample( + samples_ddim, intermediates = thread_data.model.sample( S=opt_ddim_steps, conditioning=c, seed=opt_seed, @@ -804,7 +810,7 @@ def _txt2img(opt_W, opt_H, opt_n_samples, opt_ddim_steps, opt_scale, start_code, mask=mask, sampler = sampler_name, ) - yield from samples_ddim + return samples_ddim def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, opt_ddim_eta, opt_seed, img_callback, mask, opt_C=1, opt_H=1, opt_W=1, opt_f=1): # encode (scaled latent) @@ -831,7 +837,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o ) # decode it - samples_ddim = thread_data.model.sample( + samples_ddim, intermediates = thread_data.model.sample( t_enc, c, z_enc, @@ -842,7 +848,7 @@ def _img2img(init_latent, t_enc, batch_size, opt_scale, c, uc, opt_ddim_steps, o x_T=x_T, sampler = 'ddim' ) - yield from samples_ddim + return samples_ddim def gc(): gc_collect() @@ -910,8 +916,16 @@ def load_mask(mask_str, h0, w0, newH, newW, invert=False): # https://stackoverflow.com/a/61114178 def img_to_base64_str(img, output_format="PNG"): + buffered = img_to_buffer(img, output_format) + return buffer_to_base64_str(buffered, output_format) + +def img_to_buffer(img, output_format="PNG"): buffered = BytesIO() img.save(buffered, format=output_format) + buffered.seek(0) + return buffered + +def buffer_to_base64_str(buffered, output_format="PNG"): buffered.seek(0) img_byte = buffered.getvalue() mime_type = "image/png" if output_format.lower() == "png" else "image/jpeg" diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py index bd87517b..cfee79f3 100644 --- a/ui/sd_internal/task_manager.py +++ b/ui/sd_internal/task_manager.py @@ -283,45 +283,24 @@ def thread_render(device): print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') try: - if runtime.thread_data.device == 'cpu' and is_alive() > 1: - # CPU is not the only device. Keep track of active time to unload resources later. - runtime.thread_data.lastActive = time.time() - # Open data generator. - res = runtime.mk_img(task.request) - if current_model_path == task.request.use_stable_diffusion_model: - current_state = ServerStates.Rendering - else: + if runtime.is_model_reload_necessary(task.request): current_state = ServerStates.LoadingModel - # Start reading from generator. - dataQueue = None - if task.request.stream_progress_updates: - dataQueue = task.buffer_queue - for result in res: - if current_state == ServerStates.LoadingModel: - current_state = ServerStates.Rendering - current_model_path = task.request.use_stable_diffusion_model - current_vae_path = task.request.use_vae_model + runtime.reload_model() + current_model_path = task.request.use_stable_diffusion_model + current_vae_path = task.request.use_vae_model + + def step_callback(): if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): runtime.thread_data.stop_processing = True if isinstance(current_state_error, StopAsyncIteration): task.error = current_state_error current_state_error = None print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') - if dataQueue: - dataQueue.put(result) - if isinstance(result, str): - result = json.loads(result) - task.response = result - if 'output' in result: - for out_obj in result['output']: - if 'path' in out_obj: - img_id = out_obj['path'][out_obj['path'].rindex('/') + 1:] - task.temp_images[int(img_id)] = runtime.thread_data.temp_images[out_obj['path'][11:]] - elif 'data' in out_obj: - buf = runtime.base64_str_to_buffer(out_obj['data']) - task.temp_images[result['output'].index(out_obj)] = buf - # Before looping back to the generator, mark cache as still alive. - task_cache.keep(task.request.session_id, TASK_TTL) + + task_cache.keep(task.request.session_id, TASK_TTL) + + current_state = ServerStates.Rendering + task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback) except Exception as e: task.error = e print(traceback.format_exc())