diff --git a/ui/index.html b/ui/index.html
index 0094201b..30648424 100644
--- a/ui/index.html
+++ b/ui/index.html
@@ -213,6 +213,7 @@
-
+ -
-
-
diff --git a/ui/media/js/auto-save.js b/ui/media/js/auto-save.js
index f503779a..9025f988 100644
--- a/ui/media/js/auto-save.js
+++ b/ui/media/js/auto-save.js
@@ -39,7 +39,8 @@ const SETTINGS_IDS_LIST = [
"turbo",
"use_full_precision",
"confirm_dangerous_actions",
- "auto_save_settings"
+ "auto_save_settings",
+ "apply_color_correction"
]
const IGNORE_BY_DEFAULT = [
diff --git a/ui/media/js/main.js b/ui/media/js/main.js
index 67ff7a9a..6c08aef7 100644
--- a/ui/media/js/main.js
+++ b/ui/media/js/main.js
@@ -26,6 +26,8 @@ let initImagePreview = document.querySelector("#init_image_preview")
let initImageSizeBox = document.querySelector("#init_image_size_box")
let maskImageSelector = document.querySelector("#mask")
let maskImagePreview = document.querySelector("#mask_preview")
+let applyColorCorrectionField = document.querySelector('#apply_color_correction')
+let colorCorrectionSetting = document.querySelector('#apply_color_correction_setting')
let promptStrengthSlider = document.querySelector('#prompt_strength_slider')
let promptStrengthField = document.querySelector('#prompt_strength')
let samplerField = document.querySelector('#sampler')
@@ -759,6 +761,9 @@ function createTask(task) {
taskConfig += `, Hypernetwork: ${task.reqBody.use_hypernetwork_model}`
taskConfig += `, Hypernetwork Strength: ${task.reqBody.hypernetwork_strength}`
}
+ if (task.reqBody.apply_color_correction) {
+ taskConfig += `, Color Correction: true`
+ }
let taskEntry = document.createElement('div')
taskEntry.id = `imageTaskContainer-${Date.now()}`
@@ -867,6 +872,7 @@ function getCurrentUserRequest() {
if (maskSetting.checked) {
newTask.reqBody.mask = imageInpainter.getImg()
}
+ newTask.reqBody.apply_color_correction = applyColorCorrectionField.checked
newTask.reqBody.sampler = 'ddim'
} else {
newTask.reqBody.sampler = samplerField.value
@@ -1257,6 +1263,7 @@ function img2imgLoad() {
promptStrengthContainer.style.display = 'table-row'
samplerSelectionContainer.style.display = "none"
initImagePreviewContainer.classList.add("has-image")
+ colorCorrectionSetting.style.display = ''
initImageSizeBox.textContent = initImagePreview.naturalWidth + " x " + initImagePreview.naturalHeight
imageEditor.setImage(this.src, initImagePreview.naturalWidth, initImagePreview.naturalHeight)
@@ -1271,6 +1278,7 @@ function img2imgUnload() {
promptStrengthContainer.style.display = "none"
samplerSelectionContainer.style.display = ""
initImagePreviewContainer.classList.remove("has-image")
+ colorCorrectionSetting.style.display = 'none'
imageEditor.setImage(null, parseInt(widthField.value), parseInt(heightField.value))
}
diff --git a/ui/sd_internal/__init__.py b/ui/sd_internal/__init__.py
index b001d3f9..9dd4f066 100644
--- a/ui/sd_internal/__init__.py
+++ b/ui/sd_internal/__init__.py
@@ -7,6 +7,7 @@ class Request:
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
+ apply_color_correction = False
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
@@ -35,6 +36,7 @@ class Request:
def json(self):
return {
+ "request_id": self.request_id,
"session_id": self.session_id,
"prompt": self.prompt,
"negative_prompt": self.negative_prompt,
@@ -46,6 +48,7 @@ class Request:
"seed": self.seed,
"prompt_strength": self.prompt_strength,
"sampler": self.sampler,
+ "apply_color_correction": self.apply_color_correction,
"use_face_correction": self.use_face_correction,
"use_upscale": self.use_upscale,
"use_stable_diffusion_model": self.use_stable_diffusion_model,
@@ -71,6 +74,7 @@ class Request:
save_to_disk_path: {self.save_to_disk_path}
turbo: {self.turbo}
use_full_precision: {self.use_full_precision}
+ apply_color_correction: {self.apply_color_correction}
use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale}
use_stable_diffusion_model: {self.use_stable_diffusion_model}
diff --git a/ui/sd_internal/runtime2.py b/ui/sd_internal/runtime2.py
index 0fcf8eb7..f74acc12 100644
--- a/ui/sd_internal/runtime2.py
+++ b/ui/sd_internal/runtime2.py
@@ -78,10 +78,15 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s
raise e
def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
- images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback)
- images = apply_filters(req, images, user_stopped)
+ args = req_to_args(req)
- save_images(req, images)
+ images, user_stopped = generate_images(args, data_queue, task_temp_images, step_callback, req.stream_image_progress)
+ images = apply_color_correction(args, images, user_stopped)
+ images = apply_filters(args, images, user_stopped, req.show_only_filtered_image)
+
+ if req.save_to_disk_path is not None:
+ out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id))
+ save_images(images, out_path, metadata=req.json(), show_only_filtered_image=req.show_only_filtered_image)
res = Response(req, images=construct_response(req, images))
res = res.json()
@@ -90,37 +95,48 @@ def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_image
return res
-def generate_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
+def generate_images(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
thread_data.temp_images.clear()
- image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback)
+ image_generator.on_image_step = make_step_callback(args, data_queue, task_temp_images, step_callback, stream_image_progress)
try:
- images = image_generator.make_images(context=thread_data, args=get_mk_img_args(req))
+ images = image_generator.make_images(context=thread_data, args=args)
user_stopped = False
except UserInitiatedStop:
images = []
user_stopped = True
if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None:
return images
- for i in range(req.num_outputs):
+ for i in range(args['num_outputs']):
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
del thread_data.partial_x_samples
finally:
model_loader.gc(thread_data)
- images = [(image, req.seed + i, False) for i, image in enumerate(images)]
+ images = [(image, args['seed'] + i, False) for i, image in enumerate(images)]
return images, user_stopped
-def apply_filters(req: Request, images: list, user_stopped):
- if user_stopped or (req.use_face_correction is None and req.use_upscale is None):
+def apply_color_correction(args: dict, images: list, user_stopped):
+ if user_stopped or args['init_image'] is None or not args['apply_color_correction']:
+ return images
+
+ for i, img_info in enumerate(images):
+ img, seed, filtered = img_info
+ img = image_utils.apply_color_correction(orig_image=args['init_image'], image_to_correct=img)
+ images[i] = (img, seed, filtered)
+
+ return images
+
+def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_image):
+ if user_stopped or (args['use_face_correction'] is None and args['use_upscale'] is None):
return images
filters = []
- if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(req.use_face_correction, model_type='gfpgan')))
- if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(req.use_upscale, model_type='realesrgan')))
+ if args['use_face_correction'].startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(args['use_face_correction'], model_type='gfpgan')))
+ if args['use_face_correction'].use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(args['use_upscale'], model_type='realesrgan')))
filtered_images = []
for img, seed, _ in images:
@@ -129,13 +145,13 @@ def apply_filters(req: Request, images: list, user_stopped):
filtered_images.append((img, seed, True))
- if not req.show_only_filtered_image:
+ if not show_only_filtered_image:
filtered_images = images + filtered_images
return filtered_images
-def save_images(req: Request, images: list):
- if req.save_to_disk_path is None:
+def save_images(images: list, save_to_disk_path, metadata: dict, show_only_filtered_image):
+ if save_to_disk_path is None:
return
def get_image_id(i):
@@ -144,25 +160,24 @@ def save_images(req: Request, images: list):
return img_id
def get_image_basepath(i):
- session_out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id))
- os.makedirs(session_out_path, exist_ok=True)
- prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
- return os.path.join(session_out_path, f"{prompt_flattened}_{get_image_id(i)}")
+ os.makedirs(save_to_disk_path, exist_ok=True)
+ prompt_flattened = filename_regex.sub('_', metadata['prompt'])[:50]
+ return os.path.join(save_to_disk_path, f"{prompt_flattened}_{get_image_id(i)}")
for i, img_data in enumerate(images):
img, seed, filtered = img_data
img_path = get_image_basepath(i)
- if not filtered or req.show_only_filtered_image:
+ if not filtered or show_only_filtered_image:
img_metadata_path = img_path + '.txt'
- metadata = req.json()
- metadata['seed'] = seed
+ m = metadata.copy()
+ m['seed'] = seed
with open(img_metadata_path, 'w', encoding='utf-8') as f:
- f.write(metadata)
+ f.write(m)
img_path += '_filtered' if filtered else ''
- img_path += '.' + req.output_format
- img.save(img_path, quality=req.output_quality)
+ img_path += '.' + metadata['output_format']
+ img.save(img_path, quality=metadata['output_quality'])
def construct_response(req: Request, images: list):
return [
@@ -172,7 +187,7 @@ def construct_response(req: Request, images: list):
) for img, seed, _ in images
]
-def get_mk_img_args(req: Request):
+def req_to_args(req: Request):
args = req.json()
args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
@@ -180,21 +195,21 @@ def get_mk_img_args(req: Request):
return args
-def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
- n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
+def make_step_callback(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
+ n_steps = args['num_inference_steps'] if args['init_image'] is None else int(args['num_inference_steps'] * args['prompt_strength'])
last_callback_time = -1
- def update_temp_img(req, x_samples, task_temp_images: list):
+ def update_temp_img(x_samples, task_temp_images: list):
partial_images = []
- for i in range(req.num_outputs):
+ for i in range(args['num_outputs']):
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
buf = image_utils.img_to_buffer(img, output_format='JPEG')
del img
- thread_data.temp_images[f'{req.request_id}/{i}'] = buf
+ thread_data.temp_images[f"{args['request_id']}/{i}"] = buf
task_temp_images[i] = buf
- partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'})
+ partial_images.append({'path': f"/image/tmp/{args['request_id']}/{i}"})
return partial_images
def on_image_step(x_samples, i):
@@ -206,8 +221,8 @@ def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images:
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
- if req.stream_image_progress and i % 5 == 0:
- progress['output'] = update_temp_img(req, x_samples, task_temp_images)
+ if stream_image_progress and i % 5 == 0:
+ progress['output'] = update_temp_img(x_samples, task_temp_images)
data_queue.put(json.dumps(progress))
diff --git a/ui/sd_internal/task_manager.py b/ui/sd_internal/task_manager.py
index aec79239..6b0a1d8c 100644
--- a/ui/sd_internal/task_manager.py
+++ b/ui/sd_internal/task_manager.py
@@ -76,6 +76,7 @@ class ImageRequest(BaseModel):
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
+ apply_color_correction: bool = False
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
@@ -522,6 +523,7 @@ def render(req : ImageRequest):
r.negative_prompt = req.negative_prompt
r.init_image = req.init_image
r.mask = req.mask
+ r.apply_color_correction = req.apply_color_correction
r.num_outputs = req.num_outputs
r.num_inference_steps = req.num_inference_steps
r.guidance_scale = req.guidance_scale