Merge pull request #1285 from ogmaresca/save-clip-skip-in-metadata-files

Add Clip Skip to metadata files
This commit is contained in:
cmdr2 2023-05-20 10:44:58 +05:30 committed by GitHub
commit 566cb55f36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 39 additions and 36 deletions

View File

@ -72,7 +72,7 @@ def make_images(
def print_task_info(req: GenerateImageRequest, task_data: TaskData):
req_str = pprint.pformat(get_printable_request(req)).replace("[", "\[")
req_str = pprint.pformat(get_printable_request(req, task_data)).replace("[", "\[")
task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
log.info(f"request: {req_str}")
log.info(f"task data: {task_str}")

View File

@ -15,23 +15,24 @@ img_number_regex = re.compile("([0-9]{5,})")
# keep in sync with `ui/media/js/dnd.js`
TASK_TEXT_MAPPING = {
"prompt": "Prompt",
"negative_prompt": "Negative Prompt",
"seed": "Seed",
"use_stable_diffusion_model": "Stable Diffusion model",
"clip_skip": "Clip Skip",
"use_vae_model": "VAE model",
"sampler_name": "Sampler",
"width": "Width",
"height": "Height",
"seed": "Seed",
"num_inference_steps": "Steps",
"guidance_scale": "Guidance Scale",
"prompt_strength": "Prompt Strength",
"use_lora_model": "LoRA model",
"lora_alpha": "LoRA Strength",
"use_hypernetwork_model": "Hypernetwork model",
"hypernetwork_strength": "Hypernetwork Strength",
"use_face_correction": "Use Face Correction",
"use_upscale": "Use Upscaling",
"upscale_amount": "Upscale By",
"sampler_name": "Sampler",
"negative_prompt": "Negative Prompt",
"use_stable_diffusion_model": "Stable Diffusion model",
"use_vae_model": "VAE model",
"use_hypernetwork_model": "Hypernetwork model",
"hypernetwork_strength": "Hypernetwork Strength",
"use_lora_model": "LoRA model",
"lora_alpha": "LoRA Strength",
}
time_placeholders = {
@ -179,27 +180,7 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData):
metadata = get_printable_request(req)
metadata.update(
{
"use_stable_diffusion_model": task_data.use_stable_diffusion_model,
"use_vae_model": task_data.use_vae_model,
"use_hypernetwork_model": task_data.use_hypernetwork_model,
"use_lora_model": task_data.use_lora_model,
"use_face_correction": task_data.use_face_correction,
"use_upscale": task_data.use_upscale,
}
)
if metadata["use_upscale"] is not None:
metadata["upscale_amount"] = task_data.upscale_amount
if task_data.use_hypernetwork_model is None:
del metadata["hypernetwork_strength"]
if task_data.use_lora_model is None:
if "lora_alpha" in metadata:
del metadata["lora_alpha"]
app_config = app.getConfig()
if not app_config.get("test_diffusers", False) and "use_lora_model" in metadata:
del metadata["use_lora_model"]
metadata = get_printable_request(req, task_data)
# if text, format it in the text format expected by the UI
is_txt_format = task_data.metadata_output_format.lower() == "txt"
@ -213,12 +194,33 @@ def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskD
return entries
def get_printable_request(req: GenerateImageRequest):
metadata = req.dict()
del metadata["init_image"]
del metadata["init_image_mask"]
if req.init_image is None:
def get_printable_request(req: GenerateImageRequest, task_data: TaskData):
req_metadata = req.dict()
task_data_metadata = task_data.dict()
# Save the metadata in the order defined in TASK_TEXT_MAPPING
metadata = {}
for key in TASK_TEXT_MAPPING.keys():
if key in req_metadata:
metadata[key] = req_metadata[key]
elif key in task_data_metadata:
metadata[key] = task_data_metadata[key]
# Clean up the metadata
if req.init_image is None and "prompt_strength" in metadata:
del metadata["prompt_strength"]
if task_data.use_upscale is None and "upscale_amount" in metadata:
del metadata["upscale_amount"]
if task_data.use_hypernetwork_model is None and "hypernetwork_strength" in metadata:
del metadata["hypernetwork_strength"]
if task_data.use_lora_model is None and "lora_alpha" in metadata:
del metadata["lora_alpha"]
app_config = app.getConfig()
if not app_config.get("test_diffusers", False):
for key in (x for x in ["use_lora_model", "lora_alpha", "clip_skip"] if x in metadata):
del metadata[key]
return metadata

View File

@ -37,6 +37,7 @@ function parseBoolean(stringValue) {
}
}
// keep in sync with `ui/easydiffusion/utils/save_utils.py`
const TASK_MAPPING = {
prompt: {
name: "Prompt",