Refactor save_to_disk

This commit is contained in:
cmdr2 2022-12-14 16:30:19 +05:30
parent 12e0194c7f
commit 35ff4f439e
2 changed files with 83 additions and 54 deletions

View File

@ -1,13 +1,9 @@
import queue import queue
import time import time
import json import json
import os
import base64
import re
import traceback
import logging import logging
from sd_internal import device_manager from sd_internal import device_manager, save_utils
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
from diffusionkit import model_loader, image_generator, image_utils, filters as image_filters, data_utils from diffusionkit import model_loader, image_generator, image_utils, filters as image_filters, data_utils
@ -20,8 +16,6 @@ context = Context() # thread-local
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
''' '''
filename_regex = re.compile('[^a-zA-Z0-9]')
def init(device): def init(device):
''' '''
Initializes the fields that will be bound to this runtime's context, and sets the current torch device Initializes the fields that will be bound to this runtime's context, and sets the current torch device
@ -33,7 +27,7 @@ def init(device):
device_manager.device_init(context, device) device_manager.device_init(context, device)
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
log.info(f'request: {get_printable_request(req)}') log.info(f'request: {save_utils.get_printable_request(req)}')
log.info(f'task data: {task_data.dict()}') log.info(f'task data: {task_data.dict()}')
images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
@ -50,8 +44,7 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q
filtered_images = apply_filters(task_data, images, user_stopped) filtered_images = apply_filters(task_data, images, user_stopped)
if task_data.save_to_disk_path is not None: if task_data.save_to_disk_path is not None:
save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) save_utils.save_to_disk(images, filtered_images, req, task_data)
save_to_disk(images, filtered_images, save_folder_path, req, task_data)
return filtered_images if task_data.show_only_filtered_image else images + filtered_images return filtered_images if task_data.show_only_filtered_image else images + filtered_images
@ -92,17 +85,6 @@ def apply_filters(task_data: TaskData, images: list, user_stopped):
return filtered_images return filtered_images
def save_to_disk(images: list, filtered_images: list, save_folder_path, req: GenerateImageRequest, task_data: TaskData):
metadata_entries = get_metadata_entries(req, task_data)
if task_data.show_only_filtered_image or filtered_images == images:
data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format)
else:
data_utils.save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality)
data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format)
def construct_response(images: list, task_data: TaskData, base_seed: int): def construct_response(images: list, task_data: TaskData, base_seed: int):
return [ return [
ResponseImage( ResponseImage(
@ -111,39 +93,6 @@ def construct_response(images: list, task_data: TaskData, base_seed: int):
) for i, img in enumerate(images) ) for i, img in enumerate(images)
] ]
def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData):
metadata = get_printable_request(req)
metadata.update({
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
'use_vae_model': task_data.use_vae_model,
'use_hypernetwork_model': task_data.use_hypernetwork_model,
'use_face_correction': task_data.use_face_correction,
'use_upscale': task_data.use_upscale,
})
entries = [metadata.copy() for _ in range(req.num_outputs)]
for i, entry in enumerate(entries):
entry['seed'] = req.seed + i
return entries
def get_printable_request(req: GenerateImageRequest):
metadata = req.dict()
del metadata['init_image']
del metadata['init_image_mask']
return metadata
def make_filename_callback(req: GenerateImageRequest, suffix=None):
def make_filename(i):
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
name = f"{prompt_flattened}_{img_id}"
name = name if suffix is None else f'{name}_{suffix}'
return name
return make_filename
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
last_callback_time = -1 last_callback_time = -1

View File

@ -0,0 +1,80 @@
import os
import time
import base64
import re
from diffusionkit import data_utils
from diffusionkit.types import GenerateImageRequest
from sd_internal import TaskData
filename_regex = re.compile('[^a-zA-Z0-9]')
# keep in sync with `ui/media/js/dnd.js`
TASK_TEXT_MAPPING = {
'prompt': 'Prompt',
'width': 'Width',
'height': 'Height',
'seed': 'Seed',
'num_inference_steps': 'Steps',
'guidance_scale': 'Guidance Scale',
'prompt_strength': 'Prompt Strength',
'use_face_correction': 'Use Face Correction',
'use_upscale': 'Use Upscaling',
'sampler_name': 'Sampler',
'negative_prompt': 'Negative Prompt',
'use_stable_diffusion_model': 'Stable Diffusion model',
'use_hypernetwork_model': 'Hypernetwork model',
'hypernetwork_strength': 'Hypernetwork Strength'
}
def save_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
metadata_entries = get_metadata_entries(req, task_data)
if task_data.show_only_filtered_image or filtered_images == images:
data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format)
else:
data_utils.save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality)
data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format)
def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData):
metadata = get_printable_request(req)
metadata.update({
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
'use_vae_model': task_data.use_vae_model,
'use_hypernetwork_model': task_data.use_hypernetwork_model,
'use_face_correction': task_data.use_face_correction,
'use_upscale': task_data.use_upscale,
})
# if text, format it in the text format expected by the UI
is_txt_format = (task_data.metadata_output_format.lower() == 'txt')
if is_txt_format:
metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING}
entries = [metadata.copy() for _ in range(req.num_outputs)]
for i, entry in enumerate(entries):
entry['Seed' if is_txt_format else 'seed'] = req.seed + i
return entries
def get_printable_request(req: GenerateImageRequest):
metadata = req.dict()
del metadata['init_image']
del metadata['init_image_mask']
return metadata
def make_filename_callback(req: GenerateImageRequest, suffix=None):
def make_filename(i):
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
name = f"{prompt_flattened}_{img_id}"
name = name if suffix is None else f'{name}_{suffix}'
return name
return make_filename