Formatting

This commit is contained in:
cmdr2 2023-04-28 16:38:55 +05:30
parent 07f52c38ef
commit d18cefc519
8 changed files with 213 additions and 105 deletions

View File

@ -1,17 +1,15 @@
import json
import logging
import os import os
import socket import socket
import sys import sys
import json
import traceback import traceback
import logging
import shlex
import urllib import urllib
from rich.logging import RichHandler
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
from easydiffusion import task_manager from easydiffusion import task_manager
from easydiffusion.utils import log from easydiffusion.utils import log
from rich.logging import RichHandler
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
# Remove all handlers associated with the root logger object. # Remove all handlers associated with the root logger object.
for handler in logging.root.handlers[:]: for handler in logging.root.handlers[:]:
@ -55,10 +53,34 @@ APP_CONFIG_DEFAULTS = {
}, },
} }
IMAGE_EXTENSIONS = [".png", ".apng", ".jpg", ".jpeg", ".jfif", ".pjpeg", ".pjp", ".jxl", ".gif", ".webp", ".avif", ".svg"] IMAGE_EXTENSIONS = [
".png",
".apng",
".jpg",
".jpeg",
".jfif",
".pjpeg",
".pjp",
".jxl",
".gif",
".webp",
".avif",
".svg",
]
CUSTOM_MODIFIERS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "modifiers")) CUSTOM_MODIFIERS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "modifiers"))
CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS=[".portrait", "_portrait", " portrait", "-portrait"] CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS = [
CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS=[".landscape", "_landscape", " landscape", "-landscape"] ".portrait",
"_portrait",
" portrait",
"-portrait",
]
CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS = [
".landscape",
"_landscape",
" landscape",
"-landscape",
]
def init(): def init():
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
@ -84,7 +106,7 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS):
if os.getenv("SD_UI_BIND_IP") is not None: if os.getenv("SD_UI_BIND_IP") is not None:
config["net"]["listen_to_network"] = os.getenv("SD_UI_BIND_IP") == "0.0.0.0" config["net"]["listen_to_network"] = os.getenv("SD_UI_BIND_IP") == "0.0.0.0"
return config return config
except Exception as e: except Exception:
log.warn(traceback.format_exc()) log.warn(traceback.format_exc())
return default_val return default_val
@ -97,6 +119,7 @@ def setConfig(config):
except: except:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level): def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level):
config = getConfig() config = getConfig()
if "model" not in config: if "model" not in config:
@ -191,11 +214,12 @@ def open_browser():
webbrowser.open(f"http://localhost:{port}") webbrowser.open(f"http://localhost:{port}")
def get_image_modifiers(): def get_image_modifiers():
modifiers_json_path = os.path.join(SD_UI_DIR, "modifiers.json") modifiers_json_path = os.path.join(SD_UI_DIR, "modifiers.json")
modifier_categories = {} modifier_categories = {}
original_category_order=[] original_category_order = []
with open(modifiers_json_path, "r", encoding="utf-8") as f: with open(modifiers_json_path, "r", encoding="utf-8") as f:
modifiers_file = json.load(f) modifiers_file = json.load(f)
@ -205,14 +229,14 @@ def get_image_modifiers():
# convert modifiers from a list of objects to a dict of dicts # convert modifiers from a list of objects to a dict of dicts
for category_item in modifiers_file: for category_item in modifiers_file:
category_name = category_item['category'] category_name = category_item["category"]
original_category_order.append(category_name) original_category_order.append(category_name)
category = {} category = {}
for modifier_item in category_item['modifiers']: for modifier_item in category_item["modifiers"]:
modifier = {} modifier = {}
for preview_item in modifier_item['previews']: for preview_item in modifier_item["previews"]:
modifier[preview_item['name']] = preview_item['path'] modifier[preview_item["name"]] = preview_item["path"]
category[modifier_item['modifier']] = modifier category[modifier_item["modifier"]] = modifier
modifier_categories[category_name] = category modifier_categories[category_name] = category
def scan_directory(directory_path: str, category_name="Modifiers"): def scan_directory(directory_path: str, category_name="Modifiers"):
@ -225,12 +249,27 @@ def get_image_modifiers():
modifier_name = entry.name[: -len(file_extension[0])] modifier_name = entry.name[: -len(file_extension[0])]
modifier_path = f"custom/{entry.path[len(CUSTOM_MODIFIERS_DIR) + 1:]}" modifier_path = f"custom/{entry.path[len(CUSTOM_MODIFIERS_DIR) + 1:]}"
# URL encode path segments # URL encode path segments
modifier_path = "/".join(map(lambda segment: urllib.parse.quote(segment), modifier_path.split("/"))) modifier_path = "/".join(
map(
lambda segment: urllib.parse.quote(segment),
modifier_path.split("/"),
)
)
is_portrait = True is_portrait = True
is_landscape = True is_landscape = True
portrait_extension = list(filter(lambda e: modifier_name.lower().endswith(e), CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS)) portrait_extension = list(
landscape_extension = list(filter(lambda e: modifier_name.lower().endswith(e), CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS)) filter(
lambda e: modifier_name.lower().endswith(e),
CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS,
)
)
landscape_extension = list(
filter(
lambda e: modifier_name.lower().endswith(e),
CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS,
)
)
if len(portrait_extension) > 0: if len(portrait_extension) > 0:
is_landscape = False is_landscape = False
@ -238,24 +277,24 @@ def get_image_modifiers():
elif len(landscape_extension) > 0: elif len(landscape_extension) > 0:
is_portrait = False is_portrait = False
modifier_name = modifier_name[: -len(landscape_extension[0])] modifier_name = modifier_name[: -len(landscape_extension[0])]
if (category_name not in modifier_categories): if category_name not in modifier_categories:
modifier_categories[category_name] = {} modifier_categories[category_name] = {}
category = modifier_categories[category_name] category = modifier_categories[category_name]
if (modifier_name not in category): if modifier_name not in category:
category[modifier_name] = {} category[modifier_name] = {}
if (is_portrait or "portrait" not in category[modifier_name]): if is_portrait or "portrait" not in category[modifier_name]:
category[modifier_name]["portrait"] = modifier_path category[modifier_name]["portrait"] = modifier_path
if (is_landscape or "landscape" not in category[modifier_name]): if is_landscape or "landscape" not in category[modifier_name]:
category[modifier_name]["landscape"] = modifier_path category[modifier_name]["landscape"] = modifier_path
elif entry.is_dir(): elif entry.is_dir():
scan_directory( scan_directory(
entry.path, entry.path,
entry.name if directory_path==CUSTOM_MODIFIERS_DIR else f"{category_name}/{entry.name}", entry.name if directory_path == CUSTOM_MODIFIERS_DIR else f"{category_name}/{entry.name}",
) )
scan_directory(CUSTOM_MODIFIERS_DIR) scan_directory(CUSTOM_MODIFIERS_DIR)
@ -268,12 +307,12 @@ def get_image_modifiers():
# convert the modifiers back into a list of objects # convert the modifiers back into a list of objects
modifier_categories_list = [] modifier_categories_list = []
for category_name in [*original_category_order, *custom_categories]: for category_name in [*original_category_order, *custom_categories]:
category = { 'category': category_name, 'modifiers': [] } category = {"category": category_name, "modifiers": []}
for modifier_name in sorted(modifier_categories[category_name].keys(), key=str.casefold): for modifier_name in sorted(modifier_categories[category_name].keys(), key=str.casefold):
modifier = { 'modifier': modifier_name, 'previews': [] } modifier = {"modifier": modifier_name, "previews": []}
for preview_name, preview_path in modifier_categories[category_name][modifier_name].items(): for preview_name, preview_path in modifier_categories[category_name][modifier_name].items():
modifier['previews'].append({ 'name': preview_name, 'path': preview_path }) modifier["previews"].append({"name": preview_name, "path": preview_path})
category['modifiers'].append(modifier) category["modifiers"].append(modifier)
modifier_categories_list.append(category) modifier_categories_list.append(category)
return modifier_categories_list return modifier_categories_list

View File

@ -1,9 +1,9 @@
import os import os
import platform import platform
import torch
import traceback
import re import re
import traceback
import torch
from easydiffusion.utils import log from easydiffusion.utils import log
""" """
@ -98,8 +98,8 @@ def auto_pick_devices(currently_active_devices):
continue continue
mem_free, mem_total = torch.cuda.mem_get_info(device) mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_free /= float(10**9) mem_free /= float(10 ** 9)
mem_total /= float(10**9) mem_total /= float(10 ** 9)
device_name = torch.cuda.get_device_name(device) device_name = torch.cuda.get_device_name(device)
log.debug( log.debug(
f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb" f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
@ -118,7 +118,10 @@ def auto_pick_devices(currently_active_devices):
# These already-running devices probably aren't terrible, since they were picked in the past. # These already-running devices probably aren't terrible, since they were picked in the past.
# Worst case, the user can restart the program and that'll get rid of them. # Worst case, the user can restart the program and that'll get rid of them.
devices = list( devices = list(
filter((lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices), devices) filter(
(lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices),
devices,
)
) )
devices = list(map(lambda x: x["device"], devices)) devices = list(map(lambda x: x["device"], devices))
return devices return devices
@ -178,7 +181,7 @@ def get_max_vram_usage_level(device):
else: else:
return "high" return "high"
mem_total /= float(10**9) mem_total /= float(10 ** 9)
if mem_total < 4.5: if mem_total < 4.5:
return "low" return "low"
elif mem_total < 6.5: elif mem_total < 6.5:
@ -220,7 +223,7 @@ def is_device_compatible(device):
# Memory check # Memory check
try: try:
_, mem_total = torch.cuda.mem_get_info(device) _, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9) mem_total /= float(10 ** 9)
if mem_total < 3.0: if mem_total < 3.0:
if is_device_compatible.history.get(device) == None: if is_device_compatible.history.get(device) == None:
log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion") log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion")

View File

@ -3,11 +3,17 @@ import os
from easydiffusion import app from easydiffusion import app
from easydiffusion.types import TaskData from easydiffusion.types import TaskData
from easydiffusion.utils import log from easydiffusion.utils import log
from sdkit import Context from sdkit import Context
from sdkit.models import load_model, unload_model, scan_model from sdkit.models import load_model, scan_model, unload_model
KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan", "lora"] KNOWN_MODEL_TYPES = [
"stable-diffusion",
"vae",
"hypernetwork",
"gfpgan",
"realesrgan",
"lora",
]
MODEL_EXTENSIONS = { MODEL_EXTENSIONS = {
"stable-diffusion": [".ckpt", ".safetensors"], "stable-diffusion": [".ckpt", ".safetensors"],
"vae": [".vae.pt", ".ckpt", ".safetensors"], "vae": [".vae.pt", ".ckpt", ".safetensors"],
@ -44,13 +50,15 @@ def load_default_models(context: Context):
load_model( load_model(
context, context,
model_type, model_type,
scan_model = context.model_paths[model_type] != None and not context.model_paths[model_type].endswith('.safetensors') scan_model=context.model_paths[model_type] != None
and not context.model_paths[model_type].endswith(".safetensors"),
) )
except Exception as e: except Exception as e:
log.error(f"[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]") log.error(f"[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]")
log.exception(e) log.exception(e)
del context.model_paths[model_type] del context.model_paths[model_type]
def unload_all(context: Context): def unload_all(context: Context):
for model_type in KNOWN_MODEL_TYPES: for model_type in KNOWN_MODEL_TYPES:
unload_model(context, model_type) unload_model(context, model_type)
@ -170,13 +178,23 @@ def is_malicious_model(file_path):
if scan_result.issues_count > 0 or scan_result.infected_files > 0: if scan_result.issues_count > 0 or scan_result.infected_files > 0:
log.warn( log.warn(
":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" ":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]"
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files) % (
file_path,
scan_result.scanned_files,
scan_result.issues_count,
scan_result.infected_files,
)
) )
return True return True
else: else:
log.debug( log.debug(
"Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" "Scan %s: [green]%d scanned, %d issue, %d infected.[/green]"
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files) % (
file_path,
scan_result.scanned_files,
scan_result.issues_count,
scan_result.infected_files,
)
) )
return False return False
except Exception as e: except Exception as e:
@ -204,13 +222,13 @@ def getModels():
class MaliciousModelException(Exception): class MaliciousModelException(Exception):
"Raised when picklescan reports a problem with a model" "Raised when picklescan reports a problem with a model"
pass
def scan_directory(directory, suffixes, directoriesFirst: bool = True): def scan_directory(directory, suffixes, directoriesFirst: bool = True):
nonlocal models_scanned nonlocal models_scanned
tree = [] tree = []
for entry in sorted( for entry in sorted(
os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()) os.scandir(directory),
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
): ):
if entry.is_file(): if entry.is_file():
matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes)) matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes))

View File

@ -1,21 +1,22 @@
import queue
import time
import json import json
import pprint import pprint
import queue
import time
from easydiffusion import device_manager from easydiffusion import device_manager
from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop, GenerateImageRequest from easydiffusion.types import GenerateImageRequest
from easydiffusion.utils import get_printable_request, save_images_to_disk, log from easydiffusion.types import Image as ResponseImage
from easydiffusion.types import Response, TaskData, UserInitiatedStop
from easydiffusion.utils import get_printable_request, log, save_images_to_disk
from sdkit import Context from sdkit import Context
from sdkit.generate import generate_images
from sdkit.filter import apply_filters from sdkit.filter import apply_filters
from sdkit.generate import generate_images
from sdkit.utils import ( from sdkit.utils import (
img_to_buffer,
img_to_base64_str,
latent_samples_to_images,
diffusers_latent_samples_to_images, diffusers_latent_samples_to_images,
gc, gc,
img_to_base64_str,
img_to_buffer,
latent_samples_to_images,
) )
context = Context() # thread-local context = Context() # thread-local
@ -43,14 +44,22 @@ def init(device):
def make_images( def make_images(
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback req: GenerateImageRequest,
task_data: TaskData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
): ):
context.stop_processing = False context.stop_processing = False
print_task_info(req, task_data) print_task_info(req, task_data)
images, seeds = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback) images, seeds = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
res = Response(req, task_data, images=construct_response(images, seeds, task_data, base_seed=req.seed)) res = Response(
req,
task_data,
images=construct_response(images, seeds, task_data, base_seed=req.seed),
)
res = res.json() res = res.json()
data_queue.put(json.dumps(res)) data_queue.put(json.dumps(res))
log.info("Task completed") log.info("Task completed")
@ -66,7 +75,11 @@ def print_task_info(req: GenerateImageRequest, task_data: TaskData):
def make_images_internal( def make_images_internal(
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback req: GenerateImageRequest,
task_data: TaskData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
): ):
images, user_stopped = generate_images_internal( images, user_stopped = generate_images_internal(
req, req,
@ -155,7 +168,12 @@ def filter_images(task_data: TaskData, images: list, user_stopped):
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int): def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
return [ return [
ResponseImage( ResponseImage(
data=img_to_base64_str(img, task_data.output_format, task_data.output_quality, task_data.output_lossless), data=img_to_base64_str(
img,
task_data.output_format,
task_data.output_quality,
task_data.output_lossless,
),
seed=seed, seed=seed,
) )
for img, seed in zip(images, seeds) for img, seed in zip(images, seeds)

View File

@ -2,28 +2,30 @@
Notes: Notes:
async endpoints always run on the main thread. Without they run on the thread pool. async endpoints always run on the main thread. Without they run on the thread pool.
""" """
import datetime
import mimetypes
import os import os
import traceback import traceback
import datetime
from typing import List, Union from typing import List, Union
from easydiffusion import app, model_manager, task_manager
from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData
from easydiffusion.utils import log
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel, Extra from pydantic import BaseModel, Extra
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from easydiffusion import app, model_manager, task_manager
from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest
from easydiffusion.utils import log
import mimetypes
log.info(f"started in {app.SD_DIR}") log.info(f"started in {app.SD_DIR}")
log.info(f"started at {datetime.datetime.now():%x %X}") log.info(f"started at {datetime.datetime.now():%x %X}")
server_api = FastAPI() server_api = FastAPI()
NOCACHE_HEADERS = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} NOCACHE_HEADERS = {
"Cache-Control": "no-cache, no-store, must-revalidate",
"Pragma": "no-cache",
"Expires": "0",
}
class NoCacheStaticFiles(StaticFiles): class NoCacheStaticFiles(StaticFiles):
@ -65,11 +67,17 @@ def init():
name="custom-thumbnails", name="custom-thumbnails",
) )
server_api.mount("/media", NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")), name="media") server_api.mount(
"/media",
NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")),
name="media",
)
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES: for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
server_api.mount( server_api.mount(
f"/plugins/{dir_prefix}", NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}" f"/plugins/{dir_prefix}",
NoCacheStaticFiles(directory=plugins_dir),
name=f"plugins-{dir_prefix}",
) )
@server_api.post("/app_config") @server_api.post("/app_config")
@ -246,8 +254,8 @@ def render_internal(req: dict):
def model_merge_internal(req: dict): def model_merge_internal(req: dict):
try: try:
from sdkit.train import merge_models
from easydiffusion.utils.save_utils import filename_regex from easydiffusion.utils.save_utils import filename_regex
from sdkit.train import merge_models
mergeReq: MergeRequest = MergeRequest.parse_obj(req) mergeReq: MergeRequest = MergeRequest.parse_obj(req)
@ -255,7 +263,11 @@ def model_merge_internal(req: dict):
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"), model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"), model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
mergeReq.ratio, mergeReq.ratio,
os.path.join(app.MODELS_DIR, "stable-diffusion", filename_regex.sub("_", mergeReq.out_path)), os.path.join(
app.MODELS_DIR,
"stable-diffusion",
filename_regex.sub("_", mergeReq.out_path),
),
mergeReq.use_fp16, mergeReq.use_fp16,
) )
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS) return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)

View File

@ -9,14 +9,16 @@ import traceback
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
import torch import queue
import queue, threading, time, weakref import threading
import time
import weakref
from typing import Any, Hashable from typing import Any, Hashable
import torch
from easydiffusion import device_manager from easydiffusion import device_manager
from easydiffusion.types import TaskData, GenerateImageRequest from easydiffusion.types import GenerateImageRequest, TaskData
from easydiffusion.utils import log from easydiffusion.utils import log
from sdkit.utils import gc from sdkit.utils import gc
THREAD_NAME_PREFIX = "" THREAD_NAME_PREFIX = ""
@ -167,7 +169,7 @@ class DataCache:
raise Exception("DataCache.put" + ERR_LOCK_FAILED) raise Exception("DataCache.put" + ERR_LOCK_FAILED)
try: try:
self._base[key] = (self._get_ttl_time(ttl), value) self._base[key] = (self._get_ttl_time(ttl), value)
except Exception as e: except Exception:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
return False return False
else: else:
@ -264,7 +266,7 @@ def thread_get_next_task():
def thread_render(device): def thread_render(device):
global current_state, current_state_error global current_state, current_state_error
from easydiffusion import renderer, model_manager from easydiffusion import model_manager, renderer
try: try:
renderer.init(device) renderer.init(device)
@ -337,7 +339,11 @@ def thread_render(device):
current_state = ServerStates.Rendering current_state = ServerStates.Rendering
task.response = renderer.make_images( task.response = renderer.make_images(
task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback task.render_request,
task.task_data,
task.buffer_queue,
task.temp_images,
step_callback,
) )
# Before looping back to the generator, mark cache as still alive. # Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL) task_cache.keep(id(task), TASK_TTL)
@ -392,8 +398,8 @@ def get_devices():
return {"name": device_manager.get_processor_name()} return {"name": device_manager.get_processor_name()}
mem_free, mem_total = torch.cuda.mem_get_info(device) mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_free /= float(10**9) mem_free /= float(10 ** 9)
mem_total /= float(10**9) mem_total /= float(10 ** 9)
return { return {
"name": torch.cuda.get_device_name(device), "name": torch.cuda.get_device_name(device),

View File

@ -1,6 +1,7 @@
from pydantic import BaseModel
from typing import Any from typing import Any
from pydantic import BaseModel
class GenerateImageRequest(BaseModel): class GenerateImageRequest(BaseModel):
prompt: str = "" prompt: str = ""

View File

@ -1,14 +1,13 @@
import os import os
import time
import re import re
import time
from datetime import datetime
from functools import reduce
from easydiffusion import app from easydiffusion import app
from easydiffusion.types import TaskData, GenerateImageRequest from easydiffusion.types import GenerateImageRequest, TaskData
from functools import reduce
from datetime import datetime
from sdkit.utils import save_images, save_dicts
from numpy import base_repr from numpy import base_repr
from sdkit.utils import save_dicts, save_images
filename_regex = re.compile("[^a-zA-Z0-9._-]") filename_regex = re.compile("[^a-zA-Z0-9._-]")
img_number_regex = re.compile("([0-9]{5,})") img_number_regex = re.compile("([0-9]{5,})")
@ -50,6 +49,7 @@ other_placeholders = {
"$s": lambda req, task_data: str(req.seed), "$s": lambda req, task_data: str(req.seed),
} }
class ImageNumber: class ImageNumber:
_factory = None _factory = None
_evaluated = False _evaluated = False
@ -57,12 +57,14 @@ class ImageNumber:
def __init__(self, factory): def __init__(self, factory):
self._factory = factory self._factory = factory
self._evaluated = None self._evaluated = None
def __call__(self) -> int: def __call__(self) -> int:
if self._evaluated is None: if self._evaluated is None:
self._evaluated = self._factory() self._evaluated = self._factory()
return self._evaluated return self._evaluated
def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskData, now = None):
def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskData, now=None):
if now is None: if now is None:
now = time.time() now = time.time()
@ -75,10 +77,12 @@ def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskD
return format return format
def format_folder_name(format: str, req: GenerateImageRequest, task_data: TaskData): def format_folder_name(format: str, req: GenerateImageRequest, task_data: TaskData):
format = format_placeholders(format, req, task_data) format = format_placeholders(format, req, task_data)
return filename_regex.sub("_", format) return filename_regex.sub("_", format)
def format_file_name( def format_file_name(
format: str, format: str,
req: GenerateImageRequest, req: GenerateImageRequest,
@ -88,19 +92,22 @@ def format_file_name(
folder_img_number: ImageNumber, folder_img_number: ImageNumber,
): ):
format = format_placeholders(format, req, task_data, now) format = format_placeholders(format, req, task_data, now)
if "$n" in format: if "$n" in format:
format = format.replace("$n", f"{folder_img_number():05}") format = format.replace("$n", f"{folder_img_number():05}")
if "$tsb64" in format: if "$tsb64" in format:
img_id = base_repr(int(now * 10000), 36)[-7:] + base_repr(int(batch_file_number), 36) # Base 36 conversion, 0-9, A-Z img_id = base_repr(int(now * 10000), 36)[-7:] + base_repr(
int(batch_file_number), 36
) # Base 36 conversion, 0-9, A-Z
format = format.replace("$tsb64", img_id) format = format.replace("$tsb64", img_id)
if "$ts" in format: if "$ts" in format:
format = format.replace("$ts", str(int(now * 1000) + batch_file_number)) format = format.replace("$ts", str(int(now * 1000) + batch_file_number))
return filename_regex.sub("_", format) return filename_regex.sub("_", format)
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
now = time.time() now = time.time()
app_config = app.getConfig() app_config = app.getConfig()
@ -126,7 +133,7 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
output_lossless=task_data.output_lossless, output_lossless=task_data.output_lossless,
) )
if task_data.metadata_output_format: if task_data.metadata_output_format:
for metadata_output_format in task_data.metadata_output_format.split(','): for metadata_output_format in task_data.metadata_output_format.split(","):
if metadata_output_format.lower() in ["json", "txt", "embed"]: if metadata_output_format.lower() in ["json", "txt", "embed"]:
save_dicts( save_dicts(
metadata_entries, metadata_entries,
@ -142,7 +149,8 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
task_data, task_data,
file_number, file_number,
now=now, now=now,
suffix="filtered") suffix="filtered",
)
save_images( save_images(
images, images,
@ -233,27 +241,28 @@ def make_filename_callback(
return make_filename return make_filename
def _calculate_img_number(save_dir_path: str, task_data: TaskData): def _calculate_img_number(save_dir_path: str, task_data: TaskData):
def get_highest_img_number(accumulator: int, file: os.DirEntry) -> int: def get_highest_img_number(accumulator: int, file: os.DirEntry) -> int:
if not file.is_file: if not file.is_file:
return accumulator return accumulator
if len(list(filter(lambda e: file.name.endswith(e), app.IMAGE_EXTENSIONS))) == 0: if len(list(filter(lambda e: file.name.endswith(e), app.IMAGE_EXTENSIONS))) == 0:
return accumulator return accumulator
get_highest_img_number.number_of_images = get_highest_img_number.number_of_images + 1 get_highest_img_number.number_of_images = get_highest_img_number.number_of_images + 1
number_match = img_number_regex.match(file.name) number_match = img_number_regex.match(file.name)
if not number_match: if not number_match:
return accumulator return accumulator
file_number = number_match.group().lstrip('0') file_number = number_match.group().lstrip("0")
# Handle 00000 # Handle 00000
return int(file_number) if file_number else 0 return int(file_number) if file_number else 0
get_highest_img_number.number_of_images = 0 get_highest_img_number.number_of_images = 0
highest_file_number = -1 highest_file_number = -1
if os.path.isdir(save_dir_path): if os.path.isdir(save_dir_path):
@ -267,13 +276,15 @@ def _calculate_img_number(save_dir_path: str, task_data: TaskData):
_calculate_img_number.session_img_numbers[task_data.session_id], _calculate_img_number.session_img_numbers[task_data.session_id],
calculated_img_number, calculated_img_number,
) )
calculated_img_number = calculated_img_number + 1 calculated_img_number = calculated_img_number + 1
_calculate_img_number.session_img_numbers[task_data.session_id] = calculated_img_number _calculate_img_number.session_img_numbers[task_data.session_id] = calculated_img_number
return calculated_img_number return calculated_img_number
_calculate_img_number.session_img_numbers = {} _calculate_img_number.session_img_numbers = {}
def calculate_img_number(save_dir_path: str, task_data: TaskData): def calculate_img_number(save_dir_path: str, task_data: TaskData):
return ImageNumber(lambda: _calculate_img_number(save_dir_path, task_data)) return ImageNumber(lambda: _calculate_img_number(save_dir_path, task_data))