2022-12-07 17:45:35 +01:00
|
|
|
import queue
|
2022-12-08 17:09:09 +01:00
|
|
|
import time
|
|
|
|
import json
|
|
|
|
import os
|
|
|
|
import base64
|
|
|
|
import re
|
2022-12-09 10:51:49 +01:00
|
|
|
import traceback
|
2022-12-09 17:00:18 +01:00
|
|
|
import logging
|
2022-12-07 17:45:35 +01:00
|
|
|
|
2022-12-11 15:43:44 +01:00
|
|
|
from sd_internal import device_manager
|
2022-12-11 13:46:29 +01:00
|
|
|
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
2022-12-07 17:45:35 +01:00
|
|
|
|
2022-12-12 09:31:47 +01:00
|
|
|
from modules import model_loader, image_generator, image_utils, filters as image_filters, data_utils
|
2022-12-11 13:46:29 +01:00
|
|
|
from modules.types import Context, GenerateImageRequest
|
2022-12-07 17:45:35 +01:00
|
|
|
|
2022-12-09 17:00:18 +01:00
|
|
|
log = logging.getLogger()
|
|
|
|
|
2022-12-11 15:51:25 +01:00
|
|
|
context = Context() # thread-local
|
2022-12-07 17:45:35 +01:00
|
|
|
'''
|
|
|
|
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
|
|
|
'''
|
|
|
|
|
2022-12-08 17:09:09 +01:00
|
|
|
filename_regex = re.compile('[^a-zA-Z0-9]')
|
|
|
|
|
2022-12-07 17:45:35 +01:00
|
|
|
def init(device):
|
|
|
|
'''
|
2022-12-11 15:51:25 +01:00
|
|
|
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
|
2022-12-07 17:45:35 +01:00
|
|
|
'''
|
2022-12-11 15:51:25 +01:00
|
|
|
context.stop_processing = False
|
|
|
|
context.temp_images = {}
|
|
|
|
context.partial_x_samples = None
|
2022-12-07 17:45:35 +01:00
|
|
|
|
2022-12-11 15:51:25 +01:00
|
|
|
device_manager.device_init(context, device)
|
2022-12-07 17:45:35 +01:00
|
|
|
|
2022-12-11 13:46:29 +01:00
|
|
|
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
2022-12-12 11:14:22 +01:00
|
|
|
log.info(f'request: {get_printable_request(req)}')
|
2022-12-12 09:31:47 +01:00
|
|
|
log.info(f'task data: {task_data.dict()}')
|
|
|
|
|
2022-12-09 10:51:49 +01:00
|
|
|
try:
|
2022-12-12 09:31:47 +01:00
|
|
|
images = _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
|
|
|
|
|
|
|
res = Response(req, task_data, images=construct_response(images, task_data, base_seed=req.seed))
|
|
|
|
res = res.json()
|
|
|
|
data_queue.put(json.dumps(res))
|
|
|
|
log.info('Task completed')
|
|
|
|
|
|
|
|
return res
|
2022-12-09 10:51:49 +01:00
|
|
|
except Exception as e:
|
2022-12-09 17:00:18 +01:00
|
|
|
log.error(traceback.format_exc())
|
2022-12-09 10:51:49 +01:00
|
|
|
|
|
|
|
data_queue.put(json.dumps({
|
|
|
|
"status": 'failed',
|
|
|
|
"detail": str(e)
|
|
|
|
}))
|
|
|
|
raise e
|
|
|
|
|
2022-12-11 13:46:29 +01:00
|
|
|
def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
2022-12-12 10:48:56 +01:00
|
|
|
images, user_stopped = generate_images(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
2022-12-12 09:31:47 +01:00
|
|
|
filtered_images = apply_filters(task_data, images, user_stopped)
|
2022-12-09 15:09:56 +01:00
|
|
|
|
2022-12-11 13:46:29 +01:00
|
|
|
if task_data.save_to_disk_path is not None:
|
2022-12-12 09:31:47 +01:00
|
|
|
save_folder_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
|
|
|
save_to_disk(images, filtered_images, save_folder_path, req, task_data)
|
2022-12-08 17:09:09 +01:00
|
|
|
|
2022-12-12 09:31:47 +01:00
|
|
|
return filtered_images if task_data.show_only_filtered_image else images + filtered_images
|
2022-12-08 17:09:09 +01:00
|
|
|
|
2022-12-12 10:48:56 +01:00
|
|
|
def generate_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
2022-12-11 15:51:25 +01:00
|
|
|
context.temp_images.clear()
|
2022-12-08 17:09:09 +01:00
|
|
|
|
2022-12-12 10:48:56 +01:00
|
|
|
image_generator.on_image_step = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
|
2022-12-08 17:09:09 +01:00
|
|
|
|
2022-12-07 17:45:35 +01:00
|
|
|
try:
|
2022-12-11 15:51:25 +01:00
|
|
|
images = image_generator.make_images(context=context, req=req)
|
2022-12-08 17:09:09 +01:00
|
|
|
user_stopped = False
|
2022-12-07 17:45:35 +01:00
|
|
|
except UserInitiatedStop:
|
2022-12-08 17:09:09 +01:00
|
|
|
images = []
|
|
|
|
user_stopped = True
|
2022-12-11 15:51:25 +01:00
|
|
|
if context.partial_x_samples is not None:
|
|
|
|
images = image_utils.latent_samples_to_images(context, context.partial_x_samples)
|
|
|
|
context.partial_x_samples = None
|
2022-12-08 17:09:09 +01:00
|
|
|
finally:
|
2022-12-11 15:51:25 +01:00
|
|
|
model_loader.gc(context)
|
2022-12-11 15:28:12 +01:00
|
|
|
|
2022-12-08 17:09:09 +01:00
|
|
|
return images, user_stopped
|
|
|
|
|
2022-12-12 09:31:47 +01:00
|
|
|
def apply_filters(task_data: TaskData, images: list, user_stopped):
|
2022-12-11 13:46:29 +01:00
|
|
|
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
|
2022-12-08 17:09:09 +01:00
|
|
|
return images
|
|
|
|
|
|
|
|
filters = []
|
2022-12-11 13:46:29 +01:00
|
|
|
if 'gfpgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_gfpgan)
|
|
|
|
if 'realesrgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_realesrgan)
|
2022-12-08 17:09:09 +01:00
|
|
|
|
|
|
|
filtered_images = []
|
2022-12-12 09:31:47 +01:00
|
|
|
for img in images:
|
2022-12-11 09:44:59 +01:00
|
|
|
for filter_fn in filters:
|
2022-12-11 15:51:25 +01:00
|
|
|
img = filter_fn(context, img)
|
2022-12-08 17:09:09 +01:00
|
|
|
|
2022-12-12 09:31:47 +01:00
|
|
|
filtered_images.append(img)
|
2022-12-08 17:09:09 +01:00
|
|
|
|
|
|
|
return filtered_images
|
|
|
|
|
2022-12-12 09:31:47 +01:00
|
|
|
def save_to_disk(images: list, filtered_images: list, save_folder_path, req: GenerateImageRequest, task_data: TaskData):
|
|
|
|
metadata_entries = get_metadata_entries(req, task_data)
|
|
|
|
|
|
|
|
if task_data.show_only_filtered_image or filtered_images == images:
|
2022-12-12 10:11:36 +01:00
|
|
|
data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
|
|
|
data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.metadata_output_format)
|
2022-12-12 09:31:47 +01:00
|
|
|
else:
|
2022-12-12 10:11:36 +01:00
|
|
|
data_utils.save_images(images, save_folder_path, file_name=make_filename_callback(req), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
|
|
|
data_utils.save_images(filtered_images, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.output_format, output_quality=task_data.output_quality)
|
|
|
|
data_utils.save_metadata(metadata_entries, save_folder_path, file_name=make_filename_callback(req, suffix='filtered'), output_format=task_data.metadata_output_format)
|
|
|
|
|
|
|
|
def construct_response(images: list, task_data: TaskData, base_seed: int):
|
|
|
|
return [
|
|
|
|
ResponseImage(
|
|
|
|
data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality),
|
|
|
|
seed=base_seed + i
|
|
|
|
) for i, img in enumerate(images)
|
|
|
|
]
|
2022-12-12 09:31:47 +01:00
|
|
|
|
|
|
|
def get_metadata_entries(req: GenerateImageRequest, task_data: TaskData):
|
2022-12-12 11:14:22 +01:00
|
|
|
metadata = get_printable_request(req)
|
2022-12-12 09:31:47 +01:00
|
|
|
metadata.update({
|
|
|
|
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
|
|
|
|
'use_vae_model': task_data.use_vae_model,
|
|
|
|
'use_hypernetwork_model': task_data.use_hypernetwork_model,
|
|
|
|
'use_face_correction': task_data.use_face_correction,
|
|
|
|
'use_upscale': task_data.use_upscale,
|
|
|
|
})
|
|
|
|
|
|
|
|
return [metadata.copy().update({'seed': req.seed + i}) for i in range(req.num_outputs)]
|
|
|
|
|
2022-12-12 11:14:22 +01:00
|
|
|
def get_printable_request(req: GenerateImageRequest):
|
|
|
|
metadata = req.dict()
|
|
|
|
del metadata['init_image']
|
|
|
|
del metadata['init_image_mask']
|
|
|
|
return metadata
|
|
|
|
|
2022-12-12 10:11:36 +01:00
|
|
|
def make_filename_callback(req: GenerateImageRequest, suffix=None):
|
2022-12-12 09:31:47 +01:00
|
|
|
def make_filename(i):
|
2022-12-08 17:09:09 +01:00
|
|
|
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
|
|
|
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
|
|
|
|
2022-12-12 09:31:47 +01:00
|
|
|
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
|
|
|
|
name = f"{prompt_flattened}_{img_id}"
|
|
|
|
name = name if suffix is None else f'{name}_{suffix}'
|
|
|
|
return name
|
2022-12-08 17:09:09 +01:00
|
|
|
|
2022-12-12 09:31:47 +01:00
|
|
|
return make_filename
|
2022-12-08 17:09:09 +01:00
|
|
|
|
2022-12-11 13:46:29 +01:00
|
|
|
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
|
|
|
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
2022-12-08 17:09:09 +01:00
|
|
|
last_callback_time = -1
|
|
|
|
|
2022-12-09 15:09:56 +01:00
|
|
|
def update_temp_img(x_samples, task_temp_images: list):
|
2022-12-08 17:09:09 +01:00
|
|
|
partial_images = []
|
2022-12-11 13:46:29 +01:00
|
|
|
for i in range(req.num_outputs):
|
2022-12-11 15:51:25 +01:00
|
|
|
img = image_utils.latent_to_img(context, x_samples[i].unsqueeze(0))
|
2022-12-08 17:09:09 +01:00
|
|
|
buf = image_utils.img_to_buffer(img, output_format='JPEG')
|
|
|
|
|
|
|
|
del img
|
|
|
|
|
2022-12-11 15:51:25 +01:00
|
|
|
context.temp_images[f"{task_data.request_id}/{i}"] = buf
|
2022-12-08 17:09:09 +01:00
|
|
|
task_temp_images[i] = buf
|
2022-12-11 13:46:29 +01:00
|
|
|
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
|
2022-12-08 17:09:09 +01:00
|
|
|
return partial_images
|
|
|
|
|
|
|
|
def on_image_step(x_samples, i):
|
|
|
|
nonlocal last_callback_time
|
|
|
|
|
2022-12-11 15:51:25 +01:00
|
|
|
context.partial_x_samples = x_samples
|
2022-12-08 17:09:09 +01:00
|
|
|
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
|
|
|
|
last_callback_time = time.time()
|
|
|
|
|
|
|
|
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
|
|
|
|
|
2022-12-09 15:09:56 +01:00
|
|
|
if stream_image_progress and i % 5 == 0:
|
|
|
|
progress['output'] = update_temp_img(x_samples, task_temp_images)
|
2022-12-08 17:09:09 +01:00
|
|
|
|
|
|
|
data_queue.put(json.dumps(progress))
|
|
|
|
|
|
|
|
step_callback()
|
|
|
|
|
2022-12-11 15:51:25 +01:00
|
|
|
if context.stop_processing:
|
2022-12-08 17:09:09 +01:00
|
|
|
raise UserInitiatedStop("User requested that we stop processing")
|
2022-12-07 17:45:35 +01:00
|
|
|
|
2022-12-08 17:09:09 +01:00
|
|
|
return on_image_step
|