forked from extern/easydiffusion
Refactor the save-to-disk code, moving parts of it to diffusionkit
This commit is contained in:
parent
e45cbbf1ca
commit
b57649828d
@ -7,14 +7,17 @@ class TaskData(BaseModel):
|
|||||||
session_id: str = "session"
|
session_id: str = "session"
|
||||||
save_to_disk_path: str = None
|
save_to_disk_path: str = None
|
||||||
turbo: bool = True
|
turbo: bool = True
|
||||||
|
|
||||||
use_face_correction: str = None # or "GFPGANv1.3"
|
use_face_correction: str = None # or "GFPGANv1.3"
|
||||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||||
use_stable_diffusion_model: str = "sd-v1-4"
|
use_stable_diffusion_model: str = "sd-v1-4"
|
||||||
use_vae_model: str = None
|
use_vae_model: str = None
|
||||||
use_hypernetwork_model: str = None
|
use_hypernetwork_model: str = None
|
||||||
|
|
||||||
show_only_filtered_image: bool = False
|
show_only_filtered_image: bool = False
|
||||||
output_format: str = "jpeg" # or "png"
|
output_format: str = "jpeg" # or "png"
|
||||||
output_quality: int = 75
|
output_quality: int = 75
|
||||||
|
metadata_output_format: str = "txt" # or "json"
|
||||||
stream_image_progress: bool = False
|
stream_image_progress: bool = False
|
||||||
|
|
||||||
class Image:
|
class Image:
|
||||||
|
@ -10,7 +10,7 @@ import logging
|
|||||||
from sd_internal import device_manager
|
from sd_internal import device_manager
|
||||||
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
||||||
|
|
||||||
from modules import model_loader, image_generator, image_utils, filters as image_filters
|
from modules import model_loader, image_generator, image_utils, filters as image_filters, data_utils
|
||||||
from modules.types import Context, GenerateImageRequest
|
from modules.types import Context, GenerateImageRequest
|
||||||
|
|
||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
@ -33,8 +33,18 @@ def init(device):
|
|||||||
device_manager.device_init(context, device)
|
device_manager.device_init(context, device)
|
||||||
|
|
||||||
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
|
log.info(f'request: {req.dict()}')
|
||||||
|
log.info(f'task data: {task_data.dict()}')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||||
|
|
||||||
|
res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed))
|
||||||
|
res = res.json()
|
||||||
|
data_queue.put(json.dumps(res))
|
||||||
|
log.info('Task completed')
|
||||||
|
|
||||||
|
return res
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
|
|
||||||
@ -46,21 +56,15 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
|
|||||||
|
|
||||||
def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
||||||
images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image)
|
filtered_images = apply_filters(task_data, images, user_stopped)
|
||||||
|
|
||||||
if task_data.save_to_disk_path is not None:
|
if task_data.save_to_disk_path is not None:
|
||||||
out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
||||||
save_images(images, out_path, metadata=req.to_metadata(), show_only_filtered_image=task_data.show_only_filtered_image)
|
save_to_disk(images, filtered_images, save_folder_path, req, task_data)
|
||||||
|
|
||||||
res = Response(req, task_data, images=construct_response(images))
|
return filtered_images if task_data.show_only_filtered_image else images + filtered_images
|
||||||
res = res.json()
|
|
||||||
data_queue.put(json.dumps(res))
|
|
||||||
log.info('Task completed')
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||||
log.info(req.to_metadata())
|
|
||||||
context.temp_images.clear()
|
context.temp_images.clear()
|
||||||
|
|
||||||
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
|
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||||
@ -77,11 +81,9 @@ def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_tem
|
|||||||
finally:
|
finally:
|
||||||
model_loader.gc(context)
|
model_loader.gc(context)
|
||||||
|
|
||||||
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
|
|
||||||
|
|
||||||
return images, user_stopped
|
return images, user_stopped
|
||||||
|
|
||||||
def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_filtered_image):
|
def apply_filters(task_data: TaskData, images: list, user_stopped):
|
||||||
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
|
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
|
||||||
return images
|
return images
|
||||||
|
|
||||||
@ -90,52 +92,68 @@ def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_fil
|
|||||||
if 'realesrgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_realesrgan)
|
if 'realesrgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_realesrgan)
|
||||||
|
|
||||||
filtered_images = []
|
filtered_images = []
|
||||||
for img, seed, _ in images:
|
for img in images:
|
||||||
for filter_fn in filters:
|
for filter_fn in filters:
|
||||||
img = filter_fn(context, img)
|
img = filter_fn(context, img)
|
||||||
|
|
||||||
filtered_images.append((img, seed, True))
|
filtered_images.append(img)
|
||||||
|
|
||||||
if not show_only_filtered_image:
|
|
||||||
filtered_images = images + filtered_images
|
|
||||||
|
|
||||||
return filtered_images
|
return filtered_images
|
||||||
|
|
||||||
def save_images(images: list, save_to_disk_path, metadata: dict, show_only_filtered_image):
|
def save_to_disk(images: list, filtered_images: list, save_folder_path, req: GenerateImageRequest, task_data: TaskData):
|
||||||
if save_to_disk_path is None:
|
metadata = req.dict()
|
||||||
return
|
del metadata['init_image']
|
||||||
|
del metadata['init_image_mask']
|
||||||
|
metadata.update({
|
||||||
|
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
|
||||||
|
'use_vae_model': task_data.use_vae_model,
|
||||||
|
'use_hypernetwork_model': task_data.use_hypernetwork_model,
|
||||||
|
'use_face_correction': task_data.use_face_correction,
|
||||||
|
'use_upscale': task_data.use_upscale,
|
||||||
|
})
|
||||||
|
|
||||||
def get_image_id(i):
|
metadata_entries = get_metadata_entries(req, task_data)
|
||||||
|
|
||||||
|
if task_data.show_only_filtered_image or filtered_images == images:
|
||||||
|
data_utils.save_images(filtered_images, save_folder_path, file_name=get_output_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||||
|
data_utils.save_metadata(metadata_entries, save_folder_path, file_name=get_output_filename_callback(req), output_format=task_data.metadata_output_format)
|
||||||
|
else:
|
||||||
|
data_utils.save_images(images, save_folder_path, file_name=get_output_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||||
|
data_utils.save_images(filtered_images, save_folder_path, file_name=get_output_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||||
|
data_utils.save_metadata(metadata_entries, save_folder_path, file_name=get_output_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format)
|
||||||
|
|
||||||
|
def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData):
|
||||||
|
metadata = req.dict()
|
||||||
|
del metadata['init_image']
|
||||||
|
del metadata['init_image_mask']
|
||||||
|
metadata.update({
|
||||||
|
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
|
||||||
|
'use_vae_model': task_data.use_vae_model,
|
||||||
|
'use_hypernetwork_model': task_data.use_hypernetwork_model,
|
||||||
|
'use_face_correction': task_data.use_face_correction,
|
||||||
|
'use_upscale': task_data.use_upscale,
|
||||||
|
})
|
||||||
|
|
||||||
|
return [metadata.copy().update({'seed': req.seed + i}) for i in range(req.num_outputs)]
|
||||||
|
|
||||||
|
def get_output_filename_callback(req: GenerateImageRequest, suffix=None):
|
||||||
|
def make_filename(i):
|
||||||
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
||||||
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
||||||
return img_id
|
|
||||||
|
|
||||||
def get_image_basepath(i):
|
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
|
||||||
os.makedirs(save_to_disk_path, exist_ok=True)
|
name = f"{prompt_flattened}_{img_id}"
|
||||||
prompt_flattened = filename_regex.sub('_', metadata['prompt'])[:50]
|
name = name if suffix is None else f'{name}_{suffix}'
|
||||||
return os.path.join(save_to_disk_path, f"{prompt_flattened}_{get_image_id(i)}")
|
return name
|
||||||
|
|
||||||
for i, img_data in enumerate(images):
|
return make_filename
|
||||||
img, seed, filtered = img_data
|
|
||||||
img_path = get_image_basepath(i)
|
|
||||||
|
|
||||||
if not filtered or show_only_filtered_image:
|
def construct_response(images: list, task_data: TaskData, base_seed: int):
|
||||||
img_metadata_path = img_path + '.txt'
|
|
||||||
m = metadata.copy()
|
|
||||||
m['seed'] = seed
|
|
||||||
with open(img_metadata_path, 'w', encoding='utf-8') as f:
|
|
||||||
f.write(m)
|
|
||||||
|
|
||||||
img_path += '_filtered' if filtered else ''
|
|
||||||
img_path += '.' + metadata['output_format']
|
|
||||||
img.save(img_path, quality=metadata['output_quality'])
|
|
||||||
|
|
||||||
def construct_response(task_data: TaskData, images: list):
|
|
||||||
return [
|
return [
|
||||||
ResponseImage(
|
ResponseImage(
|
||||||
data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality),
|
data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality),
|
||||||
seed=seed
|
seed=base_seed + i
|
||||||
) for img, seed, _ in images
|
) for i, img in enumerate(images)
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||||
|
Loading…
Reference in New Issue
Block a user