Merge pull request #1087 from ogmaresca/custom-folder-filename-formats-2

Allow loading/saving app.config from plugins and support custom folder/filename formats from app.config
This commit is contained in:
cmdr2 2023-04-06 16:19:59 +05:30 committed by GitHub
commit 0778078350
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 309 additions and 43 deletions

View File

@ -10,7 +10,7 @@ from typing import List, Union
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
from pydantic import BaseModel, Extra
from easydiffusion import app, model_manager, task_manager
from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest
@ -44,7 +44,7 @@ class NoCacheStaticFiles(StaticFiles):
return super().is_not_modified(response_headers, request_headers)
class SetAppConfigRequest(BaseModel):
class SetAppConfigRequest(BaseModel, extra=Extra.allow):
update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
model_vae: str = None
@ -136,6 +136,10 @@ def set_app_config_internal(req: SetAppConfigRequest):
config["test_diffusers"] = req.test_diffusers
for property, property_value in req.dict().items():
if property_value is not None and property not in req.__fields__:
config[property] = property_value
try:
app.setConfig(config)

View File

@ -1,13 +1,18 @@
import os
import time
import base64
import re
from easydiffusion import app
from easydiffusion.types import TaskData, GenerateImageRequest
from functools import reduce
from datetime import datetime
from sdkit.utils import save_images, save_dicts
from numpy import base_repr
filename_regex = re.compile("[^a-zA-Z0-9._-]")
img_number_regex = re.compile("([0-9]{5,})")
# keep in sync with `ui/media/js/dnd.js`
TASK_TEXT_MAPPING = {
@ -31,12 +36,86 @@ TASK_TEXT_MAPPING = {
"lora_alpha": "LoRA Strength",
}
time_placeholders = {
"$yyyy": "%Y",
"$MM": "%m",
"$dd": "%d",
"$HH": "%H",
"$mm": "%M",
"$ss": "%S",
}
other_placeholders = {
"$id": lambda req, task_data: filename_regex.sub("_", task_data.session_id),
"$p": lambda req, task_data: filename_regex.sub("_", req.prompt)[:50],
"$s": lambda req, task_data: str(req.seed),
}
class ImageNumber:
_factory = None
_evaluated = False
def __init__(self, factory):
self._factory = factory
self._evaluated = None
def __call__(self) -> int:
if self._evaluated is None:
self._evaluated = self._factory()
return self._evaluated
def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskData, now = None):
if now is None:
now = time.time()
for placeholder, time_format in time_placeholders.items():
if placeholder in format:
format = format.replace(placeholder, datetime.fromtimestamp(now).strftime(time_format))
for placeholder, replace_func in other_placeholders.items():
if placeholder in format:
format = format.replace(placeholder, replace_func(req, task_data))
return format
def format_folder_name(format: str, req: GenerateImageRequest, task_data: TaskData):
format = format_placeholders(format, req, task_data)
return filename_regex.sub("_", format)
def format_file_name(
format: str,
req: GenerateImageRequest,
task_data: TaskData,
now: float,
batch_file_number: int,
folder_img_number: ImageNumber,
):
format = format_placeholders(format, req, task_data, now)
if "$n" in format:
format = format.replace("$n", f"{folder_img_number():05}")
if "$tsb64" in format:
img_id = base_repr(int(now * 10000), 36)[-7:] + base_repr(int(batch_file_number), 36) # Base 36 conversion, 0-9, A-Z
format = format.replace("$tsb64", img_id)
if "$ts" in format:
format = format.replace("$ts", str(int(now * 1000) + batch_file_number))
return filename_regex.sub("_", format)
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
now = time.time()
save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub("_", task_data.session_id))
app_config = app.getConfig()
folder_format = app_config.get("folder_format", "$id")
save_dir_path = os.path.join(task_data.save_to_disk_path, format_folder_name(folder_format, req, task_data))
metadata_entries = get_metadata_entries_for_request(req, task_data)
make_filename = make_filename_callback(req, now=now)
file_number = calculate_img_number(save_dir_path, task_data)
make_filename = make_filename_callback(
app_config.get("filename_format", "$p_$tsb64"),
req,
task_data,
file_number,
now=now,
)
if task_data.show_only_filtered_image or filtered_images is images:
save_images(
@ -58,7 +137,7 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
file_format=task_data.output_format,
)
else:
make_filter_filename = make_filename_callback(req, now=now, suffix="filtered")
make_filter_filename = make_filename_callback(req, task_data, file_number, now=now, suffix="filtered")
save_images(
images,
@ -105,9 +184,6 @@ def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskD
if task_data.use_lora_model is None:
if "lora_alpha" in metadata:
del metadata["lora_alpha"]
from easydiffusion import app
app_config = app.getConfig()
if not app_config.get("test_diffusers", False) and "use_lora_model" in metadata:
del metadata["use_lora_model"]
@ -133,16 +209,66 @@ def get_printable_request(req: GenerateImageRequest):
return metadata
def make_filename_callback(req: GenerateImageRequest, suffix=None, now=None):
def make_filename_callback(
filename_format: str,
req: GenerateImageRequest,
task_data: TaskData,
folder_img_number: int,
suffix=None,
now=None,
):
if now is None:
now = time.time()
def make_filename(i):
img_id = base_repr(int(now * 10000), 36)[-7:] + base_repr(int(i),36) # Base 36 conversion, 0-9, A-Z
prompt_flattened = filename_regex.sub("_", req.prompt)[:50]
name = f"{prompt_flattened}_{img_id}"
name = format_file_name(filename_format, req, task_data, now, i, folder_img_number)
name = name if suffix is None else f"{name}_{suffix}"
return name
return make_filename
def _calculate_img_number(save_dir_path: str, task_data: TaskData):
def get_highest_img_number(accumulator: int, file: os.DirEntry) -> int:
if not file.is_file:
return accumulator
if len(list(filter(lambda e: file.name.endswith(e), app.IMAGE_EXTENSIONS))) == 0:
return accumulator
get_highest_img_number.number_of_images = get_highest_img_number.number_of_images + 1
number_match = img_number_regex.match(file.name)
if not number_match:
return accumulator
file_number = number_match.group().lstrip('0')
# Handle 00000
return int(file_number) if file_number else 0
get_highest_img_number.number_of_images = 0
highest_file_number = -1
if os.path.isdir(save_dir_path):
existing_files = list(os.scandir(save_dir_path))
highest_file_number = reduce(get_highest_img_number, existing_files, -1)
calculated_img_number = max(highest_file_number, get_highest_img_number.number_of_images - 1)
if task_data.session_id in _calculate_img_number.session_img_numbers:
calculated_img_number = max(
_calculate_img_number.session_img_numbers[task_data.session_id],
calculated_img_number,
)
calculated_img_number = calculated_img_number + 1
_calculate_img_number.session_img_numbers[task_data.session_id] = calculated_img_number
return calculated_img_number
_calculate_img_number.session_img_numbers = {}
def calculate_img_number(save_dir_path: str, task_data: TaskData):
return ImageNumber(lambda: _calculate_img_number(save_dir_path, task_data))

View File

@ -15,10 +15,13 @@
* JSDoc style
* @typedef {object} Parameter
* @property {string} id
* @property {ParameterType} type
* @property {string} label
* @property {?string} note
* @property {keyof ParameterType} type
* @property {string | (parameter: Parameter) => (HTMLElement | string)} label
* @property {string | (parameter: Parameter) => (HTMLElement | string) | undefined} note
* @property {(parameter: Parameter) => (HTMLElement | string) | undefined} render
* @property {string | undefined} icon
* @property {number|boolean|string} default
* @property {boolean?} saveInAppConfig
*/
@ -118,6 +121,7 @@ var PARAMETERS = [
note: "starts the default browser on startup",
icon: "fa-window-restore",
default: true,
saveInAppConfig: true,
},
{
id: "vram_usage_level",
@ -179,6 +183,7 @@ var PARAMETERS = [
note: "Other devices on your network can access this web page",
icon: "fa-network-wired",
default: true,
saveInAppConfig: true,
},
{
id: "listen_port",
@ -188,7 +193,8 @@ var PARAMETERS = [
icon: "fa-anchor",
render: (parameter) => {
return `<input id="${parameter.id}" name="${parameter.id}" size="6" value="9000" onkeypress="preventNonNumericalInput(event)">`
}
},
saveInAppConfig: true,
},
{
id: "use_beta_channel",
@ -205,6 +211,7 @@ var PARAMETERS = [
note: "<b>Experimental! Can have bugs!</b> Use upcoming features (like LoRA) in our new engine. Please press Save, then restart the program after changing this.",
icon: "fa-bolt",
default: false,
saveInAppConfig: true,
},
];
@ -228,6 +235,10 @@ function sliderUpdate(event) {
}
}
/**
* @param {Parameter} parameter
* @returns {string | HTMLElement}
*/
function getParameterElement(parameter) {
switch (parameter.type) {
case ParameterType.checkbox:
@ -243,29 +254,74 @@ function getParameterElement(parameter) {
case ParameterType.custom:
return parameter.render(parameter)
default:
console.error(`Invalid type for parameter ${parameter.id}`);
console.error(`Invalid type ${parameter.type} for parameter ${parameter.id}`);
return "ERROR: Invalid Type"
}
}
let parametersTable = document.querySelector("#system-settings .parameters-table")
/* fill in the system settings popup table */
function initParameters() {
PARAMETERS.forEach(parameter => {
var element = getParameterElement(parameter)
var note = parameter.note ? `<small>${parameter.note}</small>` : "";
var icon = parameter.icon ? `<i class="fa ${parameter.icon}"></i>` : "";
var newrow = document.createElement('div')
newrow.innerHTML = `
<div>${icon}</div>
<div><label for="${parameter.id}">${parameter.label}</label>${note}</div>
<div>${element}</div>`
/**
* fill in the system settings popup table
* @param {Array<Parameter> | undefined} parameters
* */
function initParameters(parameters) {
parameters.forEach(parameter => {
const element = getParameterElement(parameter)
const elementWrapper = createElement('div')
if (element instanceof Node) {
elementWrapper.appendChild(element)
} else {
elementWrapper.innerHTML = element
}
const note = typeof parameter.note === 'function' ? parameter.note(parameter) : parameter.note
const noteElements = []
if (note) {
const noteElement = createElement('small')
if (note instanceof Node) {
noteElement.appendChild(note)
} else {
noteElement.innerHTML = note || ''
}
noteElements.push(noteElement)
}
const icon = parameter.icon ? [createElement('i', undefined, ['fa', parameter.icon])] : []
const label = typeof parameter.label === 'function' ? parameter.label(parameter) : parameter.label
const labelElement = createElement('label', { for: parameter.id })
if (label instanceof Node) {
labelElement.appendChild(label)
} else {
labelElement.innerHTML = label
}
const newrow = createElement(
'div',
{ 'data-setting-id': parameter.id, 'data-save-in-app-config': parameter.saveInAppConfig },
undefined,
[
createElement('div', undefined, undefined, icon),
createElement('div', undefined, undefined, [labelElement, ...noteElements]),
elementWrapper,
]
)
parametersTable.appendChild(newrow)
parameter.settingsEntry = newrow
})
}
initParameters()
initParameters(PARAMETERS)
// listen to parameters from plugins
PARAMETERS.addEventListener('push', (...items) => {
initParameters(items)
if (items.find(item => item.saveInAppConfig)) {
console.log('Reloading app config for new parameters', items.map(p => p.id))
getAppConfig()
}
})
let vramUsageLevelField = document.querySelector('#vram_usage_level')
let useCPUField = document.querySelector('#use_cpu')
@ -330,9 +386,44 @@ async function getAppConfig() {
document.querySelector("#lora_alpha_container").style.display = (testDiffusers.checked && loraModelField.value !== "" ? '' : 'none')
}
Array.from(parametersTable.children).forEach(parameterRow => {
if (parameterRow.dataset.settingId in config && parameterRow.dataset.saveInAppConfig === 'true') {
const configValue = config[parameterRow.dataset.settingId]
const parameterElement = document.getElementById(parameterRow.dataset.settingId) ||
parameterRow.querySelector('input') || parameterRow.querySelector('select')
switch (parameterElement?.tagName) {
case 'INPUT':
if (parameterElement.type === 'checkbox') {
parameterElement.checked = configValue
} else {
parameterElement.value = configValue
}
parameterElement.dispatchEvent(new Event('change'))
break
case 'SELECT':
if (Array.isArray(configValue)) {
Array.from(parameterElement.options).forEach(option => {
if (configValue.includes(option.value || option.text)) {
option.selected = true
}
})
} else {
parameterElement.value = configValue
}
parameterElement.dispatchEvent(new Event('change'))
break
}
}
})
console.log('get config status response', config)
return config
} catch (e) {
console.log('get config status error', e)
return {}
}
}
@ -492,16 +583,43 @@ saveSettingsBtn.addEventListener('click', function() {
alert('The network port must be a number from 1 to 65535')
return
}
let updateBranch = (useBetaChannelField.checked ? 'beta' : 'main')
changeAppConfig({
const updateBranch = (useBetaChannelField.checked ? 'beta' : 'main')
const updateAppConfigRequest = {
'render_devices': getCurrentRenderDeviceSelection(),
'update_branch': updateBranch,
'ui_open_browser_on_start': uiOpenBrowserOnStartField.checked,
'listen_to_network': listenToNetworkField.checked,
'listen_port': listenPortField.value,
'test_diffusers': testDiffusers.checked
})
saveSettingsBtn.classList.add('active')
asyncDelay(300).then(() => saveSettingsBtn.classList.remove('active'))
})
}
Array.from(parametersTable.children).forEach(parameterRow => {
if (parameterRow.dataset.saveInAppConfig === 'true') {
const parameterElement = document.getElementById(parameterRow.dataset.settingId) ||
parameterRow.querySelector('input') || parameterRow.querySelector('select')
switch (parameterElement?.tagName) {
case 'INPUT':
if (parameterElement.type === 'checkbox') {
updateAppConfigRequest[parameterRow.dataset.settingId] = parameterElement.checked
} else {
updateAppConfigRequest[parameterRow.dataset.settingId] = parameterElement.value
}
break
case 'SELECT':
if (parameterElement.multiple) {
updateAppConfigRequest[parameterRow.dataset.settingId] = Array.from(parameterElement.options)
.filter(option => option.selected)
.map(option => option.value || option.text)
} else {
updateAppConfigRequest[parameterRow.dataset.settingId] = parameterElement.value
}
break
default:
console.error(`Setting parameter ${parameterRow.dataset.settingId} couldn't be saved to app.config - element #${parameter.id} is a <${parameterElement?.tagName} /> instead of a <input /> or a <select />!`)
break
}
}
})
const savePromise = changeAppConfig(updateAppConfigRequest)
saveSettingsBtn.classList.add('active')
Promise.all([savePromise, asyncDelay(300)]).then(() => saveSettingsBtn.classList.remove('active'))
})

View File

@ -683,14 +683,16 @@ class ServiceContainer {
* @param {string} tag
* @param {object} attributes
* @param {string | Array<string>} classes
* @param {string | HTMLElement | Array<string | HTMLElement>}
* @param {string | Node | Array<string | Node>}
* @returns {HTMLElement}
*/
function createElement(tagName, attributes, classes, textOrElements) {
const element = document.createElement(tagName)
if (attributes) {
Object.entries(attributes).forEach(([key, value]) => {
if (value !== undefined && value !== null) {
element.setAttribute(key, value)
}
});
}
if (classes) {
@ -699,7 +701,7 @@ function createElement(tagName, attributes, classes, textOrElements) {
if (textOrElements) {
const children = Array.isArray(textOrElements) ? textOrElements : [textOrElements]
children.forEach(textOrElem => {
if (textOrElem instanceof HTMLElement) {
if (textOrElem instanceof Node) {
element.appendChild(textOrElem)
} else {
element.appendChild(document.createTextNode(textOrElem))
@ -708,3 +710,19 @@ function createElement(tagName, attributes, classes, textOrElements) {
}
return element
}
/**
* Add a listener for arrays
* @param {keyof Array} method
* @param {(args) => {}} callback
*/
Array.prototype.addEventListener = function(method, callback) {
const originalFunction = this[method]
if (originalFunction) {
this[method] = function() {
console.log(`Array.${method}()`, arguments)
originalFunction.apply(this, arguments)
callback.apply(this, arguments)
}
}
}