Formatting

This commit is contained in:
cmdr2 2023-04-28 16:38:55 +05:30
parent 07f52c38ef
commit d18cefc519
8 changed files with 213 additions and 105 deletions

View File

@ -1,17 +1,15 @@
import json
import logging
import os
import socket
import sys
import json
import traceback
import logging
import shlex
import urllib
from rich.logging import RichHandler
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
from easydiffusion import task_manager
from easydiffusion.utils import log
from rich.logging import RichHandler
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
# Remove all handlers associated with the root logger object.
for handler in logging.root.handlers[:]:
@ -55,10 +53,34 @@ APP_CONFIG_DEFAULTS = {
},
}
IMAGE_EXTENSIONS = [".png", ".apng", ".jpg", ".jpeg", ".jfif", ".pjpeg", ".pjp", ".jxl", ".gif", ".webp", ".avif", ".svg"]
IMAGE_EXTENSIONS = [
".png",
".apng",
".jpg",
".jpeg",
".jfif",
".pjpeg",
".pjp",
".jxl",
".gif",
".webp",
".avif",
".svg",
]
CUSTOM_MODIFIERS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "modifiers"))
CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS=[".portrait", "_portrait", " portrait", "-portrait"]
CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS=[".landscape", "_landscape", " landscape", "-landscape"]
CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS = [
".portrait",
"_portrait",
" portrait",
"-portrait",
]
CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS = [
".landscape",
"_landscape",
" landscape",
"-landscape",
]
def init():
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
@ -84,7 +106,7 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS):
if os.getenv("SD_UI_BIND_IP") is not None:
config["net"]["listen_to_network"] = os.getenv("SD_UI_BIND_IP") == "0.0.0.0"
return config
except Exception as e:
except Exception:
log.warn(traceback.format_exc())
return default_val
@ -97,6 +119,7 @@ def setConfig(config):
except:
log.error(traceback.format_exc())
def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level):
config = getConfig()
if "model" not in config:
@ -191,11 +214,12 @@ def open_browser():
webbrowser.open(f"http://localhost:{port}")
def get_image_modifiers():
modifiers_json_path = os.path.join(SD_UI_DIR, "modifiers.json")
modifier_categories = {}
original_category_order=[]
original_category_order = []
with open(modifiers_json_path, "r", encoding="utf-8") as f:
modifiers_file = json.load(f)
@ -205,14 +229,14 @@ def get_image_modifiers():
# convert modifiers from a list of objects to a dict of dicts
for category_item in modifiers_file:
category_name = category_item['category']
category_name = category_item["category"]
original_category_order.append(category_name)
category = {}
for modifier_item in category_item['modifiers']:
for modifier_item in category_item["modifiers"]:
modifier = {}
for preview_item in modifier_item['previews']:
modifier[preview_item['name']] = preview_item['path']
category[modifier_item['modifier']] = modifier
for preview_item in modifier_item["previews"]:
modifier[preview_item["name"]] = preview_item["path"]
category[modifier_item["modifier"]] = modifier
modifier_categories[category_name] = category
def scan_directory(directory_path: str, category_name="Modifiers"):
@ -225,12 +249,27 @@ def get_image_modifiers():
modifier_name = entry.name[: -len(file_extension[0])]
modifier_path = f"custom/{entry.path[len(CUSTOM_MODIFIERS_DIR) + 1:]}"
# URL encode path segments
modifier_path = "/".join(map(lambda segment: urllib.parse.quote(segment), modifier_path.split("/")))
modifier_path = "/".join(
map(
lambda segment: urllib.parse.quote(segment),
modifier_path.split("/"),
)
)
is_portrait = True
is_landscape = True
portrait_extension = list(filter(lambda e: modifier_name.lower().endswith(e), CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS))
landscape_extension = list(filter(lambda e: modifier_name.lower().endswith(e), CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS))
portrait_extension = list(
filter(
lambda e: modifier_name.lower().endswith(e),
CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS,
)
)
landscape_extension = list(
filter(
lambda e: modifier_name.lower().endswith(e),
CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS,
)
)
if len(portrait_extension) > 0:
is_landscape = False
@ -238,24 +277,24 @@ def get_image_modifiers():
elif len(landscape_extension) > 0:
is_portrait = False
modifier_name = modifier_name[: -len(landscape_extension[0])]
if (category_name not in modifier_categories):
if category_name not in modifier_categories:
modifier_categories[category_name] = {}
category = modifier_categories[category_name]
if (modifier_name not in category):
if modifier_name not in category:
category[modifier_name] = {}
if (is_portrait or "portrait" not in category[modifier_name]):
if is_portrait or "portrait" not in category[modifier_name]:
category[modifier_name]["portrait"] = modifier_path
if (is_landscape or "landscape" not in category[modifier_name]):
if is_landscape or "landscape" not in category[modifier_name]:
category[modifier_name]["landscape"] = modifier_path
elif entry.is_dir():
scan_directory(
entry.path,
entry.name if directory_path==CUSTOM_MODIFIERS_DIR else f"{category_name}/{entry.name}",
entry.name if directory_path == CUSTOM_MODIFIERS_DIR else f"{category_name}/{entry.name}",
)
scan_directory(CUSTOM_MODIFIERS_DIR)
@ -268,12 +307,12 @@ def get_image_modifiers():
# convert the modifiers back into a list of objects
modifier_categories_list = []
for category_name in [*original_category_order, *custom_categories]:
category = { 'category': category_name, 'modifiers': [] }
category = {"category": category_name, "modifiers": []}
for modifier_name in sorted(modifier_categories[category_name].keys(), key=str.casefold):
modifier = { 'modifier': modifier_name, 'previews': [] }
modifier = {"modifier": modifier_name, "previews": []}
for preview_name, preview_path in modifier_categories[category_name][modifier_name].items():
modifier['previews'].append({ 'name': preview_name, 'path': preview_path })
category['modifiers'].append(modifier)
modifier["previews"].append({"name": preview_name, "path": preview_path})
category["modifiers"].append(modifier)
modifier_categories_list.append(category)
return modifier_categories_list

View File

@ -1,9 +1,9 @@
import os
import platform
import torch
import traceback
import re
import traceback
import torch
from easydiffusion.utils import log
"""
@ -98,8 +98,8 @@ def auto_pick_devices(currently_active_devices):
continue
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_free /= float(10**9)
mem_total /= float(10**9)
mem_free /= float(10 ** 9)
mem_total /= float(10 ** 9)
device_name = torch.cuda.get_device_name(device)
log.debug(
f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
@ -118,7 +118,10 @@ def auto_pick_devices(currently_active_devices):
# These already-running devices probably aren't terrible, since they were picked in the past.
# Worst case, the user can restart the program and that'll get rid of them.
devices = list(
filter((lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices), devices)
filter(
(lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices),
devices,
)
)
devices = list(map(lambda x: x["device"], devices))
return devices
@ -178,7 +181,7 @@ def get_max_vram_usage_level(device):
else:
return "high"
mem_total /= float(10**9)
mem_total /= float(10 ** 9)
if mem_total < 4.5:
return "low"
elif mem_total < 6.5:
@ -220,7 +223,7 @@ def is_device_compatible(device):
# Memory check
try:
_, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9)
mem_total /= float(10 ** 9)
if mem_total < 3.0:
if is_device_compatible.history.get(device) == None:
log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion")

View File

@ -3,11 +3,17 @@ import os
from easydiffusion import app
from easydiffusion.types import TaskData
from easydiffusion.utils import log
from sdkit import Context
from sdkit.models import load_model, unload_model, scan_model
from sdkit.models import load_model, scan_model, unload_model
KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan", "lora"]
KNOWN_MODEL_TYPES = [
"stable-diffusion",
"vae",
"hypernetwork",
"gfpgan",
"realesrgan",
"lora",
]
MODEL_EXTENSIONS = {
"stable-diffusion": [".ckpt", ".safetensors"],
"vae": [".vae.pt", ".ckpt", ".safetensors"],
@ -44,13 +50,15 @@ def load_default_models(context: Context):
load_model(
context,
model_type,
scan_model = context.model_paths[model_type] != None and not context.model_paths[model_type].endswith('.safetensors')
scan_model=context.model_paths[model_type] != None
and not context.model_paths[model_type].endswith(".safetensors"),
)
except Exception as e:
log.error(f"[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]")
log.exception(e)
del context.model_paths[model_type]
def unload_all(context: Context):
for model_type in KNOWN_MODEL_TYPES:
unload_model(context, model_type)
@ -170,13 +178,23 @@ def is_malicious_model(file_path):
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
log.warn(
":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]"
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
% (
file_path,
scan_result.scanned_files,
scan_result.issues_count,
scan_result.infected_files,
)
)
return True
else:
log.debug(
"Scan %s: [green]%d scanned, %d issue, %d infected.[/green]"
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
% (
file_path,
scan_result.scanned_files,
scan_result.issues_count,
scan_result.infected_files,
)
)
return False
except Exception as e:
@ -204,13 +222,13 @@ def getModels():
class MaliciousModelException(Exception):
"Raised when picklescan reports a problem with a model"
pass
def scan_directory(directory, suffixes, directoriesFirst: bool = True):
nonlocal models_scanned
tree = []
for entry in sorted(
os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower())
os.scandir(directory),
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
):
if entry.is_file():
matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes))

View File

@ -1,21 +1,22 @@
import queue
import time
import json
import pprint
import queue
import time
from easydiffusion import device_manager
from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop, GenerateImageRequest
from easydiffusion.utils import get_printable_request, save_images_to_disk, log
from easydiffusion.types import GenerateImageRequest
from easydiffusion.types import Image as ResponseImage
from easydiffusion.types import Response, TaskData, UserInitiatedStop
from easydiffusion.utils import get_printable_request, log, save_images_to_disk
from sdkit import Context
from sdkit.generate import generate_images
from sdkit.filter import apply_filters
from sdkit.generate import generate_images
from sdkit.utils import (
img_to_buffer,
img_to_base64_str,
latent_samples_to_images,
diffusers_latent_samples_to_images,
gc,
img_to_base64_str,
img_to_buffer,
latent_samples_to_images,
)
context = Context() # thread-local
@ -43,14 +44,22 @@ def init(device):
def make_images(
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
req: GenerateImageRequest,
task_data: TaskData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
):
context.stop_processing = False
print_task_info(req, task_data)
images, seeds = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
res = Response(req, task_data, images=construct_response(images, seeds, task_data, base_seed=req.seed))
res = Response(
req,
task_data,
images=construct_response(images, seeds, task_data, base_seed=req.seed),
)
res = res.json()
data_queue.put(json.dumps(res))
log.info("Task completed")
@ -66,7 +75,11 @@ def print_task_info(req: GenerateImageRequest, task_data: TaskData):
def make_images_internal(
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
req: GenerateImageRequest,
task_data: TaskData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
):
images, user_stopped = generate_images_internal(
req,
@ -155,7 +168,12 @@ def filter_images(task_data: TaskData, images: list, user_stopped):
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
return [
ResponseImage(
data=img_to_base64_str(img, task_data.output_format, task_data.output_quality, task_data.output_lossless),
data=img_to_base64_str(
img,
task_data.output_format,
task_data.output_quality,
task_data.output_lossless,
),
seed=seed,
)
for img, seed in zip(images, seeds)

View File

@ -2,28 +2,30 @@
Notes:
async endpoints always run on the main thread. Without they run on the thread pool.
"""
import datetime
import mimetypes
import os
import traceback
import datetime
from typing import List, Union
from easydiffusion import app, model_manager, task_manager
from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData
from easydiffusion.utils import log
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel, Extra
from easydiffusion import app, model_manager, task_manager
from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest
from easydiffusion.utils import log
import mimetypes
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
log.info(f"started in {app.SD_DIR}")
log.info(f"started at {datetime.datetime.now():%x %X}")
server_api = FastAPI()
NOCACHE_HEADERS = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
NOCACHE_HEADERS = {
"Cache-Control": "no-cache, no-store, must-revalidate",
"Pragma": "no-cache",
"Expires": "0",
}
class NoCacheStaticFiles(StaticFiles):
@ -65,11 +67,17 @@ def init():
name="custom-thumbnails",
)
server_api.mount("/media", NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")), name="media")
server_api.mount(
"/media",
NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")),
name="media",
)
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
server_api.mount(
f"/plugins/{dir_prefix}", NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}"
f"/plugins/{dir_prefix}",
NoCacheStaticFiles(directory=plugins_dir),
name=f"plugins-{dir_prefix}",
)
@server_api.post("/app_config")
@ -246,8 +254,8 @@ def render_internal(req: dict):
def model_merge_internal(req: dict):
try:
from sdkit.train import merge_models
from easydiffusion.utils.save_utils import filename_regex
from sdkit.train import merge_models
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
@ -255,7 +263,11 @@ def model_merge_internal(req: dict):
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
mergeReq.ratio,
os.path.join(app.MODELS_DIR, "stable-diffusion", filename_regex.sub("_", mergeReq.out_path)),
os.path.join(
app.MODELS_DIR,
"stable-diffusion",
filename_regex.sub("_", mergeReq.out_path),
),
mergeReq.use_fp16,
)
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)

View File

@ -9,14 +9,16 @@ import traceback
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
import torch
import queue, threading, time, weakref
import queue
import threading
import time
import weakref
from typing import Any, Hashable
import torch
from easydiffusion import device_manager
from easydiffusion.types import TaskData, GenerateImageRequest
from easydiffusion.types import GenerateImageRequest, TaskData
from easydiffusion.utils import log
from sdkit.utils import gc
THREAD_NAME_PREFIX = ""
@ -167,7 +169,7 @@ class DataCache:
raise Exception("DataCache.put" + ERR_LOCK_FAILED)
try:
self._base[key] = (self._get_ttl_time(ttl), value)
except Exception as e:
except Exception:
log.error(traceback.format_exc())
return False
else:
@ -264,7 +266,7 @@ def thread_get_next_task():
def thread_render(device):
global current_state, current_state_error
from easydiffusion import renderer, model_manager
from easydiffusion import model_manager, renderer
try:
renderer.init(device)
@ -337,7 +339,11 @@ def thread_render(device):
current_state = ServerStates.Rendering
task.response = renderer.make_images(
task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback
task.render_request,
task.task_data,
task.buffer_queue,
task.temp_images,
step_callback,
)
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL)
@ -392,8 +398,8 @@ def get_devices():
return {"name": device_manager.get_processor_name()}
mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_free /= float(10**9)
mem_total /= float(10**9)
mem_free /= float(10 ** 9)
mem_total /= float(10 ** 9)
return {
"name": torch.cuda.get_device_name(device),

View File

@ -1,6 +1,7 @@
from pydantic import BaseModel
from typing import Any
from pydantic import BaseModel
class GenerateImageRequest(BaseModel):
prompt: str = ""

View File

@ -1,14 +1,13 @@
import os
import time
import re
import time
from datetime import datetime
from functools import reduce
from easydiffusion import app
from easydiffusion.types import TaskData, GenerateImageRequest
from functools import reduce
from datetime import datetime
from sdkit.utils import save_images, save_dicts
from easydiffusion.types import GenerateImageRequest, TaskData
from numpy import base_repr
from sdkit.utils import save_dicts, save_images
filename_regex = re.compile("[^a-zA-Z0-9._-]")
img_number_regex = re.compile("([0-9]{5,})")
@ -50,6 +49,7 @@ other_placeholders = {
"$s": lambda req, task_data: str(req.seed),
}
class ImageNumber:
_factory = None
_evaluated = False
@ -57,12 +57,14 @@ class ImageNumber:
def __init__(self, factory):
self._factory = factory
self._evaluated = None
def __call__(self) -> int:
if self._evaluated is None:
self._evaluated = self._factory()
return self._evaluated
def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskData, now = None):
def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskData, now=None):
if now is None:
now = time.time()
@ -75,10 +77,12 @@ def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskD
return format
def format_folder_name(format: str, req: GenerateImageRequest, task_data: TaskData):
format = format_placeholders(format, req, task_data)
return filename_regex.sub("_", format)
def format_file_name(
format: str,
req: GenerateImageRequest,
@ -88,19 +92,22 @@ def format_file_name(
folder_img_number: ImageNumber,
):
format = format_placeholders(format, req, task_data, now)
if "$n" in format:
format = format.replace("$n", f"{folder_img_number():05}")
if "$tsb64" in format:
img_id = base_repr(int(now * 10000), 36)[-7:] + base_repr(int(batch_file_number), 36) # Base 36 conversion, 0-9, A-Z
img_id = base_repr(int(now * 10000), 36)[-7:] + base_repr(
int(batch_file_number), 36
) # Base 36 conversion, 0-9, A-Z
format = format.replace("$tsb64", img_id)
if "$ts" in format:
format = format.replace("$ts", str(int(now * 1000) + batch_file_number))
return filename_regex.sub("_", format)
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
now = time.time()
app_config = app.getConfig()
@ -126,7 +133,7 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
output_lossless=task_data.output_lossless,
)
if task_data.metadata_output_format:
for metadata_output_format in task_data.metadata_output_format.split(','):
for metadata_output_format in task_data.metadata_output_format.split(","):
if metadata_output_format.lower() in ["json", "txt", "embed"]:
save_dicts(
metadata_entries,
@ -142,7 +149,8 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
task_data,
file_number,
now=now,
suffix="filtered")
suffix="filtered",
)
save_images(
images,
@ -233,27 +241,28 @@ def make_filename_callback(
return make_filename
def _calculate_img_number(save_dir_path: str, task_data: TaskData):
def get_highest_img_number(accumulator: int, file: os.DirEntry) -> int:
if not file.is_file:
return accumulator
if len(list(filter(lambda e: file.name.endswith(e), app.IMAGE_EXTENSIONS))) == 0:
return accumulator
get_highest_img_number.number_of_images = get_highest_img_number.number_of_images + 1
number_match = img_number_regex.match(file.name)
if not number_match:
return accumulator
file_number = number_match.group().lstrip('0')
file_number = number_match.group().lstrip("0")
# Handle 00000
return int(file_number) if file_number else 0
get_highest_img_number.number_of_images = 0
highest_file_number = -1
if os.path.isdir(save_dir_path):
@ -267,13 +276,15 @@ def _calculate_img_number(save_dir_path: str, task_data: TaskData):
_calculate_img_number.session_img_numbers[task_data.session_id],
calculated_img_number,
)
calculated_img_number = calculated_img_number + 1
_calculate_img_number.session_img_numbers[task_data.session_id] = calculated_img_number
return calculated_img_number
_calculate_img_number.session_img_numbers = {}
def calculate_img_number(save_dir_path: str, task_data: TaskData):
return ImageNumber(lambda: _calculate_img_number(save_dir_path, task_data))