mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-22 00:03:20 +01:00
Formatting
This commit is contained in:
parent
07f52c38ef
commit
d18cefc519
@ -1,17 +1,15 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import json
|
||||
import traceback
|
||||
import logging
|
||||
import shlex
|
||||
import urllib
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
|
||||
|
||||
from easydiffusion import task_manager
|
||||
from easydiffusion.utils import log
|
||||
from rich.logging import RichHandler
|
||||
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
|
||||
|
||||
# Remove all handlers associated with the root logger object.
|
||||
for handler in logging.root.handlers[:]:
|
||||
@ -55,10 +53,34 @@ APP_CONFIG_DEFAULTS = {
|
||||
},
|
||||
}
|
||||
|
||||
IMAGE_EXTENSIONS = [".png", ".apng", ".jpg", ".jpeg", ".jfif", ".pjpeg", ".pjp", ".jxl", ".gif", ".webp", ".avif", ".svg"]
|
||||
IMAGE_EXTENSIONS = [
|
||||
".png",
|
||||
".apng",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".jfif",
|
||||
".pjpeg",
|
||||
".pjp",
|
||||
".jxl",
|
||||
".gif",
|
||||
".webp",
|
||||
".avif",
|
||||
".svg",
|
||||
]
|
||||
CUSTOM_MODIFIERS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "modifiers"))
|
||||
CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS=[".portrait", "_portrait", " portrait", "-portrait"]
|
||||
CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS=[".landscape", "_landscape", " landscape", "-landscape"]
|
||||
CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS = [
|
||||
".portrait",
|
||||
"_portrait",
|
||||
" portrait",
|
||||
"-portrait",
|
||||
]
|
||||
CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS = [
|
||||
".landscape",
|
||||
"_landscape",
|
||||
" landscape",
|
||||
"-landscape",
|
||||
]
|
||||
|
||||
|
||||
def init():
|
||||
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
||||
@ -84,7 +106,7 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS):
|
||||
if os.getenv("SD_UI_BIND_IP") is not None:
|
||||
config["net"]["listen_to_network"] = os.getenv("SD_UI_BIND_IP") == "0.0.0.0"
|
||||
return config
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
log.warn(traceback.format_exc())
|
||||
return default_val
|
||||
|
||||
@ -97,6 +119,7 @@ def setConfig(config):
|
||||
except:
|
||||
log.error(traceback.format_exc())
|
||||
|
||||
|
||||
def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level):
|
||||
config = getConfig()
|
||||
if "model" not in config:
|
||||
@ -191,11 +214,12 @@ def open_browser():
|
||||
|
||||
webbrowser.open(f"http://localhost:{port}")
|
||||
|
||||
|
||||
def get_image_modifiers():
|
||||
modifiers_json_path = os.path.join(SD_UI_DIR, "modifiers.json")
|
||||
|
||||
modifier_categories = {}
|
||||
original_category_order=[]
|
||||
original_category_order = []
|
||||
with open(modifiers_json_path, "r", encoding="utf-8") as f:
|
||||
modifiers_file = json.load(f)
|
||||
|
||||
@ -205,14 +229,14 @@ def get_image_modifiers():
|
||||
|
||||
# convert modifiers from a list of objects to a dict of dicts
|
||||
for category_item in modifiers_file:
|
||||
category_name = category_item['category']
|
||||
category_name = category_item["category"]
|
||||
original_category_order.append(category_name)
|
||||
category = {}
|
||||
for modifier_item in category_item['modifiers']:
|
||||
for modifier_item in category_item["modifiers"]:
|
||||
modifier = {}
|
||||
for preview_item in modifier_item['previews']:
|
||||
modifier[preview_item['name']] = preview_item['path']
|
||||
category[modifier_item['modifier']] = modifier
|
||||
for preview_item in modifier_item["previews"]:
|
||||
modifier[preview_item["name"]] = preview_item["path"]
|
||||
category[modifier_item["modifier"]] = modifier
|
||||
modifier_categories[category_name] = category
|
||||
|
||||
def scan_directory(directory_path: str, category_name="Modifiers"):
|
||||
@ -225,12 +249,27 @@ def get_image_modifiers():
|
||||
modifier_name = entry.name[: -len(file_extension[0])]
|
||||
modifier_path = f"custom/{entry.path[len(CUSTOM_MODIFIERS_DIR) + 1:]}"
|
||||
# URL encode path segments
|
||||
modifier_path = "/".join(map(lambda segment: urllib.parse.quote(segment), modifier_path.split("/")))
|
||||
modifier_path = "/".join(
|
||||
map(
|
||||
lambda segment: urllib.parse.quote(segment),
|
||||
modifier_path.split("/"),
|
||||
)
|
||||
)
|
||||
is_portrait = True
|
||||
is_landscape = True
|
||||
|
||||
portrait_extension = list(filter(lambda e: modifier_name.lower().endswith(e), CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS))
|
||||
landscape_extension = list(filter(lambda e: modifier_name.lower().endswith(e), CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS))
|
||||
portrait_extension = list(
|
||||
filter(
|
||||
lambda e: modifier_name.lower().endswith(e),
|
||||
CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS,
|
||||
)
|
||||
)
|
||||
landscape_extension = list(
|
||||
filter(
|
||||
lambda e: modifier_name.lower().endswith(e),
|
||||
CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS,
|
||||
)
|
||||
)
|
||||
|
||||
if len(portrait_extension) > 0:
|
||||
is_landscape = False
|
||||
@ -238,24 +277,24 @@ def get_image_modifiers():
|
||||
elif len(landscape_extension) > 0:
|
||||
is_portrait = False
|
||||
modifier_name = modifier_name[: -len(landscape_extension[0])]
|
||||
|
||||
if (category_name not in modifier_categories):
|
||||
|
||||
if category_name not in modifier_categories:
|
||||
modifier_categories[category_name] = {}
|
||||
|
||||
|
||||
category = modifier_categories[category_name]
|
||||
|
||||
if (modifier_name not in category):
|
||||
if modifier_name not in category:
|
||||
category[modifier_name] = {}
|
||||
|
||||
if (is_portrait or "portrait" not in category[modifier_name]):
|
||||
if is_portrait or "portrait" not in category[modifier_name]:
|
||||
category[modifier_name]["portrait"] = modifier_path
|
||||
|
||||
if (is_landscape or "landscape" not in category[modifier_name]):
|
||||
|
||||
if is_landscape or "landscape" not in category[modifier_name]:
|
||||
category[modifier_name]["landscape"] = modifier_path
|
||||
elif entry.is_dir():
|
||||
scan_directory(
|
||||
entry.path,
|
||||
entry.name if directory_path==CUSTOM_MODIFIERS_DIR else f"{category_name}/{entry.name}",
|
||||
entry.name if directory_path == CUSTOM_MODIFIERS_DIR else f"{category_name}/{entry.name}",
|
||||
)
|
||||
|
||||
scan_directory(CUSTOM_MODIFIERS_DIR)
|
||||
@ -268,12 +307,12 @@ def get_image_modifiers():
|
||||
# convert the modifiers back into a list of objects
|
||||
modifier_categories_list = []
|
||||
for category_name in [*original_category_order, *custom_categories]:
|
||||
category = { 'category': category_name, 'modifiers': [] }
|
||||
category = {"category": category_name, "modifiers": []}
|
||||
for modifier_name in sorted(modifier_categories[category_name].keys(), key=str.casefold):
|
||||
modifier = { 'modifier': modifier_name, 'previews': [] }
|
||||
modifier = {"modifier": modifier_name, "previews": []}
|
||||
for preview_name, preview_path in modifier_categories[category_name][modifier_name].items():
|
||||
modifier['previews'].append({ 'name': preview_name, 'path': preview_path })
|
||||
category['modifiers'].append(modifier)
|
||||
modifier["previews"].append({"name": preview_name, "path": preview_path})
|
||||
category["modifiers"].append(modifier)
|
||||
modifier_categories_list.append(category)
|
||||
|
||||
return modifier_categories_list
|
||||
|
@ -1,9 +1,9 @@
|
||||
import os
|
||||
import platform
|
||||
import torch
|
||||
import traceback
|
||||
import re
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from easydiffusion.utils import log
|
||||
|
||||
"""
|
||||
@ -98,8 +98,8 @@ def auto_pick_devices(currently_active_devices):
|
||||
continue
|
||||
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_free /= float(10**9)
|
||||
mem_total /= float(10**9)
|
||||
mem_free /= float(10 ** 9)
|
||||
mem_total /= float(10 ** 9)
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
log.debug(
|
||||
f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
|
||||
@ -118,7 +118,10 @@ def auto_pick_devices(currently_active_devices):
|
||||
# These already-running devices probably aren't terrible, since they were picked in the past.
|
||||
# Worst case, the user can restart the program and that'll get rid of them.
|
||||
devices = list(
|
||||
filter((lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices), devices)
|
||||
filter(
|
||||
(lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices),
|
||||
devices,
|
||||
)
|
||||
)
|
||||
devices = list(map(lambda x: x["device"], devices))
|
||||
return devices
|
||||
@ -178,7 +181,7 @@ def get_max_vram_usage_level(device):
|
||||
else:
|
||||
return "high"
|
||||
|
||||
mem_total /= float(10**9)
|
||||
mem_total /= float(10 ** 9)
|
||||
if mem_total < 4.5:
|
||||
return "low"
|
||||
elif mem_total < 6.5:
|
||||
@ -220,7 +223,7 @@ def is_device_compatible(device):
|
||||
# Memory check
|
||||
try:
|
||||
_, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_total /= float(10**9)
|
||||
mem_total /= float(10 ** 9)
|
||||
if mem_total < 3.0:
|
||||
if is_device_compatible.history.get(device) == None:
|
||||
log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion")
|
||||
|
@ -3,11 +3,17 @@ import os
|
||||
from easydiffusion import app
|
||||
from easydiffusion.types import TaskData
|
||||
from easydiffusion.utils import log
|
||||
|
||||
from sdkit import Context
|
||||
from sdkit.models import load_model, unload_model, scan_model
|
||||
from sdkit.models import load_model, scan_model, unload_model
|
||||
|
||||
KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan", "lora"]
|
||||
KNOWN_MODEL_TYPES = [
|
||||
"stable-diffusion",
|
||||
"vae",
|
||||
"hypernetwork",
|
||||
"gfpgan",
|
||||
"realesrgan",
|
||||
"lora",
|
||||
]
|
||||
MODEL_EXTENSIONS = {
|
||||
"stable-diffusion": [".ckpt", ".safetensors"],
|
||||
"vae": [".vae.pt", ".ckpt", ".safetensors"],
|
||||
@ -44,13 +50,15 @@ def load_default_models(context: Context):
|
||||
load_model(
|
||||
context,
|
||||
model_type,
|
||||
scan_model = context.model_paths[model_type] != None and not context.model_paths[model_type].endswith('.safetensors')
|
||||
scan_model=context.model_paths[model_type] != None
|
||||
and not context.model_paths[model_type].endswith(".safetensors"),
|
||||
)
|
||||
except Exception as e:
|
||||
log.error(f"[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]")
|
||||
log.exception(e)
|
||||
del context.model_paths[model_type]
|
||||
|
||||
|
||||
def unload_all(context: Context):
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
unload_model(context, model_type)
|
||||
@ -170,13 +178,23 @@ def is_malicious_model(file_path):
|
||||
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
|
||||
log.warn(
|
||||
":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]"
|
||||
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
|
||||
% (
|
||||
file_path,
|
||||
scan_result.scanned_files,
|
||||
scan_result.issues_count,
|
||||
scan_result.infected_files,
|
||||
)
|
||||
)
|
||||
return True
|
||||
else:
|
||||
log.debug(
|
||||
"Scan %s: [green]%d scanned, %d issue, %d infected.[/green]"
|
||||
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
|
||||
% (
|
||||
file_path,
|
||||
scan_result.scanned_files,
|
||||
scan_result.issues_count,
|
||||
scan_result.infected_files,
|
||||
)
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
@ -204,13 +222,13 @@ def getModels():
|
||||
|
||||
class MaliciousModelException(Exception):
|
||||
"Raised when picklescan reports a problem with a model"
|
||||
pass
|
||||
|
||||
def scan_directory(directory, suffixes, directoriesFirst: bool = True):
|
||||
nonlocal models_scanned
|
||||
tree = []
|
||||
for entry in sorted(
|
||||
os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower())
|
||||
os.scandir(directory),
|
||||
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
|
||||
):
|
||||
if entry.is_file():
|
||||
matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes))
|
||||
|
@ -1,21 +1,22 @@
|
||||
import queue
|
||||
import time
|
||||
import json
|
||||
import pprint
|
||||
import queue
|
||||
import time
|
||||
|
||||
from easydiffusion import device_manager
|
||||
from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop, GenerateImageRequest
|
||||
from easydiffusion.utils import get_printable_request, save_images_to_disk, log
|
||||
|
||||
from easydiffusion.types import GenerateImageRequest
|
||||
from easydiffusion.types import Image as ResponseImage
|
||||
from easydiffusion.types import Response, TaskData, UserInitiatedStop
|
||||
from easydiffusion.utils import get_printable_request, log, save_images_to_disk
|
||||
from sdkit import Context
|
||||
from sdkit.generate import generate_images
|
||||
from sdkit.filter import apply_filters
|
||||
from sdkit.generate import generate_images
|
||||
from sdkit.utils import (
|
||||
img_to_buffer,
|
||||
img_to_base64_str,
|
||||
latent_samples_to_images,
|
||||
diffusers_latent_samples_to_images,
|
||||
gc,
|
||||
img_to_base64_str,
|
||||
img_to_buffer,
|
||||
latent_samples_to_images,
|
||||
)
|
||||
|
||||
context = Context() # thread-local
|
||||
@ -43,14 +44,22 @@ def init(device):
|
||||
|
||||
|
||||
def make_images(
|
||||
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
):
|
||||
context.stop_processing = False
|
||||
print_task_info(req, task_data)
|
||||
|
||||
images, seeds = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||
|
||||
res = Response(req, task_data, images=construct_response(images, seeds, task_data, base_seed=req.seed))
|
||||
res = Response(
|
||||
req,
|
||||
task_data,
|
||||
images=construct_response(images, seeds, task_data, base_seed=req.seed),
|
||||
)
|
||||
res = res.json()
|
||||
data_queue.put(json.dumps(res))
|
||||
log.info("Task completed")
|
||||
@ -66,7 +75,11 @@ def print_task_info(req: GenerateImageRequest, task_data: TaskData):
|
||||
|
||||
|
||||
def make_images_internal(
|
||||
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
):
|
||||
images, user_stopped = generate_images_internal(
|
||||
req,
|
||||
@ -155,7 +168,12 @@ def filter_images(task_data: TaskData, images: list, user_stopped):
|
||||
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
|
||||
return [
|
||||
ResponseImage(
|
||||
data=img_to_base64_str(img, task_data.output_format, task_data.output_quality, task_data.output_lossless),
|
||||
data=img_to_base64_str(
|
||||
img,
|
||||
task_data.output_format,
|
||||
task_data.output_quality,
|
||||
task_data.output_lossless,
|
||||
),
|
||||
seed=seed,
|
||||
)
|
||||
for img, seed in zip(images, seeds)
|
||||
|
@ -2,28 +2,30 @@
|
||||
Notes:
|
||||
async endpoints always run on the main thread. Without they run on the thread pool.
|
||||
"""
|
||||
import datetime
|
||||
import mimetypes
|
||||
import os
|
||||
import traceback
|
||||
import datetime
|
||||
from typing import List, Union
|
||||
|
||||
from easydiffusion import app, model_manager, task_manager
|
||||
from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData
|
||||
from easydiffusion.utils import log
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from easydiffusion import app, model_manager, task_manager
|
||||
from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest
|
||||
from easydiffusion.utils import log
|
||||
|
||||
import mimetypes
|
||||
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
|
||||
log.info(f"started in {app.SD_DIR}")
|
||||
log.info(f"started at {datetime.datetime.now():%x %X}")
|
||||
|
||||
server_api = FastAPI()
|
||||
|
||||
NOCACHE_HEADERS = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||
NOCACHE_HEADERS = {
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||
"Pragma": "no-cache",
|
||||
"Expires": "0",
|
||||
}
|
||||
|
||||
|
||||
class NoCacheStaticFiles(StaticFiles):
|
||||
@ -65,11 +67,17 @@ def init():
|
||||
name="custom-thumbnails",
|
||||
)
|
||||
|
||||
server_api.mount("/media", NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")), name="media")
|
||||
server_api.mount(
|
||||
"/media",
|
||||
NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")),
|
||||
name="media",
|
||||
)
|
||||
|
||||
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
|
||||
server_api.mount(
|
||||
f"/plugins/{dir_prefix}", NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}"
|
||||
f"/plugins/{dir_prefix}",
|
||||
NoCacheStaticFiles(directory=plugins_dir),
|
||||
name=f"plugins-{dir_prefix}",
|
||||
)
|
||||
|
||||
@server_api.post("/app_config")
|
||||
@ -246,8 +254,8 @@ def render_internal(req: dict):
|
||||
|
||||
def model_merge_internal(req: dict):
|
||||
try:
|
||||
from sdkit.train import merge_models
|
||||
from easydiffusion.utils.save_utils import filename_regex
|
||||
from sdkit.train import merge_models
|
||||
|
||||
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
||||
|
||||
@ -255,7 +263,11 @@ def model_merge_internal(req: dict):
|
||||
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
|
||||
model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
|
||||
mergeReq.ratio,
|
||||
os.path.join(app.MODELS_DIR, "stable-diffusion", filename_regex.sub("_", mergeReq.out_path)),
|
||||
os.path.join(
|
||||
app.MODELS_DIR,
|
||||
"stable-diffusion",
|
||||
filename_regex.sub("_", mergeReq.out_path),
|
||||
),
|
||||
mergeReq.use_fp16,
|
||||
)
|
||||
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
||||
|
@ -9,14 +9,16 @@ import traceback
|
||||
|
||||
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
||||
|
||||
import torch
|
||||
import queue, threading, time, weakref
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import weakref
|
||||
from typing import Any, Hashable
|
||||
|
||||
import torch
|
||||
from easydiffusion import device_manager
|
||||
from easydiffusion.types import TaskData, GenerateImageRequest
|
||||
from easydiffusion.types import GenerateImageRequest, TaskData
|
||||
from easydiffusion.utils import log
|
||||
|
||||
from sdkit.utils import gc
|
||||
|
||||
THREAD_NAME_PREFIX = ""
|
||||
@ -167,7 +169,7 @@ class DataCache:
|
||||
raise Exception("DataCache.put" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
self._base[key] = (self._get_ttl_time(ttl), value)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
log.error(traceback.format_exc())
|
||||
return False
|
||||
else:
|
||||
@ -264,7 +266,7 @@ def thread_get_next_task():
|
||||
def thread_render(device):
|
||||
global current_state, current_state_error
|
||||
|
||||
from easydiffusion import renderer, model_manager
|
||||
from easydiffusion import model_manager, renderer
|
||||
|
||||
try:
|
||||
renderer.init(device)
|
||||
@ -337,7 +339,11 @@ def thread_render(device):
|
||||
|
||||
current_state = ServerStates.Rendering
|
||||
task.response = renderer.make_images(
|
||||
task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback
|
||||
task.render_request,
|
||||
task.task_data,
|
||||
task.buffer_queue,
|
||||
task.temp_images,
|
||||
step_callback,
|
||||
)
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
@ -392,8 +398,8 @@ def get_devices():
|
||||
return {"name": device_manager.get_processor_name()}
|
||||
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_free /= float(10**9)
|
||||
mem_total /= float(10**9)
|
||||
mem_free /= float(10 ** 9)
|
||||
mem_total /= float(10 ** 9)
|
||||
|
||||
return {
|
||||
"name": torch.cuda.get_device_name(device),
|
||||
|
@ -1,6 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GenerateImageRequest(BaseModel):
|
||||
prompt: str = ""
|
||||
|
@ -1,14 +1,13 @@
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from functools import reduce
|
||||
|
||||
from easydiffusion import app
|
||||
from easydiffusion.types import TaskData, GenerateImageRequest
|
||||
from functools import reduce
|
||||
from datetime import datetime
|
||||
|
||||
from sdkit.utils import save_images, save_dicts
|
||||
from easydiffusion.types import GenerateImageRequest, TaskData
|
||||
from numpy import base_repr
|
||||
from sdkit.utils import save_dicts, save_images
|
||||
|
||||
filename_regex = re.compile("[^a-zA-Z0-9._-]")
|
||||
img_number_regex = re.compile("([0-9]{5,})")
|
||||
@ -50,6 +49,7 @@ other_placeholders = {
|
||||
"$s": lambda req, task_data: str(req.seed),
|
||||
}
|
||||
|
||||
|
||||
class ImageNumber:
|
||||
_factory = None
|
||||
_evaluated = False
|
||||
@ -57,12 +57,14 @@ class ImageNumber:
|
||||
def __init__(self, factory):
|
||||
self._factory = factory
|
||||
self._evaluated = None
|
||||
|
||||
def __call__(self) -> int:
|
||||
if self._evaluated is None:
|
||||
self._evaluated = self._factory()
|
||||
return self._evaluated
|
||||
|
||||
def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskData, now = None):
|
||||
|
||||
def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskData, now=None):
|
||||
if now is None:
|
||||
now = time.time()
|
||||
|
||||
@ -75,10 +77,12 @@ def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskD
|
||||
|
||||
return format
|
||||
|
||||
|
||||
def format_folder_name(format: str, req: GenerateImageRequest, task_data: TaskData):
|
||||
format = format_placeholders(format, req, task_data)
|
||||
return filename_regex.sub("_", format)
|
||||
|
||||
|
||||
def format_file_name(
|
||||
format: str,
|
||||
req: GenerateImageRequest,
|
||||
@ -88,19 +92,22 @@ def format_file_name(
|
||||
folder_img_number: ImageNumber,
|
||||
):
|
||||
format = format_placeholders(format, req, task_data, now)
|
||||
|
||||
|
||||
if "$n" in format:
|
||||
format = format.replace("$n", f"{folder_img_number():05}")
|
||||
|
||||
|
||||
if "$tsb64" in format:
|
||||
img_id = base_repr(int(now * 10000), 36)[-7:] + base_repr(int(batch_file_number), 36) # Base 36 conversion, 0-9, A-Z
|
||||
img_id = base_repr(int(now * 10000), 36)[-7:] + base_repr(
|
||||
int(batch_file_number), 36
|
||||
) # Base 36 conversion, 0-9, A-Z
|
||||
format = format.replace("$tsb64", img_id)
|
||||
|
||||
|
||||
if "$ts" in format:
|
||||
format = format.replace("$ts", str(int(now * 1000) + batch_file_number))
|
||||
|
||||
return filename_regex.sub("_", format)
|
||||
|
||||
|
||||
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
|
||||
now = time.time()
|
||||
app_config = app.getConfig()
|
||||
@ -126,7 +133,7 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
|
||||
output_lossless=task_data.output_lossless,
|
||||
)
|
||||
if task_data.metadata_output_format:
|
||||
for metadata_output_format in task_data.metadata_output_format.split(','):
|
||||
for metadata_output_format in task_data.metadata_output_format.split(","):
|
||||
if metadata_output_format.lower() in ["json", "txt", "embed"]:
|
||||
save_dicts(
|
||||
metadata_entries,
|
||||
@ -142,7 +149,8 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
|
||||
task_data,
|
||||
file_number,
|
||||
now=now,
|
||||
suffix="filtered")
|
||||
suffix="filtered",
|
||||
)
|
||||
|
||||
save_images(
|
||||
images,
|
||||
@ -233,27 +241,28 @@ def make_filename_callback(
|
||||
|
||||
return make_filename
|
||||
|
||||
|
||||
def _calculate_img_number(save_dir_path: str, task_data: TaskData):
|
||||
def get_highest_img_number(accumulator: int, file: os.DirEntry) -> int:
|
||||
if not file.is_file:
|
||||
return accumulator
|
||||
|
||||
|
||||
if len(list(filter(lambda e: file.name.endswith(e), app.IMAGE_EXTENSIONS))) == 0:
|
||||
return accumulator
|
||||
|
||||
|
||||
get_highest_img_number.number_of_images = get_highest_img_number.number_of_images + 1
|
||||
|
||||
|
||||
number_match = img_number_regex.match(file.name)
|
||||
if not number_match:
|
||||
return accumulator
|
||||
|
||||
file_number = number_match.group().lstrip('0')
|
||||
|
||||
|
||||
file_number = number_match.group().lstrip("0")
|
||||
|
||||
# Handle 00000
|
||||
return int(file_number) if file_number else 0
|
||||
|
||||
|
||||
get_highest_img_number.number_of_images = 0
|
||||
|
||||
|
||||
highest_file_number = -1
|
||||
|
||||
if os.path.isdir(save_dir_path):
|
||||
@ -267,13 +276,15 @@ def _calculate_img_number(save_dir_path: str, task_data: TaskData):
|
||||
_calculate_img_number.session_img_numbers[task_data.session_id],
|
||||
calculated_img_number,
|
||||
)
|
||||
|
||||
|
||||
calculated_img_number = calculated_img_number + 1
|
||||
|
||||
|
||||
_calculate_img_number.session_img_numbers[task_data.session_id] = calculated_img_number
|
||||
return calculated_img_number
|
||||
|
||||
|
||||
_calculate_img_number.session_img_numbers = {}
|
||||
|
||||
|
||||
def calculate_img_number(save_dir_path: str, task_data: TaskData):
|
||||
return ImageNumber(lambda: _calculate_img_number(save_dir_path, task_data))
|
||||
|
Loading…
Reference in New Issue
Block a user