From 35ff4f439ef4cb30ed4d0aa201f355914bd7a27d Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Wed, 14 Dec 2022 16:30:19 +0530 Subject: [PATCH] Refactor save_to_disk --- ui/sd_internal/renderer.py | 57 ++----------------------- ui/sd_internal/save_utils.py | 80 ++++++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 54 deletions(-) create mode 100644 ui/sd_internal/save_utils.py diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index 0420d51d..55b2c4a7 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -1,13 +1,9 @@ import queue import time import json -import os -import base64 -import re -import traceback import logging -from sd_internal import device_manager +from sd_internal import device_manager, save_utils from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop from diffusionkit import model_loader, image_generator, image_utils, filters as image_filters, data_utils @@ -20,8 +16,6 @@ context = Context() # thread-local runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc ''' -filename_regex = re.compile('[^a-zA-Z0-9]') - def init(device): ''' Initializes the fields that will be bound to this runtime's context, and sets the current torch device @@ -33,7 +27,7 @@ def init(device): device_manager.device_init(context, device) def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): - log.info(f'request: {get_printable_request(req)}') + log.info(f'request: {save_utils.get_printable_request(req)}') log.info(f'task data: {task_data.dict()}') images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) @@ -50,8 +44,7 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q filtered_images = apply_filters(task_data, images, user_stopped) if task_data.save_to_disk_path is not None: - save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) - save_to_disk(images, filtered_images, save_folder_path, req, task_data) + save_utils.save_to_disk(images, filtered_images, req, task_data) return filtered_images if task_data.show_only_filtered_image else images + filtered_images @@ -92,17 +85,6 @@ def apply_filters(task_data: TaskData, images: list, user_stopped): return filtered_images -def save_to_disk(images: list, filtered_images: list, save_folder_path, req: GenerateImageRequest, task_data: TaskData): - metadata_entries = get_metadata_entries(req, task_data) - - if task_data.show_only_filtered_image or filtered_images == images: - data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format) - else: - data_utils.save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) - data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) - def construct_response(images: list, task_data: TaskData, base_seed: int): return [ ResponseImage( @@ -111,39 +93,6 @@ def construct_response(images: list, task_data: TaskData, base_seed: int): ) for i, img in enumerate(images) ] -def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): - metadata = get_printable_request(req) - metadata.update({ - 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, - 'use_vae_model': task_data.use_vae_model, - 'use_hypernetwork_model': task_data.use_hypernetwork_model, - 'use_face_correction': task_data.use_face_correction, - 'use_upscale': task_data.use_upscale, - }) - - entries = [metadata.copy() for _ in range(req.num_outputs)] - for i, entry in enumerate(entries): - entry['seed'] = req.seed + i - return entries - -def get_printable_request(req: GenerateImageRequest): - metadata = req.dict() - del metadata['init_image'] - del metadata['init_image_mask'] - return metadata - -def make_filename_callback(req: GenerateImageRequest, suffix=None): - def make_filename(i): - img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. - img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. - - prompt_flattened = filename_regex.sub('_', req.prompt)[:50] - name = f"{prompt_flattened}_{img_id}" - name = name if suffix is None else f'{name}_{suffix}' - return name - - return make_filename - def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) last_callback_time = -1 diff --git a/ui/sd_internal/save_utils.py b/ui/sd_internal/save_utils.py new file mode 100644 index 00000000..29c4a29c --- /dev/null +++ b/ui/sd_internal/save_utils.py @@ -0,0 +1,80 @@ +import os +import time +import base64 +import re + +from diffusionkit import data_utils +from diffusionkit.types import GenerateImageRequest + +from sd_internal import TaskData + +filename_regex = re.compile('[^a-zA-Z0-9]') + +# keep in sync with `ui/media/js/dnd.js` +TASK_TEXT_MAPPING = { + 'prompt': 'Prompt', + 'width': 'Width', + 'height': 'Height', + 'seed': 'Seed', + 'num_inference_steps': 'Steps', + 'guidance_scale': 'Guidance Scale', + 'prompt_strength': 'Prompt Strength', + 'use_face_correction': 'Use Face Correction', + 'use_upscale': 'Use Upscaling', + 'sampler_name': 'Sampler', + 'negative_prompt': 'Negative Prompt', + 'use_stable_diffusion_model': 'Stable Diffusion model', + 'use_hypernetwork_model': 'Hypernetwork model', + 'hypernetwork_strength': 'Hypernetwork Strength' +} + +def save_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): + save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) + metadata_entries = get_metadata_entries(req, task_data) + + if task_data.show_only_filtered_image or filtered_images == images: + data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format) + else: + data_utils.save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) + +def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): + metadata = get_printable_request(req) + metadata.update({ + 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, + 'use_vae_model': task_data.use_vae_model, + 'use_hypernetwork_model': task_data.use_hypernetwork_model, + 'use_face_correction': task_data.use_face_correction, + 'use_upscale': task_data.use_upscale, + }) + + # if text, format it in the text format expected by the UI + is_txt_format = (task_data.metadata_output_format.lower() == 'txt') + if is_txt_format: + metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING} + + entries = [metadata.copy() for _ in range(req.num_outputs)] + for i, entry in enumerate(entries): + entry['Seed' if is_txt_format else 'seed'] = req.seed + i + + return entries + +def get_printable_request(req: GenerateImageRequest): + metadata = req.dict() + del metadata['init_image'] + del metadata['init_image_mask'] + return metadata + +def make_filename_callback(req: GenerateImageRequest, suffix=None): + def make_filename(i): + img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. + img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. + + prompt_flattened = filename_regex.sub('_', req.prompt)[:50] + name = f"{prompt_flattened}_{img_id}" + name = name if suffix is None else f'{name}_{suffix}' + return name + + return make_filename \ No newline at end of file