diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py index a06220ca..f0710652 100644 --- a/ui/sd_internal/__init__.py +++ b/ui/sd_internal/__init__.py @@ -7,14 +7,17 @@ class TaskData(BaseModel): session_id: str = "session" save_to_disk_path: str = None turbo: bool = True + use_face_correction: str = None # or "GFPGANv1.3" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_stable_diffusion_model: str = "sd-v1-4" use_vae_model: str = None use_hypernetwork_model: str = None + show_only_filtered_image: bool = False output_format: str = "jpeg" # or "png" output_quality: int = 75 + metadata_output_format: str = "txt" # or "json" stream_image_progress: bool = False class Image: diff --git a/ui/sd_internal/renderer.py b/ui/sd_internal/renderer.py index 645fdd00..440f2bf2 100644 --- a/ui/sd_internal/renderer.py +++ b/ui/sd_internal/renderer.py @@ -10,7 +10,7 @@ import logging from sd_internal import device_manager from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop -from modules import model_loader, image_generator, image_utils, filters as image_filters +from modules import model_loader, image_generator, image_utils, filters as image_filters, data_utils from modules.types import Context, GenerateImageRequest log = logging.getLogger() @@ -33,8 +33,18 @@ def init(device): device_manager.device_init(context, device) def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): + log.info(f'request: {req.dict()}') + log.info(f'task data: {task_data.dict()}') + try: - return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) + images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) + + res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed)) + res = res.json() + data_queue.put(json.dumps(res)) + log.info('Task completed') + + return res except Exception as e: log.error(traceback.format_exc()) @@ -46,21 +56,15 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) - images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image) + filtered_images = apply_filters(task_data, images, user_stopped) if task_data.save_to_disk_path is not None: - out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) - save_images(images, out_path, metadata=req.to_metadata(), show_only_filtered_image=task_data.show_only_filtered_image) + save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) + save_to_disk(images, filtered_images, save_folder_path, req, task_data) - res = Response(req, task_data, images=construct_response(images)) - res = res.json() - data_queue.put(json.dumps(res)) - log.info('Task completed') - - return res + return filtered_images if task_data.show_only_filtered_image else images + filtered_images def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): - log.info(req.to_metadata()) context.temp_images.clear() image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress) @@ -77,11 +81,9 @@ def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_tem finally: model_loader.gc(context) - images = [(image, req.seed + i, False) for i, image in enumerate(images)] - return images, user_stopped -def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_filtered_image): +def apply_filters(task_data: TaskData, images: list, user_stopped): if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None): return images @@ -90,52 +92,68 @@ def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_fil if 'realesrgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_realesrgan) filtered_images = [] - for img, seed, _ in images: + for img in images: for filter_fn in filters: img = filter_fn(context, img) - filtered_images.append((img, seed, True)) - - if not show_only_filtered_image: - filtered_images = images + filtered_images + filtered_images.append(img) return filtered_images -def save_images(images: list, save_to_disk_path, metadata: dict, show_only_filtered_image): - if save_to_disk_path is None: - return +def save_to_disk(images: list, filtered_images: list, save_folder_path, req: GenerateImageRequest, task_data: TaskData): + metadata = req.dict() + del metadata['init_image'] + del metadata['init_image_mask'] + metadata.update({ + 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, + 'use_vae_model': task_data.use_vae_model, + 'use_hypernetwork_model': task_data.use_hypernetwork_model, + 'use_face_correction': task_data.use_face_correction, + 'use_upscale': task_data.use_upscale, + }) - def get_image_id(i): + metadata_entries = get_metadata_entries(req, task_data) + + if task_data.show_only_filtered_image or filtered_images == images: + data_utils.save_images(filtered_images, save_folder_path, file_name=get_output_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_metadata(metadata_entries, save_folder_path, file_name=get_output_filename_callback(req), output_format=task_data.metadata_output_format) + else: + data_utils.save_images(images, save_folder_path, file_name=get_output_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_images(filtered_images, save_folder_path, file_name=get_output_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality) + data_utils.save_metadata(metadata_entries, save_folder_path, file_name=get_output_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format) + +def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData): + metadata = req.dict() + del metadata['init_image'] + del metadata['init_image_mask'] + metadata.update({ + 'use_stable_diffusion_model': task_data.use_stable_diffusion_model, + 'use_vae_model': task_data.use_vae_model, + 'use_hypernetwork_model': task_data.use_hypernetwork_model, + 'use_face_correction': task_data.use_face_correction, + 'use_upscale': task_data.use_upscale, + }) + + return [metadata.copy().update({'seed': req.seed + i}) for i in range(req.num_outputs)] + +def get_output_filename_callback(req: GenerateImageRequest, suffix=None): + def make_filename(i): img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time. img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars. - return img_id - def get_image_basepath(i): - os.makedirs(save_to_disk_path, exist_ok=True) - prompt_flattened = filename_regex.sub('_', metadata['prompt'])[:50] - return os.path.join(save_to_disk_path, f"{prompt_flattened}_{get_image_id(i)}") + prompt_flattened = filename_regex.sub('_', req.prompt)[:50] + name = f"{prompt_flattened}_{img_id}" + name = name if suffix is None else f'{name}_{suffix}' + return name - for i, img_data in enumerate(images): - img, seed, filtered = img_data - img_path = get_image_basepath(i) + return make_filename - if not filtered or show_only_filtered_image: - img_metadata_path = img_path + '.txt' - m = metadata.copy() - m['seed'] = seed - with open(img_metadata_path, 'w', encoding='utf-8') as f: - f.write(m) - - img_path += '_filtered' if filtered else '' - img_path += '.' + metadata['output_format'] - img.save(img_path, quality=metadata['output_quality']) - -def construct_response(task_data: TaskData, images: list): +def construct_response(images: list, task_data: TaskData, base_seed: int): return [ ResponseImage( data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality), - seed=seed - ) for img, seed, _ in images + seed=base_seed + i + ) for i, img in enumerate(images) ] def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):