forked from extern/easydiffusion
Refactor save_to_disk
This commit is contained in:
parent
12e0194c7f
commit
35ff4f439e
@ -1,13 +1,9 @@
|
|||||||
import queue
|
import queue
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import base64
|
|
||||||
import re
|
|
||||||
import traceback
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from sd_internal import device_manager
|
from sd_internal import device_manager, save_utils
|
||||||
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
||||||
|
|
||||||
from diffusionkit import model_loader, image_generator, image_utils, filters as image_filters, data_utils
|
from diffusionkit import model_loader, image_generator, image_utils, filters as image_filters, data_utils
|
||||||
@ -20,8 +16,6 @@ context = Context() # thread-local
|
|||||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||||
'''
|
'''
|
||||||
|
|
||||||
filename_regex = re.compile('[^a-zA-Z0-9]')
|
|
||||||
|
|
||||||
def init(device):
|
def init(device):
|
||||||
'''
|
'''
|
||||||
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
|
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
|
||||||
@ -33,7 +27,7 @@ def init(device):
|
|||||||
device_manager.device_init(context, device)
|
device_manager.device_init(context, device)
|
||||||
|
|
||||||
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
log.info(f'request: {get_printable_request(req)}')
|
log.info(f'request: {save_utils.get_printable_request(req)}')
|
||||||
log.info(f'task data: {task_data.dict()}')
|
log.info(f'task data: {task_data.dict()}')
|
||||||
|
|
||||||
images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||||
@ -50,8 +44,7 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q
|
|||||||
filtered_images = apply_filters(task_data, images, user_stopped)
|
filtered_images = apply_filters(task_data, images, user_stopped)
|
||||||
|
|
||||||
if task_data.save_to_disk_path is not None:
|
if task_data.save_to_disk_path is not None:
|
||||||
save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
save_utils.save_to_disk(images, filtered_images, req, task_data)
|
||||||
save_to_disk(images, filtered_images, save_folder_path, req, task_data)
|
|
||||||
|
|
||||||
return filtered_images if task_data.show_only_filtered_image else images + filtered_images
|
return filtered_images if task_data.show_only_filtered_image else images + filtered_images
|
||||||
|
|
||||||
@ -92,17 +85,6 @@ def apply_filters(task_data: TaskData, images: list, user_stopped):
|
|||||||
|
|
||||||
return filtered_images
|
return filtered_images
|
||||||
|
|
||||||
def save_to_disk(images: list, filtered_images: list, save_folder_path, req: GenerateImageRequest, task_data: TaskData):
|
|
||||||
metadata_entries = get_metadata_entries(req, task_data)
|
|
||||||
|
|
||||||
if task_data.show_only_filtered_image or filtered_images == images:
|
|
||||||
data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
|
||||||
data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format)
|
|
||||||
else:
|
|
||||||
data_utils.save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
|
||||||
data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
|
||||||
data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format)
|
|
||||||
|
|
||||||
def construct_response(images: list, task_data: TaskData, base_seed: int):
|
def construct_response(images: list, task_data: TaskData, base_seed: int):
|
||||||
return [
|
return [
|
||||||
ResponseImage(
|
ResponseImage(
|
||||||
@ -111,39 +93,6 @@ def construct_response(images: list, task_data: TaskData, base_seed: int):
|
|||||||
) for i, img in enumerate(images)
|
) for i, img in enumerate(images)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData):
|
|
||||||
metadata = get_printable_request(req)
|
|
||||||
metadata.update({
|
|
||||||
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
|
|
||||||
'use_vae_model': task_data.use_vae_model,
|
|
||||||
'use_hypernetwork_model': task_data.use_hypernetwork_model,
|
|
||||||
'use_face_correction': task_data.use_face_correction,
|
|
||||||
'use_upscale': task_data.use_upscale,
|
|
||||||
})
|
|
||||||
|
|
||||||
entries = [metadata.copy() for _ in range(req.num_outputs)]
|
|
||||||
for i, entry in enumerate(entries):
|
|
||||||
entry['seed'] = req.seed + i
|
|
||||||
return entries
|
|
||||||
|
|
||||||
def get_printable_request(req: GenerateImageRequest):
|
|
||||||
metadata = req.dict()
|
|
||||||
del metadata['init_image']
|
|
||||||
del metadata['init_image_mask']
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
def make_filename_callback(req: GenerateImageRequest, suffix=None):
|
|
||||||
def make_filename(i):
|
|
||||||
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
|
||||||
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
|
||||||
|
|
||||||
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
|
|
||||||
name = f"{prompt_flattened}_{img_id}"
|
|
||||||
name = name if suffix is None else f'{name}_{suffix}'
|
|
||||||
return name
|
|
||||||
|
|
||||||
return make_filename
|
|
||||||
|
|
||||||
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||||
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
||||||
last_callback_time = -1
|
last_callback_time = -1
|
||||||
|
80
ui/sd_internal/save_utils.py
Normal file
80
ui/sd_internal/save_utils.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import base64
|
||||||
|
import re
|
||||||
|
|
||||||
|
from diffusionkit import data_utils
|
||||||
|
from diffusionkit.types import GenerateImageRequest
|
||||||
|
|
||||||
|
from sd_internal import TaskData
|
||||||
|
|
||||||
|
filename_regex = re.compile('[^a-zA-Z0-9]')
|
||||||
|
|
||||||
|
# keep in sync with `ui/media/js/dnd.js`
|
||||||
|
TASK_TEXT_MAPPING = {
|
||||||
|
'prompt': 'Prompt',
|
||||||
|
'width': 'Width',
|
||||||
|
'height': 'Height',
|
||||||
|
'seed': 'Seed',
|
||||||
|
'num_inference_steps': 'Steps',
|
||||||
|
'guidance_scale': 'Guidance Scale',
|
||||||
|
'prompt_strength': 'Prompt Strength',
|
||||||
|
'use_face_correction': 'Use Face Correction',
|
||||||
|
'use_upscale': 'Use Upscaling',
|
||||||
|
'sampler_name': 'Sampler',
|
||||||
|
'negative_prompt': 'Negative Prompt',
|
||||||
|
'use_stable_diffusion_model': 'Stable Diffusion model',
|
||||||
|
'use_hypernetwork_model': 'Hypernetwork model',
|
||||||
|
'hypernetwork_strength': 'Hypernetwork Strength'
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
|
||||||
|
save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
||||||
|
metadata_entries = get_metadata_entries(req, task_data)
|
||||||
|
|
||||||
|
if task_data.show_only_filtered_image or filtered_images == images:
|
||||||
|
data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||||
|
data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format)
|
||||||
|
else:
|
||||||
|
data_utils.save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||||
|
data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||||
|
data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format)
|
||||||
|
|
||||||
|
def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData):
|
||||||
|
metadata = get_printable_request(req)
|
||||||
|
metadata.update({
|
||||||
|
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
|
||||||
|
'use_vae_model': task_data.use_vae_model,
|
||||||
|
'use_hypernetwork_model': task_data.use_hypernetwork_model,
|
||||||
|
'use_face_correction': task_data.use_face_correction,
|
||||||
|
'use_upscale': task_data.use_upscale,
|
||||||
|
})
|
||||||
|
|
||||||
|
# if text, format it in the text format expected by the UI
|
||||||
|
is_txt_format = (task_data.metadata_output_format.lower() == 'txt')
|
||||||
|
if is_txt_format:
|
||||||
|
metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING}
|
||||||
|
|
||||||
|
entries = [metadata.copy() for _ in range(req.num_outputs)]
|
||||||
|
for i, entry in enumerate(entries):
|
||||||
|
entry['Seed' if is_txt_format else 'seed'] = req.seed + i
|
||||||
|
|
||||||
|
return entries
|
||||||
|
|
||||||
|
def get_printable_request(req: GenerateImageRequest):
|
||||||
|
metadata = req.dict()
|
||||||
|
del metadata['init_image']
|
||||||
|
del metadata['init_image_mask']
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def make_filename_callback(req: GenerateImageRequest, suffix=None):
|
||||||
|
def make_filename(i):
|
||||||
|
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
||||||
|
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
||||||
|
|
||||||
|
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
|
||||||
|
name = f"{prompt_flattened}_{img_id}"
|
||||||
|
name = name if suffix is None else f'{name}_{suffix}'
|
||||||
|
return name
|
||||||
|
|
||||||
|
return make_filename
|
Loading…
Reference in New Issue
Block a user