mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-25 17:55:09 +01:00
Formatting
This commit is contained in:
parent
07f52c38ef
commit
d18cefc519
@ -1,17 +1,15 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import json
|
|
||||||
import traceback
|
import traceback
|
||||||
import logging
|
|
||||||
import shlex
|
|
||||||
import urllib
|
import urllib
|
||||||
from rich.logging import RichHandler
|
|
||||||
|
|
||||||
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
|
|
||||||
|
|
||||||
from easydiffusion import task_manager
|
from easydiffusion import task_manager
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
|
||||||
|
|
||||||
# Remove all handlers associated with the root logger object.
|
# Remove all handlers associated with the root logger object.
|
||||||
for handler in logging.root.handlers[:]:
|
for handler in logging.root.handlers[:]:
|
||||||
@ -55,10 +53,34 @@ APP_CONFIG_DEFAULTS = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
IMAGE_EXTENSIONS = [".png", ".apng", ".jpg", ".jpeg", ".jfif", ".pjpeg", ".pjp", ".jxl", ".gif", ".webp", ".avif", ".svg"]
|
IMAGE_EXTENSIONS = [
|
||||||
|
".png",
|
||||||
|
".apng",
|
||||||
|
".jpg",
|
||||||
|
".jpeg",
|
||||||
|
".jfif",
|
||||||
|
".pjpeg",
|
||||||
|
".pjp",
|
||||||
|
".jxl",
|
||||||
|
".gif",
|
||||||
|
".webp",
|
||||||
|
".avif",
|
||||||
|
".svg",
|
||||||
|
]
|
||||||
CUSTOM_MODIFIERS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "modifiers"))
|
CUSTOM_MODIFIERS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "modifiers"))
|
||||||
CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS=[".portrait", "_portrait", " portrait", "-portrait"]
|
CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS = [
|
||||||
CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS=[".landscape", "_landscape", " landscape", "-landscape"]
|
".portrait",
|
||||||
|
"_portrait",
|
||||||
|
" portrait",
|
||||||
|
"-portrait",
|
||||||
|
]
|
||||||
|
CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS = [
|
||||||
|
".landscape",
|
||||||
|
"_landscape",
|
||||||
|
" landscape",
|
||||||
|
"-landscape",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
||||||
@ -84,7 +106,7 @@ def getConfig(default_val=APP_CONFIG_DEFAULTS):
|
|||||||
if os.getenv("SD_UI_BIND_IP") is not None:
|
if os.getenv("SD_UI_BIND_IP") is not None:
|
||||||
config["net"]["listen_to_network"] = os.getenv("SD_UI_BIND_IP") == "0.0.0.0"
|
config["net"]["listen_to_network"] = os.getenv("SD_UI_BIND_IP") == "0.0.0.0"
|
||||||
return config
|
return config
|
||||||
except Exception as e:
|
except Exception:
|
||||||
log.warn(traceback.format_exc())
|
log.warn(traceback.format_exc())
|
||||||
return default_val
|
return default_val
|
||||||
|
|
||||||
@ -97,6 +119,7 @@ def setConfig(config):
|
|||||||
except:
|
except:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level):
|
def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level):
|
||||||
config = getConfig()
|
config = getConfig()
|
||||||
if "model" not in config:
|
if "model" not in config:
|
||||||
@ -191,11 +214,12 @@ def open_browser():
|
|||||||
|
|
||||||
webbrowser.open(f"http://localhost:{port}")
|
webbrowser.open(f"http://localhost:{port}")
|
||||||
|
|
||||||
|
|
||||||
def get_image_modifiers():
|
def get_image_modifiers():
|
||||||
modifiers_json_path = os.path.join(SD_UI_DIR, "modifiers.json")
|
modifiers_json_path = os.path.join(SD_UI_DIR, "modifiers.json")
|
||||||
|
|
||||||
modifier_categories = {}
|
modifier_categories = {}
|
||||||
original_category_order=[]
|
original_category_order = []
|
||||||
with open(modifiers_json_path, "r", encoding="utf-8") as f:
|
with open(modifiers_json_path, "r", encoding="utf-8") as f:
|
||||||
modifiers_file = json.load(f)
|
modifiers_file = json.load(f)
|
||||||
|
|
||||||
@ -205,14 +229,14 @@ def get_image_modifiers():
|
|||||||
|
|
||||||
# convert modifiers from a list of objects to a dict of dicts
|
# convert modifiers from a list of objects to a dict of dicts
|
||||||
for category_item in modifiers_file:
|
for category_item in modifiers_file:
|
||||||
category_name = category_item['category']
|
category_name = category_item["category"]
|
||||||
original_category_order.append(category_name)
|
original_category_order.append(category_name)
|
||||||
category = {}
|
category = {}
|
||||||
for modifier_item in category_item['modifiers']:
|
for modifier_item in category_item["modifiers"]:
|
||||||
modifier = {}
|
modifier = {}
|
||||||
for preview_item in modifier_item['previews']:
|
for preview_item in modifier_item["previews"]:
|
||||||
modifier[preview_item['name']] = preview_item['path']
|
modifier[preview_item["name"]] = preview_item["path"]
|
||||||
category[modifier_item['modifier']] = modifier
|
category[modifier_item["modifier"]] = modifier
|
||||||
modifier_categories[category_name] = category
|
modifier_categories[category_name] = category
|
||||||
|
|
||||||
def scan_directory(directory_path: str, category_name="Modifiers"):
|
def scan_directory(directory_path: str, category_name="Modifiers"):
|
||||||
@ -225,12 +249,27 @@ def get_image_modifiers():
|
|||||||
modifier_name = entry.name[: -len(file_extension[0])]
|
modifier_name = entry.name[: -len(file_extension[0])]
|
||||||
modifier_path = f"custom/{entry.path[len(CUSTOM_MODIFIERS_DIR) + 1:]}"
|
modifier_path = f"custom/{entry.path[len(CUSTOM_MODIFIERS_DIR) + 1:]}"
|
||||||
# URL encode path segments
|
# URL encode path segments
|
||||||
modifier_path = "/".join(map(lambda segment: urllib.parse.quote(segment), modifier_path.split("/")))
|
modifier_path = "/".join(
|
||||||
|
map(
|
||||||
|
lambda segment: urllib.parse.quote(segment),
|
||||||
|
modifier_path.split("/"),
|
||||||
|
)
|
||||||
|
)
|
||||||
is_portrait = True
|
is_portrait = True
|
||||||
is_landscape = True
|
is_landscape = True
|
||||||
|
|
||||||
portrait_extension = list(filter(lambda e: modifier_name.lower().endswith(e), CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS))
|
portrait_extension = list(
|
||||||
landscape_extension = list(filter(lambda e: modifier_name.lower().endswith(e), CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS))
|
filter(
|
||||||
|
lambda e: modifier_name.lower().endswith(e),
|
||||||
|
CUSTOM_MODIFIERS_PORTRAIT_EXTENSIONS,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
landscape_extension = list(
|
||||||
|
filter(
|
||||||
|
lambda e: modifier_name.lower().endswith(e),
|
||||||
|
CUSTOM_MODIFIERS_LANDSCAPE_EXTENSIONS,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if len(portrait_extension) > 0:
|
if len(portrait_extension) > 0:
|
||||||
is_landscape = False
|
is_landscape = False
|
||||||
@ -238,24 +277,24 @@ def get_image_modifiers():
|
|||||||
elif len(landscape_extension) > 0:
|
elif len(landscape_extension) > 0:
|
||||||
is_portrait = False
|
is_portrait = False
|
||||||
modifier_name = modifier_name[: -len(landscape_extension[0])]
|
modifier_name = modifier_name[: -len(landscape_extension[0])]
|
||||||
|
|
||||||
if (category_name not in modifier_categories):
|
if category_name not in modifier_categories:
|
||||||
modifier_categories[category_name] = {}
|
modifier_categories[category_name] = {}
|
||||||
|
|
||||||
category = modifier_categories[category_name]
|
category = modifier_categories[category_name]
|
||||||
|
|
||||||
if (modifier_name not in category):
|
if modifier_name not in category:
|
||||||
category[modifier_name] = {}
|
category[modifier_name] = {}
|
||||||
|
|
||||||
if (is_portrait or "portrait" not in category[modifier_name]):
|
if is_portrait or "portrait" not in category[modifier_name]:
|
||||||
category[modifier_name]["portrait"] = modifier_path
|
category[modifier_name]["portrait"] = modifier_path
|
||||||
|
|
||||||
if (is_landscape or "landscape" not in category[modifier_name]):
|
if is_landscape or "landscape" not in category[modifier_name]:
|
||||||
category[modifier_name]["landscape"] = modifier_path
|
category[modifier_name]["landscape"] = modifier_path
|
||||||
elif entry.is_dir():
|
elif entry.is_dir():
|
||||||
scan_directory(
|
scan_directory(
|
||||||
entry.path,
|
entry.path,
|
||||||
entry.name if directory_path==CUSTOM_MODIFIERS_DIR else f"{category_name}/{entry.name}",
|
entry.name if directory_path == CUSTOM_MODIFIERS_DIR else f"{category_name}/{entry.name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
scan_directory(CUSTOM_MODIFIERS_DIR)
|
scan_directory(CUSTOM_MODIFIERS_DIR)
|
||||||
@ -268,12 +307,12 @@ def get_image_modifiers():
|
|||||||
# convert the modifiers back into a list of objects
|
# convert the modifiers back into a list of objects
|
||||||
modifier_categories_list = []
|
modifier_categories_list = []
|
||||||
for category_name in [*original_category_order, *custom_categories]:
|
for category_name in [*original_category_order, *custom_categories]:
|
||||||
category = { 'category': category_name, 'modifiers': [] }
|
category = {"category": category_name, "modifiers": []}
|
||||||
for modifier_name in sorted(modifier_categories[category_name].keys(), key=str.casefold):
|
for modifier_name in sorted(modifier_categories[category_name].keys(), key=str.casefold):
|
||||||
modifier = { 'modifier': modifier_name, 'previews': [] }
|
modifier = {"modifier": modifier_name, "previews": []}
|
||||||
for preview_name, preview_path in modifier_categories[category_name][modifier_name].items():
|
for preview_name, preview_path in modifier_categories[category_name][modifier_name].items():
|
||||||
modifier['previews'].append({ 'name': preview_name, 'path': preview_path })
|
modifier["previews"].append({"name": preview_name, "path": preview_path})
|
||||||
category['modifiers'].append(modifier)
|
category["modifiers"].append(modifier)
|
||||||
modifier_categories_list.append(category)
|
modifier_categories_list.append(category)
|
||||||
|
|
||||||
return modifier_categories_list
|
return modifier_categories_list
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import torch
|
|
||||||
import traceback
|
|
||||||
import re
|
import re
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import torch
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -98,8 +98,8 @@ def auto_pick_devices(currently_active_devices):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||||
mem_free /= float(10**9)
|
mem_free /= float(10 ** 9)
|
||||||
mem_total /= float(10**9)
|
mem_total /= float(10 ** 9)
|
||||||
device_name = torch.cuda.get_device_name(device)
|
device_name = torch.cuda.get_device_name(device)
|
||||||
log.debug(
|
log.debug(
|
||||||
f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
|
f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
|
||||||
@ -118,7 +118,10 @@ def auto_pick_devices(currently_active_devices):
|
|||||||
# These already-running devices probably aren't terrible, since they were picked in the past.
|
# These already-running devices probably aren't terrible, since they were picked in the past.
|
||||||
# Worst case, the user can restart the program and that'll get rid of them.
|
# Worst case, the user can restart the program and that'll get rid of them.
|
||||||
devices = list(
|
devices = list(
|
||||||
filter((lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices), devices)
|
filter(
|
||||||
|
(lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices),
|
||||||
|
devices,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
devices = list(map(lambda x: x["device"], devices))
|
devices = list(map(lambda x: x["device"], devices))
|
||||||
return devices
|
return devices
|
||||||
@ -178,7 +181,7 @@ def get_max_vram_usage_level(device):
|
|||||||
else:
|
else:
|
||||||
return "high"
|
return "high"
|
||||||
|
|
||||||
mem_total /= float(10**9)
|
mem_total /= float(10 ** 9)
|
||||||
if mem_total < 4.5:
|
if mem_total < 4.5:
|
||||||
return "low"
|
return "low"
|
||||||
elif mem_total < 6.5:
|
elif mem_total < 6.5:
|
||||||
@ -220,7 +223,7 @@ def is_device_compatible(device):
|
|||||||
# Memory check
|
# Memory check
|
||||||
try:
|
try:
|
||||||
_, mem_total = torch.cuda.mem_get_info(device)
|
_, mem_total = torch.cuda.mem_get_info(device)
|
||||||
mem_total /= float(10**9)
|
mem_total /= float(10 ** 9)
|
||||||
if mem_total < 3.0:
|
if mem_total < 3.0:
|
||||||
if is_device_compatible.history.get(device) == None:
|
if is_device_compatible.history.get(device) == None:
|
||||||
log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion")
|
log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion")
|
||||||
|
@ -3,11 +3,17 @@ import os
|
|||||||
from easydiffusion import app
|
from easydiffusion import app
|
||||||
from easydiffusion.types import TaskData
|
from easydiffusion.types import TaskData
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
from sdkit import Context
|
from sdkit import Context
|
||||||
from sdkit.models import load_model, unload_model, scan_model
|
from sdkit.models import load_model, scan_model, unload_model
|
||||||
|
|
||||||
KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan", "lora"]
|
KNOWN_MODEL_TYPES = [
|
||||||
|
"stable-diffusion",
|
||||||
|
"vae",
|
||||||
|
"hypernetwork",
|
||||||
|
"gfpgan",
|
||||||
|
"realesrgan",
|
||||||
|
"lora",
|
||||||
|
]
|
||||||
MODEL_EXTENSIONS = {
|
MODEL_EXTENSIONS = {
|
||||||
"stable-diffusion": [".ckpt", ".safetensors"],
|
"stable-diffusion": [".ckpt", ".safetensors"],
|
||||||
"vae": [".vae.pt", ".ckpt", ".safetensors"],
|
"vae": [".vae.pt", ".ckpt", ".safetensors"],
|
||||||
@ -44,13 +50,15 @@ def load_default_models(context: Context):
|
|||||||
load_model(
|
load_model(
|
||||||
context,
|
context,
|
||||||
model_type,
|
model_type,
|
||||||
scan_model = context.model_paths[model_type] != None and not context.model_paths[model_type].endswith('.safetensors')
|
scan_model=context.model_paths[model_type] != None
|
||||||
|
and not context.model_paths[model_type].endswith(".safetensors"),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f"[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]")
|
log.error(f"[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]")
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
del context.model_paths[model_type]
|
del context.model_paths[model_type]
|
||||||
|
|
||||||
|
|
||||||
def unload_all(context: Context):
|
def unload_all(context: Context):
|
||||||
for model_type in KNOWN_MODEL_TYPES:
|
for model_type in KNOWN_MODEL_TYPES:
|
||||||
unload_model(context, model_type)
|
unload_model(context, model_type)
|
||||||
@ -170,13 +178,23 @@ def is_malicious_model(file_path):
|
|||||||
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
|
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
|
||||||
log.warn(
|
log.warn(
|
||||||
":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]"
|
":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]"
|
||||||
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
|
% (
|
||||||
|
file_path,
|
||||||
|
scan_result.scanned_files,
|
||||||
|
scan_result.issues_count,
|
||||||
|
scan_result.infected_files,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
log.debug(
|
log.debug(
|
||||||
"Scan %s: [green]%d scanned, %d issue, %d infected.[/green]"
|
"Scan %s: [green]%d scanned, %d issue, %d infected.[/green]"
|
||||||
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
|
% (
|
||||||
|
file_path,
|
||||||
|
scan_result.scanned_files,
|
||||||
|
scan_result.issues_count,
|
||||||
|
scan_result.infected_files,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -204,13 +222,13 @@ def getModels():
|
|||||||
|
|
||||||
class MaliciousModelException(Exception):
|
class MaliciousModelException(Exception):
|
||||||
"Raised when picklescan reports a problem with a model"
|
"Raised when picklescan reports a problem with a model"
|
||||||
pass
|
|
||||||
|
|
||||||
def scan_directory(directory, suffixes, directoriesFirst: bool = True):
|
def scan_directory(directory, suffixes, directoriesFirst: bool = True):
|
||||||
nonlocal models_scanned
|
nonlocal models_scanned
|
||||||
tree = []
|
tree = []
|
||||||
for entry in sorted(
|
for entry in sorted(
|
||||||
os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower())
|
os.scandir(directory),
|
||||||
|
key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower()),
|
||||||
):
|
):
|
||||||
if entry.is_file():
|
if entry.is_file():
|
||||||
matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes))
|
matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes))
|
||||||
|
@ -1,21 +1,22 @@
|
|||||||
import queue
|
|
||||||
import time
|
|
||||||
import json
|
import json
|
||||||
import pprint
|
import pprint
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
|
||||||
from easydiffusion import device_manager
|
from easydiffusion import device_manager
|
||||||
from easydiffusion.types import TaskData, Response, Image as ResponseImage, UserInitiatedStop, GenerateImageRequest
|
from easydiffusion.types import GenerateImageRequest
|
||||||
from easydiffusion.utils import get_printable_request, save_images_to_disk, log
|
from easydiffusion.types import Image as ResponseImage
|
||||||
|
from easydiffusion.types import Response, TaskData, UserInitiatedStop
|
||||||
|
from easydiffusion.utils import get_printable_request, log, save_images_to_disk
|
||||||
from sdkit import Context
|
from sdkit import Context
|
||||||
from sdkit.generate import generate_images
|
|
||||||
from sdkit.filter import apply_filters
|
from sdkit.filter import apply_filters
|
||||||
|
from sdkit.generate import generate_images
|
||||||
from sdkit.utils import (
|
from sdkit.utils import (
|
||||||
img_to_buffer,
|
|
||||||
img_to_base64_str,
|
|
||||||
latent_samples_to_images,
|
|
||||||
diffusers_latent_samples_to_images,
|
diffusers_latent_samples_to_images,
|
||||||
gc,
|
gc,
|
||||||
|
img_to_base64_str,
|
||||||
|
img_to_buffer,
|
||||||
|
latent_samples_to_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
context = Context() # thread-local
|
context = Context() # thread-local
|
||||||
@ -43,14 +44,22 @@ def init(device):
|
|||||||
|
|
||||||
|
|
||||||
def make_images(
|
def make_images(
|
||||||
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
|
req: GenerateImageRequest,
|
||||||
|
task_data: TaskData,
|
||||||
|
data_queue: queue.Queue,
|
||||||
|
task_temp_images: list,
|
||||||
|
step_callback,
|
||||||
):
|
):
|
||||||
context.stop_processing = False
|
context.stop_processing = False
|
||||||
print_task_info(req, task_data)
|
print_task_info(req, task_data)
|
||||||
|
|
||||||
images, seeds = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
images, seeds = make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||||
|
|
||||||
res = Response(req, task_data, images=construct_response(images, seeds, task_data, base_seed=req.seed))
|
res = Response(
|
||||||
|
req,
|
||||||
|
task_data,
|
||||||
|
images=construct_response(images, seeds, task_data, base_seed=req.seed),
|
||||||
|
)
|
||||||
res = res.json()
|
res = res.json()
|
||||||
data_queue.put(json.dumps(res))
|
data_queue.put(json.dumps(res))
|
||||||
log.info("Task completed")
|
log.info("Task completed")
|
||||||
@ -66,7 +75,11 @@ def print_task_info(req: GenerateImageRequest, task_data: TaskData):
|
|||||||
|
|
||||||
|
|
||||||
def make_images_internal(
|
def make_images_internal(
|
||||||
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
|
req: GenerateImageRequest,
|
||||||
|
task_data: TaskData,
|
||||||
|
data_queue: queue.Queue,
|
||||||
|
task_temp_images: list,
|
||||||
|
step_callback,
|
||||||
):
|
):
|
||||||
images, user_stopped = generate_images_internal(
|
images, user_stopped = generate_images_internal(
|
||||||
req,
|
req,
|
||||||
@ -155,7 +168,12 @@ def filter_images(task_data: TaskData, images: list, user_stopped):
|
|||||||
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
|
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
|
||||||
return [
|
return [
|
||||||
ResponseImage(
|
ResponseImage(
|
||||||
data=img_to_base64_str(img, task_data.output_format, task_data.output_quality, task_data.output_lossless),
|
data=img_to_base64_str(
|
||||||
|
img,
|
||||||
|
task_data.output_format,
|
||||||
|
task_data.output_quality,
|
||||||
|
task_data.output_lossless,
|
||||||
|
),
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
for img, seed in zip(images, seeds)
|
for img, seed in zip(images, seeds)
|
||||||
|
@ -2,28 +2,30 @@
|
|||||||
Notes:
|
Notes:
|
||||||
async endpoints always run on the main thread. Without they run on the thread pool.
|
async endpoints always run on the main thread. Without they run on the thread pool.
|
||||||
"""
|
"""
|
||||||
|
import datetime
|
||||||
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
import datetime
|
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
from easydiffusion import app, model_manager, task_manager
|
||||||
|
from easydiffusion.types import GenerateImageRequest, MergeRequest, TaskData
|
||||||
|
from easydiffusion.utils import log
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
|
||||||
from pydantic import BaseModel, Extra
|
from pydantic import BaseModel, Extra
|
||||||
|
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||||
from easydiffusion import app, model_manager, task_manager
|
|
||||||
from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest
|
|
||||||
from easydiffusion.utils import log
|
|
||||||
|
|
||||||
import mimetypes
|
|
||||||
|
|
||||||
log.info(f"started in {app.SD_DIR}")
|
log.info(f"started in {app.SD_DIR}")
|
||||||
log.info(f"started at {datetime.datetime.now():%x %X}")
|
log.info(f"started at {datetime.datetime.now():%x %X}")
|
||||||
|
|
||||||
server_api = FastAPI()
|
server_api = FastAPI()
|
||||||
|
|
||||||
NOCACHE_HEADERS = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
NOCACHE_HEADERS = {
|
||||||
|
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||||
|
"Pragma": "no-cache",
|
||||||
|
"Expires": "0",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class NoCacheStaticFiles(StaticFiles):
|
class NoCacheStaticFiles(StaticFiles):
|
||||||
@ -65,11 +67,17 @@ def init():
|
|||||||
name="custom-thumbnails",
|
name="custom-thumbnails",
|
||||||
)
|
)
|
||||||
|
|
||||||
server_api.mount("/media", NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")), name="media")
|
server_api.mount(
|
||||||
|
"/media",
|
||||||
|
NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")),
|
||||||
|
name="media",
|
||||||
|
)
|
||||||
|
|
||||||
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
|
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
|
||||||
server_api.mount(
|
server_api.mount(
|
||||||
f"/plugins/{dir_prefix}", NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}"
|
f"/plugins/{dir_prefix}",
|
||||||
|
NoCacheStaticFiles(directory=plugins_dir),
|
||||||
|
name=f"plugins-{dir_prefix}",
|
||||||
)
|
)
|
||||||
|
|
||||||
@server_api.post("/app_config")
|
@server_api.post("/app_config")
|
||||||
@ -246,8 +254,8 @@ def render_internal(req: dict):
|
|||||||
|
|
||||||
def model_merge_internal(req: dict):
|
def model_merge_internal(req: dict):
|
||||||
try:
|
try:
|
||||||
from sdkit.train import merge_models
|
|
||||||
from easydiffusion.utils.save_utils import filename_regex
|
from easydiffusion.utils.save_utils import filename_regex
|
||||||
|
from sdkit.train import merge_models
|
||||||
|
|
||||||
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
||||||
|
|
||||||
@ -255,7 +263,11 @@ def model_merge_internal(req: dict):
|
|||||||
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
|
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
|
||||||
model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
|
model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
|
||||||
mergeReq.ratio,
|
mergeReq.ratio,
|
||||||
os.path.join(app.MODELS_DIR, "stable-diffusion", filename_regex.sub("_", mergeReq.out_path)),
|
os.path.join(
|
||||||
|
app.MODELS_DIR,
|
||||||
|
"stable-diffusion",
|
||||||
|
filename_regex.sub("_", mergeReq.out_path),
|
||||||
|
),
|
||||||
mergeReq.use_fp16,
|
mergeReq.use_fp16,
|
||||||
)
|
)
|
||||||
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
||||||
|
@ -9,14 +9,16 @@ import traceback
|
|||||||
|
|
||||||
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
||||||
|
|
||||||
import torch
|
import queue
|
||||||
import queue, threading, time, weakref
|
import threading
|
||||||
|
import time
|
||||||
|
import weakref
|
||||||
from typing import Any, Hashable
|
from typing import Any, Hashable
|
||||||
|
|
||||||
|
import torch
|
||||||
from easydiffusion import device_manager
|
from easydiffusion import device_manager
|
||||||
from easydiffusion.types import TaskData, GenerateImageRequest
|
from easydiffusion.types import GenerateImageRequest, TaskData
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
from sdkit.utils import gc
|
from sdkit.utils import gc
|
||||||
|
|
||||||
THREAD_NAME_PREFIX = ""
|
THREAD_NAME_PREFIX = ""
|
||||||
@ -167,7 +169,7 @@ class DataCache:
|
|||||||
raise Exception("DataCache.put" + ERR_LOCK_FAILED)
|
raise Exception("DataCache.put" + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
self._base[key] = (self._get_ttl_time(ttl), value)
|
self._base[key] = (self._get_ttl_time(ttl), value)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
@ -264,7 +266,7 @@ def thread_get_next_task():
|
|||||||
def thread_render(device):
|
def thread_render(device):
|
||||||
global current_state, current_state_error
|
global current_state, current_state_error
|
||||||
|
|
||||||
from easydiffusion import renderer, model_manager
|
from easydiffusion import model_manager, renderer
|
||||||
|
|
||||||
try:
|
try:
|
||||||
renderer.init(device)
|
renderer.init(device)
|
||||||
@ -337,7 +339,11 @@ def thread_render(device):
|
|||||||
|
|
||||||
current_state = ServerStates.Rendering
|
current_state = ServerStates.Rendering
|
||||||
task.response = renderer.make_images(
|
task.response = renderer.make_images(
|
||||||
task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback
|
task.render_request,
|
||||||
|
task.task_data,
|
||||||
|
task.buffer_queue,
|
||||||
|
task.temp_images,
|
||||||
|
step_callback,
|
||||||
)
|
)
|
||||||
# Before looping back to the generator, mark cache as still alive.
|
# Before looping back to the generator, mark cache as still alive.
|
||||||
task_cache.keep(id(task), TASK_TTL)
|
task_cache.keep(id(task), TASK_TTL)
|
||||||
@ -392,8 +398,8 @@ def get_devices():
|
|||||||
return {"name": device_manager.get_processor_name()}
|
return {"name": device_manager.get_processor_name()}
|
||||||
|
|
||||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||||
mem_free /= float(10**9)
|
mem_free /= float(10 ** 9)
|
||||||
mem_total /= float(10**9)
|
mem_total /= float(10 ** 9)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": torch.cuda.get_device_name(device),
|
"name": torch.cuda.get_device_name(device),
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class GenerateImageRequest(BaseModel):
|
class GenerateImageRequest(BaseModel):
|
||||||
prompt: str = ""
|
prompt: str = ""
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from functools import reduce
|
||||||
|
|
||||||
from easydiffusion import app
|
from easydiffusion import app
|
||||||
from easydiffusion.types import TaskData, GenerateImageRequest
|
from easydiffusion.types import GenerateImageRequest, TaskData
|
||||||
from functools import reduce
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sdkit.utils import save_images, save_dicts
|
|
||||||
from numpy import base_repr
|
from numpy import base_repr
|
||||||
|
from sdkit.utils import save_dicts, save_images
|
||||||
|
|
||||||
filename_regex = re.compile("[^a-zA-Z0-9._-]")
|
filename_regex = re.compile("[^a-zA-Z0-9._-]")
|
||||||
img_number_regex = re.compile("([0-9]{5,})")
|
img_number_regex = re.compile("([0-9]{5,})")
|
||||||
@ -50,6 +49,7 @@ other_placeholders = {
|
|||||||
"$s": lambda req, task_data: str(req.seed),
|
"$s": lambda req, task_data: str(req.seed),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ImageNumber:
|
class ImageNumber:
|
||||||
_factory = None
|
_factory = None
|
||||||
_evaluated = False
|
_evaluated = False
|
||||||
@ -57,12 +57,14 @@ class ImageNumber:
|
|||||||
def __init__(self, factory):
|
def __init__(self, factory):
|
||||||
self._factory = factory
|
self._factory = factory
|
||||||
self._evaluated = None
|
self._evaluated = None
|
||||||
|
|
||||||
def __call__(self) -> int:
|
def __call__(self) -> int:
|
||||||
if self._evaluated is None:
|
if self._evaluated is None:
|
||||||
self._evaluated = self._factory()
|
self._evaluated = self._factory()
|
||||||
return self._evaluated
|
return self._evaluated
|
||||||
|
|
||||||
def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskData, now = None):
|
|
||||||
|
def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskData, now=None):
|
||||||
if now is None:
|
if now is None:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
@ -75,10 +77,12 @@ def format_placeholders(format: str, req: GenerateImageRequest, task_data: TaskD
|
|||||||
|
|
||||||
return format
|
return format
|
||||||
|
|
||||||
|
|
||||||
def format_folder_name(format: str, req: GenerateImageRequest, task_data: TaskData):
|
def format_folder_name(format: str, req: GenerateImageRequest, task_data: TaskData):
|
||||||
format = format_placeholders(format, req, task_data)
|
format = format_placeholders(format, req, task_data)
|
||||||
return filename_regex.sub("_", format)
|
return filename_regex.sub("_", format)
|
||||||
|
|
||||||
|
|
||||||
def format_file_name(
|
def format_file_name(
|
||||||
format: str,
|
format: str,
|
||||||
req: GenerateImageRequest,
|
req: GenerateImageRequest,
|
||||||
@ -88,19 +92,22 @@ def format_file_name(
|
|||||||
folder_img_number: ImageNumber,
|
folder_img_number: ImageNumber,
|
||||||
):
|
):
|
||||||
format = format_placeholders(format, req, task_data, now)
|
format = format_placeholders(format, req, task_data, now)
|
||||||
|
|
||||||
if "$n" in format:
|
if "$n" in format:
|
||||||
format = format.replace("$n", f"{folder_img_number():05}")
|
format = format.replace("$n", f"{folder_img_number():05}")
|
||||||
|
|
||||||
if "$tsb64" in format:
|
if "$tsb64" in format:
|
||||||
img_id = base_repr(int(now * 10000), 36)[-7:] + base_repr(int(batch_file_number), 36) # Base 36 conversion, 0-9, A-Z
|
img_id = base_repr(int(now * 10000), 36)[-7:] + base_repr(
|
||||||
|
int(batch_file_number), 36
|
||||||
|
) # Base 36 conversion, 0-9, A-Z
|
||||||
format = format.replace("$tsb64", img_id)
|
format = format.replace("$tsb64", img_id)
|
||||||
|
|
||||||
if "$ts" in format:
|
if "$ts" in format:
|
||||||
format = format.replace("$ts", str(int(now * 1000) + batch_file_number))
|
format = format.replace("$ts", str(int(now * 1000) + batch_file_number))
|
||||||
|
|
||||||
return filename_regex.sub("_", format)
|
return filename_regex.sub("_", format)
|
||||||
|
|
||||||
|
|
||||||
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
|
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
|
||||||
now = time.time()
|
now = time.time()
|
||||||
app_config = app.getConfig()
|
app_config = app.getConfig()
|
||||||
@ -126,7 +133,7 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
|
|||||||
output_lossless=task_data.output_lossless,
|
output_lossless=task_data.output_lossless,
|
||||||
)
|
)
|
||||||
if task_data.metadata_output_format:
|
if task_data.metadata_output_format:
|
||||||
for metadata_output_format in task_data.metadata_output_format.split(','):
|
for metadata_output_format in task_data.metadata_output_format.split(","):
|
||||||
if metadata_output_format.lower() in ["json", "txt", "embed"]:
|
if metadata_output_format.lower() in ["json", "txt", "embed"]:
|
||||||
save_dicts(
|
save_dicts(
|
||||||
metadata_entries,
|
metadata_entries,
|
||||||
@ -142,7 +149,8 @@ def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageR
|
|||||||
task_data,
|
task_data,
|
||||||
file_number,
|
file_number,
|
||||||
now=now,
|
now=now,
|
||||||
suffix="filtered")
|
suffix="filtered",
|
||||||
|
)
|
||||||
|
|
||||||
save_images(
|
save_images(
|
||||||
images,
|
images,
|
||||||
@ -233,27 +241,28 @@ def make_filename_callback(
|
|||||||
|
|
||||||
return make_filename
|
return make_filename
|
||||||
|
|
||||||
|
|
||||||
def _calculate_img_number(save_dir_path: str, task_data: TaskData):
|
def _calculate_img_number(save_dir_path: str, task_data: TaskData):
|
||||||
def get_highest_img_number(accumulator: int, file: os.DirEntry) -> int:
|
def get_highest_img_number(accumulator: int, file: os.DirEntry) -> int:
|
||||||
if not file.is_file:
|
if not file.is_file:
|
||||||
return accumulator
|
return accumulator
|
||||||
|
|
||||||
if len(list(filter(lambda e: file.name.endswith(e), app.IMAGE_EXTENSIONS))) == 0:
|
if len(list(filter(lambda e: file.name.endswith(e), app.IMAGE_EXTENSIONS))) == 0:
|
||||||
return accumulator
|
return accumulator
|
||||||
|
|
||||||
get_highest_img_number.number_of_images = get_highest_img_number.number_of_images + 1
|
get_highest_img_number.number_of_images = get_highest_img_number.number_of_images + 1
|
||||||
|
|
||||||
number_match = img_number_regex.match(file.name)
|
number_match = img_number_regex.match(file.name)
|
||||||
if not number_match:
|
if not number_match:
|
||||||
return accumulator
|
return accumulator
|
||||||
|
|
||||||
file_number = number_match.group().lstrip('0')
|
file_number = number_match.group().lstrip("0")
|
||||||
|
|
||||||
# Handle 00000
|
# Handle 00000
|
||||||
return int(file_number) if file_number else 0
|
return int(file_number) if file_number else 0
|
||||||
|
|
||||||
get_highest_img_number.number_of_images = 0
|
get_highest_img_number.number_of_images = 0
|
||||||
|
|
||||||
highest_file_number = -1
|
highest_file_number = -1
|
||||||
|
|
||||||
if os.path.isdir(save_dir_path):
|
if os.path.isdir(save_dir_path):
|
||||||
@ -267,13 +276,15 @@ def _calculate_img_number(save_dir_path: str, task_data: TaskData):
|
|||||||
_calculate_img_number.session_img_numbers[task_data.session_id],
|
_calculate_img_number.session_img_numbers[task_data.session_id],
|
||||||
calculated_img_number,
|
calculated_img_number,
|
||||||
)
|
)
|
||||||
|
|
||||||
calculated_img_number = calculated_img_number + 1
|
calculated_img_number = calculated_img_number + 1
|
||||||
|
|
||||||
_calculate_img_number.session_img_numbers[task_data.session_id] = calculated_img_number
|
_calculate_img_number.session_img_numbers[task_data.session_id] = calculated_img_number
|
||||||
return calculated_img_number
|
return calculated_img_number
|
||||||
|
|
||||||
|
|
||||||
_calculate_img_number.session_img_numbers = {}
|
_calculate_img_number.session_img_numbers = {}
|
||||||
|
|
||||||
|
|
||||||
def calculate_img_number(save_dir_path: str, task_data: TaskData):
|
def calculate_img_number(save_dir_path: str, task_data: TaskData):
|
||||||
return ImageNumber(lambda: _calculate_img_number(save_dir_path, task_data))
|
return ImageNumber(lambda: _calculate_img_number(save_dir_path, task_data))
|
||||||
|
Loading…
Reference in New Issue
Block a user