forked from extern/easydiffusion
Option to apply color correction (balances the histogram) during inpainting; Refactor the runtime to use a general-purpose dict
This commit is contained in:
parent
f1de0be679
commit
79cc84b611
@ -213,6 +213,7 @@
|
||||
<div><ul>
|
||||
<li><b class="settings-subheader">Render Settings</b></li>
|
||||
<li class="pl-5"><input id="stream_image_progress" name="stream_image_progress" type="checkbox"> <label for="stream_image_progress">Show a live preview <small>(uses more VRAM, slower images)</small></label></li>
|
||||
<li id="apply_color_correction_setting" class="pl-5"><input id="apply_color_correction" name="apply_color_correction" type="checkbox"> <label for="apply_color_correction">Apply color correction <small>(helps during inpainting)</small></label></li>
|
||||
<li class="pl-5"><input id="use_face_correction" name="use_face_correction" type="checkbox"> <label for="use_face_correction">Fix incorrect faces and eyes <small>(uses GFPGAN)</small></label></li>
|
||||
<li class="pl-5">
|
||||
<input id="use_upscale" name="use_upscale" type="checkbox"> <label for="use_upscale">Upscale image by 4x with </label>
|
||||
|
@ -39,7 +39,8 @@ const SETTINGS_IDS_LIST = [
|
||||
"turbo",
|
||||
"use_full_precision",
|
||||
"confirm_dangerous_actions",
|
||||
"auto_save_settings"
|
||||
"auto_save_settings",
|
||||
"apply_color_correction"
|
||||
]
|
||||
|
||||
const IGNORE_BY_DEFAULT = [
|
||||
|
@ -26,6 +26,8 @@ let initImagePreview = document.querySelector("#init_image_preview")
|
||||
let initImageSizeBox = document.querySelector("#init_image_size_box")
|
||||
let maskImageSelector = document.querySelector("#mask")
|
||||
let maskImagePreview = document.querySelector("#mask_preview")
|
||||
let applyColorCorrectionField = document.querySelector('#apply_color_correction')
|
||||
let colorCorrectionSetting = document.querySelector('#apply_color_correction_setting')
|
||||
let promptStrengthSlider = document.querySelector('#prompt_strength_slider')
|
||||
let promptStrengthField = document.querySelector('#prompt_strength')
|
||||
let samplerField = document.querySelector('#sampler')
|
||||
@ -759,6 +761,9 @@ function createTask(task) {
|
||||
taskConfig += `, <b>Hypernetwork:</b> ${task.reqBody.use_hypernetwork_model}`
|
||||
taskConfig += `, <b>Hypernetwork Strength:</b> ${task.reqBody.hypernetwork_strength}`
|
||||
}
|
||||
if (task.reqBody.apply_color_correction) {
|
||||
taskConfig += `, <b>Color Correction:</b> true`
|
||||
}
|
||||
|
||||
let taskEntry = document.createElement('div')
|
||||
taskEntry.id = `imageTaskContainer-${Date.now()}`
|
||||
@ -867,6 +872,7 @@ function getCurrentUserRequest() {
|
||||
if (maskSetting.checked) {
|
||||
newTask.reqBody.mask = imageInpainter.getImg()
|
||||
}
|
||||
newTask.reqBody.apply_color_correction = applyColorCorrectionField.checked
|
||||
newTask.reqBody.sampler = 'ddim'
|
||||
} else {
|
||||
newTask.reqBody.sampler = samplerField.value
|
||||
@ -1257,6 +1263,7 @@ function img2imgLoad() {
|
||||
promptStrengthContainer.style.display = 'table-row'
|
||||
samplerSelectionContainer.style.display = "none"
|
||||
initImagePreviewContainer.classList.add("has-image")
|
||||
colorCorrectionSetting.style.display = ''
|
||||
|
||||
initImageSizeBox.textContent = initImagePreview.naturalWidth + " x " + initImagePreview.naturalHeight
|
||||
imageEditor.setImage(this.src, initImagePreview.naturalWidth, initImagePreview.naturalHeight)
|
||||
@ -1271,6 +1278,7 @@ function img2imgUnload() {
|
||||
promptStrengthContainer.style.display = "none"
|
||||
samplerSelectionContainer.style.display = ""
|
||||
initImagePreviewContainer.classList.remove("has-image")
|
||||
colorCorrectionSetting.style.display = 'none'
|
||||
imageEditor.setImage(null, parseInt(widthField.value), parseInt(heightField.value))
|
||||
|
||||
}
|
||||
|
@ -7,6 +7,7 @@ class Request:
|
||||
negative_prompt: str = ""
|
||||
init_image: str = None # base64
|
||||
mask: str = None # base64
|
||||
apply_color_correction = False
|
||||
num_outputs: int = 1
|
||||
num_inference_steps: int = 50
|
||||
guidance_scale: float = 7.5
|
||||
@ -35,6 +36,7 @@ class Request:
|
||||
|
||||
def json(self):
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
"session_id": self.session_id,
|
||||
"prompt": self.prompt,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
@ -46,6 +48,7 @@ class Request:
|
||||
"seed": self.seed,
|
||||
"prompt_strength": self.prompt_strength,
|
||||
"sampler": self.sampler,
|
||||
"apply_color_correction": self.apply_color_correction,
|
||||
"use_face_correction": self.use_face_correction,
|
||||
"use_upscale": self.use_upscale,
|
||||
"use_stable_diffusion_model": self.use_stable_diffusion_model,
|
||||
@ -71,6 +74,7 @@ class Request:
|
||||
save_to_disk_path: {self.save_to_disk_path}
|
||||
turbo: {self.turbo}
|
||||
use_full_precision: {self.use_full_precision}
|
||||
apply_color_correction: {self.apply_color_correction}
|
||||
use_face_correction: {self.use_face_correction}
|
||||
use_upscale: {self.use_upscale}
|
||||
use_stable_diffusion_model: {self.use_stable_diffusion_model}
|
||||
|
@ -78,10 +78,15 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s
|
||||
raise e
|
||||
|
||||
def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback)
|
||||
images = apply_filters(req, images, user_stopped)
|
||||
args = req_to_args(req)
|
||||
|
||||
save_images(req, images)
|
||||
images, user_stopped = generate_images(args, data_queue, task_temp_images, step_callback, req.stream_image_progress)
|
||||
images = apply_color_correction(args, images, user_stopped)
|
||||
images = apply_filters(args, images, user_stopped, req.show_only_filtered_image)
|
||||
|
||||
if req.save_to_disk_path is not None:
|
||||
out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id))
|
||||
save_images(images, out_path, metadata=req.json(), show_only_filtered_image=req.show_only_filtered_image)
|
||||
|
||||
res = Response(req, images=construct_response(req, images))
|
||||
res = res.json()
|
||||
@ -90,37 +95,48 @@ def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_image
|
||||
|
||||
return res
|
||||
|
||||
def generate_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
def generate_images(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
thread_data.temp_images.clear()
|
||||
|
||||
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback)
|
||||
image_generator.on_image_step = make_step_callback(args, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||
|
||||
try:
|
||||
images = image_generator.make_images(context=thread_data, args=get_mk_img_args(req))
|
||||
images = image_generator.make_images(context=thread_data, args=args)
|
||||
user_stopped = False
|
||||
except UserInitiatedStop:
|
||||
images = []
|
||||
user_stopped = True
|
||||
if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None:
|
||||
return images
|
||||
for i in range(req.num_outputs):
|
||||
for i in range(args['num_outputs']):
|
||||
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
|
||||
|
||||
del thread_data.partial_x_samples
|
||||
finally:
|
||||
model_loader.gc(thread_data)
|
||||
|
||||
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
|
||||
images = [(image, args['seed'] + i, False) for i, image in enumerate(images)]
|
||||
|
||||
return images, user_stopped
|
||||
|
||||
def apply_filters(req: Request, images: list, user_stopped):
|
||||
if user_stopped or (req.use_face_correction is None and req.use_upscale is None):
|
||||
def apply_color_correction(args: dict, images: list, user_stopped):
|
||||
if user_stopped or args['init_image'] is None or not args['apply_color_correction']:
|
||||
return images
|
||||
|
||||
for i, img_info in enumerate(images):
|
||||
img, seed, filtered = img_info
|
||||
img = image_utils.apply_color_correction(orig_image=args['init_image'], image_to_correct=img)
|
||||
images[i] = (img, seed, filtered)
|
||||
|
||||
return images
|
||||
|
||||
def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_image):
|
||||
if user_stopped or (args['use_face_correction'] is None and args['use_upscale'] is None):
|
||||
return images
|
||||
|
||||
filters = []
|
||||
if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(req.use_face_correction, model_type='gfpgan')))
|
||||
if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(req.use_upscale, model_type='realesrgan')))
|
||||
if args['use_face_correction'].startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(args['use_face_correction'], model_type='gfpgan')))
|
||||
if args['use_face_correction'].use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(args['use_upscale'], model_type='realesrgan')))
|
||||
|
||||
filtered_images = []
|
||||
for img, seed, _ in images:
|
||||
@ -129,13 +145,13 @@ def apply_filters(req: Request, images: list, user_stopped):
|
||||
|
||||
filtered_images.append((img, seed, True))
|
||||
|
||||
if not req.show_only_filtered_image:
|
||||
if not show_only_filtered_image:
|
||||
filtered_images = images + filtered_images
|
||||
|
||||
return filtered_images
|
||||
|
||||
def save_images(req: Request, images: list):
|
||||
if req.save_to_disk_path is None:
|
||||
def save_images(images: list, save_to_disk_path, metadata: dict, show_only_filtered_image):
|
||||
if save_to_disk_path is None:
|
||||
return
|
||||
|
||||
def get_image_id(i):
|
||||
@ -144,25 +160,24 @@ def save_images(req: Request, images: list):
|
||||
return img_id
|
||||
|
||||
def get_image_basepath(i):
|
||||
session_out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id))
|
||||
os.makedirs(session_out_path, exist_ok=True)
|
||||
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
|
||||
return os.path.join(session_out_path, f"{prompt_flattened}_{get_image_id(i)}")
|
||||
os.makedirs(save_to_disk_path, exist_ok=True)
|
||||
prompt_flattened = filename_regex.sub('_', metadata['prompt'])[:50]
|
||||
return os.path.join(save_to_disk_path, f"{prompt_flattened}_{get_image_id(i)}")
|
||||
|
||||
for i, img_data in enumerate(images):
|
||||
img, seed, filtered = img_data
|
||||
img_path = get_image_basepath(i)
|
||||
|
||||
if not filtered or req.show_only_filtered_image:
|
||||
if not filtered or show_only_filtered_image:
|
||||
img_metadata_path = img_path + '.txt'
|
||||
metadata = req.json()
|
||||
metadata['seed'] = seed
|
||||
m = metadata.copy()
|
||||
m['seed'] = seed
|
||||
with open(img_metadata_path, 'w', encoding='utf-8') as f:
|
||||
f.write(metadata)
|
||||
f.write(m)
|
||||
|
||||
img_path += '_filtered' if filtered else ''
|
||||
img_path += '.' + req.output_format
|
||||
img.save(img_path, quality=req.output_quality)
|
||||
img_path += '.' + metadata['output_format']
|
||||
img.save(img_path, quality=metadata['output_quality'])
|
||||
|
||||
def construct_response(req: Request, images: list):
|
||||
return [
|
||||
@ -172,7 +187,7 @@ def construct_response(req: Request, images: list):
|
||||
) for img, seed, _ in images
|
||||
]
|
||||
|
||||
def get_mk_img_args(req: Request):
|
||||
def req_to_args(req: Request):
|
||||
args = req.json()
|
||||
|
||||
args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
|
||||
@ -180,21 +195,21 @@ def get_mk_img_args(req: Request):
|
||||
|
||||
return args
|
||||
|
||||
def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
||||
def make_step_callback(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
n_steps = args['num_inference_steps'] if args['init_image'] is None else int(args['num_inference_steps'] * args['prompt_strength'])
|
||||
last_callback_time = -1
|
||||
|
||||
def update_temp_img(req, x_samples, task_temp_images: list):
|
||||
def update_temp_img(x_samples, task_temp_images: list):
|
||||
partial_images = []
|
||||
for i in range(req.num_outputs):
|
||||
for i in range(args['num_outputs']):
|
||||
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
|
||||
buf = image_utils.img_to_buffer(img, output_format='JPEG')
|
||||
|
||||
del img
|
||||
|
||||
thread_data.temp_images[f'{req.request_id}/{i}'] = buf
|
||||
thread_data.temp_images[f"{args['request_id']}/{i}"] = buf
|
||||
task_temp_images[i] = buf
|
||||
partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'})
|
||||
partial_images.append({'path': f"/image/tmp/{args['request_id']}/{i}"})
|
||||
return partial_images
|
||||
|
||||
def on_image_step(x_samples, i):
|
||||
@ -206,8 +221,8 @@ def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images:
|
||||
|
||||
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
|
||||
|
||||
if req.stream_image_progress and i % 5 == 0:
|
||||
progress['output'] = update_temp_img(req, x_samples, task_temp_images)
|
||||
if stream_image_progress and i % 5 == 0:
|
||||
progress['output'] = update_temp_img(x_samples, task_temp_images)
|
||||
|
||||
data_queue.put(json.dumps(progress))
|
||||
|
||||
|
@ -76,6 +76,7 @@ class ImageRequest(BaseModel):
|
||||
negative_prompt: str = ""
|
||||
init_image: str = None # base64
|
||||
mask: str = None # base64
|
||||
apply_color_correction: bool = False
|
||||
num_outputs: int = 1
|
||||
num_inference_steps: int = 50
|
||||
guidance_scale: float = 7.5
|
||||
@ -522,6 +523,7 @@ def render(req : ImageRequest):
|
||||
r.negative_prompt = req.negative_prompt
|
||||
r.init_image = req.init_image
|
||||
r.mask = req.mask
|
||||
r.apply_color_correction = req.apply_color_correction
|
||||
r.num_outputs = req.num_outputs
|
||||
r.num_inference_steps = req.num_inference_steps
|
||||
r.guidance_scale = req.guidance_scale
|
||||
|
Loading…
Reference in New Issue
Block a user