From 7e53eb658c8105fa86fc285305556d994148e390 Mon Sep 17 00:00:00 2001 From: Olivia Godone-Maresca Date: Wed, 29 Mar 2023 20:56:24 -0400 Subject: [PATCH 1/3] Allow loading/saving app.config from plugins and support custom folder/filename formats from app.config --- ui/easydiffusion/server.py | 12 +- ui/easydiffusion/utils/save_utils.py | 148 +++++++++++++++++++++-- ui/media/js/parameters.js | 172 ++++++++++++++++++++++----- ui/media/js/utils.js | 26 +++- 4 files changed, 315 insertions(+), 43 deletions(-) diff --git a/ui/easydiffusion/server.py b/ui/easydiffusion/server.py index e27f9c5b..9642d735 100644 --- a/ui/easydiffusion/server.py +++ b/ui/easydiffusion/server.py @@ -10,7 +10,7 @@ from typing import List, Union from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles from starlette.responses import FileResponse, JSONResponse, StreamingResponse -from pydantic import BaseModel +from pydantic import BaseModel, Extra from easydiffusion import app, model_manager, task_manager from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest @@ -44,7 +44,7 @@ class NoCacheStaticFiles(StaticFiles): return super().is_not_modified(response_headers, request_headers) -class SetAppConfigRequest(BaseModel): +class SetAppConfigRequest(BaseModel, extra=Extra.allow): update_branch: str = None render_devices: Union[List[str], List[int], str, int] = None model_vae: str = None @@ -136,6 +136,14 @@ def set_app_config_internal(req: SetAppConfigRequest): config["test_diffusers"] = req.test_diffusers + for property, property_value in req.dict().items(): + log.info(f"set_app_config_internal {property} === {property_value}") + if property_value is not None and property in req.__fields__: + log.info(f"set_app_config_internal {property} IS DEFINED PROPERTY") + if property_value is not None and property not in req.__fields__: + log.info(f"set_app_config_internal {property} IS ADDITIONAL PROPERTY") + config[property] = property_value + try: app.setConfig(config) diff --git a/ui/easydiffusion/utils/save_utils.py b/ui/easydiffusion/utils/save_utils.py index 950f04b0..f936ad1b 100644 --- a/ui/easydiffusion/utils/save_utils.py +++ b/ui/easydiffusion/utils/save_utils.py @@ -1,13 +1,18 @@ import os import time +import base64 import re +from easydiffusion import app from easydiffusion.types import TaskData, GenerateImageRequest +from functools import reduce +from datetime import datetime from sdkit.utils import save_images, save_dicts from numpy import base_repr filename_regex = re.compile("[^a-zA-Z0-9._-]") +img_number_regex = re.compile("([0-9]{5,})") # keep in sync with `ui/media/js/dnd.js` TASK_TEXT_MAPPING = { @@ -31,12 +36,86 @@ TASK_TEXT_MAPPING = { # "lora_alpha": "LoRA Strength", } +time_placeholders = { + "$yyyy": "%Y", + "$MM": "%m", + "$dd": "%d", + "$HH": "%H", + "$mm": "%M", + "$ss": "%S", +} + +other_placeholders = { + "$id": lambda req, task_data: filename_regex.sub("_", task_data.session_id), + "$p": lambda req, task_data: filename_regex.sub("_", req.prompt)[:50], + "$s": lambda req, task_data: str(req.seed), +} + +class ImageNumber: + _factory = None + _evaluated = False + + def __init__(self, factory): + self._factory = factory + self._evaluated = None + def __call__(self) -> int: + if self._evaluated is None: + self._evaluated = self._factory() + return self._evaluated + +def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskData, now = None): + if now is None: + now = time.time() + + for placeholder, time_format in time_placeholders.items(): + if placeholder in format: + format = format.replace(placeholder, datetime.fromtimestamp(now).strftime(time_format)) + for placeholder, replace_func in other_placeholders.items(): + if placeholder in format: + format = format.replace(placeholder, replace_func(req, task_data)) + + return format + +def format_folder_name(format: str, req: GenerateImageRequest, task_data: TaskData): + format = format_placeholders(format, req, task_data) + return filename_regex.sub("_", format) + +def format_file_name( + format: str, + req: GenerateImageRequest, + task_data: TaskData, + now: float, + batch_file_number: int, + folder_img_number: ImageNumber, +): + format = format_placeholders(format, req, task_data, now) + + if "$n" in format: + format = format.replace("$n", f"{folder_img_number():05}") + + if "$tsb64" in format: + img_id = base_repr(int(now * 10000), 36)[-7:] + base_repr(int(batch_file_number), 36) # Base 36 conversion, 0-9, A-Z + format = format.replace("$tsb64", img_id) + + if "$ts" in format: + format = format.replace("$ts", str(int(now * 1000) + batch_file_number)) + + return filename_regex.sub("_", format) def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): now = time.time() - save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub("_", task_data.session_id)) + app_config = app.getConfig() + folder_format = app_config.get("folder_format", "$id") + save_dir_path = os.path.join(task_data.save_to_disk_path, format_folder_name(folder_format, req, task_data)) metadata_entries = get_metadata_entries_for_request(req, task_data) - make_filename = make_filename_callback(req, now=now) + file_number = calculate_img_number(save_dir_path, task_data) + make_filename = make_filename_callback( + app_config.get("filename_format", "$p_$tsb64"), + req, + task_data, + file_number, + now=now, + ) if task_data.show_only_filtered_image or filtered_images is images: save_images( @@ -58,7 +137,7 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR file_format=task_data.output_format, ) else: - make_filter_filename = make_filename_callback(req, now=now, suffix="filtered") + make_filter_filename = make_filename_callback(req, task_data, file_number, now=now, suffix="filtered") save_images( images, @@ -105,9 +184,6 @@ def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskD if task_data.use_lora_model is None: if "lora_alpha" in metadata: del metadata["lora_alpha"] - - from easydiffusion import app - app_config = app.getConfig() if not app_config.get("test_diffusers", False) and "use_lora_model" in metadata: del metadata["use_lora_model"] @@ -133,16 +209,66 @@ def get_printable_request(req: GenerateImageRequest): return metadata -def make_filename_callback(req: GenerateImageRequest, suffix=None, now=None): +def make_filename_callback( + filename_format: str, + req: GenerateImageRequest, + task_data: TaskData, + folder_img_number: int, + suffix=None, + now=None, +): if now is None: now = time.time() def make_filename(i): - img_id = base_repr(int(now * 10000), 36)[-7:] + base_repr(int(i),36) # Base 36 conversion, 0-9, A-Z - - prompt_flattened = filename_regex.sub("_", req.prompt)[:50] - name = f"{prompt_flattened}_{img_id}" + name = format_file_name(filename_format, req, task_data, now, i, folder_img_number) name = name if suffix is None else f"{name}_{suffix}" + return name return make_filename + +def _calculate_img_number(save_dir_path: str, task_data: TaskData): + def get_highest_img_number(accumulator: int, file: os.DirEntry) -> int: + if not file.is_file: + return accumulator + + if len(list(filter(lambda e: file.name.endswith(e), app.IMAGE_EXTENSIONS))) == 0: + return accumulator + + get_highest_img_number.number_of_images = get_highest_img_number.number_of_images + 1 + + number_match = img_number_regex.match(file.name) + if not number_match: + return accumulator + + file_number = number_match.group().lstrip('0') + + # Handle 00000 + return int(file_number) if file_number else 0 + + get_highest_img_number.number_of_images = 0 + + highest_file_number = -1 + + if os.path.isdir(save_dir_path): + existing_files = list(os.scandir(save_dir_path)) + highest_file_number = reduce(get_highest_img_number, existing_files, -1) + + calculated_img_number = max(highest_file_number, get_highest_img_number.number_of_images - 1) + + if task_data.session_id in _calculate_img_number.session_img_numbers: + calculated_img_number = max( + _calculate_img_number.session_img_numbers[task_data.session_id], + calculated_img_number, + ) + + calculated_img_number = calculated_img_number + 1 + + _calculate_img_number.session_img_numbers[task_data.session_id] = calculated_img_number + return calculated_img_number + +_calculate_img_number.session_img_numbers = {} + +def calculate_img_number(save_dir_path: str, task_data: TaskData): + return ImageNumber(lambda: _calculate_img_number(save_dir_path, task_data)) diff --git a/ui/media/js/parameters.js b/ui/media/js/parameters.js index baa55469..0fc89359 100644 --- a/ui/media/js/parameters.js +++ b/ui/media/js/parameters.js @@ -15,10 +15,13 @@ * JSDoc style * @typedef {object} Parameter * @property {string} id - * @property {ParameterType} type - * @property {string} label - * @property {?string} note + * @property {keyof ParameterType} type + * @property {string | (parameter: Parameter) => (HTMLElement | string)} label + * @property {string | (parameter: Parameter) => (HTMLElement | string) | undefined} note + * @property {(parameter: Parameter) => (HTMLElement | string) | undefined} render + * @property {string | undefined} icon * @property {number|boolean|string} default + * @property {boolean?} saveInAppConfig */ @@ -118,6 +121,7 @@ var PARAMETERS = [ note: "starts the default browser on startup", icon: "fa-window-restore", default: true, + saveInAppConfig: true, }, { id: "vram_usage_level", @@ -179,6 +183,7 @@ var PARAMETERS = [ note: "Other devices on your network can access this web page", icon: "fa-network-wired", default: true, + saveInAppConfig: true, }, { id: "listen_port", @@ -188,7 +193,8 @@ var PARAMETERS = [ icon: "fa-anchor", render: (parameter) => { return `` - } + }, + saveInAppConfig: true, }, { id: "use_beta_channel", @@ -205,6 +211,7 @@ var PARAMETERS = [ note: "Experimental! Can have bugs! Use upcoming features (like LoRA) in our new engine. Please press Save, then restart the program after changing this.", icon: "fa-bolt", default: false, + saveInAppConfig: true, }, ]; @@ -228,6 +235,10 @@ function sliderUpdate(event) { } } +/** + * @param {Parameter} parameter + * @returns {string | HTMLElement} + */ function getParameterElement(parameter) { switch (parameter.type) { case ParameterType.checkbox: @@ -243,29 +254,74 @@ function getParameterElement(parameter) { case ParameterType.custom: return parameter.render(parameter) default: - console.error(`Invalid type for parameter ${parameter.id}`); + console.error(`Invalid type ${parameter.type} for parameter ${parameter.id}`); return "ERROR: Invalid Type" } } let parametersTable = document.querySelector("#system-settings .parameters-table") -/* fill in the system settings popup table */ -function initParameters() { - PARAMETERS.forEach(parameter => { - var element = getParameterElement(parameter) - var note = parameter.note ? `${parameter.note}` : ""; - var icon = parameter.icon ? `` : ""; - var newrow = document.createElement('div') - newrow.innerHTML = ` -
${icon}
-
${note}
-
${element}
` +/** + * fill in the system settings popup table + * @param {Array | undefined} parameters + * */ +function initParameters(parameters) { + parameters.forEach(parameter => { + const element = getParameterElement(parameter) + const elementWrapper = createElement('div') + if (element instanceof Node) { + elementWrapper.appendChild(element) + } else { + elementWrapper.innerHTML = element + } + + const note = typeof parameter.note === 'function' ? parameter.note(parameter) : parameter.note + const noteElements = [] + if (note) { + const noteElement = createElement('small') + if (note instanceof Node) { + noteElement.appendChild(note) + } else { + noteElement.innerHTML = note || '' + } + noteElements.push(noteElement) + } + + const icon = parameter.icon ? [createElement('i', undefined, ['fa', parameter.icon])] : [] + + const label = typeof parameter.label === 'function' ? parameter.label(parameter) : parameter.label + const labelElement = createElement('label', { for: parameter.id }) + if (label instanceof Node) { + labelElement.appendChild(label) + } else { + labelElement.innerHTML = label + } + + const newrow = createElement( + 'div', + { 'data-setting-id': parameter.id, 'data-save-in-app-config': parameter.saveInAppConfig }, + undefined, + [ + createElement('div', undefined, undefined, icon), + createElement('div', undefined, undefined, [labelElement, ...noteElements]), + elementWrapper, + ] + ) parametersTable.appendChild(newrow) parameter.settingsEntry = newrow }) } -initParameters() +initParameters(PARAMETERS) + +// listen to parameters from plugins +PARAMETERS.addEventListener('push', (...items) => { + initParameters(items) + + if (items.find(item => item.saveInAppConfig)) { + console.log('Reloading app config for new parameters', items.map(p => p.id)) + getAppConfig() + } +}) let vramUsageLevelField = document.querySelector('#vram_usage_level') let useCPUField = document.querySelector('#use_cpu') @@ -324,9 +380,44 @@ async function getAppConfig() { document.querySelector("#lora_model_container").style.display = (testDiffusers.checked ? '' : 'none') } + Array.from(parametersTable.children).forEach(parameterRow => { + if (parameterRow.dataset.settingId in config && parameterRow.dataset.saveInAppConfig === 'true') { + const configValue = config[parameterRow.dataset.settingId] + const parameterElement = document.getElementById(parameterRow.dataset.settingId) || + parameterRow.querySelector('input') || parameterRow.querySelector('select') + + switch (parameterElement?.tagName) { + case 'INPUT': + if (parameterElement.type === 'checkbox') { + parameterElement.checked = configValue + } else { + parameterElement.value = configValue + } + parameterElement.dispatchEvent(new Event('change')) + break + case 'SELECT': + if (Array.isArray(configValue)) { + Array.from(parameterElement.options).forEach(option => { + if (configValue.includes(option.value || option.text)) { + option.selected = true + } + }) + } else { + parameterElement.value = configValue + } + parameterElement.dispatchEvent(new Event('change')) + break + } + } + }) + console.log('get config status response', config) + + return config } catch (e) { console.log('get config status error', e) + + return {} } } @@ -486,16 +577,43 @@ saveSettingsBtn.addEventListener('click', function() { alert('The network port must be a number from 1 to 65535') return } - let updateBranch = (useBetaChannelField.checked ? 'beta' : 'main') - changeAppConfig({ + const updateBranch = (useBetaChannelField.checked ? 'beta' : 'main') + + const updateAppConfigRequest = { 'render_devices': getCurrentRenderDeviceSelection(), 'update_branch': updateBranch, - 'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked, - 'listen_to_network': listenToNetworkField.checked, - 'listen_port': listenPortField.value, - 'test_diffusers': testDiffusers.checked - }) - saveSettingsBtn.classList.add('active') - asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active')) -}) + } + Array.from(parametersTable.children).forEach(parameterRow => { + if (parameterRow.dataset.saveInAppConfig === 'true') { + const parameterElement = document.getElementById(parameterRow.dataset.settingId) || + parameterRow.querySelector('input') || parameterRow.querySelector('select') + + switch (parameterElement?.tagName) { + case 'INPUT': + if (parameterElement.type === 'checkbox') { + updateAppConfigRequest[parameterRow.dataset.settingId] = parameterElement.checked + } else { + updateAppConfigRequest[parameterRow.dataset.settingId] = parameterElement.value + } + break + case 'SELECT': + if (parameterElement.multiple) { + updateAppConfigRequest[parameterRow.dataset.settingId] = Array.from(parameterElement.options) + .filter(option => option.selected) + .map(option => option.value || option.text) + } else { + updateAppConfigRequest[parameterRow.dataset.settingId] = parameterElement.value + } + break + default: + console.error(`Setting parameter ${parameterRow.dataset.settingId} couldn't be saved to app.config - element #${parameter.id} is a <${parameterElement?.tagName} /> instead of a or a