From 536082c1a625d8c45dffafd0acf201163c6d8d9b Mon Sep 17 00:00:00 2001 From: cmdr2 Date: Thu, 31 Aug 2023 15:57:53 +0530 Subject: [PATCH] Save filtered images to disk if required by the API, for e.g. when clicking 'Upscale' or 'Fix Faces on the image --- ui/easydiffusion/server.py | 12 +++-- ui/easydiffusion/tasks/filter_images.py | 61 ++++++++++++++++++++++--- ui/easydiffusion/tasks/render_images.py | 49 +++++++++++++------- ui/easydiffusion/types.py | 13 +++++- ui/easydiffusion/utils/save_utils.py | 58 +++++++++++++++-------- ui/media/js/main.js | 4 ++ 6 files changed, 151 insertions(+), 46 deletions(-) diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index 5af198fe..b8a7b3ff 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -15,8 +15,10 @@ from easydiffusion.types import ( FilterImageRequest, MergeRequest, TaskData, + RenderTaskData, ModelsData, OutputFormatData, + SaveToDiskData, convert_legacy_render_req_to_new, ) from easydiffusion.utils import log @@ -262,9 +264,10 @@ def render_internal(req: dict): # separate out the request data into rendering and task-specific data render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req) - task_data: TaskData = TaskData.parse_obj(req) + task_data: RenderTaskData = RenderTaskData.parse_obj(req) models_data: ModelsData = ModelsData.parse_obj(req) output_format: OutputFormatData = OutputFormatData.parse_obj(req) + save_data: SaveToDiskData = SaveToDiskData.parse_obj(req) # Overwrite user specified save path config = app.getConfig() @@ -281,7 +284,7 @@ def render_internal(req: dict): ) # enqueue the task - task = RenderTask(render_req, task_data, models_data, output_format) + task = RenderTask(render_req, task_data, models_data, output_format, save_data) return enqueue_task(task) except HTTPException as e: raise e @@ -292,13 +295,14 @@ def render_internal(req: dict): def filter_internal(req: dict): try: - session_id = req.get("session_id", "session") filter_req: FilterImageRequest = FilterImageRequest.parse_obj(req) + task_data: TaskData = TaskData.parse_obj(req) models_data: ModelsData = ModelsData.parse_obj(req) output_format: OutputFormatData = OutputFormatData.parse_obj(req) + save_data: SaveToDiskData = SaveToDiskData.parse_obj(req) # enqueue the task - task = FilterTask(filter_req, session_id, models_data, output_format) + task = FilterTask(filter_req, task_data, models_data, output_format, save_data) return enqueue_task(task) except HTTPException as e: raise e diff --git a/ui/easydiffusion/tasks/filter_images.py b/ui/easydiffusion/tasks/filter_images.py index 1e653e3e..7d3d1326 100644 --- a/ui/easydiffusion/tasks/filter_images.py +++ b/ui/easydiffusion/tasks/filter_images.py @@ -1,12 +1,25 @@ +import os import json import pprint +import time + +from numpy import base_repr from sdkit.filter import apply_filters from sdkit.models import load_model -from sdkit.utils import img_to_base64_str, get_image, log +from sdkit.utils import img_to_base64_str, get_image, log, save_images from easydiffusion import model_manager, runtime -from easydiffusion.types import FilterImageRequest, FilterImageResponse, ModelsData, OutputFormatData +from easydiffusion.types import ( + FilterImageRequest, + FilterImageResponse, + ModelsData, + OutputFormatData, + SaveToDiskData, + TaskData, + GenerateImageRequest, +) +from easydiffusion.utils.save_utils import format_folder_name from .task import Task @@ -15,13 +28,22 @@ class FilterTask(Task): "For applying filters to input images" def __init__( - self, req: FilterImageRequest, session_id: str, models_data: ModelsData, output_format: OutputFormatData + self, + req: FilterImageRequest, + task_data: TaskData, + models_data: ModelsData, + output_format: OutputFormatData, + save_data: SaveToDiskData, ): - super().__init__(session_id) + super().__init__(task_data.session_id) + + task_data.request_id = self.id self.request = req + self.task_data = task_data self.models_data = models_data self.output_format = output_format + self.save_data = save_data # convert to multi-filter format, if necessary if isinstance(req.filter, str): @@ -34,13 +56,15 @@ class FilterTask(Task): def run(self): "Runs the image filtering task on the assigned thread" + from easydiffusion import app + context = runtime.context model_manager.resolve_model_paths(self.models_data) model_manager.reload_models_if_necessary(context, self.models_data) model_manager.fail_if_models_did_not_load(context) - print_task_info(self.request, self.models_data, self.output_format) + print_task_info(self.request, self.models_data, self.output_format, self.save_data) if isinstance(self.request.image, list): images = [get_image(img) for img in self.request.image] @@ -50,6 +74,26 @@ class FilterTask(Task): images = filter_images(context, images, self.request.filter, self.request.filter_params) output_format = self.output_format + + if self.save_data.save_to_disk_path is not None: + app_config = app.getConfig() + folder_format = app_config.get("folder_format", "$id") + + dummy_req = GenerateImageRequest() + img_id = base_repr(int(time.time() * 10000), 36)[-7:] # Base 36 conversion, 0-9, A-Z + + save_dir_path = os.path.join( + self.save_data.save_to_disk_path, format_folder_name(folder_format, dummy_req, self.task_data) + ) + save_images( + images, + save_dir_path, + file_name=img_id, + output_format=output_format.output_format, + output_quality=output_format.output_quality, + output_lossless=output_format.output_lossless, + ) + images = [ img_to_base64_str( img, output_format.output_format, output_format.output_quality, output_format.output_lossless @@ -60,6 +104,7 @@ class FilterTask(Task): res = FilterImageResponse(self.request, self.models_data, images=images) res = res.json() self.buffer_queue.put(json.dumps(res)) + log.info("Filter task completed") self.response = res @@ -105,11 +150,15 @@ def after_filter(context, filter_name, filter_params, previous_state): load_model(context, "realesrgan") -def print_task_info(req: FilterImageRequest, models_data: ModelsData, output_format: OutputFormatData): +def print_task_info( + req: FilterImageRequest, models_data: ModelsData, output_format: OutputFormatData, save_data: SaveToDiskData +): req_str = pprint.pformat({"filter": req.filter, "filter_params": req.filter_params}).replace("[", "\[") models_data = pprint.pformat(models_data.dict()).replace("[", "\[") output_format = pprint.pformat(output_format.dict()).replace("[", "\[") + save_data = pprint.pformat(save_data.dict()).replace("[", "\[") log.info(f"request: {req_str}") log.info(f"models data: {models_data}") log.info(f"output format: {output_format}") + log.info(f"save data: {save_data}") diff --git a/ui/easydiffusion/tasks/render_images.py b/ui/easydiffusion/tasks/render_images.py index b512e707..528a33b1 100644 --- a/ui/easydiffusion/tasks/render_images.py +++ b/ui/easydiffusion/tasks/render_images.py @@ -4,9 +4,9 @@ import queue import time from easydiffusion import model_manager, runtime -from easydiffusion.types import GenerateImageRequest, ModelsData, OutputFormatData +from easydiffusion.types import GenerateImageRequest, ModelsData, OutputFormatData, SaveToDiskData from easydiffusion.types import Image as ResponseImage -from easydiffusion.types import GenerateImageResponse, TaskData, UserInitiatedStop +from easydiffusion.types import GenerateImageResponse, RenderTaskData, UserInitiatedStop from easydiffusion.utils import get_printable_request, log, save_images_to_disk from sdkit.generate import generate_images from sdkit.utils import ( @@ -28,15 +28,23 @@ class RenderTask(Task): "For image generation" def __init__( - self, req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData + self, + req: GenerateImageRequest, + task_data: RenderTaskData, + models_data: ModelsData, + output_format: OutputFormatData, + save_data: SaveToDiskData, ): super().__init__(task_data.session_id) task_data.request_id = self.id - self.render_request: GenerateImageRequest = req # Initial Request - self.task_data: TaskData = task_data + + self.render_request = req # Initial Request + self.task_data = task_data self.models_data = models_data self.output_format = output_format + self.save_data = save_data + self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2) def run(self): @@ -87,6 +95,7 @@ class RenderTask(Task): self.task_data, self.models_data, self.output_format, + self.save_data, self.buffer_queue, self.temp_images, step_callback, @@ -129,22 +138,23 @@ class RenderTask(Task): def make_images( context, req: GenerateImageRequest, - task_data: TaskData, + task_data: RenderTaskData, models_data: ModelsData, output_format: OutputFormatData, + save_data: SaveToDiskData, data_queue: queue.Queue, task_temp_images: list, step_callback, ): context.stop_processing = False - print_task_info(req, task_data, models_data, output_format) + print_task_info(req, task_data, models_data, output_format, save_data) images, seeds = make_images_internal( - context, req, task_data, models_data, output_format, data_queue, task_temp_images, step_callback + context, req, task_data, models_data, output_format, save_data, data_queue, task_temp_images, step_callback ) res = GenerateImageResponse( - req, task_data, models_data, output_format, images=construct_response(images, seeds, output_format) + req, task_data, models_data, output_format, save_data, images=construct_response(images, seeds, output_format) ) res = res.json() data_queue.put(json.dumps(res)) @@ -154,25 +164,32 @@ def make_images( def print_task_info( - req: GenerateImageRequest, task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData + req: GenerateImageRequest, + task_data: RenderTaskData, + models_data: ModelsData, + output_format: OutputFormatData, + save_data: SaveToDiskData, ): - req_str = pprint.pformat(get_printable_request(req, task_data, output_format)).replace("[", "\[") + req_str = pprint.pformat(get_printable_request(req, task_data, output_format, save_data)).replace("[", "\[") task_str = pprint.pformat(task_data.dict()).replace("[", "\[") models_data = pprint.pformat(models_data.dict()).replace("[", "\[") output_format = pprint.pformat(output_format.dict()).replace("[", "\[") + save_data = pprint.pformat(save_data.dict()).replace("[", "\[") log.info(f"request: {req_str}") log.info(f"task data: {task_str}") # log.info(f"models data: {models_data}") log.info(f"output format: {output_format}") + log.info(f"save data: {save_data}") def make_images_internal( context, req: GenerateImageRequest, - task_data: TaskData, + task_data: RenderTaskData, models_data: ModelsData, output_format: OutputFormatData, + save_data: SaveToDiskData, data_queue: queue.Queue, task_temp_images: list, step_callback, @@ -194,8 +211,8 @@ def make_images_internal( filters, filter_params = task_data.filters, task_data.filter_params filtered_images = filter_images(context, images, filters, filter_params) if not user_stopped else images - if task_data.save_to_disk_path is not None: - save_images_to_disk(images, filtered_images, req, task_data, output_format) + if save_data.save_to_disk_path is not None: + save_images_to_disk(images, filtered_images, req, task_data, output_format, save_data) seeds = [*range(req.seed, req.seed + len(images))] if task_data.show_only_filtered_image or filtered_images is images: @@ -207,7 +224,7 @@ def make_images_internal( def generate_images_internal( context, req: GenerateImageRequest, - task_data: TaskData, + task_data: RenderTaskData, models_data: ModelsData, data_queue: queue.Queue, task_temp_images: list, @@ -298,7 +315,7 @@ def construct_response(images: list, seeds: list, output_format: OutputFormatDat def make_step_callback( context, req: GenerateImageRequest, - task_data: TaskData, + task_data: RenderTaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, diff --git a/ui/easydiffusion/types.py b/ui/easydiffusion/types.py index 29082438..ec331c29 100644 --- a/ui/easydiffusion/types.py +++ b/ui/easydiffusion/types.py @@ -58,10 +58,17 @@ class OutputFormatData(BaseModel): output_lossless: bool = False +class SaveToDiskData(BaseModel): + save_to_disk_path: str = None + metadata_output_format: str = "txt" # or "json" + + class TaskData(BaseModel): request_id: str = None session_id: str = "session" - save_to_disk_path: str = None + + +class RenderTaskData(TaskData): vram_usage_level: str = "balanced" # or "low" or "medium" use_face_correction: Union[str, List[str]] = None # or "GFPGANv1.3" @@ -80,7 +87,6 @@ class TaskData(BaseModel): show_only_filtered_image: bool = False block_nsfw: bool = False - metadata_output_format: str = "txt" # or "json" stream_image_progress: bool = False stream_image_progress_interval: int = 5 clip_skip: bool = False @@ -126,12 +132,14 @@ class GenerateImageResponse: task_data: TaskData, models_data: ModelsData, output_format: OutputFormatData, + save_data: SaveToDiskData, images: list, ): self.render_request = render_request self.task_data = task_data self.models_data = models_data self.output_format = output_format + self.save_data = save_data self.images = images def json(self): @@ -141,6 +149,7 @@ class GenerateImageResponse: task_data = self.task_data.dict() task_data.update(self.output_format.dict()) + task_data.update(self.save_data.dict()) res = { "status": "succeeded", diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index 3bd5efba..95bb37c5 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -7,7 +7,7 @@ from datetime import datetime from functools import reduce from easydiffusion import app -from easydiffusion.types import GenerateImageRequest, TaskData, OutputFormatData +from easydiffusion.types import GenerateImageRequest, TaskData, RenderTaskData, OutputFormatData, SaveToDiskData from numpy import base_repr from sdkit.utils import save_dicts, save_images from sdkit.models.model_loader.embeddings import get_embedding_token @@ -95,7 +95,7 @@ def format_folder_name(format: str, req: GenerateImageRequest, task_data: TaskDa def format_file_name( format: str, req: GenerateImageRequest, - task_data: TaskData, + task_data: RenderTaskData, now: float, batch_file_number: int, folder_img_number: ImageNumber, @@ -118,13 +118,18 @@ def format_file_name( def save_images_to_disk( - images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData + images: list, + filtered_images: list, + req: GenerateImageRequest, + task_data: RenderTaskData, + output_format: OutputFormatData, + save_data: SaveToDiskData, ): now = time.time() app_config = app.getConfig() folder_format = app_config.get("folder_format", "$id") - save_dir_path = os.path.join(task_data.save_to_disk_path, format_folder_name(folder_format, req, task_data)) - metadata_entries = get_metadata_entries_for_request(req, task_data, output_format) + save_dir_path = os.path.join(save_data.save_to_disk_path, format_folder_name(folder_format, req, task_data)) + metadata_entries = get_metadata_entries_for_request(req, task_data, output_format, save_data) file_number = calculate_img_number(save_dir_path, task_data) make_filename = make_filename_callback( app_config.get("filename_format", "$p_$tsb64"), @@ -143,8 +148,8 @@ def save_images_to_disk( output_quality=output_format.output_quality, output_lossless=output_format.output_lossless, ) - if task_data.metadata_output_format: - for metadata_output_format in task_data.metadata_output_format.split(","): + if save_data.metadata_output_format: + for metadata_output_format in save_data.metadata_output_format.split(","): if metadata_output_format.lower() in ["json", "txt", "embed"]: save_dicts( metadata_entries, @@ -179,8 +184,8 @@ def save_images_to_disk( output_quality=output_format.output_quality, output_lossless=output_format.output_lossless, ) - if task_data.metadata_output_format: - for metadata_output_format in task_data.metadata_output_format.split(","): + if save_data.metadata_output_format: + for metadata_output_format in save_data.metadata_output_format.split(","): if metadata_output_format.lower() in ["json", "txt", "embed"]: save_dicts( metadata_entries, @@ -191,11 +196,13 @@ def save_images_to_disk( ) -def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData): - metadata = get_printable_request(req, task_data, output_format) +def get_metadata_entries_for_request( + req: GenerateImageRequest, task_data: RenderTaskData, output_format: OutputFormatData, save_data: SaveToDiskData +): + metadata = get_printable_request(req, task_data, output_format, save_data) # if text, format it in the text format expected by the UI - is_txt_format = task_data.metadata_output_format and "txt" in task_data.metadata_output_format.lower().split(",") + is_txt_format = save_data.metadata_output_format and "txt" in save_data.metadata_output_format.lower().split(",") if is_txt_format: def format_value(value): @@ -214,10 +221,13 @@ def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskD return entries -def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output_format: OutputFormatData): +def get_printable_request( + req: GenerateImageRequest, task_data: RenderTaskData, output_format: OutputFormatData, save_data: SaveToDiskData +): req_metadata = req.dict() task_data_metadata = task_data.dict() task_data_metadata.update(output_format.dict()) + task_data_metadata.update(save_data.dict()) app_config = app.getConfig() using_diffusers = app_config.get("test_diffusers", True) @@ -240,7 +250,9 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output # Check if the filename has the right extension if not any(map(lambda ext: entry.name.endswith(ext), embeddings_extensions)): continue - embedding_name_regex = regex.compile(r"(^|[\s,])" + regex.escape(get_embedding_token(entry.name)) + r"([+-]*$|[\s,]|[+-]+[\s,])") + embedding_name_regex = regex.compile( + r"(^|[\s,])" + regex.escape(get_embedding_token(entry.name)) + r"([+-]*$|[\s,]|[+-]+[\s,])" + ) if embedding_name_regex.search(req.prompt) or embedding_name_regex.search(req.negative_prompt): used_embeddings.append(entry.path) elif entry.is_dir(): @@ -269,7 +281,17 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output del metadata[key] else: for key in ( - x for x in ["use_lora_model", "lora_alpha", "clip_skip", "tiling", "latent_upscaler_steps", "use_controlnet_model", "control_filter_to_apply"] if x in metadata + x + for x in [ + "use_lora_model", + "lora_alpha", + "clip_skip", + "tiling", + "latent_upscaler_steps", + "use_controlnet_model", + "control_filter_to_apply", + ] + if x in metadata ): del metadata[key] @@ -279,7 +301,7 @@ def get_printable_request(req: GenerateImageRequest, task_data: TaskData, output def make_filename_callback( filename_format: str, req: GenerateImageRequest, - task_data: TaskData, + task_data: RenderTaskData, folder_img_number: int, suffix=None, now=None, @@ -296,7 +318,7 @@ def make_filename_callback( return make_filename -def _calculate_img_number(save_dir_path: str, task_data: TaskData): +def _calculate_img_number(save_dir_path: str, task_data: RenderTaskData): def get_highest_img_number(accumulator: int, file: os.DirEntry) -> int: if not file.is_file: return accumulator @@ -340,5 +362,5 @@ def _calculate_img_number(save_dir_path: str, task_data: TaskData): _calculate_img_number.session_img_numbers = {} -def calculate_img_number(save_dir_path: str, task_data: TaskData): +def calculate_img_number(save_dir_path: str, task_data: RenderTaskData): return ImageNumber(lambda: _calculate_img_number(save_dir_path, task_data)) diff --git a/ui/media/js/main.js b/ui/media/js/main.js index d26fe0db..ade330d8 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -852,6 +852,10 @@ function applyInlineFilter(filterName, path, filterParams, img, statusText, tool } filterReq.model_paths[filterName] = path + if (saveToDiskField.checked && diskPathField.value.trim() !== "") { + filterReq.save_to_disk_path = diskPathField.value.trim() + } + tools.spinnerStatus.innerText = statusText tools.spinner.classList.remove("displayNone")