forked from extern/easydiffusion
Format code, PEP8 using Black
This commit is contained in:
parent
0ad08c609d
commit
2eb317c6b6
@ -7,7 +7,7 @@ import logging
|
|||||||
import shlex
|
import shlex
|
||||||
from rich.logging import RichHandler
|
from rich.logging import RichHandler
|
||||||
|
|
||||||
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
|
from sdkit.utils import log as sdkit_log # hack, so we can overwrite the log config
|
||||||
|
|
||||||
from easydiffusion import task_manager
|
from easydiffusion import task_manager
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
@ -16,83 +16,86 @@ from easydiffusion.utils import log
|
|||||||
for handler in logging.root.handlers[:]:
|
for handler in logging.root.handlers[:]:
|
||||||
logging.root.removeHandler(handler)
|
logging.root.removeHandler(handler)
|
||||||
|
|
||||||
LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s'
|
LOG_FORMAT = "%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s"
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format=LOG_FORMAT,
|
format=LOG_FORMAT,
|
||||||
datefmt="%X",
|
datefmt="%X",
|
||||||
handlers=[RichHandler(markup=True, rich_tracebacks=False, show_time=False, show_level=False)],
|
handlers=[RichHandler(markup=True, rich_tracebacks=False, show_time=False, show_level=False)],
|
||||||
)
|
)
|
||||||
|
|
||||||
SD_DIR = os.getcwd()
|
SD_DIR = os.getcwd()
|
||||||
|
|
||||||
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
|
SD_UI_DIR = os.getenv("SD_UI_PATH", None)
|
||||||
sys.path.append(os.path.dirname(SD_UI_DIR))
|
sys.path.append(os.path.dirname(SD_UI_DIR))
|
||||||
|
|
||||||
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
|
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts"))
|
||||||
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
|
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models"))
|
||||||
|
|
||||||
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'))
|
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins", "ui"))
|
||||||
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
|
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins", "ui"))
|
||||||
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
|
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, "core"), (USER_UI_PLUGINS_DIR, "user"))
|
||||||
|
|
||||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||||
PRESERVE_CONFIG_VARS = ['FORCE_FULL_PRECISION']
|
PRESERVE_CONFIG_VARS = ["FORCE_FULL_PRECISION"]
|
||||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||||
APP_CONFIG_DEFAULTS = {
|
APP_CONFIG_DEFAULTS = {
|
||||||
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
|
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
|
||||||
'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
|
"render_devices": "auto", # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
|
||||||
'update_branch': 'main',
|
"update_branch": "main",
|
||||||
'ui': {
|
"ui": {
|
||||||
'open_browser_on_start': True,
|
"open_browser_on_start": True,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
||||||
|
|
||||||
update_render_threads()
|
update_render_threads()
|
||||||
|
|
||||||
|
|
||||||
def getConfig(default_val=APP_CONFIG_DEFAULTS):
|
def getConfig(default_val=APP_CONFIG_DEFAULTS):
|
||||||
try:
|
try:
|
||||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
config_json_path = os.path.join(CONFIG_DIR, "config.json")
|
||||||
if not os.path.exists(config_json_path):
|
if not os.path.exists(config_json_path):
|
||||||
config = default_val
|
config = default_val
|
||||||
else:
|
else:
|
||||||
with open(config_json_path, 'r', encoding='utf-8') as f:
|
with open(config_json_path, "r", encoding="utf-8") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
if 'net' not in config:
|
if "net" not in config:
|
||||||
config['net'] = {}
|
config["net"] = {}
|
||||||
if os.getenv('SD_UI_BIND_PORT') is not None:
|
if os.getenv("SD_UI_BIND_PORT") is not None:
|
||||||
config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT'))
|
config["net"]["listen_port"] = int(os.getenv("SD_UI_BIND_PORT"))
|
||||||
else:
|
else:
|
||||||
config['net']['listen_port'] = 9000
|
config["net"]["listen_port"] = 9000
|
||||||
if os.getenv('SD_UI_BIND_IP') is not None:
|
if os.getenv("SD_UI_BIND_IP") is not None:
|
||||||
config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0')
|
config["net"]["listen_to_network"] = os.getenv("SD_UI_BIND_IP") == "0.0.0.0"
|
||||||
else:
|
else:
|
||||||
config['net']['listen_to_network'] = True
|
config["net"]["listen_to_network"] = True
|
||||||
return config
|
return config
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.warn(traceback.format_exc())
|
log.warn(traceback.format_exc())
|
||||||
return default_val
|
return default_val
|
||||||
|
|
||||||
|
|
||||||
def setConfig(config):
|
def setConfig(config):
|
||||||
try: # config.json
|
try: # config.json
|
||||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
config_json_path = os.path.join(CONFIG_DIR, "config.json")
|
||||||
with open(config_json_path, 'w', encoding='utf-8') as f:
|
with open(config_json_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(config, f)
|
json.dump(config, f)
|
||||||
except:
|
except:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
|
|
||||||
try: # config.bat
|
try: # config.bat
|
||||||
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
|
config_bat_path = os.path.join(CONFIG_DIR, "config.bat")
|
||||||
config_bat = []
|
config_bat = []
|
||||||
|
|
||||||
if 'update_branch' in config:
|
if "update_branch" in config:
|
||||||
config_bat.append(f"@set update_branch={config['update_branch']}")
|
config_bat.append(f"@set update_branch={config['update_branch']}")
|
||||||
|
|
||||||
config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}")
|
config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}")
|
||||||
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
|
bind_ip = "0.0.0.0" if config["net"]["listen_to_network"] else "127.0.0.1"
|
||||||
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
|
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
|
||||||
|
|
||||||
# Preserve these variables if they are set
|
# Preserve these variables if they are set
|
||||||
@ -101,20 +104,20 @@ def setConfig(config):
|
|||||||
config_bat.append(f"@set {var}={os.getenv(var)}")
|
config_bat.append(f"@set {var}={os.getenv(var)}")
|
||||||
|
|
||||||
if len(config_bat) > 0:
|
if len(config_bat) > 0:
|
||||||
with open(config_bat_path, 'w', encoding='utf-8') as f:
|
with open(config_bat_path, "w", encoding="utf-8") as f:
|
||||||
f.write('\r\n'.join(config_bat))
|
f.write("\r\n".join(config_bat))
|
||||||
except:
|
except:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
|
|
||||||
try: # config.sh
|
try: # config.sh
|
||||||
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
|
config_sh_path = os.path.join(CONFIG_DIR, "config.sh")
|
||||||
config_sh = ['#!/bin/bash']
|
config_sh = ["#!/bin/bash"]
|
||||||
|
|
||||||
if 'update_branch' in config:
|
if "update_branch" in config:
|
||||||
config_sh.append(f"export update_branch={config['update_branch']}")
|
config_sh.append(f"export update_branch={config['update_branch']}")
|
||||||
|
|
||||||
config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}")
|
config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}")
|
||||||
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
|
bind_ip = "0.0.0.0" if config["net"]["listen_to_network"] else "127.0.0.1"
|
||||||
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
|
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
|
||||||
|
|
||||||
# Preserve these variables if they are set
|
# Preserve these variables if they are set
|
||||||
@ -123,47 +126,51 @@ def setConfig(config):
|
|||||||
config_bat.append(f'export {var}="{shlex.quote(os.getenv(var))}"')
|
config_bat.append(f'export {var}="{shlex.quote(os.getenv(var))}"')
|
||||||
|
|
||||||
if len(config_sh) > 1:
|
if len(config_sh) > 1:
|
||||||
with open(config_sh_path, 'w', encoding='utf-8') as f:
|
with open(config_sh_path, "w", encoding="utf-8") as f:
|
||||||
f.write('\n'.join(config_sh))
|
f.write("\n".join(config_sh))
|
||||||
except:
|
except:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level):
|
def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level):
|
||||||
config = getConfig()
|
config = getConfig()
|
||||||
if 'model' not in config:
|
if "model" not in config:
|
||||||
config['model'] = {}
|
config["model"] = {}
|
||||||
|
|
||||||
config['model']['stable-diffusion'] = ckpt_model_name
|
config["model"]["stable-diffusion"] = ckpt_model_name
|
||||||
config['model']['vae'] = vae_model_name
|
config["model"]["vae"] = vae_model_name
|
||||||
config['model']['hypernetwork'] = hypernetwork_model_name
|
config["model"]["hypernetwork"] = hypernetwork_model_name
|
||||||
|
|
||||||
if vae_model_name is None or vae_model_name == "":
|
if vae_model_name is None or vae_model_name == "":
|
||||||
del config['model']['vae']
|
del config["model"]["vae"]
|
||||||
if hypernetwork_model_name is None or hypernetwork_model_name == "":
|
if hypernetwork_model_name is None or hypernetwork_model_name == "":
|
||||||
del config['model']['hypernetwork']
|
del config["model"]["hypernetwork"]
|
||||||
|
|
||||||
config['vram_usage_level'] = vram_usage_level
|
config["vram_usage_level"] = vram_usage_level
|
||||||
|
|
||||||
setConfig(config)
|
setConfig(config)
|
||||||
|
|
||||||
|
|
||||||
def update_render_threads():
|
def update_render_threads():
|
||||||
config = getConfig()
|
config = getConfig()
|
||||||
render_devices = config.get('render_devices', 'auto')
|
render_devices = config.get("render_devices", "auto")
|
||||||
active_devices = task_manager.get_devices()['active'].keys()
|
active_devices = task_manager.get_devices()["active"].keys()
|
||||||
|
|
||||||
log.debug(f'requesting for render_devices: {render_devices}')
|
log.debug(f"requesting for render_devices: {render_devices}")
|
||||||
task_manager.update_render_threads(render_devices, active_devices)
|
task_manager.update_render_threads(render_devices, active_devices)
|
||||||
|
|
||||||
|
|
||||||
def getUIPlugins():
|
def getUIPlugins():
|
||||||
plugins = []
|
plugins = []
|
||||||
|
|
||||||
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
|
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
|
||||||
for file in os.listdir(plugins_dir):
|
for file in os.listdir(plugins_dir):
|
||||||
if file.endswith('.plugin.js'):
|
if file.endswith(".plugin.js"):
|
||||||
plugins.append(f'/plugins/{dir_prefix}/{file}')
|
plugins.append(f"/plugins/{dir_prefix}/{file}")
|
||||||
|
|
||||||
return plugins
|
return plugins
|
||||||
|
|
||||||
|
|
||||||
def getIPConfig():
|
def getIPConfig():
|
||||||
try:
|
try:
|
||||||
ips = socket.gethostbyname_ex(socket.gethostname())
|
ips = socket.gethostbyname_ex(socket.gethostname())
|
||||||
@ -173,10 +180,13 @@ def getIPConfig():
|
|||||||
log.exception(e)
|
log.exception(e)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def open_browser():
|
def open_browser():
|
||||||
config = getConfig()
|
config = getConfig()
|
||||||
ui = config.get('ui', {})
|
ui = config.get("ui", {})
|
||||||
net = config.get('net', {'listen_port':9000})
|
net = config.get("net", {"listen_port": 9000})
|
||||||
port = net.get('listen_port', 9000)
|
port = net.get("listen_port", 9000)
|
||||||
if ui.get('open_browser_on_start', True):
|
if ui.get("open_browser_on_start", True):
|
||||||
import webbrowser; webbrowser.open(f"http://localhost:{port}")
|
import webbrowser
|
||||||
|
|
||||||
|
webbrowser.open(f"http://localhost:{port}")
|
||||||
|
@ -5,45 +5,54 @@ import re
|
|||||||
|
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
'''
|
"""
|
||||||
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
|
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
|
||||||
Otherwise the models will load at half-precision (i.e. float16).
|
Otherwise the models will load at half-precision (i.e. float16).
|
||||||
|
|
||||||
Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs).
|
Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs).
|
||||||
'''
|
"""
|
||||||
|
|
||||||
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
|
COMPARABLE_GPU_PERCENTILE = (
|
||||||
|
0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
|
||||||
|
)
|
||||||
|
|
||||||
mem_free_threshold = 0
|
mem_free_threshold = 0
|
||||||
|
|
||||||
|
|
||||||
def get_device_delta(render_devices, active_devices):
|
def get_device_delta(render_devices, active_devices):
|
||||||
'''
|
"""
|
||||||
render_devices: 'cpu', or 'auto' or ['cuda:N'...]
|
render_devices: 'cpu', or 'auto' or ['cuda:N'...]
|
||||||
active_devices: ['cpu', 'cuda:N'...]
|
active_devices: ['cpu', 'cuda:N'...]
|
||||||
'''
|
"""
|
||||||
|
|
||||||
if render_devices in ('cpu', 'auto'):
|
if render_devices in ("cpu", "auto"):
|
||||||
render_devices = [render_devices]
|
render_devices = [render_devices]
|
||||||
elif render_devices is not None:
|
elif render_devices is not None:
|
||||||
if isinstance(render_devices, str):
|
if isinstance(render_devices, str):
|
||||||
render_devices = [render_devices]
|
render_devices = [render_devices]
|
||||||
if isinstance(render_devices, list) and len(render_devices) > 0:
|
if isinstance(render_devices, list) and len(render_devices) > 0:
|
||||||
render_devices = list(filter(lambda x: x.startswith('cuda:'), render_devices))
|
render_devices = list(filter(lambda x: x.startswith("cuda:"), render_devices))
|
||||||
if len(render_devices) == 0:
|
if len(render_devices) == 0:
|
||||||
raise Exception('Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}')
|
raise Exception(
|
||||||
|
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
|
||||||
|
)
|
||||||
|
|
||||||
render_devices = list(filter(lambda x: is_device_compatible(x), render_devices))
|
render_devices = list(filter(lambda x: is_device_compatible(x), render_devices))
|
||||||
if len(render_devices) == 0:
|
if len(render_devices) == 0:
|
||||||
raise Exception('Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion')
|
raise Exception(
|
||||||
|
"Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise Exception('Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}')
|
raise Exception(
|
||||||
|
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
render_devices = ['auto']
|
render_devices = ["auto"]
|
||||||
|
|
||||||
if 'auto' in render_devices:
|
if "auto" in render_devices:
|
||||||
render_devices = auto_pick_devices(active_devices)
|
render_devices = auto_pick_devices(active_devices)
|
||||||
if 'cpu' in render_devices:
|
if "cpu" in render_devices:
|
||||||
log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
|
log.warn("WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!")
|
||||||
|
|
||||||
active_devices = set(active_devices)
|
active_devices = set(active_devices)
|
||||||
render_devices = set(render_devices)
|
render_devices = set(render_devices)
|
||||||
@ -53,19 +62,21 @@ def get_device_delta(render_devices, active_devices):
|
|||||||
|
|
||||||
return devices_to_start, devices_to_stop
|
return devices_to_start, devices_to_stop
|
||||||
|
|
||||||
|
|
||||||
def auto_pick_devices(currently_active_devices):
|
def auto_pick_devices(currently_active_devices):
|
||||||
global mem_free_threshold
|
global mem_free_threshold
|
||||||
|
|
||||||
if not torch.cuda.is_available(): return ['cpu']
|
if not torch.cuda.is_available():
|
||||||
|
return ["cpu"]
|
||||||
|
|
||||||
device_count = torch.cuda.device_count()
|
device_count = torch.cuda.device_count()
|
||||||
if device_count == 1:
|
if device_count == 1:
|
||||||
return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu']
|
return ["cuda:0"] if is_device_compatible("cuda:0") else ["cpu"]
|
||||||
|
|
||||||
log.debug('Autoselecting GPU. Using most free memory.')
|
log.debug("Autoselecting GPU. Using most free memory.")
|
||||||
devices = []
|
devices = []
|
||||||
for device in range(device_count):
|
for device in range(device_count):
|
||||||
device = f'cuda:{device}'
|
device = f"cuda:{device}"
|
||||||
if not is_device_compatible(device):
|
if not is_device_compatible(device):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -73,11 +84,13 @@ def auto_pick_devices(currently_active_devices):
|
|||||||
mem_free /= float(10**9)
|
mem_free /= float(10**9)
|
||||||
mem_total /= float(10**9)
|
mem_total /= float(10**9)
|
||||||
device_name = torch.cuda.get_device_name(device)
|
device_name = torch.cuda.get_device_name(device)
|
||||||
log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb')
|
log.debug(
|
||||||
devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free})
|
f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
|
||||||
|
)
|
||||||
|
devices.append({"device": device, "device_name": device_name, "mem_free": mem_free})
|
||||||
|
|
||||||
devices.sort(key=lambda x:x['mem_free'], reverse=True)
|
devices.sort(key=lambda x: x["mem_free"], reverse=True)
|
||||||
max_mem_free = devices[0]['mem_free']
|
max_mem_free = devices[0]["mem_free"]
|
||||||
curr_mem_free_threshold = COMPARABLE_GPU_PERCENTILE * max_mem_free
|
curr_mem_free_threshold = COMPARABLE_GPU_PERCENTILE * max_mem_free
|
||||||
mem_free_threshold = max(curr_mem_free_threshold, mem_free_threshold)
|
mem_free_threshold = max(curr_mem_free_threshold, mem_free_threshold)
|
||||||
|
|
||||||
@ -87,23 +100,26 @@ def auto_pick_devices(currently_active_devices):
|
|||||||
# always be very low (since their VRAM contains the model).
|
# always be very low (since their VRAM contains the model).
|
||||||
# These already-running devices probably aren't terrible, since they were picked in the past.
|
# These already-running devices probably aren't terrible, since they were picked in the past.
|
||||||
# Worst case, the user can restart the program and that'll get rid of them.
|
# Worst case, the user can restart the program and that'll get rid of them.
|
||||||
devices = list(filter((lambda x: x['mem_free'] > mem_free_threshold or x['device'] in currently_active_devices), devices))
|
devices = list(
|
||||||
devices = list(map(lambda x: x['device'], devices))
|
filter((lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices), devices)
|
||||||
|
)
|
||||||
|
devices = list(map(lambda x: x["device"], devices))
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
|
|
||||||
def device_init(context, device):
|
def device_init(context, device):
|
||||||
'''
|
"""
|
||||||
This function assumes the 'device' has already been verified to be compatible.
|
This function assumes the 'device' has already been verified to be compatible.
|
||||||
`get_device_delta()` has already filtered out incompatible devices.
|
`get_device_delta()` has already filtered out incompatible devices.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
validate_device_id(device, log_prefix='device_init')
|
validate_device_id(device, log_prefix="device_init")
|
||||||
|
|
||||||
if device == 'cpu':
|
if device == "cpu":
|
||||||
context.device = 'cpu'
|
context.device = "cpu"
|
||||||
context.device_name = get_processor_name()
|
context.device_name = get_processor_name()
|
||||||
context.half_precision = False
|
context.half_precision = False
|
||||||
log.debug(f'Render device CPU available as {context.device_name}')
|
log.debug(f"Render device CPU available as {context.device_name}")
|
||||||
return
|
return
|
||||||
|
|
||||||
context.device_name = torch.cuda.get_device_name(device)
|
context.device_name = torch.cuda.get_device_name(device)
|
||||||
@ -111,7 +127,7 @@ def device_init(context, device):
|
|||||||
|
|
||||||
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
|
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
|
||||||
if needs_to_force_full_precision(context):
|
if needs_to_force_full_precision(context):
|
||||||
log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
|
log.warn(f"forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}")
|
||||||
# Apply force_full_precision now before models are loaded.
|
# Apply force_full_precision now before models are loaded.
|
||||||
context.half_precision = False
|
context.half_precision = False
|
||||||
|
|
||||||
@ -120,72 +136,90 @@ def device_init(context, device):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def needs_to_force_full_precision(context):
|
def needs_to_force_full_precision(context):
|
||||||
if 'FORCE_FULL_PRECISION' in os.environ:
|
if "FORCE_FULL_PRECISION" in os.environ:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
device_name = context.device_name.lower()
|
device_name = context.device_name.lower()
|
||||||
return (('nvidia' in device_name or 'geforce' in device_name or 'quadro' in device_name) and (' 1660' in device_name or ' 1650' in device_name or ' t400' in device_name or ' t550' in device_name or ' t600' in device_name or ' t1000' in device_name or ' t1200' in device_name or ' t2000' in device_name))
|
return ("nvidia" in device_name or "geforce" in device_name or "quadro" in device_name) and (
|
||||||
|
" 1660" in device_name
|
||||||
|
or " 1650" in device_name
|
||||||
|
or " t400" in device_name
|
||||||
|
or " t550" in device_name
|
||||||
|
or " t600" in device_name
|
||||||
|
or " t1000" in device_name
|
||||||
|
or " t1200" in device_name
|
||||||
|
or " t2000" in device_name
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_max_vram_usage_level(device):
|
def get_max_vram_usage_level(device):
|
||||||
if device != 'cpu':
|
if device != "cpu":
|
||||||
_, mem_total = torch.cuda.mem_get_info(device)
|
_, mem_total = torch.cuda.mem_get_info(device)
|
||||||
mem_total /= float(10**9)
|
mem_total /= float(10**9)
|
||||||
|
|
||||||
if mem_total < 4.5:
|
if mem_total < 4.5:
|
||||||
return 'low'
|
return "low"
|
||||||
elif mem_total < 6.5:
|
elif mem_total < 6.5:
|
||||||
return 'balanced'
|
return "balanced"
|
||||||
|
|
||||||
return 'high'
|
return "high"
|
||||||
|
|
||||||
def validate_device_id(device, log_prefix=''):
|
|
||||||
|
def validate_device_id(device, log_prefix=""):
|
||||||
def is_valid():
|
def is_valid():
|
||||||
if not isinstance(device, str):
|
if not isinstance(device, str):
|
||||||
return False
|
return False
|
||||||
if device == 'cpu':
|
if device == "cpu":
|
||||||
return True
|
return True
|
||||||
if not device.startswith('cuda:') or not device[5:].isnumeric():
|
if not device.startswith("cuda:") or not device[5:].isnumeric():
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not is_valid():
|
if not is_valid():
|
||||||
raise EnvironmentError(f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}")
|
raise EnvironmentError(
|
||||||
|
f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_device_compatible(device):
|
def is_device_compatible(device):
|
||||||
'''
|
"""
|
||||||
Returns True/False, and prints any compatibility errors
|
Returns True/False, and prints any compatibility errors
|
||||||
'''
|
"""
|
||||||
# static variable "history".
|
# static variable "history".
|
||||||
is_device_compatible.history = getattr(is_device_compatible, 'history', {})
|
is_device_compatible.history = getattr(is_device_compatible, "history", {})
|
||||||
try:
|
try:
|
||||||
validate_device_id(device, log_prefix='is_device_compatible')
|
validate_device_id(device, log_prefix="is_device_compatible")
|
||||||
except:
|
except:
|
||||||
log.error(str(e))
|
log.error(str(e))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if device == 'cpu': return True
|
if device == "cpu":
|
||||||
|
return True
|
||||||
# Memory check
|
# Memory check
|
||||||
try:
|
try:
|
||||||
_, mem_total = torch.cuda.mem_get_info(device)
|
_, mem_total = torch.cuda.mem_get_info(device)
|
||||||
mem_total /= float(10**9)
|
mem_total /= float(10**9)
|
||||||
if mem_total < 3.0:
|
if mem_total < 3.0:
|
||||||
if is_device_compatible.history.get(device) == None:
|
if is_device_compatible.history.get(device) == None:
|
||||||
log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion')
|
log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion")
|
||||||
is_device_compatible.history[device] = 1
|
is_device_compatible.history[device] = 1
|
||||||
return False
|
return False
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
log.error(str(e))
|
log.error(str(e))
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_processor_name():
|
def get_processor_name():
|
||||||
try:
|
try:
|
||||||
import platform, subprocess
|
import platform, subprocess
|
||||||
|
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
return platform.processor()
|
return platform.processor()
|
||||||
elif platform.system() == "Darwin":
|
elif platform.system() == "Darwin":
|
||||||
os.environ['PATH'] = os.environ['PATH'] + os.pathsep + '/usr/sbin'
|
os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
|
||||||
command = "sysctl -n machdep.cpu.brand_string"
|
command = "sysctl -n machdep.cpu.brand_string"
|
||||||
return subprocess.check_output(command).strip()
|
return subprocess.check_output(command).strip()
|
||||||
elif platform.system() == "Linux":
|
elif platform.system() == "Linux":
|
||||||
|
@ -8,29 +8,31 @@ from sdkit import Context
|
|||||||
from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model
|
from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model
|
||||||
from sdkit.utils import hash_file_quick
|
from sdkit.utils import hash_file_quick
|
||||||
|
|
||||||
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan"]
|
||||||
MODEL_EXTENSIONS = {
|
MODEL_EXTENSIONS = {
|
||||||
'stable-diffusion': ['.ckpt', '.safetensors'],
|
"stable-diffusion": [".ckpt", ".safetensors"],
|
||||||
'vae': ['.vae.pt', '.ckpt', '.safetensors'],
|
"vae": [".vae.pt", ".ckpt", ".safetensors"],
|
||||||
'hypernetwork': ['.pt', '.safetensors'],
|
"hypernetwork": [".pt", ".safetensors"],
|
||||||
'gfpgan': ['.pth'],
|
"gfpgan": [".pth"],
|
||||||
'realesrgan': ['.pth'],
|
"realesrgan": [".pth"],
|
||||||
}
|
}
|
||||||
DEFAULT_MODELS = {
|
DEFAULT_MODELS = {
|
||||||
'stable-diffusion': [ # needed to support the legacy installations
|
"stable-diffusion": [ # needed to support the legacy installations
|
||||||
'custom-model', # only one custom model file was supported initially, creatively named 'custom-model'
|
"custom-model", # only one custom model file was supported initially, creatively named 'custom-model'
|
||||||
'sd-v1-4', # Default fallback.
|
"sd-v1-4", # Default fallback.
|
||||||
],
|
],
|
||||||
'gfpgan': ['GFPGANv1.3'],
|
"gfpgan": ["GFPGANv1.3"],
|
||||||
'realesrgan': ['RealESRGAN_x4plus'],
|
"realesrgan": ["RealESRGAN_x4plus"],
|
||||||
}
|
}
|
||||||
MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork']
|
MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork"]
|
||||||
|
|
||||||
known_models = {}
|
known_models = {}
|
||||||
|
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
make_model_folders()
|
make_model_folders()
|
||||||
getModels() # run this once, to cache the picklescan results
|
getModels() # run this once, to cache the picklescan results
|
||||||
|
|
||||||
|
|
||||||
def load_default_models(context: Context):
|
def load_default_models(context: Context):
|
||||||
set_vram_optimizations(context)
|
set_vram_optimizations(context)
|
||||||
@ -39,27 +41,28 @@ def load_default_models(context: Context):
|
|||||||
for model_type in MODELS_TO_LOAD_ON_START:
|
for model_type in MODELS_TO_LOAD_ON_START:
|
||||||
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
|
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
|
||||||
try:
|
try:
|
||||||
load_model(context, model_type)
|
load_model(context, model_type)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f'[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]')
|
log.error(f"[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]")
|
||||||
log.error(f'[red]Error: {e}[/red]')
|
log.error(f"[red]Error: {e}[/red]")
|
||||||
log.error(f'[red]Consider removing the model from the model folder.[red]')
|
log.error(f"[red]Consider removing the model from the model folder.[red]")
|
||||||
|
|
||||||
|
|
||||||
def unload_all(context: Context):
|
def unload_all(context: Context):
|
||||||
for model_type in KNOWN_MODEL_TYPES:
|
for model_type in KNOWN_MODEL_TYPES:
|
||||||
unload_model(context, model_type)
|
unload_model(context, model_type)
|
||||||
|
|
||||||
def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
|
||||||
|
def resolve_model_to_use(model_name: str = None, model_type: str = None):
|
||||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
default_models = DEFAULT_MODELS.get(model_type, [])
|
default_models = DEFAULT_MODELS.get(model_type, [])
|
||||||
config = app.getConfig()
|
config = app.getConfig()
|
||||||
|
|
||||||
model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR]
|
model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR]
|
||||||
if not model_name: # When None try user configured model.
|
if not model_name: # When None try user configured model.
|
||||||
# config = getConfig()
|
# config = getConfig()
|
||||||
if 'model' in config and model_type in config['model']:
|
if "model" in config and model_type in config["model"]:
|
||||||
model_name = config['model'][model_type]
|
model_name = config["model"][model_type]
|
||||||
|
|
||||||
if model_name:
|
if model_name:
|
||||||
# Check models directory
|
# Check models directory
|
||||||
@ -84,41 +87,54 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
|||||||
for model_extension in model_extensions:
|
for model_extension in model_extensions:
|
||||||
if os.path.exists(default_model_path + model_extension):
|
if os.path.exists(default_model_path + model_extension):
|
||||||
if model_name is not None:
|
if model_name is not None:
|
||||||
log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
|
log.warn(
|
||||||
|
f"Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}"
|
||||||
|
)
|
||||||
return default_model_path + model_extension
|
return default_model_path + model_extension
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def reload_models_if_necessary(context: Context, task_data: TaskData):
|
def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||||
model_paths_in_req = {
|
model_paths_in_req = {
|
||||||
'stable-diffusion': task_data.use_stable_diffusion_model,
|
"stable-diffusion": task_data.use_stable_diffusion_model,
|
||||||
'vae': task_data.use_vae_model,
|
"vae": task_data.use_vae_model,
|
||||||
'hypernetwork': task_data.use_hypernetwork_model,
|
"hypernetwork": task_data.use_hypernetwork_model,
|
||||||
'gfpgan': task_data.use_face_correction,
|
"gfpgan": task_data.use_face_correction,
|
||||||
'realesrgan': task_data.use_upscale,
|
"realesrgan": task_data.use_upscale,
|
||||||
|
}
|
||||||
|
models_to_reload = {
|
||||||
|
model_type: path
|
||||||
|
for model_type, path in model_paths_in_req.items()
|
||||||
|
if context.model_paths.get(model_type) != path
|
||||||
}
|
}
|
||||||
models_to_reload = {model_type: path for model_type, path in model_paths_in_req.items() if context.model_paths.get(model_type) != path}
|
|
||||||
|
|
||||||
if set_vram_optimizations(context): # reload SD
|
if set_vram_optimizations(context): # reload SD
|
||||||
models_to_reload['stable-diffusion'] = model_paths_in_req['stable-diffusion']
|
models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"]
|
||||||
|
|
||||||
for model_type, model_path_in_req in models_to_reload.items():
|
for model_type, model_path_in_req in models_to_reload.items():
|
||||||
context.model_paths[model_type] = model_path_in_req
|
context.model_paths[model_type] = model_path_in_req
|
||||||
|
|
||||||
action_fn = unload_model if context.model_paths[model_type] is None else load_model
|
action_fn = unload_model if context.model_paths[model_type] is None else load_model
|
||||||
action_fn(context, model_type, scan_model=False) # we've scanned them already
|
action_fn(context, model_type, scan_model=False) # we've scanned them already
|
||||||
|
|
||||||
|
|
||||||
def resolve_model_paths(task_data: TaskData):
|
def resolve_model_paths(task_data: TaskData):
|
||||||
task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
|
task_data.use_stable_diffusion_model = resolve_model_to_use(
|
||||||
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae')
|
task_data.use_stable_diffusion_model, model_type="stable-diffusion"
|
||||||
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
|
)
|
||||||
|
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae")
|
||||||
|
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork")
|
||||||
|
|
||||||
|
if task_data.use_face_correction:
|
||||||
|
task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, "gfpgan")
|
||||||
|
if task_data.use_upscale:
|
||||||
|
task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, "realesrgan")
|
||||||
|
|
||||||
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
|
|
||||||
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan')
|
|
||||||
|
|
||||||
def set_vram_optimizations(context: Context):
|
def set_vram_optimizations(context: Context):
|
||||||
config = app.getConfig()
|
config = app.getConfig()
|
||||||
vram_usage_level = config.get('vram_usage_level', 'balanced')
|
vram_usage_level = config.get("vram_usage_level", "balanced")
|
||||||
|
|
||||||
if vram_usage_level != context.vram_usage_level:
|
if vram_usage_level != context.vram_usage_level:
|
||||||
context.vram_usage_level = vram_usage_level
|
context.vram_usage_level = vram_usage_level
|
||||||
@ -126,42 +142,51 @@ def set_vram_optimizations(context: Context):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def make_model_folders():
|
def make_model_folders():
|
||||||
for model_type in KNOWN_MODEL_TYPES:
|
for model_type in KNOWN_MODEL_TYPES:
|
||||||
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
||||||
|
|
||||||
os.makedirs(model_dir_path, exist_ok=True)
|
os.makedirs(model_dir_path, exist_ok=True)
|
||||||
|
|
||||||
help_file_name = f'Place your {model_type} model files here.txt'
|
help_file_name = f"Place your {model_type} model files here.txt"
|
||||||
help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}'
|
help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}'
|
||||||
|
|
||||||
with open(os.path.join(model_dir_path, help_file_name), 'w', encoding='utf-8') as f:
|
with open(os.path.join(model_dir_path, help_file_name), "w", encoding="utf-8") as f:
|
||||||
f.write(help_file_contents)
|
f.write(help_file_contents)
|
||||||
|
|
||||||
|
|
||||||
def is_malicious_model(file_path):
|
def is_malicious_model(file_path):
|
||||||
try:
|
try:
|
||||||
scan_result = scan_model(file_path)
|
scan_result = scan_model(file_path)
|
||||||
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
|
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
|
||||||
log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
log.warn(
|
||||||
|
":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]"
|
||||||
|
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
log.debug(
|
||||||
|
"Scan %s: [green]%d scanned, %d issue, %d infected.[/green]"
|
||||||
|
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(f'error while scanning: {file_path}, error: {e}')
|
log.error(f"error while scanning: {file_path}, error: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def getModels():
|
def getModels():
|
||||||
models = {
|
models = {
|
||||||
'active': {
|
"active": {
|
||||||
'stable-diffusion': 'sd-v1-4',
|
"stable-diffusion": "sd-v1-4",
|
||||||
'vae': '',
|
"vae": "",
|
||||||
'hypernetwork': '',
|
"hypernetwork": "",
|
||||||
},
|
},
|
||||||
'options': {
|
"options": {
|
||||||
'stable-diffusion': ['sd-v1-4'],
|
"stable-diffusion": ["sd-v1-4"],
|
||||||
'vae': [],
|
"vae": [],
|
||||||
'hypernetwork': [],
|
"hypernetwork": [],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -171,13 +196,16 @@ def getModels():
|
|||||||
"Raised when picklescan reports a problem with a model"
|
"Raised when picklescan reports a problem with a model"
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def scan_directory(directory, suffixes, directoriesFirst:bool=True):
|
def scan_directory(directory, suffixes, directoriesFirst: bool = True):
|
||||||
nonlocal models_scanned
|
nonlocal models_scanned
|
||||||
tree = []
|
tree = []
|
||||||
for entry in sorted(os.scandir(directory), key = lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower())):
|
for entry in sorted(
|
||||||
|
os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower())
|
||||||
|
):
|
||||||
if entry.is_file():
|
if entry.is_file():
|
||||||
matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes))
|
matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes))
|
||||||
if len(matching_suffix) == 0: continue
|
if len(matching_suffix) == 0:
|
||||||
|
continue
|
||||||
matching_suffix = matching_suffix[0]
|
matching_suffix = matching_suffix[0]
|
||||||
|
|
||||||
mtime = entry.stat().st_mtime
|
mtime = entry.stat().st_mtime
|
||||||
@ -187,12 +215,12 @@ def getModels():
|
|||||||
if is_malicious_model(entry.path):
|
if is_malicious_model(entry.path):
|
||||||
raise MaliciousModelException(entry.path)
|
raise MaliciousModelException(entry.path)
|
||||||
known_models[entry.path] = mtime
|
known_models[entry.path] = mtime
|
||||||
tree.append(entry.name[:-len(matching_suffix)])
|
tree.append(entry.name[: -len(matching_suffix)])
|
||||||
elif entry.is_dir():
|
elif entry.is_dir():
|
||||||
scan=scan_directory(entry.path, suffixes, directoriesFirst=False)
|
scan = scan_directory(entry.path, suffixes, directoriesFirst=False)
|
||||||
|
|
||||||
if len(scan) != 0:
|
if len(scan) != 0:
|
||||||
tree.append( (entry.name, scan ) )
|
tree.append((entry.name, scan))
|
||||||
return tree
|
return tree
|
||||||
|
|
||||||
def listModels(model_type):
|
def listModels(model_type):
|
||||||
@ -204,21 +232,22 @@ def getModels():
|
|||||||
os.makedirs(models_dir)
|
os.makedirs(models_dir)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
models['options'][model_type] = scan_directory(models_dir, model_extensions)
|
models["options"][model_type] = scan_directory(models_dir, model_extensions)
|
||||||
except MaliciousModelException as e:
|
except MaliciousModelException as e:
|
||||||
models['scan-error'] = e
|
models["scan-error"] = e
|
||||||
|
|
||||||
# custom models
|
# custom models
|
||||||
listModels(model_type='stable-diffusion')
|
listModels(model_type="stable-diffusion")
|
||||||
listModels(model_type='vae')
|
listModels(model_type="vae")
|
||||||
listModels(model_type='hypernetwork')
|
listModels(model_type="hypernetwork")
|
||||||
listModels(model_type='gfpgan')
|
listModels(model_type="gfpgan")
|
||||||
|
|
||||||
if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. Nothing infected[/]')
|
if models_scanned > 0:
|
||||||
|
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
|
||||||
|
|
||||||
# legacy
|
# legacy
|
||||||
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
|
custom_weight_path = os.path.join(app.SD_DIR, "custom-model.ckpt")
|
||||||
if os.path.exists(custom_weight_path):
|
if os.path.exists(custom_weight_path):
|
||||||
models['options']['stable-diffusion'].append('custom-model')
|
models["options"]["stable-diffusion"].append("custom-model")
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
@ -12,22 +12,26 @@ from sdkit.generate import generate_images
|
|||||||
from sdkit.filter import apply_filters
|
from sdkit.filter import apply_filters
|
||||||
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc
|
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc
|
||||||
|
|
||||||
context = Context() # thread-local
|
context = Context() # thread-local
|
||||||
'''
|
"""
|
||||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
def init(device):
|
def init(device):
|
||||||
'''
|
"""
|
||||||
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
|
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
|
||||||
'''
|
"""
|
||||||
context.stop_processing = False
|
context.stop_processing = False
|
||||||
context.temp_images = {}
|
context.temp_images = {}
|
||||||
context.partial_x_samples = None
|
context.partial_x_samples = None
|
||||||
|
|
||||||
device_manager.device_init(context, device)
|
device_manager.device_init(context, device)
|
||||||
|
|
||||||
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
|
||||||
|
def make_images(
|
||||||
|
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
|
||||||
|
):
|
||||||
context.stop_processing = False
|
context.stop_processing = False
|
||||||
print_task_info(req, task_data)
|
print_task_info(req, task_data)
|
||||||
|
|
||||||
@ -36,18 +40,24 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
|
|||||||
res = Response(req, task_data, images=construct_response(images, seeds, task_data, base_seed=req.seed))
|
res = Response(req, task_data, images=construct_response(images, seeds, task_data, base_seed=req.seed))
|
||||||
res = res.json()
|
res = res.json()
|
||||||
data_queue.put(json.dumps(res))
|
data_queue.put(json.dumps(res))
|
||||||
log.info('Task completed')
|
log.info("Task completed")
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def print_task_info(req: GenerateImageRequest, task_data: TaskData):
|
|
||||||
req_str = pprint.pformat(get_printable_request(req)).replace("[","\[")
|
|
||||||
task_str = pprint.pformat(task_data.dict()).replace("[","\[")
|
|
||||||
log.info(f'request: {req_str}')
|
|
||||||
log.info(f'task data: {task_str}')
|
|
||||||
|
|
||||||
def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def print_task_info(req: GenerateImageRequest, task_data: TaskData):
|
||||||
images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
req_str = pprint.pformat(get_printable_request(req)).replace("[", "\[")
|
||||||
|
task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
|
||||||
|
log.info(f"request: {req_str}")
|
||||||
|
log.info(f"task data: {task_str}")
|
||||||
|
|
||||||
|
|
||||||
|
def make_images_internal(
|
||||||
|
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
|
||||||
|
):
|
||||||
|
images, user_stopped = generate_images_internal(
|
||||||
|
req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress
|
||||||
|
)
|
||||||
filtered_images = filter_images(task_data, images, user_stopped)
|
filtered_images = filter_images(task_data, images, user_stopped)
|
||||||
|
|
||||||
if task_data.save_to_disk_path is not None:
|
if task_data.save_to_disk_path is not None:
|
||||||
@ -59,13 +69,22 @@ def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_qu
|
|||||||
else:
|
else:
|
||||||
return images + filtered_images, seeds + seeds
|
return images + filtered_images, seeds + seeds
|
||||||
|
|
||||||
def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
|
||||||
|
def generate_images_internal(
|
||||||
|
req: GenerateImageRequest,
|
||||||
|
task_data: TaskData,
|
||||||
|
data_queue: queue.Queue,
|
||||||
|
task_temp_images: list,
|
||||||
|
step_callback,
|
||||||
|
stream_image_progress: bool,
|
||||||
|
):
|
||||||
context.temp_images.clear()
|
context.temp_images.clear()
|
||||||
|
|
||||||
callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
|
callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if req.init_image is not None: req.sampler_name = 'ddim'
|
if req.init_image is not None:
|
||||||
|
req.sampler_name = "ddim"
|
||||||
|
|
||||||
images = generate_images(context, callback=callback, **req.dict())
|
images = generate_images(context, callback=callback, **req.dict())
|
||||||
user_stopped = False
|
user_stopped = False
|
||||||
@ -75,31 +94,44 @@ def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, dat
|
|||||||
if context.partial_x_samples is not None:
|
if context.partial_x_samples is not None:
|
||||||
images = latent_samples_to_images(context, context.partial_x_samples)
|
images = latent_samples_to_images(context, context.partial_x_samples)
|
||||||
finally:
|
finally:
|
||||||
if hasattr(context, 'partial_x_samples') and context.partial_x_samples is not None:
|
if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None:
|
||||||
del context.partial_x_samples
|
del context.partial_x_samples
|
||||||
context.partial_x_samples = None
|
context.partial_x_samples = None
|
||||||
|
|
||||||
return images, user_stopped
|
return images, user_stopped
|
||||||
|
|
||||||
|
|
||||||
def filter_images(task_data: TaskData, images: list, user_stopped):
|
def filter_images(task_data: TaskData, images: list, user_stopped):
|
||||||
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
|
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
|
||||||
return images
|
return images
|
||||||
|
|
||||||
filters_to_apply = []
|
filters_to_apply = []
|
||||||
if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
|
if task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower():
|
||||||
if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan')
|
filters_to_apply.append("gfpgan")
|
||||||
|
if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower():
|
||||||
|
filters_to_apply.append("realesrgan")
|
||||||
|
|
||||||
return apply_filters(context, filters_to_apply, images, scale=task_data.upscale_amount)
|
return apply_filters(context, filters_to_apply, images, scale=task_data.upscale_amount)
|
||||||
|
|
||||||
|
|
||||||
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
|
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
|
||||||
return [
|
return [
|
||||||
ResponseImage(
|
ResponseImage(
|
||||||
data=img_to_base64_str(img, task_data.output_format, task_data.output_quality),
|
data=img_to_base64_str(img, task_data.output_format, task_data.output_quality),
|
||||||
seed=seed,
|
seed=seed,
|
||||||
) for img, seed in zip(images, seeds)
|
)
|
||||||
|
for img, seed in zip(images, seeds)
|
||||||
]
|
]
|
||||||
|
|
||||||
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
|
||||||
|
def make_step_callback(
|
||||||
|
req: GenerateImageRequest,
|
||||||
|
task_data: TaskData,
|
||||||
|
data_queue: queue.Queue,
|
||||||
|
task_temp_images: list,
|
||||||
|
step_callback,
|
||||||
|
stream_image_progress: bool,
|
||||||
|
):
|
||||||
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
||||||
last_callback_time = -1
|
last_callback_time = -1
|
||||||
|
|
||||||
@ -107,11 +139,11 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
|
|||||||
partial_images = []
|
partial_images = []
|
||||||
images = latent_samples_to_images(context, x_samples)
|
images = latent_samples_to_images(context, x_samples)
|
||||||
for i, img in enumerate(images):
|
for i, img in enumerate(images):
|
||||||
buf = img_to_buffer(img, output_format='JPEG')
|
buf = img_to_buffer(img, output_format="JPEG")
|
||||||
|
|
||||||
context.temp_images[f"{task_data.request_id}/{i}"] = buf
|
context.temp_images[f"{task_data.request_id}/{i}"] = buf
|
||||||
task_temp_images[i] = buf
|
task_temp_images[i] = buf
|
||||||
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
|
partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"})
|
||||||
del images
|
del images
|
||||||
return partial_images
|
return partial_images
|
||||||
|
|
||||||
@ -125,7 +157,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
|
|||||||
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
|
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
|
||||||
|
|
||||||
if stream_image_progress and i % 5 == 0:
|
if stream_image_progress and i % 5 == 0:
|
||||||
progress['output'] = update_temp_img(x_samples, task_temp_images)
|
progress["output"] = update_temp_img(x_samples, task_temp_images)
|
||||||
|
|
||||||
data_queue.put(json.dumps(progress))
|
data_queue.put(json.dumps(progress))
|
||||||
|
|
||||||
|
@ -16,21 +16,25 @@ from easydiffusion import app, model_manager, task_manager
|
|||||||
from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest
|
from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest
|
||||||
from easydiffusion.utils import log
|
from easydiffusion.utils import log
|
||||||
|
|
||||||
log.info(f'started in {app.SD_DIR}')
|
log.info(f"started in {app.SD_DIR}")
|
||||||
log.info(f'started at {datetime.datetime.now():%x %X}')
|
log.info(f"started at {datetime.datetime.now():%x %X}")
|
||||||
|
|
||||||
server_api = FastAPI()
|
server_api = FastAPI()
|
||||||
|
|
||||||
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
NOCACHE_HEADERS = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||||
|
|
||||||
|
|
||||||
class NoCacheStaticFiles(StaticFiles):
|
class NoCacheStaticFiles(StaticFiles):
|
||||||
def is_not_modified(self, response_headers, request_headers) -> bool:
|
def is_not_modified(self, response_headers, request_headers) -> bool:
|
||||||
if 'content-type' in response_headers and ('javascript' in response_headers['content-type'] or 'css' in response_headers['content-type']):
|
if "content-type" in response_headers and (
|
||||||
|
"javascript" in response_headers["content-type"] or "css" in response_headers["content-type"]
|
||||||
|
):
|
||||||
response_headers.update(NOCACHE_HEADERS)
|
response_headers.update(NOCACHE_HEADERS)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return super().is_not_modified(response_headers, request_headers)
|
return super().is_not_modified(response_headers, request_headers)
|
||||||
|
|
||||||
|
|
||||||
class SetAppConfigRequest(BaseModel):
|
class SetAppConfigRequest(BaseModel):
|
||||||
update_branch: str = None
|
update_branch: str = None
|
||||||
render_devices: Union[List[str], List[int], str, int] = None
|
render_devices: Union[List[str], List[int], str, int] = None
|
||||||
@ -39,130 +43,142 @@ class SetAppConfigRequest(BaseModel):
|
|||||||
listen_to_network: bool = None
|
listen_to_network: bool = None
|
||||||
listen_port: int = None
|
listen_port: int = None
|
||||||
|
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media")
|
server_api.mount("/media", NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")), name="media")
|
||||||
|
|
||||||
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
|
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
|
||||||
server_api.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}")
|
server_api.mount(
|
||||||
|
f"/plugins/{dir_prefix}", NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}"
|
||||||
|
)
|
||||||
|
|
||||||
@server_api.post('/app_config')
|
@server_api.post("/app_config")
|
||||||
async def set_app_config(req : SetAppConfigRequest):
|
async def set_app_config(req: SetAppConfigRequest):
|
||||||
return set_app_config_internal(req)
|
return set_app_config_internal(req)
|
||||||
|
|
||||||
@server_api.get('/get/{key:path}')
|
@server_api.get("/get/{key:path}")
|
||||||
def read_web_data(key:str=None):
|
def read_web_data(key: str = None):
|
||||||
return read_web_data_internal(key)
|
return read_web_data_internal(key)
|
||||||
|
|
||||||
@server_api.get('/ping') # Get server and optionally session status.
|
@server_api.get("/ping") # Get server and optionally session status.
|
||||||
def ping(session_id:str=None):
|
def ping(session_id: str = None):
|
||||||
return ping_internal(session_id)
|
return ping_internal(session_id)
|
||||||
|
|
||||||
@server_api.post('/render')
|
@server_api.post("/render")
|
||||||
def render(req: dict):
|
def render(req: dict):
|
||||||
return render_internal(req)
|
return render_internal(req)
|
||||||
|
|
||||||
@server_api.post('/model/merge')
|
@server_api.post("/model/merge")
|
||||||
def model_merge(req: dict):
|
def model_merge(req: dict):
|
||||||
print(req)
|
print(req)
|
||||||
return model_merge_internal(req)
|
return model_merge_internal(req)
|
||||||
|
|
||||||
@server_api.get('/image/stream/{task_id:int}')
|
@server_api.get("/image/stream/{task_id:int}")
|
||||||
def stream(task_id:int):
|
def stream(task_id: int):
|
||||||
return stream_internal(task_id)
|
return stream_internal(task_id)
|
||||||
|
|
||||||
@server_api.get('/image/stop')
|
@server_api.get("/image/stop")
|
||||||
def stop(task: int):
|
def stop(task: int):
|
||||||
return stop_internal(task)
|
return stop_internal(task)
|
||||||
|
|
||||||
@server_api.get('/image/tmp/{task_id:int}/{img_id:int}')
|
@server_api.get("/image/tmp/{task_id:int}/{img_id:int}")
|
||||||
def get_image(task_id: int, img_id: int):
|
def get_image(task_id: int, img_id: int):
|
||||||
return get_image_internal(task_id, img_id)
|
return get_image_internal(task_id, img_id)
|
||||||
|
|
||||||
@server_api.get('/')
|
@server_api.get("/")
|
||||||
def read_root():
|
def read_root():
|
||||||
return FileResponse(os.path.join(app.SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
|
return FileResponse(os.path.join(app.SD_UI_DIR, "index.html"), headers=NOCACHE_HEADERS)
|
||||||
|
|
||||||
@server_api.on_event("shutdown")
|
@server_api.on_event("shutdown")
|
||||||
def shutdown_event(): # Signal render thread to close on shutdown
|
def shutdown_event(): # Signal render thread to close on shutdown
|
||||||
task_manager.current_state_error = SystemExit('Application shutting down.')
|
task_manager.current_state_error = SystemExit("Application shutting down.")
|
||||||
|
|
||||||
|
|
||||||
# API implementations
|
# API implementations
|
||||||
def set_app_config_internal(req : SetAppConfigRequest):
|
def set_app_config_internal(req: SetAppConfigRequest):
|
||||||
config = app.getConfig()
|
config = app.getConfig()
|
||||||
if req.update_branch is not None:
|
if req.update_branch is not None:
|
||||||
config['update_branch'] = req.update_branch
|
config["update_branch"] = req.update_branch
|
||||||
if req.render_devices is not None:
|
if req.render_devices is not None:
|
||||||
update_render_devices_in_config(config, req.render_devices)
|
update_render_devices_in_config(config, req.render_devices)
|
||||||
if req.ui_open_browser_on_start is not None:
|
if req.ui_open_browser_on_start is not None:
|
||||||
if 'ui' not in config:
|
if "ui" not in config:
|
||||||
config['ui'] = {}
|
config["ui"] = {}
|
||||||
config['ui']['open_browser_on_start'] = req.ui_open_browser_on_start
|
config["ui"]["open_browser_on_start"] = req.ui_open_browser_on_start
|
||||||
if req.listen_to_network is not None:
|
if req.listen_to_network is not None:
|
||||||
if 'net' not in config:
|
if "net" not in config:
|
||||||
config['net'] = {}
|
config["net"] = {}
|
||||||
config['net']['listen_to_network'] = bool(req.listen_to_network)
|
config["net"]["listen_to_network"] = bool(req.listen_to_network)
|
||||||
if req.listen_port is not None:
|
if req.listen_port is not None:
|
||||||
if 'net' not in config:
|
if "net" not in config:
|
||||||
config['net'] = {}
|
config["net"] = {}
|
||||||
config['net']['listen_port'] = int(req.listen_port)
|
config["net"]["listen_port"] = int(req.listen_port)
|
||||||
try:
|
try:
|
||||||
app.setConfig(config)
|
app.setConfig(config)
|
||||||
|
|
||||||
if req.render_devices:
|
if req.render_devices:
|
||||||
app.update_render_threads()
|
app.update_render_threads()
|
||||||
|
|
||||||
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
|
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
def update_render_devices_in_config(config, render_devices):
|
def update_render_devices_in_config(config, render_devices):
|
||||||
if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'):
|
if render_devices not in ("cpu", "auto") and not render_devices.startswith("cuda:"):
|
||||||
raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}')
|
raise HTTPException(status_code=400, detail=f"Invalid render device requested: {render_devices}")
|
||||||
|
|
||||||
if render_devices.startswith('cuda:'):
|
if render_devices.startswith("cuda:"):
|
||||||
render_devices = render_devices.split(',')
|
render_devices = render_devices.split(",")
|
||||||
|
|
||||||
config['render_devices'] = render_devices
|
config["render_devices"] = render_devices
|
||||||
|
|
||||||
def read_web_data_internal(key:str=None):
|
|
||||||
if not key: # /get without parameters, stable-diffusion easter egg.
|
def read_web_data_internal(key: str = None):
|
||||||
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
if not key: # /get without parameters, stable-diffusion easter egg.
|
||||||
elif key == 'app_config':
|
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
||||||
|
elif key == "app_config":
|
||||||
return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS)
|
return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS)
|
||||||
elif key == 'system_info':
|
elif key == "system_info":
|
||||||
config = app.getConfig()
|
config = app.getConfig()
|
||||||
|
|
||||||
output_dir = config.get('force_save_path', os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME))
|
output_dir = config.get("force_save_path", os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME))
|
||||||
|
|
||||||
system_info = {
|
system_info = {
|
||||||
'devices': task_manager.get_devices(),
|
"devices": task_manager.get_devices(),
|
||||||
'hosts': app.getIPConfig(),
|
"hosts": app.getIPConfig(),
|
||||||
'default_output_dir': output_dir,
|
"default_output_dir": output_dir,
|
||||||
'enforce_output_dir': ('force_save_path' in config),
|
"enforce_output_dir": ("force_save_path" in config),
|
||||||
}
|
}
|
||||||
system_info['devices']['config'] = config.get('render_devices', "auto")
|
system_info["devices"]["config"] = config.get("render_devices", "auto")
|
||||||
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
|
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
|
||||||
elif key == 'models':
|
elif key == "models":
|
||||||
return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS)
|
return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS)
|
||||||
elif key == 'modifiers': return FileResponse(os.path.join(app.SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
|
elif key == "modifiers":
|
||||||
elif key == 'ui_plugins': return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS)
|
return FileResponse(os.path.join(app.SD_UI_DIR, "modifiers.json"), headers=NOCACHE_HEADERS)
|
||||||
|
elif key == "ui_plugins":
|
||||||
|
return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
|
raise HTTPException(status_code=404, detail=f"Request for unknown {key}") # HTTP404 Not Found
|
||||||
|
|
||||||
def ping_internal(session_id:str=None):
|
|
||||||
if task_manager.is_alive() <= 0: # Check that render threads are alive.
|
def ping_internal(session_id: str = None):
|
||||||
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
if task_manager.is_alive() <= 0: # Check that render threads are alive.
|
||||||
raise HTTPException(status_code=500, detail='Render thread is dead.')
|
if task_manager.current_state_error:
|
||||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||||
|
raise HTTPException(status_code=500, detail="Render thread is dead.")
|
||||||
|
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration):
|
||||||
|
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||||
# Alive
|
# Alive
|
||||||
response = {'status': str(task_manager.current_state)}
|
response = {"status": str(task_manager.current_state)}
|
||||||
if session_id:
|
if session_id:
|
||||||
session = task_manager.get_cached_session(session_id, update_ttl=True)
|
session = task_manager.get_cached_session(session_id, update_ttl=True)
|
||||||
response['tasks'] = {id(t): t.status for t in session.tasks}
|
response["tasks"] = {id(t): t.status for t in session.tasks}
|
||||||
response['devices'] = task_manager.get_devices()
|
response["devices"] = task_manager.get_devices()
|
||||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||||
|
|
||||||
|
|
||||||
def render_internal(req: dict):
|
def render_internal(req: dict):
|
||||||
try:
|
try:
|
||||||
# separate out the request data into rendering and task-specific data
|
# separate out the request data into rendering and task-specific data
|
||||||
@ -171,80 +187,99 @@ def render_internal(req: dict):
|
|||||||
|
|
||||||
# Overwrite user specified save path
|
# Overwrite user specified save path
|
||||||
config = app.getConfig()
|
config = app.getConfig()
|
||||||
if 'force_save_path' in config:
|
if "force_save_path" in config:
|
||||||
task_data.save_to_disk_path = config['force_save_path']
|
task_data.save_to_disk_path = config["force_save_path"]
|
||||||
|
|
||||||
render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision
|
render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision
|
||||||
|
|
||||||
app.save_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model, task_data.vram_usage_level)
|
app.save_to_config(
|
||||||
|
task_data.use_stable_diffusion_model,
|
||||||
|
task_data.use_vae_model,
|
||||||
|
task_data.use_hypernetwork_model,
|
||||||
|
task_data.vram_usage_level,
|
||||||
|
)
|
||||||
|
|
||||||
# enqueue the task
|
# enqueue the task
|
||||||
new_task = task_manager.render(render_req, task_data)
|
new_task = task_manager.render(render_req, task_data)
|
||||||
response = {
|
response = {
|
||||||
'status': str(task_manager.current_state),
|
"status": str(task_manager.current_state),
|
||||||
'queue': len(task_manager.tasks_queue),
|
"queue": len(task_manager.tasks_queue),
|
||||||
'stream': f'/image/stream/{id(new_task)}',
|
"stream": f"/image/stream/{id(new_task)}",
|
||||||
'task': id(new_task)
|
"task": id(new_task),
|
||||||
}
|
}
|
||||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||||
except ChildProcessError as e: # Render thread is dead
|
except ChildProcessError as e: # Render thread is dead
|
||||||
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
|
raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error
|
||||||
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
|
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
|
||||||
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
|
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
def model_merge_internal(req: dict):
|
def model_merge_internal(req: dict):
|
||||||
try:
|
try:
|
||||||
from sdkit.train import merge_models
|
from sdkit.train import merge_models
|
||||||
from easydiffusion.utils.save_utils import filename_regex
|
from easydiffusion.utils.save_utils import filename_regex
|
||||||
|
|
||||||
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
||||||
|
|
||||||
merge_models(model_manager.resolve_model_to_use(mergeReq.model0,'stable-diffusion'),
|
merge_models(
|
||||||
model_manager.resolve_model_to_use(mergeReq.model1,'stable-diffusion'),
|
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
|
||||||
mergeReq.ratio,
|
model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
|
||||||
os.path.join(app.MODELS_DIR, 'stable-diffusion', filename_regex.sub('_', mergeReq.out_path)),
|
mergeReq.ratio,
|
||||||
mergeReq.use_fp16
|
os.path.join(app.MODELS_DIR, "stable-diffusion", filename_regex.sub("_", mergeReq.out_path)),
|
||||||
|
mergeReq.use_fp16,
|
||||||
)
|
)
|
||||||
return JSONResponse({'status':'OK'}, headers=NOCACHE_HEADERS)
|
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
def stream_internal(task_id:int):
|
|
||||||
#TODO Move to WebSockets ??
|
def stream_internal(task_id: int):
|
||||||
|
# TODO Move to WebSockets ??
|
||||||
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
||||||
if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound
|
if not task:
|
||||||
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
raise HTTPException(status_code=404, detail=f"Request {task_id} not found.") # HTTP404 NotFound
|
||||||
|
# if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
||||||
if task.buffer_queue.empty() and not task.lock.locked():
|
if task.buffer_queue.empty() and not task.lock.locked():
|
||||||
if task.response:
|
if task.response:
|
||||||
#log.info(f'Session {session_id} sending cached response')
|
# log.info(f'Session {session_id} sending cached response')
|
||||||
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
|
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
|
||||||
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
|
raise HTTPException(status_code=425, detail="Too Early, task not started yet.") # HTTP425 Too Early
|
||||||
#log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
|
# log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
|
||||||
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
|
return StreamingResponse(task.read_buffer_generator(), media_type="application/json")
|
||||||
|
|
||||||
|
|
||||||
def stop_internal(task: int):
|
def stop_internal(task: int):
|
||||||
if not task:
|
if not task:
|
||||||
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
|
if (
|
||||||
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
|
task_manager.current_state == task_manager.ServerStates.Online
|
||||||
task_manager.current_state_error = StopAsyncIteration('')
|
or task_manager.current_state == task_manager.ServerStates.Unavailable
|
||||||
return {'OK'}
|
):
|
||||||
|
raise HTTPException(status_code=409, detail="Not currently running any tasks.") # HTTP409 Conflict
|
||||||
|
task_manager.current_state_error = StopAsyncIteration("")
|
||||||
|
return {"OK"}
|
||||||
task_id = task
|
task_id = task
|
||||||
task = task_manager.get_cached_task(task_id, update_ttl=False)
|
task = task_manager.get_cached_task(task_id, update_ttl=False)
|
||||||
if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found
|
if not task:
|
||||||
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict
|
raise HTTPException(status_code=404, detail=f"Task {task_id} was not found.") # HTTP404 Not Found
|
||||||
task.error = StopAsyncIteration(f'Task {task_id} stop requested.')
|
if isinstance(task.error, StopAsyncIteration):
|
||||||
return {'OK'}
|
raise HTTPException(status_code=409, detail=f"Task {task_id} is already stopped.") # HTTP409 Conflict
|
||||||
|
task.error = StopAsyncIteration(f"Task {task_id} stop requested.")
|
||||||
|
return {"OK"}
|
||||||
|
|
||||||
|
|
||||||
def get_image_internal(task_id: int, img_id: int):
|
def get_image_internal(task_id: int, img_id: int):
|
||||||
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
||||||
if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound
|
if not task:
|
||||||
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
|
raise HTTPException(status_code=410, detail=f"Task {task_id} could not be found.") # HTTP404 NotFound
|
||||||
|
if not task.temp_images[img_id]:
|
||||||
|
raise HTTPException(status_code=425, detail="Too Early, task data is not available yet.") # HTTP425 Too Early
|
||||||
try:
|
try:
|
||||||
img_data = task.temp_images[img_id]
|
img_data = task.temp_images[img_id]
|
||||||
img_data.seek(0)
|
img_data.seek(0)
|
||||||
return StreamingResponse(img_data, media_type='image/jpeg')
|
return StreamingResponse(img_data, media_type="image/jpeg")
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
@ -7,7 +7,7 @@ Notes:
|
|||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import queue, threading, time, weakref
|
import queue, threading, time, weakref
|
||||||
@ -19,71 +19,98 @@ from easydiffusion.utils import log
|
|||||||
|
|
||||||
from sdkit.utils import gc
|
from sdkit.utils import gc
|
||||||
|
|
||||||
THREAD_NAME_PREFIX = ''
|
THREAD_NAME_PREFIX = ""
|
||||||
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
ERR_LOCK_FAILED = " failed to acquire lock within timeout."
|
||||||
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
||||||
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
|
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
|
||||||
|
|
||||||
DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init.
|
DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init.
|
||||||
|
|
||||||
|
|
||||||
|
class SymbolClass(type): # Print nicely formatted Symbol names.
|
||||||
|
def __repr__(self):
|
||||||
|
return self.__qualname__
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.__name__
|
||||||
|
|
||||||
|
|
||||||
|
class Symbol(metaclass=SymbolClass):
|
||||||
|
pass
|
||||||
|
|
||||||
class SymbolClass(type): # Print nicely formatted Symbol names.
|
|
||||||
def __repr__(self): return self.__qualname__
|
|
||||||
def __str__(self): return self.__name__
|
|
||||||
class Symbol(metaclass=SymbolClass): pass
|
|
||||||
|
|
||||||
class ServerStates:
|
class ServerStates:
|
||||||
class Init(Symbol): pass
|
class Init(Symbol):
|
||||||
class LoadingModel(Symbol): pass
|
pass
|
||||||
class Online(Symbol): pass
|
|
||||||
class Rendering(Symbol): pass
|
|
||||||
class Unavailable(Symbol): pass
|
|
||||||
|
|
||||||
class RenderTask(): # Task with output queue and completion lock.
|
class LoadingModel(Symbol):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Online(Symbol):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Rendering(Symbol):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Unavailable(Symbol):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RenderTask: # Task with output queue and completion lock.
|
||||||
def __init__(self, req: GenerateImageRequest, task_data: TaskData):
|
def __init__(self, req: GenerateImageRequest, task_data: TaskData):
|
||||||
task_data.request_id = id(self)
|
task_data.request_id = id(self)
|
||||||
self.render_request: GenerateImageRequest = req # Initial Request
|
self.render_request: GenerateImageRequest = req # Initial Request
|
||||||
self.task_data: TaskData = task_data
|
self.task_data: TaskData = task_data
|
||||||
self.response: Any = None # Copy of the last reponse
|
self.response: Any = None # Copy of the last reponse
|
||||||
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
||||||
self.temp_images:list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
|
self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
|
||||||
self.error: Exception = None
|
self.error: Exception = None
|
||||||
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
|
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
|
||||||
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
|
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
|
||||||
|
|
||||||
async def read_buffer_generator(self):
|
async def read_buffer_generator(self):
|
||||||
try:
|
try:
|
||||||
while not self.buffer_queue.empty():
|
while not self.buffer_queue.empty():
|
||||||
res = self.buffer_queue.get(block=False)
|
res = self.buffer_queue.get(block=False)
|
||||||
self.buffer_queue.task_done()
|
self.buffer_queue.task_done()
|
||||||
yield res
|
yield res
|
||||||
except queue.Empty as e: yield
|
except queue.Empty as e:
|
||||||
|
yield
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def status(self):
|
def status(self):
|
||||||
if self.lock.locked():
|
if self.lock.locked():
|
||||||
return 'running'
|
return "running"
|
||||||
if isinstance(self.error, StopAsyncIteration):
|
if isinstance(self.error, StopAsyncIteration):
|
||||||
return 'stopped'
|
return "stopped"
|
||||||
if self.error:
|
if self.error:
|
||||||
return 'error'
|
return "error"
|
||||||
if not self.buffer_queue.empty():
|
if not self.buffer_queue.empty():
|
||||||
return 'buffer'
|
return "buffer"
|
||||||
if self.response:
|
if self.response:
|
||||||
return 'completed'
|
return "completed"
|
||||||
return 'pending'
|
return "pending"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_pending(self):
|
def is_pending(self):
|
||||||
return bool(not self.response and not self.error)
|
return bool(not self.response and not self.error)
|
||||||
|
|
||||||
|
|
||||||
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||||
class DataCache():
|
class DataCache:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._base = dict()
|
self._base = dict()
|
||||||
self._lock: threading.Lock = threading.Lock()
|
self._lock: threading.Lock = threading.Lock()
|
||||||
|
|
||||||
def _get_ttl_time(self, ttl: int) -> int:
|
def _get_ttl_time(self, ttl: int) -> int:
|
||||||
return int(time.time()) + ttl
|
return int(time.time()) + ttl
|
||||||
|
|
||||||
def _is_expired(self, timestamp: int) -> bool:
|
def _is_expired(self, timestamp: int) -> bool:
|
||||||
return int(time.time()) >= timestamp
|
return int(time.time()) >= timestamp
|
||||||
|
|
||||||
def clean(self) -> None:
|
def clean(self) -> None:
|
||||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clean' + ERR_LOCK_FAILED)
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
|
raise Exception("DataCache.clean" + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
# Create a list of expired keys to delete
|
# Create a list of expired keys to delete
|
||||||
to_delete = []
|
to_delete = []
|
||||||
@ -95,20 +122,26 @@ class DataCache():
|
|||||||
for key in to_delete:
|
for key in to_delete:
|
||||||
(_, val) = self._base[key]
|
(_, val) = self._base[key]
|
||||||
if isinstance(val, RenderTask):
|
if isinstance(val, RenderTask):
|
||||||
log.debug(f'RenderTask {key} expired. Data removed.')
|
log.debug(f"RenderTask {key} expired. Data removed.")
|
||||||
elif isinstance(val, SessionState):
|
elif isinstance(val, SessionState):
|
||||||
log.debug(f'Session {key} expired. Data removed.')
|
log.debug(f"Session {key} expired. Data removed.")
|
||||||
else:
|
else:
|
||||||
log.debug(f'Key {key} expired. Data removed.')
|
log.debug(f"Key {key} expired. Data removed.")
|
||||||
del self._base[key]
|
del self._base[key]
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clear' + ERR_LOCK_FAILED)
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
try: self._base.clear()
|
raise Exception("DataCache.clear" + ERR_LOCK_FAILED)
|
||||||
finally: self._lock.release()
|
try:
|
||||||
|
self._base.clear()
|
||||||
|
finally:
|
||||||
|
self._lock.release()
|
||||||
|
|
||||||
def delete(self, key: Hashable) -> bool:
|
def delete(self, key: Hashable) -> bool:
|
||||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.delete' + ERR_LOCK_FAILED)
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
|
raise Exception("DataCache.delete" + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
if key not in self._base:
|
if key not in self._base:
|
||||||
return False
|
return False
|
||||||
@ -116,8 +149,10 @@ class DataCache():
|
|||||||
return True
|
return True
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def keep(self, key: Hashable, ttl: int) -> bool:
|
def keep(self, key: Hashable, ttl: int) -> bool:
|
||||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.keep' + ERR_LOCK_FAILED)
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
|
raise Exception("DataCache.keep" + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
if key in self._base:
|
if key in self._base:
|
||||||
_, value = self._base.get(key)
|
_, value = self._base.get(key)
|
||||||
@ -126,12 +161,12 @@ class DataCache():
|
|||||||
return False
|
return False
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
|
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
|
||||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.put' + ERR_LOCK_FAILED)
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
|
raise Exception("DataCache.put" + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
self._base[key] = (
|
self._base[key] = (self._get_ttl_time(ttl), value)
|
||||||
self._get_ttl_time(ttl), value
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
@ -139,35 +174,41 @@ class DataCache():
|
|||||||
return True
|
return True
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
def tryGet(self, key: Hashable) -> Any:
|
def tryGet(self, key: Hashable) -> Any:
|
||||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.tryGet' + ERR_LOCK_FAILED)
|
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
|
raise Exception("DataCache.tryGet" + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
ttl, value = self._base.get(key, (None, None))
|
ttl, value = self._base.get(key, (None, None))
|
||||||
if ttl is not None and self._is_expired(ttl):
|
if ttl is not None and self._is_expired(ttl):
|
||||||
log.debug(f'Session {key} expired. Discarding data.')
|
log.debug(f"Session {key} expired. Discarding data.")
|
||||||
del self._base[key]
|
del self._base[key]
|
||||||
return None
|
return None
|
||||||
return value
|
return value
|
||||||
finally:
|
finally:
|
||||||
self._lock.release()
|
self._lock.release()
|
||||||
|
|
||||||
|
|
||||||
manager_lock = threading.RLock()
|
manager_lock = threading.RLock()
|
||||||
render_threads = []
|
render_threads = []
|
||||||
current_state = ServerStates.Init
|
current_state = ServerStates.Init
|
||||||
current_state_error:Exception = None
|
current_state_error: Exception = None
|
||||||
tasks_queue = []
|
tasks_queue = []
|
||||||
session_cache = DataCache()
|
session_cache = DataCache()
|
||||||
task_cache = DataCache()
|
task_cache = DataCache()
|
||||||
weak_thread_data = weakref.WeakKeyDictionary()
|
weak_thread_data = weakref.WeakKeyDictionary()
|
||||||
idle_event: threading.Event = threading.Event()
|
idle_event: threading.Event = threading.Event()
|
||||||
|
|
||||||
class SessionState():
|
|
||||||
|
class SessionState:
|
||||||
def __init__(self, id: str):
|
def __init__(self, id: str):
|
||||||
self._id = id
|
self._id = id
|
||||||
self._tasks_ids = []
|
self._tasks_ids = []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self):
|
def id(self):
|
||||||
return self._id
|
return self._id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tasks(self):
|
def tasks(self):
|
||||||
tasks = []
|
tasks = []
|
||||||
@ -176,6 +217,7 @@ class SessionState():
|
|||||||
if task:
|
if task:
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
return tasks
|
return tasks
|
||||||
|
|
||||||
def put(self, task, ttl=TASK_TTL):
|
def put(self, task, ttl=TASK_TTL):
|
||||||
task_id = id(task)
|
task_id = id(task)
|
||||||
self._tasks_ids.append(task_id)
|
self._tasks_ids.append(task_id)
|
||||||
@ -185,10 +227,12 @@ class SessionState():
|
|||||||
self._tasks_ids.pop(0)
|
self._tasks_ids.pop(0)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def thread_get_next_task():
|
def thread_get_next_task():
|
||||||
from easydiffusion import renderer
|
from easydiffusion import renderer
|
||||||
|
|
||||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.')
|
log.warn(f"Render thread on device: {renderer.context.device} failed to acquire manager lock.")
|
||||||
return None
|
return None
|
||||||
if len(tasks_queue) <= 0:
|
if len(tasks_queue) <= 0:
|
||||||
manager_lock.release()
|
manager_lock.release()
|
||||||
@ -202,10 +246,10 @@ def thread_get_next_task():
|
|||||||
continue # requested device alive, skip current one.
|
continue # requested device alive, skip current one.
|
||||||
else:
|
else:
|
||||||
# Requested device is not active, return error to UI.
|
# Requested device is not active, return error to UI.
|
||||||
queued_task.error = Exception(queued_task.render_device + ' is not currently active.')
|
queued_task.error = Exception(queued_task.render_device + " is not currently active.")
|
||||||
task = queued_task
|
task = queued_task
|
||||||
break
|
break
|
||||||
if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1:
|
if not queued_task.render_device and renderer.context.device == "cpu" and is_alive() > 1:
|
||||||
# not asking for any specific devices, cpu want to grab task but other render devices are alive.
|
# not asking for any specific devices, cpu want to grab task but other render devices are alive.
|
||||||
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
|
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
|
||||||
task = queued_task
|
task = queued_task
|
||||||
@ -216,17 +260,19 @@ def thread_get_next_task():
|
|||||||
finally:
|
finally:
|
||||||
manager_lock.release()
|
manager_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def thread_render(device):
|
def thread_render(device):
|
||||||
global current_state, current_state_error
|
global current_state, current_state_error
|
||||||
|
|
||||||
from easydiffusion import renderer, model_manager
|
from easydiffusion import renderer, model_manager
|
||||||
|
|
||||||
try:
|
try:
|
||||||
renderer.init(device)
|
renderer.init(device)
|
||||||
|
|
||||||
weak_thread_data[threading.current_thread()] = {
|
weak_thread_data[threading.current_thread()] = {
|
||||||
'device': renderer.context.device,
|
"device": renderer.context.device,
|
||||||
'device_name': renderer.context.device_name,
|
"device_name": renderer.context.device_name,
|
||||||
'alive': True
|
"alive": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
current_state = ServerStates.LoadingModel
|
current_state = ServerStates.LoadingModel
|
||||||
@ -235,17 +281,14 @@ def thread_render(device):
|
|||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
weak_thread_data[threading.current_thread()] = {
|
weak_thread_data[threading.current_thread()] = {"error": e, "alive": False}
|
||||||
'error': e,
|
|
||||||
'alive': False
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
session_cache.clean()
|
session_cache.clean()
|
||||||
task_cache.clean()
|
task_cache.clean()
|
||||||
if not weak_thread_data[threading.current_thread()]['alive']:
|
if not weak_thread_data[threading.current_thread()]["alive"]:
|
||||||
log.info(f'Shutting down thread for device {renderer.context.device}')
|
log.info(f"Shutting down thread for device {renderer.context.device}")
|
||||||
model_manager.unload_all(renderer.context)
|
model_manager.unload_all(renderer.context)
|
||||||
return
|
return
|
||||||
if isinstance(current_state_error, SystemExit):
|
if isinstance(current_state_error, SystemExit):
|
||||||
@ -258,39 +301,47 @@ def thread_render(device):
|
|||||||
continue
|
continue
|
||||||
if task.error is not None:
|
if task.error is not None:
|
||||||
log.error(task.error)
|
log.error(task.error)
|
||||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
task.response = {"status": "failed", "detail": str(task.error)}
|
||||||
task.buffer_queue.put(json.dumps(task.response))
|
task.buffer_queue.put(json.dumps(task.response))
|
||||||
continue
|
continue
|
||||||
if current_state_error:
|
if current_state_error:
|
||||||
task.error = current_state_error
|
task.error = current_state_error
|
||||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
task.response = {"status": "failed", "detail": str(task.error)}
|
||||||
task.buffer_queue.put(json.dumps(task.response))
|
task.buffer_queue.put(json.dumps(task.response))
|
||||||
continue
|
continue
|
||||||
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}')
|
log.info(f"Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}")
|
||||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
if not task.lock.acquire(blocking=False):
|
||||||
|
raise Exception("Got locked task from queue.")
|
||||||
try:
|
try:
|
||||||
|
|
||||||
def step_callback():
|
def step_callback():
|
||||||
global current_state_error
|
global current_state_error
|
||||||
|
|
||||||
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
if (
|
||||||
|
isinstance(current_state_error, SystemExit)
|
||||||
|
or isinstance(current_state_error, StopAsyncIteration)
|
||||||
|
or isinstance(task.error, StopAsyncIteration)
|
||||||
|
):
|
||||||
renderer.context.stop_processing = True
|
renderer.context.stop_processing = True
|
||||||
if isinstance(current_state_error, StopAsyncIteration):
|
if isinstance(current_state_error, StopAsyncIteration):
|
||||||
task.error = current_state_error
|
task.error = current_state_error
|
||||||
current_state_error = None
|
current_state_error = None
|
||||||
log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}')
|
log.info(f"Session {task.task_data.session_id} sent cancel signal for task {id(task)}")
|
||||||
|
|
||||||
current_state = ServerStates.LoadingModel
|
current_state = ServerStates.LoadingModel
|
||||||
model_manager.resolve_model_paths(task.task_data)
|
model_manager.resolve_model_paths(task.task_data)
|
||||||
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
|
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
|
||||||
|
|
||||||
current_state = ServerStates.Rendering
|
current_state = ServerStates.Rendering
|
||||||
task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
task.response = renderer.make_images(
|
||||||
|
task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback
|
||||||
|
)
|
||||||
# Before looping back to the generator, mark cache as still alive.
|
# Before looping back to the generator, mark cache as still alive.
|
||||||
task_cache.keep(id(task), TASK_TTL)
|
task_cache.keep(id(task), TASK_TTL)
|
||||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
task.error = str(e)
|
task.error = str(e)
|
||||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
task.response = {"status": "failed", "detail": str(task.error)}
|
||||||
task.buffer_queue.put(json.dumps(task.response))
|
task.buffer_queue.put(json.dumps(task.response))
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
finally:
|
finally:
|
||||||
@ -299,21 +350,25 @@ def thread_render(device):
|
|||||||
task_cache.keep(id(task), TASK_TTL)
|
task_cache.keep(id(task), TASK_TTL)
|
||||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||||
if isinstance(task.error, StopAsyncIteration):
|
if isinstance(task.error, StopAsyncIteration):
|
||||||
log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!')
|
log.info(f"Session {task.task_data.session_id} task {id(task)} cancelled!")
|
||||||
elif task.error is not None:
|
elif task.error is not None:
|
||||||
log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
|
log.info(f"Session {task.task_data.session_id} task {id(task)} failed!")
|
||||||
else:
|
else:
|
||||||
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.')
|
log.info(
|
||||||
|
f"Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}."
|
||||||
|
)
|
||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
|
|
||||||
def get_cached_task(task_id:str, update_ttl:bool=False):
|
|
||||||
|
def get_cached_task(task_id: str, update_ttl: bool = False):
|
||||||
# By calling keep before tryGet, wont discard if was expired.
|
# By calling keep before tryGet, wont discard if was expired.
|
||||||
if update_ttl and not task_cache.keep(task_id, TASK_TTL):
|
if update_ttl and not task_cache.keep(task_id, TASK_TTL):
|
||||||
# Failed to keep task, already gone.
|
# Failed to keep task, already gone.
|
||||||
return None
|
return None
|
||||||
return task_cache.tryGet(task_id)
|
return task_cache.tryGet(task_id)
|
||||||
|
|
||||||
def get_cached_session(session_id:str, update_ttl:bool=False):
|
|
||||||
|
def get_cached_session(session_id: str, update_ttl: bool = False):
|
||||||
if update_ttl:
|
if update_ttl:
|
||||||
session_cache.keep(session_id, TASK_TTL)
|
session_cache.keep(session_id, TASK_TTL)
|
||||||
session = session_cache.tryGet(session_id)
|
session = session_cache.tryGet(session_id)
|
||||||
@ -322,64 +377,68 @@ def get_cached_session(session_id:str, update_ttl:bool=False):
|
|||||||
session_cache.put(session_id, session, TASK_TTL)
|
session_cache.put(session_id, session, TASK_TTL)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
def get_devices():
|
def get_devices():
|
||||||
devices = {
|
devices = {
|
||||||
'all': {},
|
"all": {},
|
||||||
'active': {},
|
"active": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_device_info(device):
|
def get_device_info(device):
|
||||||
if device == 'cpu':
|
if device == "cpu":
|
||||||
return {'name': device_manager.get_processor_name()}
|
return {"name": device_manager.get_processor_name()}
|
||||||
|
|
||||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||||
mem_free /= float(10**9)
|
mem_free /= float(10**9)
|
||||||
mem_total /= float(10**9)
|
mem_total /= float(10**9)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'name': torch.cuda.get_device_name(device),
|
"name": torch.cuda.get_device_name(device),
|
||||||
'mem_free': mem_free,
|
"mem_free": mem_free,
|
||||||
'mem_total': mem_total,
|
"mem_total": mem_total,
|
||||||
'max_vram_usage_level': device_manager.get_max_vram_usage_level(device),
|
"max_vram_usage_level": device_manager.get_max_vram_usage_level(device),
|
||||||
}
|
}
|
||||||
|
|
||||||
# list the compatible devices
|
# list the compatible devices
|
||||||
gpu_count = torch.cuda.device_count()
|
gpu_count = torch.cuda.device_count()
|
||||||
for device in range(gpu_count):
|
for device in range(gpu_count):
|
||||||
device = f'cuda:{device}'
|
device = f"cuda:{device}"
|
||||||
if not device_manager.is_device_compatible(device):
|
if not device_manager.is_device_compatible(device):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
devices['all'].update({device: get_device_info(device)})
|
devices["all"].update({device: get_device_info(device)})
|
||||||
|
|
||||||
devices['all'].update({'cpu': get_device_info('cpu')})
|
devices["all"].update({"cpu": get_device_info("cpu")})
|
||||||
|
|
||||||
# list the activated devices
|
# list the activated devices
|
||||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('get_devices' + ERR_LOCK_FAILED)
|
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
|
raise Exception("get_devices" + ERR_LOCK_FAILED)
|
||||||
try:
|
try:
|
||||||
for rthread in render_threads:
|
for rthread in render_threads:
|
||||||
if not rthread.is_alive():
|
if not rthread.is_alive():
|
||||||
continue
|
continue
|
||||||
weak_data = weak_thread_data.get(rthread)
|
weak_data = weak_thread_data.get(rthread)
|
||||||
if not weak_data or not 'device' in weak_data or not 'device_name' in weak_data:
|
if not weak_data or not "device" in weak_data or not "device_name" in weak_data:
|
||||||
continue
|
continue
|
||||||
device = weak_data['device']
|
device = weak_data["device"]
|
||||||
devices['active'].update({device: get_device_info(device)})
|
devices["active"].update({device: get_device_info(device)})
|
||||||
finally:
|
finally:
|
||||||
manager_lock.release()
|
manager_lock.release()
|
||||||
|
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
|
|
||||||
def is_alive(device=None):
|
def is_alive(device=None):
|
||||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED)
|
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
|
raise Exception("is_alive" + ERR_LOCK_FAILED)
|
||||||
nbr_alive = 0
|
nbr_alive = 0
|
||||||
try:
|
try:
|
||||||
for rthread in render_threads:
|
for rthread in render_threads:
|
||||||
if device is not None:
|
if device is not None:
|
||||||
weak_data = weak_thread_data.get(rthread)
|
weak_data = weak_thread_data.get(rthread)
|
||||||
if weak_data is None or not 'device' in weak_data or weak_data['device'] is None:
|
if weak_data is None or not "device" in weak_data or weak_data["device"] is None:
|
||||||
continue
|
continue
|
||||||
thread_device = weak_data['device']
|
thread_device = weak_data["device"]
|
||||||
if thread_device != device:
|
if thread_device != device:
|
||||||
continue
|
continue
|
||||||
if rthread.is_alive():
|
if rthread.is_alive():
|
||||||
@ -388,11 +447,13 @@ def is_alive(device=None):
|
|||||||
finally:
|
finally:
|
||||||
manager_lock.release()
|
manager_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def start_render_thread(device):
|
def start_render_thread(device):
|
||||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED)
|
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
log.info(f'Start new Rendering Thread on device: {device}')
|
raise Exception("start_render_thread" + ERR_LOCK_FAILED)
|
||||||
|
log.info(f"Start new Rendering Thread on device: {device}")
|
||||||
try:
|
try:
|
||||||
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
|
rthread = threading.Thread(target=thread_render, kwargs={"device": device})
|
||||||
rthread.daemon = True
|
rthread.daemon = True
|
||||||
rthread.name = THREAD_NAME_PREFIX + device
|
rthread.name = THREAD_NAME_PREFIX + device
|
||||||
rthread.start()
|
rthread.start()
|
||||||
@ -400,8 +461,8 @@ def start_render_thread(device):
|
|||||||
finally:
|
finally:
|
||||||
manager_lock.release()
|
manager_lock.release()
|
||||||
timeout = DEVICE_START_TIMEOUT
|
timeout = DEVICE_START_TIMEOUT
|
||||||
while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]:
|
while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]:
|
||||||
if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]:
|
if rthread in weak_thread_data and "error" in weak_thread_data[rthread]:
|
||||||
log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
|
log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
|
||||||
return False
|
return False
|
||||||
if timeout <= 0:
|
if timeout <= 0:
|
||||||
@ -410,25 +471,27 @@ def start_render_thread(device):
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def stop_render_thread(device):
|
def stop_render_thread(device):
|
||||||
try:
|
try:
|
||||||
device_manager.validate_device_id(device, log_prefix='stop_render_thread')
|
device_manager.validate_device_id(device, log_prefix="stop_render_thread")
|
||||||
except:
|
except:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED)
|
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||||
log.info(f'Stopping Rendering Thread on device: {device}')
|
raise Exception("stop_render_thread" + ERR_LOCK_FAILED)
|
||||||
|
log.info(f"Stopping Rendering Thread on device: {device}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
thread_to_remove = None
|
thread_to_remove = None
|
||||||
for rthread in render_threads:
|
for rthread in render_threads:
|
||||||
weak_data = weak_thread_data.get(rthread)
|
weak_data = weak_thread_data.get(rthread)
|
||||||
if weak_data is None or not 'device' in weak_data or weak_data['device'] is None:
|
if weak_data is None or not "device" in weak_data or weak_data["device"] is None:
|
||||||
continue
|
continue
|
||||||
thread_device = weak_data['device']
|
thread_device = weak_data["device"]
|
||||||
if thread_device == device:
|
if thread_device == device:
|
||||||
weak_data['alive'] = False
|
weak_data["alive"] = False
|
||||||
thread_to_remove = rthread
|
thread_to_remove = rthread
|
||||||
break
|
break
|
||||||
if thread_to_remove is not None:
|
if thread_to_remove is not None:
|
||||||
@ -439,44 +502,51 @@ def stop_render_thread(device):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def update_render_threads(render_devices, active_devices):
|
def update_render_threads(render_devices, active_devices):
|
||||||
devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices)
|
devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices)
|
||||||
log.debug(f'devices_to_start: {devices_to_start}')
|
log.debug(f"devices_to_start: {devices_to_start}")
|
||||||
log.debug(f'devices_to_stop: {devices_to_stop}')
|
log.debug(f"devices_to_stop: {devices_to_stop}")
|
||||||
|
|
||||||
for device in devices_to_stop:
|
for device in devices_to_stop:
|
||||||
if is_alive(device) <= 0:
|
if is_alive(device) <= 0:
|
||||||
log.debug(f'{device} is not alive')
|
log.debug(f"{device} is not alive")
|
||||||
continue
|
continue
|
||||||
if not stop_render_thread(device):
|
if not stop_render_thread(device):
|
||||||
log.warn(f'{device} could not stop render thread')
|
log.warn(f"{device} could not stop render thread")
|
||||||
|
|
||||||
for device in devices_to_start:
|
for device in devices_to_start:
|
||||||
if is_alive(device) >= 1:
|
if is_alive(device) >= 1:
|
||||||
log.debug(f'{device} already registered.')
|
log.debug(f"{device} already registered.")
|
||||||
continue
|
continue
|
||||||
if not start_render_thread(device):
|
if not start_render_thread(device):
|
||||||
log.warn(f'{device} failed to start.')
|
log.warn(f"{device} failed to start.")
|
||||||
|
|
||||||
if is_alive() <= 0: # No running devices, probably invalid user config.
|
if is_alive() <= 0: # No running devices, probably invalid user config.
|
||||||
raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json')
|
raise EnvironmentError(
|
||||||
|
'ERROR: No active render devices! Please verify the "render_devices" value in config.json'
|
||||||
|
)
|
||||||
|
|
||||||
log.debug(f"active devices: {get_devices()['active']}")
|
log.debug(f"active devices: {get_devices()['active']}")
|
||||||
|
|
||||||
def shutdown_event(): # Signal render thread to close on shutdown
|
|
||||||
|
def shutdown_event(): # Signal render thread to close on shutdown
|
||||||
global current_state_error
|
global current_state_error
|
||||||
current_state_error = SystemExit('Application shutting down.')
|
current_state_error = SystemExit("Application shutting down.")
|
||||||
|
|
||||||
|
|
||||||
def render(render_req: GenerateImageRequest, task_data: TaskData):
|
def render(render_req: GenerateImageRequest, task_data: TaskData):
|
||||||
current_thread_count = is_alive()
|
current_thread_count = is_alive()
|
||||||
if current_thread_count <= 0: # Render thread is dead
|
if current_thread_count <= 0: # Render thread is dead
|
||||||
raise ChildProcessError('Rendering thread has died.')
|
raise ChildProcessError("Rendering thread has died.")
|
||||||
|
|
||||||
# Alive, check if task in cache
|
# Alive, check if task in cache
|
||||||
session = get_cached_session(task_data.session_id, update_ttl=True)
|
session = get_cached_session(task_data.session_id, update_ttl=True)
|
||||||
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
|
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
|
||||||
if current_thread_count < len(pending_tasks):
|
if current_thread_count < len(pending_tasks):
|
||||||
raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
|
raise ConnectionRefusedError(
|
||||||
|
f"Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}."
|
||||||
|
)
|
||||||
|
|
||||||
new_task = RenderTask(render_req, task_data)
|
new_task = RenderTask(render_req, task_data)
|
||||||
if session.put(new_task, TASK_TTL):
|
if session.put(new_task, TASK_TTL):
|
||||||
@ -489,4 +559,4 @@ def render(render_req: GenerateImageRequest, task_data: TaskData):
|
|||||||
return new_task
|
return new_task
|
||||||
finally:
|
finally:
|
||||||
manager_lock.release()
|
manager_lock.release()
|
||||||
raise RuntimeError('Failed to add task to cache.')
|
raise RuntimeError("Failed to add task to cache.")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
class GenerateImageRequest(BaseModel):
|
class GenerateImageRequest(BaseModel):
|
||||||
prompt: str = ""
|
prompt: str = ""
|
||||||
negative_prompt: str = ""
|
negative_prompt: str = ""
|
||||||
@ -18,29 +19,31 @@ class GenerateImageRequest(BaseModel):
|
|||||||
prompt_strength: float = 0.8
|
prompt_strength: float = 0.8
|
||||||
preserve_init_image_color_profile = False
|
preserve_init_image_color_profile = False
|
||||||
|
|
||||||
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
||||||
hypernetwork_strength: float = 0
|
hypernetwork_strength: float = 0
|
||||||
|
|
||||||
|
|
||||||
class TaskData(BaseModel):
|
class TaskData(BaseModel):
|
||||||
request_id: str = None
|
request_id: str = None
|
||||||
session_id: str = "session"
|
session_id: str = "session"
|
||||||
save_to_disk_path: str = None
|
save_to_disk_path: str = None
|
||||||
vram_usage_level: str = "balanced" # or "low" or "medium"
|
vram_usage_level: str = "balanced" # or "low" or "medium"
|
||||||
|
|
||||||
use_face_correction: str = None # or "GFPGANv1.3"
|
use_face_correction: str = None # or "GFPGANv1.3"
|
||||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||||
upscale_amount: int = 4 # or 2
|
upscale_amount: int = 4 # or 2
|
||||||
use_stable_diffusion_model: str = "sd-v1-4"
|
use_stable_diffusion_model: str = "sd-v1-4"
|
||||||
# use_stable_diffusion_config: str = "v1-inference"
|
# use_stable_diffusion_config: str = "v1-inference"
|
||||||
use_vae_model: str = None
|
use_vae_model: str = None
|
||||||
use_hypernetwork_model: str = None
|
use_hypernetwork_model: str = None
|
||||||
|
|
||||||
show_only_filtered_image: bool = False
|
show_only_filtered_image: bool = False
|
||||||
output_format: str = "jpeg" # or "png"
|
output_format: str = "jpeg" # or "png"
|
||||||
output_quality: int = 75
|
output_quality: int = 75
|
||||||
metadata_output_format: str = "txt" # or "json"
|
metadata_output_format: str = "txt" # or "json"
|
||||||
stream_image_progress: bool = False
|
stream_image_progress: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MergeRequest(BaseModel):
|
class MergeRequest(BaseModel):
|
||||||
model0: str = None
|
model0: str = None
|
||||||
model1: str = None
|
model1: str = None
|
||||||
@ -48,8 +51,9 @@ class MergeRequest(BaseModel):
|
|||||||
out_path: str = "mix"
|
out_path: str = "mix"
|
||||||
use_fp16 = True
|
use_fp16 = True
|
||||||
|
|
||||||
|
|
||||||
class Image:
|
class Image:
|
||||||
data: str # base64
|
data: str # base64
|
||||||
seed: int
|
seed: int
|
||||||
is_nsfw: bool
|
is_nsfw: bool
|
||||||
path_abs: str = None
|
path_abs: str = None
|
||||||
@ -65,6 +69,7 @@ class Image:
|
|||||||
"path_abs": self.path_abs,
|
"path_abs": self.path_abs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Response:
|
class Response:
|
||||||
render_request: GenerateImageRequest
|
render_request: GenerateImageRequest
|
||||||
task_data: TaskData
|
task_data: TaskData
|
||||||
@ -80,7 +85,7 @@ class Response:
|
|||||||
del self.render_request.init_image_mask
|
del self.render_request.init_image_mask
|
||||||
|
|
||||||
res = {
|
res = {
|
||||||
"status": 'succeeded',
|
"status": "succeeded",
|
||||||
"render_request": self.render_request.dict(),
|
"render_request": self.render_request.dict(),
|
||||||
"task_data": self.task_data.dict(),
|
"task_data": self.task_data.dict(),
|
||||||
"output": [],
|
"output": [],
|
||||||
@ -91,5 +96,6 @@ class Response:
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
class UserInitiatedStop(Exception):
|
class UserInitiatedStop(Exception):
|
||||||
pass
|
pass
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
log = logging.getLogger('easydiffusion')
|
log = logging.getLogger("easydiffusion")
|
||||||
|
|
||||||
from .save_utils import (
|
from .save_utils import (
|
||||||
save_images_to_disk,
|
save_images_to_disk,
|
||||||
get_printable_request,
|
get_printable_request,
|
||||||
)
|
)
|
||||||
|
@ -7,89 +7,126 @@ from easydiffusion.types import TaskData, GenerateImageRequest
|
|||||||
|
|
||||||
from sdkit.utils import save_images, save_dicts
|
from sdkit.utils import save_images, save_dicts
|
||||||
|
|
||||||
filename_regex = re.compile('[^a-zA-Z0-9._-]')
|
filename_regex = re.compile("[^a-zA-Z0-9._-]")
|
||||||
|
|
||||||
# keep in sync with `ui/media/js/dnd.js`
|
# keep in sync with `ui/media/js/dnd.js`
|
||||||
TASK_TEXT_MAPPING = {
|
TASK_TEXT_MAPPING = {
|
||||||
'prompt': 'Prompt',
|
"prompt": "Prompt",
|
||||||
'width': 'Width',
|
"width": "Width",
|
||||||
'height': 'Height',
|
"height": "Height",
|
||||||
'seed': 'Seed',
|
"seed": "Seed",
|
||||||
'num_inference_steps': 'Steps',
|
"num_inference_steps": "Steps",
|
||||||
'guidance_scale': 'Guidance Scale',
|
"guidance_scale": "Guidance Scale",
|
||||||
'prompt_strength': 'Prompt Strength',
|
"prompt_strength": "Prompt Strength",
|
||||||
'use_face_correction': 'Use Face Correction',
|
"use_face_correction": "Use Face Correction",
|
||||||
'use_upscale': 'Use Upscaling',
|
"use_upscale": "Use Upscaling",
|
||||||
'upscale_amount': 'Upscale By',
|
"upscale_amount": "Upscale By",
|
||||||
'sampler_name': 'Sampler',
|
"sampler_name": "Sampler",
|
||||||
'negative_prompt': 'Negative Prompt',
|
"negative_prompt": "Negative Prompt",
|
||||||
'use_stable_diffusion_model': 'Stable Diffusion model',
|
"use_stable_diffusion_model": "Stable Diffusion model",
|
||||||
'use_vae_model': 'VAE model',
|
"use_vae_model": "VAE model",
|
||||||
'use_hypernetwork_model': 'Hypernetwork model',
|
"use_hypernetwork_model": "Hypernetwork model",
|
||||||
'hypernetwork_strength': 'Hypernetwork Strength'
|
"hypernetwork_strength": "Hypernetwork Strength",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
|
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
|
||||||
now = time.time()
|
now = time.time()
|
||||||
save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub("_", task_data.session_id))
|
||||||
metadata_entries = get_metadata_entries_for_request(req, task_data)
|
metadata_entries = get_metadata_entries_for_request(req, task_data)
|
||||||
make_filename = make_filename_callback(req, now=now)
|
make_filename = make_filename_callback(req, now=now)
|
||||||
|
|
||||||
if task_data.show_only_filtered_image or filtered_images is images:
|
if task_data.show_only_filtered_image or filtered_images is images:
|
||||||
save_images(filtered_images, save_dir_path, file_name=make_filename, output_format=task_data.output_format, output_quality=task_data.output_quality)
|
save_images(
|
||||||
if task_data.metadata_output_format.lower() in ['json', 'txt', 'embed']:
|
filtered_images,
|
||||||
save_dicts(metadata_entries, save_dir_path, file_name=make_filename, output_format=task_data.metadata_output_format, file_format=task_data.output_format)
|
save_dir_path,
|
||||||
|
file_name=make_filename,
|
||||||
|
output_format=task_data.output_format,
|
||||||
|
output_quality=task_data.output_quality,
|
||||||
|
)
|
||||||
|
if task_data.metadata_output_format.lower() in ["json", "txt", "embed"]:
|
||||||
|
save_dicts(
|
||||||
|
metadata_entries,
|
||||||
|
save_dir_path,
|
||||||
|
file_name=make_filename,
|
||||||
|
output_format=task_data.metadata_output_format,
|
||||||
|
file_format=task_data.output_format,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
make_filter_filename = make_filename_callback(req, now=now, suffix='filtered')
|
make_filter_filename = make_filename_callback(req, now=now, suffix="filtered")
|
||||||
|
|
||||||
|
save_images(
|
||||||
|
images,
|
||||||
|
save_dir_path,
|
||||||
|
file_name=make_filename,
|
||||||
|
output_format=task_data.output_format,
|
||||||
|
output_quality=task_data.output_quality,
|
||||||
|
)
|
||||||
|
save_images(
|
||||||
|
filtered_images,
|
||||||
|
save_dir_path,
|
||||||
|
file_name=make_filter_filename,
|
||||||
|
output_format=task_data.output_format,
|
||||||
|
output_quality=task_data.output_quality,
|
||||||
|
)
|
||||||
|
if task_data.metadata_output_format.lower() in ["json", "txt", "embed"]:
|
||||||
|
save_dicts(
|
||||||
|
metadata_entries,
|
||||||
|
save_dir_path,
|
||||||
|
file_name=make_filter_filename,
|
||||||
|
output_format=task_data.metadata_output_format,
|
||||||
|
file_format=task_data.output_format,
|
||||||
|
)
|
||||||
|
|
||||||
save_images(images, save_dir_path, file_name=make_filename, output_format=task_data.output_format, output_quality=task_data.output_quality)
|
|
||||||
save_images(filtered_images, save_dir_path, file_name=make_filter_filename, output_format=task_data.output_format, output_quality=task_data.output_quality)
|
|
||||||
if task_data.metadata_output_format.lower() in ['json', 'txt', 'embed']:
|
|
||||||
save_dicts(metadata_entries, save_dir_path, file_name=make_filter_filename, output_format=task_data.metadata_output_format, file_format=task_data.output_format)
|
|
||||||
|
|
||||||
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData):
|
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData):
|
||||||
metadata = get_printable_request(req)
|
metadata = get_printable_request(req)
|
||||||
metadata.update({
|
metadata.update(
|
||||||
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
|
{
|
||||||
'use_vae_model': task_data.use_vae_model,
|
"use_stable_diffusion_model": task_data.use_stable_diffusion_model,
|
||||||
'use_hypernetwork_model': task_data.use_hypernetwork_model,
|
"use_vae_model": task_data.use_vae_model,
|
||||||
'use_face_correction': task_data.use_face_correction,
|
"use_hypernetwork_model": task_data.use_hypernetwork_model,
|
||||||
'use_upscale': task_data.use_upscale,
|
"use_face_correction": task_data.use_face_correction,
|
||||||
})
|
"use_upscale": task_data.use_upscale,
|
||||||
if metadata['use_upscale'] is not None:
|
}
|
||||||
metadata['upscale_amount'] = task_data.upscale_amount
|
)
|
||||||
if (task_data.use_hypernetwork_model is None):
|
if metadata["use_upscale"] is not None:
|
||||||
del metadata['hypernetwork_strength']
|
metadata["upscale_amount"] = task_data.upscale_amount
|
||||||
|
if task_data.use_hypernetwork_model is None:
|
||||||
|
del metadata["hypernetwork_strength"]
|
||||||
|
|
||||||
# if text, format it in the text format expected by the UI
|
# if text, format it in the text format expected by the UI
|
||||||
is_txt_format = (task_data.metadata_output_format.lower() == 'txt')
|
is_txt_format = task_data.metadata_output_format.lower() == "txt"
|
||||||
if is_txt_format:
|
if is_txt_format:
|
||||||
metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING}
|
metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING}
|
||||||
|
|
||||||
entries = [metadata.copy() for _ in range(req.num_outputs)]
|
entries = [metadata.copy() for _ in range(req.num_outputs)]
|
||||||
for i, entry in enumerate(entries):
|
for i, entry in enumerate(entries):
|
||||||
entry['Seed' if is_txt_format else 'seed'] = req.seed + i
|
entry["Seed" if is_txt_format else "seed"] = req.seed + i
|
||||||
|
|
||||||
return entries
|
return entries
|
||||||
|
|
||||||
|
|
||||||
def get_printable_request(req: GenerateImageRequest):
|
def get_printable_request(req: GenerateImageRequest):
|
||||||
metadata = req.dict()
|
metadata = req.dict()
|
||||||
del metadata['init_image']
|
del metadata["init_image"]
|
||||||
del metadata['init_image_mask']
|
del metadata["init_image_mask"]
|
||||||
if (req.init_image is None):
|
if req.init_image is None:
|
||||||
del metadata['prompt_strength']
|
del metadata["prompt_strength"]
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
def make_filename_callback(req: GenerateImageRequest, suffix=None, now=None):
|
def make_filename_callback(req: GenerateImageRequest, suffix=None, now=None):
|
||||||
if now is None:
|
if now is None:
|
||||||
now = time.time()
|
now = time.time()
|
||||||
def make_filename(i):
|
|
||||||
img_id = base64.b64encode(int(now+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
|
||||||
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
|
||||||
|
|
||||||
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
|
def make_filename(i):
|
||||||
|
img_id = base64.b64encode(int(now + i).to_bytes(8, "big")).decode() # Generate unique ID based on time.
|
||||||
|
img_id = img_id.translate({43: None, 47: None, 61: None})[-8:] # Remove + / = and keep last 8 chars.
|
||||||
|
|
||||||
|
prompt_flattened = filename_regex.sub("_", req.prompt)[:50]
|
||||||
name = f"{prompt_flattened}_{img_id}"
|
name = f"{prompt_flattened}_{img_id}"
|
||||||
name = name if suffix is None else f'{name}_{suffix}'
|
name = name if suffix is None else f"{name}_{suffix}"
|
||||||
return name
|
return name
|
||||||
|
|
||||||
return make_filename
|
return make_filename
|
||||||
|
Loading…
Reference in New Issue
Block a user