Option to apply color correction (balances the histogram) during inpainting; Refactor the runtime to use a general-purpose dict

This commit is contained in:
cmdr2 2022-12-09 19:39:56 +05:30
parent f1de0be679
commit 79cc84b611
6 changed files with 66 additions and 35 deletions

View File

@ -213,6 +213,7 @@
<div><ul>
<li><b class="settings-subheader">Render Settings</b></li>
<li class="pl-5"><input id="stream_image_progress" name="stream_image_progress" type="checkbox"> <label for="stream_image_progress">Show a live preview <small>(uses more VRAM, slower images)</small></label></li>
<li id="apply_color_correction_setting" class="pl-5"><input id="apply_color_correction" name="apply_color_correction" type="checkbox"> <label for="apply_color_correction">Apply color correction <small>(helps during inpainting)</small></label></li>
<li class="pl-5"><input id="use_face_correction" name="use_face_correction" type="checkbox"> <label for="use_face_correction">Fix incorrect faces and eyes <small>(uses GFPGAN)</small></label></li>
<li class="pl-5">
<input id="use_upscale" name="use_upscale" type="checkbox"> <label for="use_upscale">Upscale image by 4x with </label>

View File

@ -39,7 +39,8 @@ const SETTINGS_IDS_LIST = [
"turbo",
"use_full_precision",
"confirm_dangerous_actions",
"auto_save_settings"
"auto_save_settings",
"apply_color_correction"
]
const IGNORE_BY_DEFAULT = [

View File

@ -26,6 +26,8 @@ let initImagePreview = document.querySelector("#init_image_preview")
let initImageSizeBox = document.querySelector("#init_image_size_box")
let maskImageSelector = document.querySelector("#mask")
let maskImagePreview = document.querySelector("#mask_preview")
let applyColorCorrectionField = document.querySelector('#apply_color_correction')
let colorCorrectionSetting = document.querySelector('#apply_color_correction_setting')
let promptStrengthSlider = document.querySelector('#prompt_strength_slider')
let promptStrengthField = document.querySelector('#prompt_strength')
let samplerField = document.querySelector('#sampler')
@ -759,6 +761,9 @@ function createTask(task) {
taskConfig += `, <b>Hypernetwork:</b> ${task.reqBody.use_hypernetwork_model}`
taskConfig += `, <b>Hypernetwork Strength:</b> ${task.reqBody.hypernetwork_strength}`
}
if (task.reqBody.apply_color_correction) {
taskConfig += `, <b>Color Correction:</b> true`
}
let taskEntry = document.createElement('div')
taskEntry.id = `imageTaskContainer-${Date.now()}`
@ -867,6 +872,7 @@ function getCurrentUserRequest() {
if (maskSetting.checked) {
newTask.reqBody.mask = imageInpainter.getImg()
}
newTask.reqBody.apply_color_correction = applyColorCorrectionField.checked
newTask.reqBody.sampler = 'ddim'
} else {
newTask.reqBody.sampler = samplerField.value
@ -1257,6 +1263,7 @@ function img2imgLoad() {
promptStrengthContainer.style.display = 'table-row'
samplerSelectionContainer.style.display = "none"
initImagePreviewContainer.classList.add("has-image")
colorCorrectionSetting.style.display = ''
initImageSizeBox.textContent = initImagePreview.naturalWidth + " x " + initImagePreview.naturalHeight
imageEditor.setImage(this.src, initImagePreview.naturalWidth, initImagePreview.naturalHeight)
@ -1271,6 +1278,7 @@ function img2imgUnload() {
promptStrengthContainer.style.display = "none"
samplerSelectionContainer.style.display = ""
initImagePreviewContainer.classList.remove("has-image")
colorCorrectionSetting.style.display = 'none'
imageEditor.setImage(null, parseInt(widthField.value), parseInt(heightField.value))
}

View File

@ -7,6 +7,7 @@ class Request:
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
apply_color_correction = False
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
@ -35,6 +36,7 @@ class Request:
def json(self):
return {
"request_id": self.request_id,
"session_id": self.session_id,
"prompt": self.prompt,
"negative_prompt": self.negative_prompt,
@ -46,6 +48,7 @@ class Request:
"seed": self.seed,
"prompt_strength": self.prompt_strength,
"sampler": self.sampler,
"apply_color_correction": self.apply_color_correction,
"use_face_correction": self.use_face_correction,
"use_upscale": self.use_upscale,
"use_stable_diffusion_model": self.use_stable_diffusion_model,
@ -71,6 +74,7 @@ class Request:
save_to_disk_path: {self.save_to_disk_path}
turbo: {self.turbo}
use_full_precision: {self.use_full_precision}
apply_color_correction: {self.apply_color_correction}
use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale}
use_stable_diffusion_model: {self.use_stable_diffusion_model}

View File

@ -78,10 +78,15 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s
raise e
def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback)
images = apply_filters(req, images, user_stopped)
args = req_to_args(req)
save_images(req, images)
images, user_stopped = generate_images(args, data_queue, task_temp_images, step_callback, req.stream_image_progress)
images = apply_color_correction(args, images, user_stopped)
images = apply_filters(args, images, user_stopped, req.show_only_filtered_image)
if req.save_to_disk_path is not None:
out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id))
save_images(images, out_path, metadata=req.json(), show_only_filtered_image=req.show_only_filtered_image)
res = Response(req, images=construct_response(req, images))
res = res.json()
@ -90,37 +95,48 @@ def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_image
return res
def generate_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
def generate_images(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
thread_data.temp_images.clear()
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback)
image_generator.on_image_step = make_step_callback(args, data_queue, task_temp_images, step_callback, stream_image_progress)
try:
images = image_generator.make_images(context=thread_data, args=get_mk_img_args(req))
images = image_generator.make_images(context=thread_data, args=args)
user_stopped = False
except UserInitiatedStop:
images = []
user_stopped = True
if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None:
return images
for i in range(req.num_outputs):
for i in range(args['num_outputs']):
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
del thread_data.partial_x_samples
finally:
model_loader.gc(thread_data)
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
images = [(image, args['seed'] + i, False) for i, image in enumerate(images)]
return images, user_stopped
def apply_filters(req: Request, images: list, user_stopped):
if user_stopped or (req.use_face_correction is None and req.use_upscale is None):
def apply_color_correction(args: dict, images: list, user_stopped):
if user_stopped or args['init_image'] is None or not args['apply_color_correction']:
return images
for i, img_info in enumerate(images):
img, seed, filtered = img_info
img = image_utils.apply_color_correction(orig_image=args['init_image'], image_to_correct=img)
images[i] = (img, seed, filtered)
return images
def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_image):
if user_stopped or (args['use_face_correction'] is None and args['use_upscale'] is None):
return images
filters = []
if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(req.use_face_correction, model_type='gfpgan')))
if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(req.use_upscale, model_type='realesrgan')))
if args['use_face_correction'].startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(args['use_face_correction'], model_type='gfpgan')))
if args['use_face_correction'].use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(args['use_upscale'], model_type='realesrgan')))
filtered_images = []
for img, seed, _ in images:
@ -129,13 +145,13 @@ def apply_filters(req: Request, images: list, user_stopped):
filtered_images.append((img, seed, True))
if not req.show_only_filtered_image:
if not show_only_filtered_image:
filtered_images = images + filtered_images
return filtered_images
def save_images(req: Request, images: list):
if req.save_to_disk_path is None:
def save_images(images: list, save_to_disk_path, metadata: dict, show_only_filtered_image):
if save_to_disk_path is None:
return
def get_image_id(i):
@ -144,25 +160,24 @@ def save_images(req: Request, images: list):
return img_id
def get_image_basepath(i):
session_out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id))
os.makedirs(session_out_path, exist_ok=True)
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
return os.path.join(session_out_path, f"{prompt_flattened}_{get_image_id(i)}")
os.makedirs(save_to_disk_path, exist_ok=True)
prompt_flattened = filename_regex.sub('_', metadata['prompt'])[:50]
return os.path.join(save_to_disk_path, f"{prompt_flattened}_{get_image_id(i)}")
for i, img_data in enumerate(images):
img, seed, filtered = img_data
img_path = get_image_basepath(i)
if not filtered or req.show_only_filtered_image:
if not filtered or show_only_filtered_image:
img_metadata_path = img_path + '.txt'
metadata = req.json()
metadata['seed'] = seed
m = metadata.copy()
m['seed'] = seed
with open(img_metadata_path, 'w', encoding='utf-8') as f:
f.write(metadata)
f.write(m)
img_path += '_filtered' if filtered else ''
img_path += '.' + req.output_format
img.save(img_path, quality=req.output_quality)
img_path += '.' + metadata['output_format']
img.save(img_path, quality=metadata['output_quality'])
def construct_response(req: Request, images: list):
return [
@ -172,7 +187,7 @@ def construct_response(req: Request, images: list):
) for img, seed, _ in images
]
def get_mk_img_args(req: Request):
def req_to_args(req: Request):
args = req.json()
args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
@ -180,21 +195,21 @@ def get_mk_img_args(req: Request):
return args
def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
def make_step_callback(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
n_steps = args['num_inference_steps'] if args['init_image'] is None else int(args['num_inference_steps'] * args['prompt_strength'])
last_callback_time = -1
def update_temp_img(req, x_samples, task_temp_images: list):
def update_temp_img(x_samples, task_temp_images: list):
partial_images = []
for i in range(req.num_outputs):
for i in range(args['num_outputs']):
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
buf = image_utils.img_to_buffer(img, output_format='JPEG')
del img
thread_data.temp_images[f'{req.request_id}/{i}'] = buf
thread_data.temp_images[f"{args['request_id']}/{i}"] = buf
task_temp_images[i] = buf
partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'})
partial_images.append({'path': f"/image/tmp/{args['request_id']}/{i}"})
return partial_images
def on_image_step(x_samples, i):
@ -206,8 +221,8 @@ def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images:
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
if req.stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(req, x_samples, task_temp_images)
if stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(x_samples, task_temp_images)
data_queue.put(json.dumps(progress))

View File

@ -76,6 +76,7 @@ class ImageRequest(BaseModel):
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
apply_color_correction: bool = False
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
@ -522,6 +523,7 @@ def render(req : ImageRequest):
r.negative_prompt = req.negative_prompt
r.init_image = req.init_image
r.mask = req.mask
r.apply_color_correction = req.apply_color_correction
r.num_outputs = req.num_outputs
r.num_inference_steps = req.num_inference_steps
r.guidance_scale = req.guidance_scale