Format code, PEP8 using Black

This commit is contained in:
cmdr2 2023-02-14 18:47:50 +05:30
parent 0ad08c609d
commit 2eb317c6b6
9 changed files with 728 additions and 475 deletions

View File

@ -7,7 +7,7 @@ import logging
import shlex import shlex
from rich.logging import RichHandler from rich.logging import RichHandler
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
from easydiffusion import task_manager from easydiffusion import task_manager
from easydiffusion.utils import log from easydiffusion.utils import log
@ -16,83 +16,86 @@ from easydiffusion.utils import log
for handler in logging.root.handlers[:]: for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler) logging.root.removeHandler(handler)
LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s' LOG_FORMAT = "%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s"
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format=LOG_FORMAT, format=LOG_FORMAT,
datefmt="%X", datefmt="%X",
handlers=[RichHandler(markup=True, rich_tracebacks=False, show_time=False, show_level=False)], handlers=[RichHandler(markup=True, rich_tracebacks=False, show_time=False, show_level=False)],
) )
SD_DIR = os.getcwd() SD_DIR = os.getcwd()
SD_UI_DIR = os.getenv('SD_UI_PATH', None) SD_UI_DIR = os.getenv("SD_UI_PATH", None)
sys.path.append(os.path.dirname(SD_UI_DIR)) sys.path.append(os.path.dirname(SD_UI_DIR))
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts')) CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts"))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models')) MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models"))
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui')) USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins", "ui"))
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui')) CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins", "ui"))
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user')) UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, "core"), (USER_UI_PLUGINS_DIR, "user"))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
PRESERVE_CONFIG_VARS = ['FORCE_FULL_PRECISION'] PRESERVE_CONFIG_VARS = ["FORCE_FULL_PRECISION"]
TASK_TTL = 15 * 60 # Discard last session's task timeout TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = { APP_CONFIG_DEFAULTS = {
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device. # auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index) "render_devices": "auto", # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
'update_branch': 'main', "update_branch": "main",
'ui': { "ui": {
'open_browser_on_start': True, "open_browser_on_start": True,
} },
} }
def init(): def init():
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
update_render_threads() update_render_threads()
def getConfig(default_val=APP_CONFIG_DEFAULTS): def getConfig(default_val=APP_CONFIG_DEFAULTS):
try: try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json') config_json_path = os.path.join(CONFIG_DIR, "config.json")
if not os.path.exists(config_json_path): if not os.path.exists(config_json_path):
config = default_val config = default_val
else: else:
with open(config_json_path, 'r', encoding='utf-8') as f: with open(config_json_path, "r", encoding="utf-8") as f:
config = json.load(f) config = json.load(f)
if 'net' not in config: if "net" not in config:
config['net'] = {} config["net"] = {}
if os.getenv('SD_UI_BIND_PORT') is not None: if os.getenv("SD_UI_BIND_PORT") is not None:
config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT')) config["net"]["listen_port"] = int(os.getenv("SD_UI_BIND_PORT"))
else: else:
config['net']['listen_port'] = 9000 config["net"]["listen_port"] = 9000
if os.getenv('SD_UI_BIND_IP') is not None: if os.getenv("SD_UI_BIND_IP") is not None:
config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0') config["net"]["listen_to_network"] = os.getenv("SD_UI_BIND_IP") == "0.0.0.0"
else: else:
config['net']['listen_to_network'] = True config["net"]["listen_to_network"] = True
return config return config
except Exception as e: except Exception as e:
log.warn(traceback.format_exc()) log.warn(traceback.format_exc())
return default_val return default_val
def setConfig(config): def setConfig(config):
try: # config.json try: # config.json
config_json_path = os.path.join(CONFIG_DIR, 'config.json') config_json_path = os.path.join(CONFIG_DIR, "config.json")
with open(config_json_path, 'w', encoding='utf-8') as f: with open(config_json_path, "w", encoding="utf-8") as f:
json.dump(config, f) json.dump(config, f)
except: except:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
try: # config.bat try: # config.bat
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat') config_bat_path = os.path.join(CONFIG_DIR, "config.bat")
config_bat = [] config_bat = []
if 'update_branch' in config: if "update_branch" in config:
config_bat.append(f"@set update_branch={config['update_branch']}") config_bat.append(f"@set update_branch={config['update_branch']}")
config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}") config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' bind_ip = "0.0.0.0" if config["net"]["listen_to_network"] else "127.0.0.1"
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}") config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
# Preserve these variables if they are set # Preserve these variables if they are set
@ -101,20 +104,20 @@ def setConfig(config):
config_bat.append(f"@set {var}={os.getenv(var)}") config_bat.append(f"@set {var}={os.getenv(var)}")
if len(config_bat) > 0: if len(config_bat) > 0:
with open(config_bat_path, 'w', encoding='utf-8') as f: with open(config_bat_path, "w", encoding="utf-8") as f:
f.write('\r\n'.join(config_bat)) f.write("\r\n".join(config_bat))
except: except:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
try: # config.sh try: # config.sh
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh') config_sh_path = os.path.join(CONFIG_DIR, "config.sh")
config_sh = ['#!/bin/bash'] config_sh = ["#!/bin/bash"]
if 'update_branch' in config: if "update_branch" in config:
config_sh.append(f"export update_branch={config['update_branch']}") config_sh.append(f"export update_branch={config['update_branch']}")
config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}") config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1' bind_ip = "0.0.0.0" if config["net"]["listen_to_network"] else "127.0.0.1"
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}") config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
# Preserve these variables if they are set # Preserve these variables if they are set
@ -123,47 +126,51 @@ def setConfig(config):
config_bat.append(f'export {var}="{shlex.quote(os.getenv(var))}"') config_bat.append(f'export {var}="{shlex.quote(os.getenv(var))}"')
if len(config_sh) > 1: if len(config_sh) > 1:
with open(config_sh_path, 'w', encoding='utf-8') as f: with open(config_sh_path, "w", encoding="utf-8") as f:
f.write('\n'.join(config_sh)) f.write("\n".join(config_sh))
except: except:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level): def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level):
config = getConfig() config = getConfig()
if 'model' not in config: if "model" not in config:
config['model'] = {} config["model"] = {}
config['model']['stable-diffusion'] = ckpt_model_name config["model"]["stable-diffusion"] = ckpt_model_name
config['model']['vae'] = vae_model_name config["model"]["vae"] = vae_model_name
config['model']['hypernetwork'] = hypernetwork_model_name config["model"]["hypernetwork"] = hypernetwork_model_name
if vae_model_name is None or vae_model_name == "": if vae_model_name is None or vae_model_name == "":
del config['model']['vae'] del config["model"]["vae"]
if hypernetwork_model_name is None or hypernetwork_model_name == "": if hypernetwork_model_name is None or hypernetwork_model_name == "":
del config['model']['hypernetwork'] del config["model"]["hypernetwork"]
config['vram_usage_level'] = vram_usage_level config["vram_usage_level"] = vram_usage_level
setConfig(config) setConfig(config)
def update_render_threads(): def update_render_threads():
config = getConfig() config = getConfig()
render_devices = config.get('render_devices', 'auto') render_devices = config.get("render_devices", "auto")
active_devices = task_manager.get_devices()['active'].keys() active_devices = task_manager.get_devices()["active"].keys()
log.debug(f'requesting for render_devices: {render_devices}') log.debug(f"requesting for render_devices: {render_devices}")
task_manager.update_render_threads(render_devices, active_devices) task_manager.update_render_threads(render_devices, active_devices)
def getUIPlugins(): def getUIPlugins():
plugins = [] plugins = []
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES: for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
for file in os.listdir(plugins_dir): for file in os.listdir(plugins_dir):
if file.endswith('.plugin.js'): if file.endswith(".plugin.js"):
plugins.append(f'/plugins/{dir_prefix}/{file}') plugins.append(f"/plugins/{dir_prefix}/{file}")
return plugins return plugins
def getIPConfig(): def getIPConfig():
try: try:
ips = socket.gethostbyname_ex(socket.gethostname()) ips = socket.gethostbyname_ex(socket.gethostname())
@ -173,10 +180,13 @@ def getIPConfig():
log.exception(e) log.exception(e)
return [] return []
def open_browser(): def open_browser():
config = getConfig() config = getConfig()
ui = config.get('ui', {}) ui = config.get("ui", {})
net = config.get('net', {'listen_port':9000}) net = config.get("net", {"listen_port": 9000})
port = net.get('listen_port', 9000) port = net.get("listen_port", 9000)
if ui.get('open_browser_on_start', True): if ui.get("open_browser_on_start", True):
import webbrowser; webbrowser.open(f"http://localhost:{port}") import webbrowser
webbrowser.open(f"http://localhost:{port}")

View File

@ -5,45 +5,54 @@ import re
from easydiffusion.utils import log from easydiffusion.utils import log
''' """
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32). Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
Otherwise the models will load at half-precision (i.e. float16). Otherwise the models will load at half-precision (i.e. float16).
Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs). Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs).
''' """
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked COMPARABLE_GPU_PERCENTILE = (
0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
)
mem_free_threshold = 0 mem_free_threshold = 0
def get_device_delta(render_devices, active_devices): def get_device_delta(render_devices, active_devices):
''' """
render_devices: 'cpu', or 'auto' or ['cuda:N'...] render_devices: 'cpu', or 'auto' or ['cuda:N'...]
active_devices: ['cpu', 'cuda:N'...] active_devices: ['cpu', 'cuda:N'...]
''' """
if render_devices in ('cpu', 'auto'): if render_devices in ("cpu", "auto"):
render_devices = [render_devices] render_devices = [render_devices]
elif render_devices is not None: elif render_devices is not None:
if isinstance(render_devices, str): if isinstance(render_devices, str):
render_devices = [render_devices] render_devices = [render_devices]
if isinstance(render_devices, list) and len(render_devices) > 0: if isinstance(render_devices, list) and len(render_devices) > 0:
render_devices = list(filter(lambda x: x.startswith('cuda:'), render_devices)) render_devices = list(filter(lambda x: x.startswith("cuda:"), render_devices))
if len(render_devices) == 0: if len(render_devices) == 0:
raise Exception('Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}') raise Exception(
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
)
render_devices = list(filter(lambda x: is_device_compatible(x), render_devices)) render_devices = list(filter(lambda x: is_device_compatible(x), render_devices))
if len(render_devices) == 0: if len(render_devices) == 0:
raise Exception('Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion') raise Exception(
"Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion"
)
else: else:
raise Exception('Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}') raise Exception(
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
)
else: else:
render_devices = ['auto'] render_devices = ["auto"]
if 'auto' in render_devices: if "auto" in render_devices:
render_devices = auto_pick_devices(active_devices) render_devices = auto_pick_devices(active_devices)
if 'cpu' in render_devices: if "cpu" in render_devices:
log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!') log.warn("WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!")
active_devices = set(active_devices) active_devices = set(active_devices)
render_devices = set(render_devices) render_devices = set(render_devices)
@ -53,19 +62,21 @@ def get_device_delta(render_devices, active_devices):
return devices_to_start, devices_to_stop return devices_to_start, devices_to_stop
def auto_pick_devices(currently_active_devices): def auto_pick_devices(currently_active_devices):
global mem_free_threshold global mem_free_threshold
if not torch.cuda.is_available(): return ['cpu'] if not torch.cuda.is_available():
return ["cpu"]
device_count = torch.cuda.device_count() device_count = torch.cuda.device_count()
if device_count == 1: if device_count == 1:
return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu'] return ["cuda:0"] if is_device_compatible("cuda:0") else ["cpu"]
log.debug('Autoselecting GPU. Using most free memory.') log.debug("Autoselecting GPU. Using most free memory.")
devices = [] devices = []
for device in range(device_count): for device in range(device_count):
device = f'cuda:{device}' device = f"cuda:{device}"
if not is_device_compatible(device): if not is_device_compatible(device):
continue continue
@ -73,11 +84,13 @@ def auto_pick_devices(currently_active_devices):
mem_free /= float(10**9) mem_free /= float(10**9)
mem_total /= float(10**9) mem_total /= float(10**9)
device_name = torch.cuda.get_device_name(device) device_name = torch.cuda.get_device_name(device)
log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb') log.debug(
devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free}) f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
)
devices.append({"device": device, "device_name": device_name, "mem_free": mem_free})
devices.sort(key=lambda x:x['mem_free'], reverse=True) devices.sort(key=lambda x: x["mem_free"], reverse=True)
max_mem_free = devices[0]['mem_free'] max_mem_free = devices[0]["mem_free"]
curr_mem_free_threshold = COMPARABLE_GPU_PERCENTILE * max_mem_free curr_mem_free_threshold = COMPARABLE_GPU_PERCENTILE * max_mem_free
mem_free_threshold = max(curr_mem_free_threshold, mem_free_threshold) mem_free_threshold = max(curr_mem_free_threshold, mem_free_threshold)
@ -87,23 +100,26 @@ def auto_pick_devices(currently_active_devices):
# always be very low (since their VRAM contains the model). # always be very low (since their VRAM contains the model).
# These already-running devices probably aren't terrible, since they were picked in the past. # These already-running devices probably aren't terrible, since they were picked in the past.
# Worst case, the user can restart the program and that'll get rid of them. # Worst case, the user can restart the program and that'll get rid of them.
devices = list(filter((lambda x: x['mem_free'] > mem_free_threshold or x['device'] in currently_active_devices), devices)) devices = list(
devices = list(map(lambda x: x['device'], devices)) filter((lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices), devices)
)
devices = list(map(lambda x: x["device"], devices))
return devices return devices
def device_init(context, device): def device_init(context, device):
''' """
This function assumes the 'device' has already been verified to be compatible. This function assumes the 'device' has already been verified to be compatible.
`get_device_delta()` has already filtered out incompatible devices. `get_device_delta()` has already filtered out incompatible devices.
''' """
validate_device_id(device, log_prefix='device_init') validate_device_id(device, log_prefix="device_init")
if device == 'cpu': if device == "cpu":
context.device = 'cpu' context.device = "cpu"
context.device_name = get_processor_name() context.device_name = get_processor_name()
context.half_precision = False context.half_precision = False
log.debug(f'Render device CPU available as {context.device_name}') log.debug(f"Render device CPU available as {context.device_name}")
return return
context.device_name = torch.cuda.get_device_name(device) context.device_name = torch.cuda.get_device_name(device)
@ -111,7 +127,7 @@ def device_init(context, device):
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
if needs_to_force_full_precision(context): if needs_to_force_full_precision(context):
log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') log.warn(f"forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}")
# Apply force_full_precision now before models are loaded. # Apply force_full_precision now before models are loaded.
context.half_precision = False context.half_precision = False
@ -120,72 +136,90 @@ def device_init(context, device):
return return
def needs_to_force_full_precision(context): def needs_to_force_full_precision(context):
if 'FORCE_FULL_PRECISION' in os.environ: if "FORCE_FULL_PRECISION" in os.environ:
return True return True
device_name = context.device_name.lower() device_name = context.device_name.lower()
return (('nvidia' in device_name or 'geforce' in device_name or 'quadro' in device_name) and (' 1660' in device_name or ' 1650' in device_name or ' t400' in device_name or ' t550' in device_name or ' t600' in device_name or ' t1000' in device_name or ' t1200' in device_name or ' t2000' in device_name)) return ("nvidia" in device_name or "geforce" in device_name or "quadro" in device_name) and (
" 1660" in device_name
or " 1650" in device_name
or " t400" in device_name
or " t550" in device_name
or " t600" in device_name
or " t1000" in device_name
or " t1200" in device_name
or " t2000" in device_name
)
def get_max_vram_usage_level(device): def get_max_vram_usage_level(device):
if device != 'cpu': if device != "cpu":
_, mem_total = torch.cuda.mem_get_info(device) _, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9) mem_total /= float(10**9)
if mem_total < 4.5: if mem_total < 4.5:
return 'low' return "low"
elif mem_total < 6.5: elif mem_total < 6.5:
return 'balanced' return "balanced"
return 'high' return "high"
def validate_device_id(device, log_prefix=''):
def validate_device_id(device, log_prefix=""):
def is_valid(): def is_valid():
if not isinstance(device, str): if not isinstance(device, str):
return False return False
if device == 'cpu': if device == "cpu":
return True return True
if not device.startswith('cuda:') or not device[5:].isnumeric(): if not device.startswith("cuda:") or not device[5:].isnumeric():
return False return False
return True return True
if not is_valid(): if not is_valid():
raise EnvironmentError(f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}") raise EnvironmentError(
f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}"
)
def is_device_compatible(device): def is_device_compatible(device):
''' """
Returns True/False, and prints any compatibility errors Returns True/False, and prints any compatibility errors
''' """
# static variable "history". # static variable "history".
is_device_compatible.history = getattr(is_device_compatible, 'history', {}) is_device_compatible.history = getattr(is_device_compatible, "history", {})
try: try:
validate_device_id(device, log_prefix='is_device_compatible') validate_device_id(device, log_prefix="is_device_compatible")
except: except:
log.error(str(e)) log.error(str(e))
return False return False
if device == 'cpu': return True if device == "cpu":
return True
# Memory check # Memory check
try: try:
_, mem_total = torch.cuda.mem_get_info(device) _, mem_total = torch.cuda.mem_get_info(device)
mem_total /= float(10**9) mem_total /= float(10**9)
if mem_total < 3.0: if mem_total < 3.0:
if is_device_compatible.history.get(device) == None: if is_device_compatible.history.get(device) == None:
log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion') log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion")
is_device_compatible.history[device] = 1 is_device_compatible.history[device] = 1
return False return False
except RuntimeError as e: except RuntimeError as e:
log.error(str(e)) log.error(str(e))
return False return False
return True return True
def get_processor_name(): def get_processor_name():
try: try:
import platform, subprocess import platform, subprocess
if platform.system() == "Windows": if platform.system() == "Windows":
return platform.processor() return platform.processor()
elif platform.system() == "Darwin": elif platform.system() == "Darwin":
os.environ['PATH'] = os.environ['PATH'] + os.pathsep + '/usr/sbin' os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
command = "sysctl -n machdep.cpu.brand_string" command = "sysctl -n machdep.cpu.brand_string"
return subprocess.check_output(command).strip() return subprocess.check_output(command).strip()
elif platform.system() == "Linux": elif platform.system() == "Linux":

View File

@ -8,29 +8,31 @@ from sdkit import Context
from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model
from sdkit.utils import hash_file_quick from sdkit.utils import hash_file_quick
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'] KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan"]
MODEL_EXTENSIONS = { MODEL_EXTENSIONS = {
'stable-diffusion': ['.ckpt', '.safetensors'], "stable-diffusion": [".ckpt", ".safetensors"],
'vae': ['.vae.pt', '.ckpt', '.safetensors'], "vae": [".vae.pt", ".ckpt", ".safetensors"],
'hypernetwork': ['.pt', '.safetensors'], "hypernetwork": [".pt", ".safetensors"],
'gfpgan': ['.pth'], "gfpgan": [".pth"],
'realesrgan': ['.pth'], "realesrgan": [".pth"],
} }
DEFAULT_MODELS = { DEFAULT_MODELS = {
'stable-diffusion': [ # needed to support the legacy installations "stable-diffusion": [ # needed to support the legacy installations
'custom-model', # only one custom model file was supported initially, creatively named 'custom-model' "custom-model", # only one custom model file was supported initially, creatively named 'custom-model'
'sd-v1-4', # Default fallback. "sd-v1-4", # Default fallback.
], ],
'gfpgan': ['GFPGANv1.3'], "gfpgan": ["GFPGANv1.3"],
'realesrgan': ['RealESRGAN_x4plus'], "realesrgan": ["RealESRGAN_x4plus"],
} }
MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork'] MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork"]
known_models = {} known_models = {}
def init(): def init():
make_model_folders() make_model_folders()
getModels() # run this once, to cache the picklescan results getModels() # run this once, to cache the picklescan results
def load_default_models(context: Context): def load_default_models(context: Context):
set_vram_optimizations(context) set_vram_optimizations(context)
@ -39,27 +41,28 @@ def load_default_models(context: Context):
for model_type in MODELS_TO_LOAD_ON_START: for model_type in MODELS_TO_LOAD_ON_START:
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type) context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
try: try:
load_model(context, model_type) load_model(context, model_type)
except Exception as e: except Exception as e:
log.error(f'[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]') log.error(f"[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]")
log.error(f'[red]Error: {e}[/red]') log.error(f"[red]Error: {e}[/red]")
log.error(f'[red]Consider removing the model from the model folder.[red]') log.error(f"[red]Consider removing the model from the model folder.[red]")
def unload_all(context: Context): def unload_all(context: Context):
for model_type in KNOWN_MODEL_TYPES: for model_type in KNOWN_MODEL_TYPES:
unload_model(context, model_type) unload_model(context, model_type)
def resolve_model_to_use(model_name:str=None, model_type:str=None):
def resolve_model_to_use(model_name: str = None, model_type: str = None):
model_extensions = MODEL_EXTENSIONS.get(model_type, []) model_extensions = MODEL_EXTENSIONS.get(model_type, [])
default_models = DEFAULT_MODELS.get(model_type, []) default_models = DEFAULT_MODELS.get(model_type, [])
config = app.getConfig() config = app.getConfig()
model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR] model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR]
if not model_name: # When None try user configured model. if not model_name: # When None try user configured model.
# config = getConfig() # config = getConfig()
if 'model' in config and model_type in config['model']: if "model" in config and model_type in config["model"]:
model_name = config['model'][model_type] model_name = config["model"][model_type]
if model_name: if model_name:
# Check models directory # Check models directory
@ -84,41 +87,54 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
for model_extension in model_extensions: for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension): if os.path.exists(default_model_path + model_extension):
if model_name is not None: if model_name is not None:
log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}') log.warn(
f"Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}"
)
return default_model_path + model_extension return default_model_path + model_extension
return None return None
def reload_models_if_necessary(context: Context, task_data: TaskData): def reload_models_if_necessary(context: Context, task_data: TaskData):
model_paths_in_req = { model_paths_in_req = {
'stable-diffusion': task_data.use_stable_diffusion_model, "stable-diffusion": task_data.use_stable_diffusion_model,
'vae': task_data.use_vae_model, "vae": task_data.use_vae_model,
'hypernetwork': task_data.use_hypernetwork_model, "hypernetwork": task_data.use_hypernetwork_model,
'gfpgan': task_data.use_face_correction, "gfpgan": task_data.use_face_correction,
'realesrgan': task_data.use_upscale, "realesrgan": task_data.use_upscale,
}
models_to_reload = {
model_type: path
for model_type, path in model_paths_in_req.items()
if context.model_paths.get(model_type) != path
} }
models_to_reload = {model_type: path for model_type, path in model_paths_in_req.items() if context.model_paths.get(model_type) != path}
if set_vram_optimizations(context): # reload SD if set_vram_optimizations(context): # reload SD
models_to_reload['stable-diffusion'] = model_paths_in_req['stable-diffusion'] models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"]
for model_type, model_path_in_req in models_to_reload.items(): for model_type, model_path_in_req in models_to_reload.items():
context.model_paths[model_type] = model_path_in_req context.model_paths[model_type] = model_path_in_req
action_fn = unload_model if context.model_paths[model_type] is None else load_model action_fn = unload_model if context.model_paths[model_type] is None else load_model
action_fn(context, model_type, scan_model=False) # we've scanned them already action_fn(context, model_type, scan_model=False) # we've scanned them already
def resolve_model_paths(task_data: TaskData): def resolve_model_paths(task_data: TaskData):
task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion') task_data.use_stable_diffusion_model = resolve_model_to_use(
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae') task_data.use_stable_diffusion_model, model_type="stable-diffusion"
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork') )
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae")
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork")
if task_data.use_face_correction:
task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, "gfpgan")
if task_data.use_upscale:
task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, "realesrgan")
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan')
def set_vram_optimizations(context: Context): def set_vram_optimizations(context: Context):
config = app.getConfig() config = app.getConfig()
vram_usage_level = config.get('vram_usage_level', 'balanced') vram_usage_level = config.get("vram_usage_level", "balanced")
if vram_usage_level != context.vram_usage_level: if vram_usage_level != context.vram_usage_level:
context.vram_usage_level = vram_usage_level context.vram_usage_level = vram_usage_level
@ -126,42 +142,51 @@ def set_vram_optimizations(context: Context):
return False return False
def make_model_folders(): def make_model_folders():
for model_type in KNOWN_MODEL_TYPES: for model_type in KNOWN_MODEL_TYPES:
model_dir_path = os.path.join(app.MODELS_DIR, model_type) model_dir_path = os.path.join(app.MODELS_DIR, model_type)
os.makedirs(model_dir_path, exist_ok=True) os.makedirs(model_dir_path, exist_ok=True)
help_file_name = f'Place your {model_type} model files here.txt' help_file_name = f"Place your {model_type} model files here.txt"
help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}' help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}'
with open(os.path.join(model_dir_path, help_file_name), 'w', encoding='utf-8') as f: with open(os.path.join(model_dir_path, help_file_name), "w", encoding="utf-8") as f:
f.write(help_file_contents) f.write(help_file_contents)
def is_malicious_model(file_path): def is_malicious_model(file_path):
try: try:
scan_result = scan_model(file_path) scan_result = scan_model(file_path)
if scan_result.issues_count > 0 or scan_result.infected_files > 0: if scan_result.issues_count > 0 or scan_result.infected_files > 0:
log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) log.warn(
":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]"
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
)
return True return True
else: else:
log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) log.debug(
"Scan %s: [green]%d scanned, %d issue, %d infected.[/green]"
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
)
return False return False
except Exception as e: except Exception as e:
log.error(f'error while scanning: {file_path}, error: {e}') log.error(f"error while scanning: {file_path}, error: {e}")
return False return False
def getModels(): def getModels():
models = { models = {
'active': { "active": {
'stable-diffusion': 'sd-v1-4', "stable-diffusion": "sd-v1-4",
'vae': '', "vae": "",
'hypernetwork': '', "hypernetwork": "",
}, },
'options': { "options": {
'stable-diffusion': ['sd-v1-4'], "stable-diffusion": ["sd-v1-4"],
'vae': [], "vae": [],
'hypernetwork': [], "hypernetwork": [],
}, },
} }
@ -171,13 +196,16 @@ def getModels():
"Raised when picklescan reports a problem with a model" "Raised when picklescan reports a problem with a model"
pass pass
def scan_directory(directory, suffixes, directoriesFirst:bool=True): def scan_directory(directory, suffixes, directoriesFirst: bool = True):
nonlocal models_scanned nonlocal models_scanned
tree = [] tree = []
for entry in sorted(os.scandir(directory), key = lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower())): for entry in sorted(
os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower())
):
if entry.is_file(): if entry.is_file():
matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes)) matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes))
if len(matching_suffix) == 0: continue if len(matching_suffix) == 0:
continue
matching_suffix = matching_suffix[0] matching_suffix = matching_suffix[0]
mtime = entry.stat().st_mtime mtime = entry.stat().st_mtime
@ -187,12 +215,12 @@ def getModels():
if is_malicious_model(entry.path): if is_malicious_model(entry.path):
raise MaliciousModelException(entry.path) raise MaliciousModelException(entry.path)
known_models[entry.path] = mtime known_models[entry.path] = mtime
tree.append(entry.name[:-len(matching_suffix)]) tree.append(entry.name[: -len(matching_suffix)])
elif entry.is_dir(): elif entry.is_dir():
scan=scan_directory(entry.path, suffixes, directoriesFirst=False) scan = scan_directory(entry.path, suffixes, directoriesFirst=False)
if len(scan) != 0: if len(scan) != 0:
tree.append( (entry.name, scan ) ) tree.append((entry.name, scan))
return tree return tree
def listModels(model_type): def listModels(model_type):
@ -204,21 +232,22 @@ def getModels():
os.makedirs(models_dir) os.makedirs(models_dir)
try: try:
models['options'][model_type] = scan_directory(models_dir, model_extensions) models["options"][model_type] = scan_directory(models_dir, model_extensions)
except MaliciousModelException as e: except MaliciousModelException as e:
models['scan-error'] = e models["scan-error"] = e
# custom models # custom models
listModels(model_type='stable-diffusion') listModels(model_type="stable-diffusion")
listModels(model_type='vae') listModels(model_type="vae")
listModels(model_type='hypernetwork') listModels(model_type="hypernetwork")
listModels(model_type='gfpgan') listModels(model_type="gfpgan")
if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. Nothing infected[/]') if models_scanned > 0:
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
# legacy # legacy
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') custom_weight_path = os.path.join(app.SD_DIR, "custom-model.ckpt")
if os.path.exists(custom_weight_path): if os.path.exists(custom_weight_path):
models['options']['stable-diffusion'].append('custom-model') models["options"]["stable-diffusion"].append("custom-model")
return models return models

View File

@ -12,22 +12,26 @@ from sdkit.generate import generate_images
from sdkit.filter import apply_filters from sdkit.filter import apply_filters
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc
context = Context() # thread-local context = Context() # thread-local
''' """
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
''' """
def init(device): def init(device):
''' """
Initializes the fields that will be bound to this runtime's context, and sets the current torch device Initializes the fields that will be bound to this runtime's context, and sets the current torch device
''' """
context.stop_processing = False context.stop_processing = False
context.temp_images = {} context.temp_images = {}
context.partial_x_samples = None context.partial_x_samples = None
device_manager.device_init(context, device) device_manager.device_init(context, device)
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
def make_images(
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
):
context.stop_processing = False context.stop_processing = False
print_task_info(req, task_data) print_task_info(req, task_data)
@ -36,18 +40,24 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
res = Response(req, task_data, images=construct_response(images, seeds, task_data, base_seed=req.seed)) res = Response(req, task_data, images=construct_response(images, seeds, task_data, base_seed=req.seed))
res = res.json() res = res.json()
data_queue.put(json.dumps(res)) data_queue.put(json.dumps(res))
log.info('Task completed') log.info("Task completed")
return res return res
def print_task_info(req: GenerateImageRequest, task_data: TaskData):
req_str = pprint.pformat(get_printable_request(req)).replace("[","\[")
task_str = pprint.pformat(task_data.dict()).replace("[","\[")
log.info(f'request: {req_str}')
log.info(f'task data: {task_str}')
def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback): def print_task_info(req: GenerateImageRequest, task_data: TaskData):
images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress) req_str = pprint.pformat(get_printable_request(req)).replace("[", "\[")
task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
log.info(f"request: {req_str}")
log.info(f"task data: {task_str}")
def make_images_internal(
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
):
images, user_stopped = generate_images_internal(
req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress
)
filtered_images = filter_images(task_data, images, user_stopped) filtered_images = filter_images(task_data, images, user_stopped)
if task_data.save_to_disk_path is not None: if task_data.save_to_disk_path is not None:
@ -59,13 +69,22 @@ def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_qu
else: else:
return images + filtered_images, seeds + seeds return images + filtered_images, seeds + seeds
def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
def generate_images_internal(
req: GenerateImageRequest,
task_data: TaskData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
stream_image_progress: bool,
):
context.temp_images.clear() context.temp_images.clear()
callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress) callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
try: try:
if req.init_image is not None: req.sampler_name = 'ddim' if req.init_image is not None:
req.sampler_name = "ddim"
images = generate_images(context, callback=callback, **req.dict()) images = generate_images(context, callback=callback, **req.dict())
user_stopped = False user_stopped = False
@ -75,31 +94,44 @@ def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, dat
if context.partial_x_samples is not None: if context.partial_x_samples is not None:
images = latent_samples_to_images(context, context.partial_x_samples) images = latent_samples_to_images(context, context.partial_x_samples)
finally: finally:
if hasattr(context, 'partial_x_samples') and context.partial_x_samples is not None: if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None:
del context.partial_x_samples del context.partial_x_samples
context.partial_x_samples = None context.partial_x_samples = None
return images, user_stopped return images, user_stopped
def filter_images(task_data: TaskData, images: list, user_stopped): def filter_images(task_data: TaskData, images: list, user_stopped):
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None): if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
return images return images
filters_to_apply = [] filters_to_apply = []
if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan') if task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower():
if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan') filters_to_apply.append("gfpgan")
if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower():
filters_to_apply.append("realesrgan")
return apply_filters(context, filters_to_apply, images, scale=task_data.upscale_amount) return apply_filters(context, filters_to_apply, images, scale=task_data.upscale_amount)
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int): def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
return [ return [
ResponseImage( ResponseImage(
data=img_to_base64_str(img, task_data.output_format, task_data.output_quality), data=img_to_base64_str(img, task_data.output_format, task_data.output_quality),
seed=seed, seed=seed,
) for img, seed in zip(images, seeds) )
for img, seed in zip(images, seeds)
] ]
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
def make_step_callback(
req: GenerateImageRequest,
task_data: TaskData,
data_queue: queue.Queue,
task_temp_images: list,
step_callback,
stream_image_progress: bool,
):
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength) n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
last_callback_time = -1 last_callback_time = -1
@ -107,11 +139,11 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
partial_images = [] partial_images = []
images = latent_samples_to_images(context, x_samples) images = latent_samples_to_images(context, x_samples)
for i, img in enumerate(images): for i, img in enumerate(images):
buf = img_to_buffer(img, output_format='JPEG') buf = img_to_buffer(img, output_format="JPEG")
context.temp_images[f"{task_data.request_id}/{i}"] = buf context.temp_images[f"{task_data.request_id}/{i}"] = buf
task_temp_images[i] = buf task_temp_images[i] = buf
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"}) partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"})
del images del images
return partial_images return partial_images
@ -125,7 +157,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
progress = {"step": i, "step_time": step_time, "total_steps": n_steps} progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
if stream_image_progress and i % 5 == 0: if stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(x_samples, task_temp_images) progress["output"] = update_temp_img(x_samples, task_temp_images)
data_queue.put(json.dumps(progress)) data_queue.put(json.dumps(progress))

View File

@ -16,21 +16,25 @@ from easydiffusion import app, model_manager, task_manager
from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest
from easydiffusion.utils import log from easydiffusion.utils import log
log.info(f'started in {app.SD_DIR}') log.info(f"started in {app.SD_DIR}")
log.info(f'started at {datetime.datetime.now():%x %X}') log.info(f"started at {datetime.datetime.now():%x %X}")
server_api = FastAPI() server_api = FastAPI()
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} NOCACHE_HEADERS = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
class NoCacheStaticFiles(StaticFiles): class NoCacheStaticFiles(StaticFiles):
def is_not_modified(self, response_headers, request_headers) -> bool: def is_not_modified(self, response_headers, request_headers) -> bool:
if 'content-type' in response_headers and ('javascript' in response_headers['content-type'] or 'css' in response_headers['content-type']): if "content-type" in response_headers and (
"javascript" in response_headers["content-type"] or "css" in response_headers["content-type"]
):
response_headers.update(NOCACHE_HEADERS) response_headers.update(NOCACHE_HEADERS)
return False return False
return super().is_not_modified(response_headers, request_headers) return super().is_not_modified(response_headers, request_headers)
class SetAppConfigRequest(BaseModel): class SetAppConfigRequest(BaseModel):
update_branch: str = None update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None render_devices: Union[List[str], List[int], str, int] = None
@ -39,130 +43,142 @@ class SetAppConfigRequest(BaseModel):
listen_to_network: bool = None listen_to_network: bool = None
listen_port: int = None listen_port: int = None
def init(): def init():
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media") server_api.mount("/media", NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")), name="media")
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES: for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
server_api.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}") server_api.mount(
f"/plugins/{dir_prefix}", NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}"
)
@server_api.post('/app_config') @server_api.post("/app_config")
async def set_app_config(req : SetAppConfigRequest): async def set_app_config(req: SetAppConfigRequest):
return set_app_config_internal(req) return set_app_config_internal(req)
@server_api.get('/get/{key:path}') @server_api.get("/get/{key:path}")
def read_web_data(key:str=None): def read_web_data(key: str = None):
return read_web_data_internal(key) return read_web_data_internal(key)
@server_api.get('/ping') # Get server and optionally session status. @server_api.get("/ping") # Get server and optionally session status.
def ping(session_id:str=None): def ping(session_id: str = None):
return ping_internal(session_id) return ping_internal(session_id)
@server_api.post('/render') @server_api.post("/render")
def render(req: dict): def render(req: dict):
return render_internal(req) return render_internal(req)
@server_api.post('/model/merge') @server_api.post("/model/merge")
def model_merge(req: dict): def model_merge(req: dict):
print(req) print(req)
return model_merge_internal(req) return model_merge_internal(req)
@server_api.get('/image/stream/{task_id:int}') @server_api.get("/image/stream/{task_id:int}")
def stream(task_id:int): def stream(task_id: int):
return stream_internal(task_id) return stream_internal(task_id)
@server_api.get('/image/stop') @server_api.get("/image/stop")
def stop(task: int): def stop(task: int):
return stop_internal(task) return stop_internal(task)
@server_api.get('/image/tmp/{task_id:int}/{img_id:int}') @server_api.get("/image/tmp/{task_id:int}/{img_id:int}")
def get_image(task_id: int, img_id: int): def get_image(task_id: int, img_id: int):
return get_image_internal(task_id, img_id) return get_image_internal(task_id, img_id)
@server_api.get('/') @server_api.get("/")
def read_root(): def read_root():
return FileResponse(os.path.join(app.SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) return FileResponse(os.path.join(app.SD_UI_DIR, "index.html"), headers=NOCACHE_HEADERS)
@server_api.on_event("shutdown") @server_api.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.') task_manager.current_state_error = SystemExit("Application shutting down.")
# API implementations # API implementations
def set_app_config_internal(req : SetAppConfigRequest): def set_app_config_internal(req: SetAppConfigRequest):
config = app.getConfig() config = app.getConfig()
if req.update_branch is not None: if req.update_branch is not None:
config['update_branch'] = req.update_branch config["update_branch"] = req.update_branch
if req.render_devices is not None: if req.render_devices is not None:
update_render_devices_in_config(config, req.render_devices) update_render_devices_in_config(config, req.render_devices)
if req.ui_open_browser_on_start is not None: if req.ui_open_browser_on_start is not None:
if 'ui' not in config: if "ui" not in config:
config['ui'] = {} config["ui"] = {}
config['ui']['open_browser_on_start'] = req.ui_open_browser_on_start config["ui"]["open_browser_on_start"] = req.ui_open_browser_on_start
if req.listen_to_network is not None: if req.listen_to_network is not None:
if 'net' not in config: if "net" not in config:
config['net'] = {} config["net"] = {}
config['net']['listen_to_network'] = bool(req.listen_to_network) config["net"]["listen_to_network"] = bool(req.listen_to_network)
if req.listen_port is not None: if req.listen_port is not None:
if 'net' not in config: if "net" not in config:
config['net'] = {} config["net"] = {}
config['net']['listen_port'] = int(req.listen_port) config["net"]["listen_port"] = int(req.listen_port)
try: try:
app.setConfig(config) app.setConfig(config)
if req.render_devices: if req.render_devices:
app.update_render_threads() app.update_render_threads()
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS) return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
except Exception as e: except Exception as e:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
def update_render_devices_in_config(config, render_devices): def update_render_devices_in_config(config, render_devices):
if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'): if render_devices not in ("cpu", "auto") and not render_devices.startswith("cuda:"):
raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}') raise HTTPException(status_code=400, detail=f"Invalid render device requested: {render_devices}")
if render_devices.startswith('cuda:'): if render_devices.startswith("cuda:"):
render_devices = render_devices.split(',') render_devices = render_devices.split(",")
config['render_devices'] = render_devices config["render_devices"] = render_devices
def read_web_data_internal(key:str=None):
if not key: # /get without parameters, stable-diffusion easter egg. def read_web_data_internal(key: str = None):
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot if not key: # /get without parameters, stable-diffusion easter egg.
elif key == 'app_config': raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == "app_config":
return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS) return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS)
elif key == 'system_info': elif key == "system_info":
config = app.getConfig() config = app.getConfig()
output_dir = config.get('force_save_path', os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME)) output_dir = config.get("force_save_path", os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME))
system_info = { system_info = {
'devices': task_manager.get_devices(), "devices": task_manager.get_devices(),
'hosts': app.getIPConfig(), "hosts": app.getIPConfig(),
'default_output_dir': output_dir, "default_output_dir": output_dir,
'enforce_output_dir': ('force_save_path' in config), "enforce_output_dir": ("force_save_path" in config),
} }
system_info['devices']['config'] = config.get('render_devices', "auto") system_info["devices"]["config"] = config.get("render_devices", "auto")
return JSONResponse(system_info, headers=NOCACHE_HEADERS) return JSONResponse(system_info, headers=NOCACHE_HEADERS)
elif key == 'models': elif key == "models":
return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS) return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(app.SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) elif key == "modifiers":
elif key == 'ui_plugins': return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS) return FileResponse(os.path.join(app.SD_UI_DIR, "modifiers.json"), headers=NOCACHE_HEADERS)
elif key == "ui_plugins":
return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS)
else: else:
raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found raise HTTPException(status_code=404, detail=f"Request for unknown {key}") # HTTP404 Not Found
def ping_internal(session_id:str=None):
if task_manager.is_alive() <= 0: # Check that render threads are alive. def ping_internal(session_id: str = None):
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) if task_manager.is_alive() <= 0: # Check that render threads are alive.
raise HTTPException(status_code=500, detail='Render thread is dead.') if task_manager.current_state_error:
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
raise HTTPException(status_code=500, detail="Render thread is dead.")
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration):
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive # Alive
response = {'status': str(task_manager.current_state)} response = {"status": str(task_manager.current_state)}
if session_id: if session_id:
session = task_manager.get_cached_session(session_id, update_ttl=True) session = task_manager.get_cached_session(session_id, update_ttl=True)
response['tasks'] = {id(t): t.status for t in session.tasks} response["tasks"] = {id(t): t.status for t in session.tasks}
response['devices'] = task_manager.get_devices() response["devices"] = task_manager.get_devices()
return JSONResponse(response, headers=NOCACHE_HEADERS) return JSONResponse(response, headers=NOCACHE_HEADERS)
def render_internal(req: dict): def render_internal(req: dict):
try: try:
# separate out the request data into rendering and task-specific data # separate out the request data into rendering and task-specific data
@ -171,80 +187,99 @@ def render_internal(req: dict):
# Overwrite user specified save path # Overwrite user specified save path
config = app.getConfig() config = app.getConfig()
if 'force_save_path' in config: if "force_save_path" in config:
task_data.save_to_disk_path = config['force_save_path'] task_data.save_to_disk_path = config["force_save_path"]
render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision
app.save_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model, task_data.vram_usage_level) app.save_to_config(
task_data.use_stable_diffusion_model,
task_data.use_vae_model,
task_data.use_hypernetwork_model,
task_data.vram_usage_level,
)
# enqueue the task # enqueue the task
new_task = task_manager.render(render_req, task_data) new_task = task_manager.render(render_req, task_data)
response = { response = {
'status': str(task_manager.current_state), "status": str(task_manager.current_state),
'queue': len(task_manager.tasks_queue), "queue": len(task_manager.tasks_queue),
'stream': f'/image/stream/{id(new_task)}', "stream": f"/image/stream/{id(new_task)}",
'task': id(new_task) "task": id(new_task),
} }
return JSONResponse(response, headers=NOCACHE_HEADERS) return JSONResponse(response, headers=NOCACHE_HEADERS)
except ChildProcessError as e: # Render thread is dead except ChildProcessError as e: # Render thread is dead
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many. except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
except Exception as e: except Exception as e:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
def model_merge_internal(req: dict): def model_merge_internal(req: dict):
try: try:
from sdkit.train import merge_models from sdkit.train import merge_models
from easydiffusion.utils.save_utils import filename_regex from easydiffusion.utils.save_utils import filename_regex
mergeReq: MergeRequest = MergeRequest.parse_obj(req) mergeReq: MergeRequest = MergeRequest.parse_obj(req)
merge_models(model_manager.resolve_model_to_use(mergeReq.model0,'stable-diffusion'), merge_models(
model_manager.resolve_model_to_use(mergeReq.model1,'stable-diffusion'), model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
mergeReq.ratio, model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
os.path.join(app.MODELS_DIR, 'stable-diffusion', filename_regex.sub('_', mergeReq.out_path)), mergeReq.ratio,
mergeReq.use_fp16 os.path.join(app.MODELS_DIR, "stable-diffusion", filename_regex.sub("_", mergeReq.out_path)),
mergeReq.use_fp16,
) )
return JSONResponse({'status':'OK'}, headers=NOCACHE_HEADERS) return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
except Exception as e: except Exception as e:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
def stream_internal(task_id:int):
#TODO Move to WebSockets ?? def stream_internal(task_id: int):
# TODO Move to WebSockets ??
task = task_manager.get_cached_task(task_id, update_ttl=True) task = task_manager.get_cached_task(task_id, update_ttl=True)
if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound if not task:
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict raise HTTPException(status_code=404, detail=f"Request {task_id} not found.") # HTTP404 NotFound
# if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
if task.buffer_queue.empty() and not task.lock.locked(): if task.buffer_queue.empty() and not task.lock.locked():
if task.response: if task.response:
#log.info(f'Session {session_id} sending cached response') # log.info(f'Session {session_id} sending cached response')
return JSONResponse(task.response, headers=NOCACHE_HEADERS) return JSONResponse(task.response, headers=NOCACHE_HEADERS)
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early raise HTTPException(status_code=425, detail="Too Early, task not started yet.") # HTTP425 Too Early
#log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') # log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(task.read_buffer_generator(), media_type='application/json') return StreamingResponse(task.read_buffer_generator(), media_type="application/json")
def stop_internal(task: int): def stop_internal(task: int):
if not task: if not task:
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable: if (
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict task_manager.current_state == task_manager.ServerStates.Online
task_manager.current_state_error = StopAsyncIteration('') or task_manager.current_state == task_manager.ServerStates.Unavailable
return {'OK'} ):
raise HTTPException(status_code=409, detail="Not currently running any tasks.") # HTTP409 Conflict
task_manager.current_state_error = StopAsyncIteration("")
return {"OK"}
task_id = task task_id = task
task = task_manager.get_cached_task(task_id, update_ttl=False) task = task_manager.get_cached_task(task_id, update_ttl=False)
if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found if not task:
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict raise HTTPException(status_code=404, detail=f"Task {task_id} was not found.") # HTTP404 Not Found
task.error = StopAsyncIteration(f'Task {task_id} stop requested.') if isinstance(task.error, StopAsyncIteration):
return {'OK'} raise HTTPException(status_code=409, detail=f"Task {task_id} is already stopped.") # HTTP409 Conflict
task.error = StopAsyncIteration(f"Task {task_id} stop requested.")
return {"OK"}
def get_image_internal(task_id: int, img_id: int): def get_image_internal(task_id: int, img_id: int):
task = task_manager.get_cached_task(task_id, update_ttl=True) task = task_manager.get_cached_task(task_id, update_ttl=True)
if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound if not task:
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early raise HTTPException(status_code=410, detail=f"Task {task_id} could not be found.") # HTTP404 NotFound
if not task.temp_images[img_id]:
raise HTTPException(status_code=425, detail="Too Early, task data is not available yet.") # HTTP425 Too Early
try: try:
img_data = task.temp_images[img_id] img_data = task.temp_images[img_id]
img_data.seek(0) img_data.seek(0)
return StreamingResponse(img_data, media_type='image/jpeg') return StreamingResponse(img_data, media_type="image/jpeg")
except KeyError as e: except KeyError as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View File

@ -7,7 +7,7 @@ Notes:
import json import json
import traceback import traceback
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
import torch import torch
import queue, threading, time, weakref import queue, threading, time, weakref
@ -19,71 +19,98 @@ from easydiffusion.utils import log
from sdkit.utils import gc from sdkit.utils import gc
THREAD_NAME_PREFIX = '' THREAD_NAME_PREFIX = ""
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.' ERR_LOCK_FAILED = " failed to acquire lock within timeout."
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task. LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths. # It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init. DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init.
class SymbolClass(type): # Print nicely formatted Symbol names.
def __repr__(self):
return self.__qualname__
def __str__(self):
return self.__name__
class Symbol(metaclass=SymbolClass):
pass
class SymbolClass(type): # Print nicely formatted Symbol names.
def __repr__(self): return self.__qualname__
def __str__(self): return self.__name__
class Symbol(metaclass=SymbolClass): pass
class ServerStates: class ServerStates:
class Init(Symbol): pass class Init(Symbol):
class LoadingModel(Symbol): pass pass
class Online(Symbol): pass
class Rendering(Symbol): pass
class Unavailable(Symbol): pass
class RenderTask(): # Task with output queue and completion lock. class LoadingModel(Symbol):
pass
class Online(Symbol):
pass
class Rendering(Symbol):
pass
class Unavailable(Symbol):
pass
class RenderTask: # Task with output queue and completion lock.
def __init__(self, req: GenerateImageRequest, task_data: TaskData): def __init__(self, req: GenerateImageRequest, task_data: TaskData):
task_data.request_id = id(self) task_data.request_id = id(self)
self.render_request: GenerateImageRequest = req # Initial Request self.render_request: GenerateImageRequest = req # Initial Request
self.task_data: TaskData = task_data self.task_data: TaskData = task_data
self.response: Any = None # Copy of the last reponse self.response: Any = None # Copy of the last reponse
self.render_device = None # Select the task affinity. (Not used to change active devices). self.render_device = None # Select the task affinity. (Not used to change active devices).
self.temp_images:list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2) self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
self.error: Exception = None self.error: Exception = None
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
async def read_buffer_generator(self): async def read_buffer_generator(self):
try: try:
while not self.buffer_queue.empty(): while not self.buffer_queue.empty():
res = self.buffer_queue.get(block=False) res = self.buffer_queue.get(block=False)
self.buffer_queue.task_done() self.buffer_queue.task_done()
yield res yield res
except queue.Empty as e: yield except queue.Empty as e:
yield
@property @property
def status(self): def status(self):
if self.lock.locked(): if self.lock.locked():
return 'running' return "running"
if isinstance(self.error, StopAsyncIteration): if isinstance(self.error, StopAsyncIteration):
return 'stopped' return "stopped"
if self.error: if self.error:
return 'error' return "error"
if not self.buffer_queue.empty(): if not self.buffer_queue.empty():
return 'buffer' return "buffer"
if self.response: if self.response:
return 'completed' return "completed"
return 'pending' return "pending"
@property @property
def is_pending(self): def is_pending(self):
return bool(not self.response and not self.error) return bool(not self.response and not self.error)
# Temporary cache to allow to query tasks results for a short time after they are completed. # Temporary cache to allow to query tasks results for a short time after they are completed.
class DataCache(): class DataCache:
def __init__(self): def __init__(self):
self._base = dict() self._base = dict()
self._lock: threading.Lock = threading.Lock() self._lock: threading.Lock = threading.Lock()
def _get_ttl_time(self, ttl: int) -> int: def _get_ttl_time(self, ttl: int) -> int:
return int(time.time()) + ttl return int(time.time()) + ttl
def _is_expired(self, timestamp: int) -> bool: def _is_expired(self, timestamp: int) -> bool:
return int(time.time()) >= timestamp return int(time.time()) >= timestamp
def clean(self) -> None: def clean(self) -> None:
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clean' + ERR_LOCK_FAILED) if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
raise Exception("DataCache.clean" + ERR_LOCK_FAILED)
try: try:
# Create a list of expired keys to delete # Create a list of expired keys to delete
to_delete = [] to_delete = []
@ -95,20 +122,26 @@ class DataCache():
for key in to_delete: for key in to_delete:
(_, val) = self._base[key] (_, val) = self._base[key]
if isinstance(val, RenderTask): if isinstance(val, RenderTask):
log.debug(f'RenderTask {key} expired. Data removed.') log.debug(f"RenderTask {key} expired. Data removed.")
elif isinstance(val, SessionState): elif isinstance(val, SessionState):
log.debug(f'Session {key} expired. Data removed.') log.debug(f"Session {key} expired. Data removed.")
else: else:
log.debug(f'Key {key} expired. Data removed.') log.debug(f"Key {key} expired. Data removed.")
del self._base[key] del self._base[key]
finally: finally:
self._lock.release() self._lock.release()
def clear(self) -> None: def clear(self) -> None:
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clear' + ERR_LOCK_FAILED) if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
try: self._base.clear() raise Exception("DataCache.clear" + ERR_LOCK_FAILED)
finally: self._lock.release() try:
self._base.clear()
finally:
self._lock.release()
def delete(self, key: Hashable) -> bool: def delete(self, key: Hashable) -> bool:
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.delete' + ERR_LOCK_FAILED) if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
raise Exception("DataCache.delete" + ERR_LOCK_FAILED)
try: try:
if key not in self._base: if key not in self._base:
return False return False
@ -116,8 +149,10 @@ class DataCache():
return True return True
finally: finally:
self._lock.release() self._lock.release()
def keep(self, key: Hashable, ttl: int) -> bool: def keep(self, key: Hashable, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.keep' + ERR_LOCK_FAILED) if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
raise Exception("DataCache.keep" + ERR_LOCK_FAILED)
try: try:
if key in self._base: if key in self._base:
_, value = self._base.get(key) _, value = self._base.get(key)
@ -126,12 +161,12 @@ class DataCache():
return False return False
finally: finally:
self._lock.release() self._lock.release()
def put(self, key: Hashable, value: Any, ttl: int) -> bool: def put(self, key: Hashable, value: Any, ttl: int) -> bool:
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.put' + ERR_LOCK_FAILED) if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
raise Exception("DataCache.put" + ERR_LOCK_FAILED)
try: try:
self._base[key] = ( self._base[key] = (self._get_ttl_time(ttl), value)
self._get_ttl_time(ttl), value
)
except Exception as e: except Exception as e:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
return False return False
@ -139,35 +174,41 @@ class DataCache():
return True return True
finally: finally:
self._lock.release() self._lock.release()
def tryGet(self, key: Hashable) -> Any: def tryGet(self, key: Hashable) -> Any:
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.tryGet' + ERR_LOCK_FAILED) if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
raise Exception("DataCache.tryGet" + ERR_LOCK_FAILED)
try: try:
ttl, value = self._base.get(key, (None, None)) ttl, value = self._base.get(key, (None, None))
if ttl is not None and self._is_expired(ttl): if ttl is not None and self._is_expired(ttl):
log.debug(f'Session {key} expired. Discarding data.') log.debug(f"Session {key} expired. Discarding data.")
del self._base[key] del self._base[key]
return None return None
return value return value
finally: finally:
self._lock.release() self._lock.release()
manager_lock = threading.RLock() manager_lock = threading.RLock()
render_threads = [] render_threads = []
current_state = ServerStates.Init current_state = ServerStates.Init
current_state_error:Exception = None current_state_error: Exception = None
tasks_queue = [] tasks_queue = []
session_cache = DataCache() session_cache = DataCache()
task_cache = DataCache() task_cache = DataCache()
weak_thread_data = weakref.WeakKeyDictionary() weak_thread_data = weakref.WeakKeyDictionary()
idle_event: threading.Event = threading.Event() idle_event: threading.Event = threading.Event()
class SessionState():
class SessionState:
def __init__(self, id: str): def __init__(self, id: str):
self._id = id self._id = id
self._tasks_ids = [] self._tasks_ids = []
@property @property
def id(self): def id(self):
return self._id return self._id
@property @property
def tasks(self): def tasks(self):
tasks = [] tasks = []
@ -176,6 +217,7 @@ class SessionState():
if task: if task:
tasks.append(task) tasks.append(task)
return tasks return tasks
def put(self, task, ttl=TASK_TTL): def put(self, task, ttl=TASK_TTL):
task_id = id(task) task_id = id(task)
self._tasks_ids.append(task_id) self._tasks_ids.append(task_id)
@ -185,10 +227,12 @@ class SessionState():
self._tasks_ids.pop(0) self._tasks_ids.pop(0)
return True return True
def thread_get_next_task(): def thread_get_next_task():
from easydiffusion import renderer from easydiffusion import renderer
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.') log.warn(f"Render thread on device: {renderer.context.device} failed to acquire manager lock.")
return None return None
if len(tasks_queue) <= 0: if len(tasks_queue) <= 0:
manager_lock.release() manager_lock.release()
@ -202,10 +246,10 @@ def thread_get_next_task():
continue # requested device alive, skip current one. continue # requested device alive, skip current one.
else: else:
# Requested device is not active, return error to UI. # Requested device is not active, return error to UI.
queued_task.error = Exception(queued_task.render_device + ' is not currently active.') queued_task.error = Exception(queued_task.render_device + " is not currently active.")
task = queued_task task = queued_task
break break
if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1: if not queued_task.render_device and renderer.context.device == "cpu" and is_alive() > 1:
# not asking for any specific devices, cpu want to grab task but other render devices are alive. # not asking for any specific devices, cpu want to grab task but other render devices are alive.
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
task = queued_task task = queued_task
@ -216,17 +260,19 @@ def thread_get_next_task():
finally: finally:
manager_lock.release() manager_lock.release()
def thread_render(device): def thread_render(device):
global current_state, current_state_error global current_state, current_state_error
from easydiffusion import renderer, model_manager from easydiffusion import renderer, model_manager
try: try:
renderer.init(device) renderer.init(device)
weak_thread_data[threading.current_thread()] = { weak_thread_data[threading.current_thread()] = {
'device': renderer.context.device, "device": renderer.context.device,
'device_name': renderer.context.device_name, "device_name": renderer.context.device_name,
'alive': True "alive": True,
} }
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
@ -235,17 +281,14 @@ def thread_render(device):
current_state = ServerStates.Online current_state = ServerStates.Online
except Exception as e: except Exception as e:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
weak_thread_data[threading.current_thread()] = { weak_thread_data[threading.current_thread()] = {"error": e, "alive": False}
'error': e,
'alive': False
}
return return
while True: while True:
session_cache.clean() session_cache.clean()
task_cache.clean() task_cache.clean()
if not weak_thread_data[threading.current_thread()]['alive']: if not weak_thread_data[threading.current_thread()]["alive"]:
log.info(f'Shutting down thread for device {renderer.context.device}') log.info(f"Shutting down thread for device {renderer.context.device}")
model_manager.unload_all(renderer.context) model_manager.unload_all(renderer.context)
return return
if isinstance(current_state_error, SystemExit): if isinstance(current_state_error, SystemExit):
@ -258,39 +301,47 @@ def thread_render(device):
continue continue
if task.error is not None: if task.error is not None:
log.error(task.error) log.error(task.error)
task.response = {"status": 'failed', "detail": str(task.error)} task.response = {"status": "failed", "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
continue continue
if current_state_error: if current_state_error:
task.error = current_state_error task.error = current_state_error
task.response = {"status": 'failed', "detail": str(task.error)} task.response = {"status": "failed", "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
continue continue
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}') log.info(f"Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}")
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') if not task.lock.acquire(blocking=False):
raise Exception("Got locked task from queue.")
try: try:
def step_callback(): def step_callback():
global current_state_error global current_state_error
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): if (
isinstance(current_state_error, SystemExit)
or isinstance(current_state_error, StopAsyncIteration)
or isinstance(task.error, StopAsyncIteration)
):
renderer.context.stop_processing = True renderer.context.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration): if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error task.error = current_state_error
current_state_error = None current_state_error = None
log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}') log.info(f"Session {task.task_data.session_id} sent cancel signal for task {id(task)}")
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
model_manager.resolve_model_paths(task.task_data) model_manager.resolve_model_paths(task.task_data)
model_manager.reload_models_if_necessary(renderer.context, task.task_data) model_manager.reload_models_if_necessary(renderer.context, task.task_data)
current_state = ServerStates.Rendering current_state = ServerStates.Rendering
task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback) task.response = renderer.make_images(
task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback
)
# Before looping back to the generator, mark cache as still alive. # Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL) task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL)
except Exception as e: except Exception as e:
task.error = str(e) task.error = str(e)
task.response = {"status": 'failed', "detail": str(task.error)} task.response = {"status": "failed", "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
log.error(traceback.format_exc()) log.error(traceback.format_exc())
finally: finally:
@ -299,21 +350,25 @@ def thread_render(device):
task_cache.keep(id(task), TASK_TTL) task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration): if isinstance(task.error, StopAsyncIteration):
log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!') log.info(f"Session {task.task_data.session_id} task {id(task)} cancelled!")
elif task.error is not None: elif task.error is not None:
log.info(f'Session {task.task_data.session_id} task {id(task)} failed!') log.info(f"Session {task.task_data.session_id} task {id(task)} failed!")
else: else:
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.') log.info(
f"Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}."
)
current_state = ServerStates.Online current_state = ServerStates.Online
def get_cached_task(task_id:str, update_ttl:bool=False):
def get_cached_task(task_id: str, update_ttl: bool = False):
# By calling keep before tryGet, wont discard if was expired. # By calling keep before tryGet, wont discard if was expired.
if update_ttl and not task_cache.keep(task_id, TASK_TTL): if update_ttl and not task_cache.keep(task_id, TASK_TTL):
# Failed to keep task, already gone. # Failed to keep task, already gone.
return None return None
return task_cache.tryGet(task_id) return task_cache.tryGet(task_id)
def get_cached_session(session_id:str, update_ttl:bool=False):
def get_cached_session(session_id: str, update_ttl: bool = False):
if update_ttl: if update_ttl:
session_cache.keep(session_id, TASK_TTL) session_cache.keep(session_id, TASK_TTL)
session = session_cache.tryGet(session_id) session = session_cache.tryGet(session_id)
@ -322,64 +377,68 @@ def get_cached_session(session_id:str, update_ttl:bool=False):
session_cache.put(session_id, session, TASK_TTL) session_cache.put(session_id, session, TASK_TTL)
return session return session
def get_devices(): def get_devices():
devices = { devices = {
'all': {}, "all": {},
'active': {}, "active": {},
} }
def get_device_info(device): def get_device_info(device):
if device == 'cpu': if device == "cpu":
return {'name': device_manager.get_processor_name()} return {"name": device_manager.get_processor_name()}
mem_free, mem_total = torch.cuda.mem_get_info(device) mem_free, mem_total = torch.cuda.mem_get_info(device)
mem_free /= float(10**9) mem_free /= float(10**9)
mem_total /= float(10**9) mem_total /= float(10**9)
return { return {
'name': torch.cuda.get_device_name(device), "name": torch.cuda.get_device_name(device),
'mem_free': mem_free, "mem_free": mem_free,
'mem_total': mem_total, "mem_total": mem_total,
'max_vram_usage_level': device_manager.get_max_vram_usage_level(device), "max_vram_usage_level": device_manager.get_max_vram_usage_level(device),
} }
# list the compatible devices # list the compatible devices
gpu_count = torch.cuda.device_count() gpu_count = torch.cuda.device_count()
for device in range(gpu_count): for device in range(gpu_count):
device = f'cuda:{device}' device = f"cuda:{device}"
if not device_manager.is_device_compatible(device): if not device_manager.is_device_compatible(device):
continue continue
devices['all'].update({device: get_device_info(device)}) devices["all"].update({device: get_device_info(device)})
devices['all'].update({'cpu': get_device_info('cpu')}) devices["all"].update({"cpu": get_device_info("cpu")})
# list the activated devices # list the activated devices
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('get_devices' + ERR_LOCK_FAILED) if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
raise Exception("get_devices" + ERR_LOCK_FAILED)
try: try:
for rthread in render_threads: for rthread in render_threads:
if not rthread.is_alive(): if not rthread.is_alive():
continue continue
weak_data = weak_thread_data.get(rthread) weak_data = weak_thread_data.get(rthread)
if not weak_data or not 'device' in weak_data or not 'device_name' in weak_data: if not weak_data or not "device" in weak_data or not "device_name" in weak_data:
continue continue
device = weak_data['device'] device = weak_data["device"]
devices['active'].update({device: get_device_info(device)}) devices["active"].update({device: get_device_info(device)})
finally: finally:
manager_lock.release() manager_lock.release()
return devices return devices
def is_alive(device=None): def is_alive(device=None):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED) if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
raise Exception("is_alive" + ERR_LOCK_FAILED)
nbr_alive = 0 nbr_alive = 0
try: try:
for rthread in render_threads: for rthread in render_threads:
if device is not None: if device is not None:
weak_data = weak_thread_data.get(rthread) weak_data = weak_thread_data.get(rthread)
if weak_data is None or not 'device' in weak_data or weak_data['device'] is None: if weak_data is None or not "device" in weak_data or weak_data["device"] is None:
continue continue
thread_device = weak_data['device'] thread_device = weak_data["device"]
if thread_device != device: if thread_device != device:
continue continue
if rthread.is_alive(): if rthread.is_alive():
@ -388,11 +447,13 @@ def is_alive(device=None):
finally: finally:
manager_lock.release() manager_lock.release()
def start_render_thread(device): def start_render_thread(device):
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED) if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
log.info(f'Start new Rendering Thread on device: {device}') raise Exception("start_render_thread" + ERR_LOCK_FAILED)
log.info(f"Start new Rendering Thread on device: {device}")
try: try:
rthread = threading.Thread(target=thread_render, kwargs={'device': device}) rthread = threading.Thread(target=thread_render, kwargs={"device": device})
rthread.daemon = True rthread.daemon = True
rthread.name = THREAD_NAME_PREFIX + device rthread.name = THREAD_NAME_PREFIX + device
rthread.start() rthread.start()
@ -400,8 +461,8 @@ def start_render_thread(device):
finally: finally:
manager_lock.release() manager_lock.release()
timeout = DEVICE_START_TIMEOUT timeout = DEVICE_START_TIMEOUT
while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]: while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]:
if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]: if rthread in weak_thread_data and "error" in weak_thread_data[rthread]:
log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}") log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
return False return False
if timeout <= 0: if timeout <= 0:
@ -410,25 +471,27 @@ def start_render_thread(device):
time.sleep(1) time.sleep(1)
return True return True
def stop_render_thread(device): def stop_render_thread(device):
try: try:
device_manager.validate_device_id(device, log_prefix='stop_render_thread') device_manager.validate_device_id(device, log_prefix="stop_render_thread")
except: except:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
return False return False
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED) if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
log.info(f'Stopping Rendering Thread on device: {device}') raise Exception("stop_render_thread" + ERR_LOCK_FAILED)
log.info(f"Stopping Rendering Thread on device: {device}")
try: try:
thread_to_remove = None thread_to_remove = None
for rthread in render_threads: for rthread in render_threads:
weak_data = weak_thread_data.get(rthread) weak_data = weak_thread_data.get(rthread)
if weak_data is None or not 'device' in weak_data or weak_data['device'] is None: if weak_data is None or not "device" in weak_data or weak_data["device"] is None:
continue continue
thread_device = weak_data['device'] thread_device = weak_data["device"]
if thread_device == device: if thread_device == device:
weak_data['alive'] = False weak_data["alive"] = False
thread_to_remove = rthread thread_to_remove = rthread
break break
if thread_to_remove is not None: if thread_to_remove is not None:
@ -439,44 +502,51 @@ def stop_render_thread(device):
return False return False
def update_render_threads(render_devices, active_devices): def update_render_threads(render_devices, active_devices):
devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices) devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices)
log.debug(f'devices_to_start: {devices_to_start}') log.debug(f"devices_to_start: {devices_to_start}")
log.debug(f'devices_to_stop: {devices_to_stop}') log.debug(f"devices_to_stop: {devices_to_stop}")
for device in devices_to_stop: for device in devices_to_stop:
if is_alive(device) <= 0: if is_alive(device) <= 0:
log.debug(f'{device} is not alive') log.debug(f"{device} is not alive")
continue continue
if not stop_render_thread(device): if not stop_render_thread(device):
log.warn(f'{device} could not stop render thread') log.warn(f"{device} could not stop render thread")
for device in devices_to_start: for device in devices_to_start:
if is_alive(device) >= 1: if is_alive(device) >= 1:
log.debug(f'{device} already registered.') log.debug(f"{device} already registered.")
continue continue
if not start_render_thread(device): if not start_render_thread(device):
log.warn(f'{device} failed to start.') log.warn(f"{device} failed to start.")
if is_alive() <= 0: # No running devices, probably invalid user config. if is_alive() <= 0: # No running devices, probably invalid user config.
raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json') raise EnvironmentError(
'ERROR: No active render devices! Please verify the "render_devices" value in config.json'
)
log.debug(f"active devices: {get_devices()['active']}") log.debug(f"active devices: {get_devices()['active']}")
def shutdown_event(): # Signal render thread to close on shutdown
def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error global current_state_error
current_state_error = SystemExit('Application shutting down.') current_state_error = SystemExit("Application shutting down.")
def render(render_req: GenerateImageRequest, task_data: TaskData): def render(render_req: GenerateImageRequest, task_data: TaskData):
current_thread_count = is_alive() current_thread_count = is_alive()
if current_thread_count <= 0: # Render thread is dead if current_thread_count <= 0: # Render thread is dead
raise ChildProcessError('Rendering thread has died.') raise ChildProcessError("Rendering thread has died.")
# Alive, check if task in cache # Alive, check if task in cache
session = get_cached_session(task_data.session_id, update_ttl=True) session = get_cached_session(task_data.session_id, update_ttl=True)
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
if current_thread_count < len(pending_tasks): if current_thread_count < len(pending_tasks):
raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') raise ConnectionRefusedError(
f"Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}."
)
new_task = RenderTask(render_req, task_data) new_task = RenderTask(render_req, task_data)
if session.put(new_task, TASK_TTL): if session.put(new_task, TASK_TTL):
@ -489,4 +559,4 @@ def render(render_req: GenerateImageRequest, task_data: TaskData):
return new_task return new_task
finally: finally:
manager_lock.release() manager_lock.release()
raise RuntimeError('Failed to add task to cache.') raise RuntimeError("Failed to add task to cache.")

View File

@ -1,6 +1,7 @@
from pydantic import BaseModel from pydantic import BaseModel
from typing import Any from typing import Any
class GenerateImageRequest(BaseModel): class GenerateImageRequest(BaseModel):
prompt: str = "" prompt: str = ""
negative_prompt: str = "" negative_prompt: str = ""
@ -18,29 +19,31 @@ class GenerateImageRequest(BaseModel):
prompt_strength: float = 0.8 prompt_strength: float = 0.8
preserve_init_image_color_profile = False preserve_init_image_color_profile = False
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms" sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
hypernetwork_strength: float = 0 hypernetwork_strength: float = 0
class TaskData(BaseModel): class TaskData(BaseModel):
request_id: str = None request_id: str = None
session_id: str = "session" session_id: str = "session"
save_to_disk_path: str = None save_to_disk_path: str = None
vram_usage_level: str = "balanced" # or "low" or "medium" vram_usage_level: str = "balanced" # or "low" or "medium"
use_face_correction: str = None # or "GFPGANv1.3" use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
upscale_amount: int = 4 # or 2 upscale_amount: int = 4 # or 2
use_stable_diffusion_model: str = "sd-v1-4" use_stable_diffusion_model: str = "sd-v1-4"
# use_stable_diffusion_config: str = "v1-inference" # use_stable_diffusion_config: str = "v1-inference"
use_vae_model: str = None use_vae_model: str = None
use_hypernetwork_model: str = None use_hypernetwork_model: str = None
show_only_filtered_image: bool = False show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png" output_format: str = "jpeg" # or "png"
output_quality: int = 75 output_quality: int = 75
metadata_output_format: str = "txt" # or "json" metadata_output_format: str = "txt" # or "json"
stream_image_progress: bool = False stream_image_progress: bool = False
class MergeRequest(BaseModel): class MergeRequest(BaseModel):
model0: str = None model0: str = None
model1: str = None model1: str = None
@ -48,8 +51,9 @@ class MergeRequest(BaseModel):
out_path: str = "mix" out_path: str = "mix"
use_fp16 = True use_fp16 = True
class Image: class Image:
data: str # base64 data: str # base64
seed: int seed: int
is_nsfw: bool is_nsfw: bool
path_abs: str = None path_abs: str = None
@ -65,6 +69,7 @@ class Image:
"path_abs": self.path_abs, "path_abs": self.path_abs,
} }
class Response: class Response:
render_request: GenerateImageRequest render_request: GenerateImageRequest
task_data: TaskData task_data: TaskData
@ -80,7 +85,7 @@ class Response:
del self.render_request.init_image_mask del self.render_request.init_image_mask
res = { res = {
"status": 'succeeded', "status": "succeeded",
"render_request": self.render_request.dict(), "render_request": self.render_request.dict(),
"task_data": self.task_data.dict(), "task_data": self.task_data.dict(),
"output": [], "output": [],
@ -91,5 +96,6 @@ class Response:
return res return res
class UserInitiatedStop(Exception): class UserInitiatedStop(Exception):
pass pass

View File

@ -1,8 +1,8 @@
import logging import logging
log = logging.getLogger('easydiffusion') log = logging.getLogger("easydiffusion")
from .save_utils import ( from .save_utils import (
save_images_to_disk, save_images_to_disk,
get_printable_request, get_printable_request,
) )

View File

@ -7,89 +7,126 @@ from easydiffusion.types import TaskData, GenerateImageRequest
from sdkit.utils import save_images, save_dicts from sdkit.utils import save_images, save_dicts
filename_regex = re.compile('[^a-zA-Z0-9._-]') filename_regex = re.compile("[^a-zA-Z0-9._-]")
# keep in sync with `ui/media/js/dnd.js` # keep in sync with `ui/media/js/dnd.js`
TASK_TEXT_MAPPING = { TASK_TEXT_MAPPING = {
'prompt': 'Prompt', "prompt": "Prompt",
'width': 'Width', "width": "Width",
'height': 'Height', "height": "Height",
'seed': 'Seed', "seed": "Seed",
'num_inference_steps': 'Steps', "num_inference_steps": "Steps",
'guidance_scale': 'Guidance Scale', "guidance_scale": "Guidance Scale",
'prompt_strength': 'Prompt Strength', "prompt_strength": "Prompt Strength",
'use_face_correction': 'Use Face Correction', "use_face_correction": "Use Face Correction",
'use_upscale': 'Use Upscaling', "use_upscale": "Use Upscaling",
'upscale_amount': 'Upscale By', "upscale_amount": "Upscale By",
'sampler_name': 'Sampler', "sampler_name": "Sampler",
'negative_prompt': 'Negative Prompt', "negative_prompt": "Negative Prompt",
'use_stable_diffusion_model': 'Stable Diffusion model', "use_stable_diffusion_model": "Stable Diffusion model",
'use_vae_model': 'VAE model', "use_vae_model": "VAE model",
'use_hypernetwork_model': 'Hypernetwork model', "use_hypernetwork_model": "Hypernetwork model",
'hypernetwork_strength': 'Hypernetwork Strength' "hypernetwork_strength": "Hypernetwork Strength",
} }
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData): def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
now = time.time() now = time.time()
save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id)) save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub("_", task_data.session_id))
metadata_entries = get_metadata_entries_for_request(req, task_data) metadata_entries = get_metadata_entries_for_request(req, task_data)
make_filename = make_filename_callback(req, now=now) make_filename = make_filename_callback(req, now=now)
if task_data.show_only_filtered_image or filtered_images is images: if task_data.show_only_filtered_image or filtered_images is images:
save_images(filtered_images, save_dir_path, file_name=make_filename, output_format=task_data.output_format, output_quality=task_data.output_quality) save_images(
if task_data.metadata_output_format.lower() in ['json', 'txt', 'embed']: filtered_images,
save_dicts(metadata_entries, save_dir_path, file_name=make_filename, output_format=task_data.metadata_output_format, file_format=task_data.output_format) save_dir_path,
file_name=make_filename,
output_format=task_data.output_format,
output_quality=task_data.output_quality,
)
if task_data.metadata_output_format.lower() in ["json", "txt", "embed"]:
save_dicts(
metadata_entries,
save_dir_path,
file_name=make_filename,
output_format=task_data.metadata_output_format,
file_format=task_data.output_format,
)
else: else:
make_filter_filename = make_filename_callback(req, now=now, suffix='filtered') make_filter_filename = make_filename_callback(req, now=now, suffix="filtered")
save_images(
images,
save_dir_path,
file_name=make_filename,
output_format=task_data.output_format,
output_quality=task_data.output_quality,
)
save_images(
filtered_images,
save_dir_path,
file_name=make_filter_filename,
output_format=task_data.output_format,
output_quality=task_data.output_quality,
)
if task_data.metadata_output_format.lower() in ["json", "txt", "embed"]:
save_dicts(
metadata_entries,
save_dir_path,
file_name=make_filter_filename,
output_format=task_data.metadata_output_format,
file_format=task_data.output_format,
)
save_images(images, save_dir_path, file_name=make_filename, output_format=task_data.output_format, output_quality=task_data.output_quality)
save_images(filtered_images, save_dir_path, file_name=make_filter_filename, output_format=task_data.output_format, output_quality=task_data.output_quality)
if task_data.metadata_output_format.lower() in ['json', 'txt', 'embed']:
save_dicts(metadata_entries, save_dir_path, file_name=make_filter_filename, output_format=task_data.metadata_output_format, file_format=task_data.output_format)
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData): def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData):
metadata = get_printable_request(req) metadata = get_printable_request(req)
metadata.update({ metadata.update(
'use_stable_diffusion_model': task_data.use_stable_diffusion_model, {
'use_vae_model': task_data.use_vae_model, "use_stable_diffusion_model": task_data.use_stable_diffusion_model,
'use_hypernetwork_model': task_data.use_hypernetwork_model, "use_vae_model": task_data.use_vae_model,
'use_face_correction': task_data.use_face_correction, "use_hypernetwork_model": task_data.use_hypernetwork_model,
'use_upscale': task_data.use_upscale, "use_face_correction": task_data.use_face_correction,
}) "use_upscale": task_data.use_upscale,
if metadata['use_upscale'] is not None: }
metadata['upscale_amount'] = task_data.upscale_amount )
if (task_data.use_hypernetwork_model is None): if metadata["use_upscale"] is not None:
del metadata['hypernetwork_strength'] metadata["upscale_amount"] = task_data.upscale_amount
if task_data.use_hypernetwork_model is None:
del metadata["hypernetwork_strength"]
# if text, format it in the text format expected by the UI # if text, format it in the text format expected by the UI
is_txt_format = (task_data.metadata_output_format.lower() == 'txt') is_txt_format = task_data.metadata_output_format.lower() == "txt"
if is_txt_format: if is_txt_format:
metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING} metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING}
entries = [metadata.copy() for _ in range(req.num_outputs)] entries = [metadata.copy() for _ in range(req.num_outputs)]
for i, entry in enumerate(entries): for i, entry in enumerate(entries):
entry['Seed' if is_txt_format else 'seed'] = req.seed + i entry["Seed" if is_txt_format else "seed"] = req.seed + i
return entries return entries
def get_printable_request(req: GenerateImageRequest): def get_printable_request(req: GenerateImageRequest):
metadata = req.dict() metadata = req.dict()
del metadata['init_image'] del metadata["init_image"]
del metadata['init_image_mask'] del metadata["init_image_mask"]
if (req.init_image is None): if req.init_image is None:
del metadata['prompt_strength'] del metadata["prompt_strength"]
return metadata return metadata
def make_filename_callback(req: GenerateImageRequest, suffix=None, now=None): def make_filename_callback(req: GenerateImageRequest, suffix=None, now=None):
if now is None: if now is None:
now = time.time() now = time.time()
def make_filename(i):
img_id = base64.b64encode(int(now+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
prompt_flattened = filename_regex.sub('_', req.prompt)[:50] def make_filename(i):
img_id = base64.b64encode(int(now + i).to_bytes(8, "big")).decode() # Generate unique ID based on time.
img_id = img_id.translate({43: None, 47: None, 61: None})[-8:] # Remove + / = and keep last 8 chars.
prompt_flattened = filename_regex.sub("_", req.prompt)[:50]
name = f"{prompt_flattened}_{img_id}" name = f"{prompt_flattened}_{img_id}"
name = name if suffix is None else f'{name}_{suffix}' name = name if suffix is None else f"{name}_{suffix}"
return name return name
return make_filename return make_filename