forked from extern/easydiffusion
Format code, PEP8 using Black
This commit is contained in:
parent
0ad08c609d
commit
2eb317c6b6
@ -16,7 +16,7 @@ from easydiffusion.utils import log
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
LOG_FORMAT = '%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s'
|
||||
LOG_FORMAT = "%(asctime)s.%(msecs)03d %(levelname)s %(threadName)s %(message)s"
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format=LOG_FORMAT,
|
||||
@ -26,73 +26,76 @@ logging.basicConfig(
|
||||
|
||||
SD_DIR = os.getcwd()
|
||||
|
||||
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
|
||||
SD_UI_DIR = os.getenv("SD_UI_PATH", None)
|
||||
sys.path.append(os.path.dirname(SD_UI_DIR))
|
||||
|
||||
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
|
||||
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
|
||||
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts"))
|
||||
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models"))
|
||||
|
||||
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'))
|
||||
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
|
||||
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
|
||||
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins", "ui"))
|
||||
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins", "ui"))
|
||||
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, "core"), (USER_UI_PLUGINS_DIR, "user"))
|
||||
|
||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||
PRESERVE_CONFIG_VARS = ['FORCE_FULL_PRECISION']
|
||||
PRESERVE_CONFIG_VARS = ["FORCE_FULL_PRECISION"]
|
||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||
APP_CONFIG_DEFAULTS = {
|
||||
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
|
||||
'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
|
||||
'update_branch': 'main',
|
||||
'ui': {
|
||||
'open_browser_on_start': True,
|
||||
}
|
||||
"render_devices": "auto", # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
|
||||
"update_branch": "main",
|
||||
"ui": {
|
||||
"open_browser_on_start": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def init():
|
||||
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
||||
|
||||
update_render_threads()
|
||||
|
||||
|
||||
def getConfig(default_val=APP_CONFIG_DEFAULTS):
|
||||
try:
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
config_json_path = os.path.join(CONFIG_DIR, "config.json")
|
||||
if not os.path.exists(config_json_path):
|
||||
config = default_val
|
||||
else:
|
||||
with open(config_json_path, 'r', encoding='utf-8') as f:
|
||||
with open(config_json_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
if 'net' not in config:
|
||||
config['net'] = {}
|
||||
if os.getenv('SD_UI_BIND_PORT') is not None:
|
||||
config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT'))
|
||||
if "net" not in config:
|
||||
config["net"] = {}
|
||||
if os.getenv("SD_UI_BIND_PORT") is not None:
|
||||
config["net"]["listen_port"] = int(os.getenv("SD_UI_BIND_PORT"))
|
||||
else:
|
||||
config['net']['listen_port'] = 9000
|
||||
if os.getenv('SD_UI_BIND_IP') is not None:
|
||||
config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0')
|
||||
config["net"]["listen_port"] = 9000
|
||||
if os.getenv("SD_UI_BIND_IP") is not None:
|
||||
config["net"]["listen_to_network"] = os.getenv("SD_UI_BIND_IP") == "0.0.0.0"
|
||||
else:
|
||||
config['net']['listen_to_network'] = True
|
||||
config["net"]["listen_to_network"] = True
|
||||
return config
|
||||
except Exception as e:
|
||||
log.warn(traceback.format_exc())
|
||||
return default_val
|
||||
|
||||
|
||||
def setConfig(config):
|
||||
try: # config.json
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
with open(config_json_path, 'w', encoding='utf-8') as f:
|
||||
config_json_path = os.path.join(CONFIG_DIR, "config.json")
|
||||
with open(config_json_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f)
|
||||
except:
|
||||
log.error(traceback.format_exc())
|
||||
|
||||
try: # config.bat
|
||||
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
|
||||
config_bat_path = os.path.join(CONFIG_DIR, "config.bat")
|
||||
config_bat = []
|
||||
|
||||
if 'update_branch' in config:
|
||||
if "update_branch" in config:
|
||||
config_bat.append(f"@set update_branch={config['update_branch']}")
|
||||
|
||||
config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}")
|
||||
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
|
||||
bind_ip = "0.0.0.0" if config["net"]["listen_to_network"] else "127.0.0.1"
|
||||
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
|
||||
|
||||
# Preserve these variables if they are set
|
||||
@ -101,20 +104,20 @@ def setConfig(config):
|
||||
config_bat.append(f"@set {var}={os.getenv(var)}")
|
||||
|
||||
if len(config_bat) > 0:
|
||||
with open(config_bat_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\r\n'.join(config_bat))
|
||||
with open(config_bat_path, "w", encoding="utf-8") as f:
|
||||
f.write("\r\n".join(config_bat))
|
||||
except:
|
||||
log.error(traceback.format_exc())
|
||||
|
||||
try: # config.sh
|
||||
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
|
||||
config_sh = ['#!/bin/bash']
|
||||
config_sh_path = os.path.join(CONFIG_DIR, "config.sh")
|
||||
config_sh = ["#!/bin/bash"]
|
||||
|
||||
if 'update_branch' in config:
|
||||
if "update_branch" in config:
|
||||
config_sh.append(f"export update_branch={config['update_branch']}")
|
||||
|
||||
config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}")
|
||||
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
|
||||
bind_ip = "0.0.0.0" if config["net"]["listen_to_network"] else "127.0.0.1"
|
||||
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
|
||||
|
||||
# Preserve these variables if they are set
|
||||
@ -123,47 +126,51 @@ def setConfig(config):
|
||||
config_bat.append(f'export {var}="{shlex.quote(os.getenv(var))}"')
|
||||
|
||||
if len(config_sh) > 1:
|
||||
with open(config_sh_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(config_sh))
|
||||
with open(config_sh_path, "w", encoding="utf-8") as f:
|
||||
f.write("\n".join(config_sh))
|
||||
except:
|
||||
log.error(traceback.format_exc())
|
||||
|
||||
|
||||
def save_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name, vram_usage_level):
|
||||
config = getConfig()
|
||||
if 'model' not in config:
|
||||
config['model'] = {}
|
||||
if "model" not in config:
|
||||
config["model"] = {}
|
||||
|
||||
config['model']['stable-diffusion'] = ckpt_model_name
|
||||
config['model']['vae'] = vae_model_name
|
||||
config['model']['hypernetwork'] = hypernetwork_model_name
|
||||
config["model"]["stable-diffusion"] = ckpt_model_name
|
||||
config["model"]["vae"] = vae_model_name
|
||||
config["model"]["hypernetwork"] = hypernetwork_model_name
|
||||
|
||||
if vae_model_name is None or vae_model_name == "":
|
||||
del config['model']['vae']
|
||||
del config["model"]["vae"]
|
||||
if hypernetwork_model_name is None or hypernetwork_model_name == "":
|
||||
del config['model']['hypernetwork']
|
||||
del config["model"]["hypernetwork"]
|
||||
|
||||
config['vram_usage_level'] = vram_usage_level
|
||||
config["vram_usage_level"] = vram_usage_level
|
||||
|
||||
setConfig(config)
|
||||
|
||||
|
||||
def update_render_threads():
|
||||
config = getConfig()
|
||||
render_devices = config.get('render_devices', 'auto')
|
||||
active_devices = task_manager.get_devices()['active'].keys()
|
||||
render_devices = config.get("render_devices", "auto")
|
||||
active_devices = task_manager.get_devices()["active"].keys()
|
||||
|
||||
log.debug(f'requesting for render_devices: {render_devices}')
|
||||
log.debug(f"requesting for render_devices: {render_devices}")
|
||||
task_manager.update_render_threads(render_devices, active_devices)
|
||||
|
||||
|
||||
def getUIPlugins():
|
||||
plugins = []
|
||||
|
||||
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
|
||||
for file in os.listdir(plugins_dir):
|
||||
if file.endswith('.plugin.js'):
|
||||
plugins.append(f'/plugins/{dir_prefix}/{file}')
|
||||
if file.endswith(".plugin.js"):
|
||||
plugins.append(f"/plugins/{dir_prefix}/{file}")
|
||||
|
||||
return plugins
|
||||
|
||||
|
||||
def getIPConfig():
|
||||
try:
|
||||
ips = socket.gethostbyname_ex(socket.gethostname())
|
||||
@ -173,10 +180,13 @@ def getIPConfig():
|
||||
log.exception(e)
|
||||
return []
|
||||
|
||||
|
||||
def open_browser():
|
||||
config = getConfig()
|
||||
ui = config.get('ui', {})
|
||||
net = config.get('net', {'listen_port':9000})
|
||||
port = net.get('listen_port', 9000)
|
||||
if ui.get('open_browser_on_start', True):
|
||||
import webbrowser; webbrowser.open(f"http://localhost:{port}")
|
||||
ui = config.get("ui", {})
|
||||
net = config.get("net", {"listen_port": 9000})
|
||||
port = net.get("listen_port", 9000)
|
||||
if ui.get("open_browser_on_start", True):
|
||||
import webbrowser
|
||||
|
||||
webbrowser.open(f"http://localhost:{port}")
|
||||
|
@ -5,45 +5,54 @@ import re
|
||||
|
||||
from easydiffusion.utils import log
|
||||
|
||||
'''
|
||||
"""
|
||||
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
|
||||
Otherwise the models will load at half-precision (i.e. float16).
|
||||
|
||||
Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs).
|
||||
'''
|
||||
"""
|
||||
|
||||
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
|
||||
COMPARABLE_GPU_PERCENTILE = (
|
||||
0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
|
||||
)
|
||||
|
||||
mem_free_threshold = 0
|
||||
|
||||
|
||||
def get_device_delta(render_devices, active_devices):
|
||||
'''
|
||||
"""
|
||||
render_devices: 'cpu', or 'auto' or ['cuda:N'...]
|
||||
active_devices: ['cpu', 'cuda:N'...]
|
||||
'''
|
||||
"""
|
||||
|
||||
if render_devices in ('cpu', 'auto'):
|
||||
if render_devices in ("cpu", "auto"):
|
||||
render_devices = [render_devices]
|
||||
elif render_devices is not None:
|
||||
if isinstance(render_devices, str):
|
||||
render_devices = [render_devices]
|
||||
if isinstance(render_devices, list) and len(render_devices) > 0:
|
||||
render_devices = list(filter(lambda x: x.startswith('cuda:'), render_devices))
|
||||
render_devices = list(filter(lambda x: x.startswith("cuda:"), render_devices))
|
||||
if len(render_devices) == 0:
|
||||
raise Exception('Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}')
|
||||
raise Exception(
|
||||
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
|
||||
)
|
||||
|
||||
render_devices = list(filter(lambda x: is_device_compatible(x), render_devices))
|
||||
if len(render_devices) == 0:
|
||||
raise Exception('Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion')
|
||||
raise Exception(
|
||||
"Sorry, none of the render_devices configured in config.json are compatible with Stable Diffusion"
|
||||
)
|
||||
else:
|
||||
raise Exception('Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}')
|
||||
raise Exception(
|
||||
'Invalid render_devices value in config.json. Valid: {"render_devices": ["cuda:0", "cuda:1"...]}, or {"render_devices": "cpu"} or {"render_devices": "auto"}'
|
||||
)
|
||||
else:
|
||||
render_devices = ['auto']
|
||||
render_devices = ["auto"]
|
||||
|
||||
if 'auto' in render_devices:
|
||||
if "auto" in render_devices:
|
||||
render_devices = auto_pick_devices(active_devices)
|
||||
if 'cpu' in render_devices:
|
||||
log.warn('WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!')
|
||||
if "cpu" in render_devices:
|
||||
log.warn("WARNING: Could not find a compatible GPU. Using the CPU, but this will be very slow!")
|
||||
|
||||
active_devices = set(active_devices)
|
||||
render_devices = set(render_devices)
|
||||
@ -53,19 +62,21 @@ def get_device_delta(render_devices, active_devices):
|
||||
|
||||
return devices_to_start, devices_to_stop
|
||||
|
||||
|
||||
def auto_pick_devices(currently_active_devices):
|
||||
global mem_free_threshold
|
||||
|
||||
if not torch.cuda.is_available(): return ['cpu']
|
||||
if not torch.cuda.is_available():
|
||||
return ["cpu"]
|
||||
|
||||
device_count = torch.cuda.device_count()
|
||||
if device_count == 1:
|
||||
return ['cuda:0'] if is_device_compatible('cuda:0') else ['cpu']
|
||||
return ["cuda:0"] if is_device_compatible("cuda:0") else ["cpu"]
|
||||
|
||||
log.debug('Autoselecting GPU. Using most free memory.')
|
||||
log.debug("Autoselecting GPU. Using most free memory.")
|
||||
devices = []
|
||||
for device in range(device_count):
|
||||
device = f'cuda:{device}'
|
||||
device = f"cuda:{device}"
|
||||
if not is_device_compatible(device):
|
||||
continue
|
||||
|
||||
@ -73,11 +84,13 @@ def auto_pick_devices(currently_active_devices):
|
||||
mem_free /= float(10**9)
|
||||
mem_total /= float(10**9)
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
log.debug(f'{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb')
|
||||
devices.append({'device': device, 'device_name': device_name, 'mem_free': mem_free})
|
||||
log.debug(
|
||||
f"{device} detected: {device_name} - Memory (free/total): {round(mem_free, 2)}Gb / {round(mem_total, 2)}Gb"
|
||||
)
|
||||
devices.append({"device": device, "device_name": device_name, "mem_free": mem_free})
|
||||
|
||||
devices.sort(key=lambda x:x['mem_free'], reverse=True)
|
||||
max_mem_free = devices[0]['mem_free']
|
||||
devices.sort(key=lambda x: x["mem_free"], reverse=True)
|
||||
max_mem_free = devices[0]["mem_free"]
|
||||
curr_mem_free_threshold = COMPARABLE_GPU_PERCENTILE * max_mem_free
|
||||
mem_free_threshold = max(curr_mem_free_threshold, mem_free_threshold)
|
||||
|
||||
@ -87,23 +100,26 @@ def auto_pick_devices(currently_active_devices):
|
||||
# always be very low (since their VRAM contains the model).
|
||||
# These already-running devices probably aren't terrible, since they were picked in the past.
|
||||
# Worst case, the user can restart the program and that'll get rid of them.
|
||||
devices = list(filter((lambda x: x['mem_free'] > mem_free_threshold or x['device'] in currently_active_devices), devices))
|
||||
devices = list(map(lambda x: x['device'], devices))
|
||||
devices = list(
|
||||
filter((lambda x: x["mem_free"] > mem_free_threshold or x["device"] in currently_active_devices), devices)
|
||||
)
|
||||
devices = list(map(lambda x: x["device"], devices))
|
||||
return devices
|
||||
|
||||
|
||||
def device_init(context, device):
|
||||
'''
|
||||
"""
|
||||
This function assumes the 'device' has already been verified to be compatible.
|
||||
`get_device_delta()` has already filtered out incompatible devices.
|
||||
'''
|
||||
"""
|
||||
|
||||
validate_device_id(device, log_prefix='device_init')
|
||||
validate_device_id(device, log_prefix="device_init")
|
||||
|
||||
if device == 'cpu':
|
||||
context.device = 'cpu'
|
||||
if device == "cpu":
|
||||
context.device = "cpu"
|
||||
context.device_name = get_processor_name()
|
||||
context.half_precision = False
|
||||
log.debug(f'Render device CPU available as {context.device_name}')
|
||||
log.debug(f"Render device CPU available as {context.device_name}")
|
||||
return
|
||||
|
||||
context.device_name = torch.cuda.get_device_name(device)
|
||||
@ -111,7 +127,7 @@ def device_init(context, device):
|
||||
|
||||
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
|
||||
if needs_to_force_full_precision(context):
|
||||
log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
|
||||
log.warn(f"forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}")
|
||||
# Apply force_full_precision now before models are loaded.
|
||||
context.half_precision = False
|
||||
|
||||
@ -120,58 +136,74 @@ def device_init(context, device):
|
||||
|
||||
return
|
||||
|
||||
|
||||
def needs_to_force_full_precision(context):
|
||||
if 'FORCE_FULL_PRECISION' in os.environ:
|
||||
if "FORCE_FULL_PRECISION" in os.environ:
|
||||
return True
|
||||
|
||||
device_name = context.device_name.lower()
|
||||
return (('nvidia' in device_name or 'geforce' in device_name or 'quadro' in device_name) and (' 1660' in device_name or ' 1650' in device_name or ' t400' in device_name or ' t550' in device_name or ' t600' in device_name or ' t1000' in device_name or ' t1200' in device_name or ' t2000' in device_name))
|
||||
return ("nvidia" in device_name or "geforce" in device_name or "quadro" in device_name) and (
|
||||
" 1660" in device_name
|
||||
or " 1650" in device_name
|
||||
or " t400" in device_name
|
||||
or " t550" in device_name
|
||||
or " t600" in device_name
|
||||
or " t1000" in device_name
|
||||
or " t1200" in device_name
|
||||
or " t2000" in device_name
|
||||
)
|
||||
|
||||
|
||||
def get_max_vram_usage_level(device):
|
||||
if device != 'cpu':
|
||||
if device != "cpu":
|
||||
_, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_total /= float(10**9)
|
||||
|
||||
if mem_total < 4.5:
|
||||
return 'low'
|
||||
return "low"
|
||||
elif mem_total < 6.5:
|
||||
return 'balanced'
|
||||
return "balanced"
|
||||
|
||||
return 'high'
|
||||
return "high"
|
||||
|
||||
def validate_device_id(device, log_prefix=''):
|
||||
|
||||
def validate_device_id(device, log_prefix=""):
|
||||
def is_valid():
|
||||
if not isinstance(device, str):
|
||||
return False
|
||||
if device == 'cpu':
|
||||
if device == "cpu":
|
||||
return True
|
||||
if not device.startswith('cuda:') or not device[5:].isnumeric():
|
||||
if not device.startswith("cuda:") or not device[5:].isnumeric():
|
||||
return False
|
||||
return True
|
||||
|
||||
if not is_valid():
|
||||
raise EnvironmentError(f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}")
|
||||
raise EnvironmentError(
|
||||
f"{log_prefix}: device id should be 'cpu', or 'cuda:N' (where N is an integer index for the GPU). Got: {device}"
|
||||
)
|
||||
|
||||
|
||||
def is_device_compatible(device):
|
||||
'''
|
||||
"""
|
||||
Returns True/False, and prints any compatibility errors
|
||||
'''
|
||||
"""
|
||||
# static variable "history".
|
||||
is_device_compatible.history = getattr(is_device_compatible, 'history', {})
|
||||
is_device_compatible.history = getattr(is_device_compatible, "history", {})
|
||||
try:
|
||||
validate_device_id(device, log_prefix='is_device_compatible')
|
||||
validate_device_id(device, log_prefix="is_device_compatible")
|
||||
except:
|
||||
log.error(str(e))
|
||||
return False
|
||||
|
||||
if device == 'cpu': return True
|
||||
if device == "cpu":
|
||||
return True
|
||||
# Memory check
|
||||
try:
|
||||
_, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_total /= float(10**9)
|
||||
if mem_total < 3.0:
|
||||
if is_device_compatible.history.get(device) == None:
|
||||
log.warn(f'GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion')
|
||||
log.warn(f"GPU {device} with less than 3 GB of VRAM is not compatible with Stable Diffusion")
|
||||
is_device_compatible.history[device] = 1
|
||||
return False
|
||||
except RuntimeError as e:
|
||||
@ -179,13 +211,15 @@ def is_device_compatible(device):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_processor_name():
|
||||
try:
|
||||
import platform, subprocess
|
||||
|
||||
if platform.system() == "Windows":
|
||||
return platform.processor()
|
||||
elif platform.system() == "Darwin":
|
||||
os.environ['PATH'] = os.environ['PATH'] + os.pathsep + '/usr/sbin'
|
||||
os.environ["PATH"] = os.environ["PATH"] + os.pathsep + "/usr/sbin"
|
||||
command = "sysctl -n machdep.cpu.brand_string"
|
||||
return subprocess.check_output(command).strip()
|
||||
elif platform.system() == "Linux":
|
||||
|
@ -8,30 +8,32 @@ from sdkit import Context
|
||||
from sdkit.models import load_model, unload_model, get_model_info_from_db, scan_model
|
||||
from sdkit.utils import hash_file_quick
|
||||
|
||||
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
||||
KNOWN_MODEL_TYPES = ["stable-diffusion", "vae", "hypernetwork", "gfpgan", "realesrgan"]
|
||||
MODEL_EXTENSIONS = {
|
||||
'stable-diffusion': ['.ckpt', '.safetensors'],
|
||||
'vae': ['.vae.pt', '.ckpt', '.safetensors'],
|
||||
'hypernetwork': ['.pt', '.safetensors'],
|
||||
'gfpgan': ['.pth'],
|
||||
'realesrgan': ['.pth'],
|
||||
"stable-diffusion": [".ckpt", ".safetensors"],
|
||||
"vae": [".vae.pt", ".ckpt", ".safetensors"],
|
||||
"hypernetwork": [".pt", ".safetensors"],
|
||||
"gfpgan": [".pth"],
|
||||
"realesrgan": [".pth"],
|
||||
}
|
||||
DEFAULT_MODELS = {
|
||||
'stable-diffusion': [ # needed to support the legacy installations
|
||||
'custom-model', # only one custom model file was supported initially, creatively named 'custom-model'
|
||||
'sd-v1-4', # Default fallback.
|
||||
"stable-diffusion": [ # needed to support the legacy installations
|
||||
"custom-model", # only one custom model file was supported initially, creatively named 'custom-model'
|
||||
"sd-v1-4", # Default fallback.
|
||||
],
|
||||
'gfpgan': ['GFPGANv1.3'],
|
||||
'realesrgan': ['RealESRGAN_x4plus'],
|
||||
"gfpgan": ["GFPGANv1.3"],
|
||||
"realesrgan": ["RealESRGAN_x4plus"],
|
||||
}
|
||||
MODELS_TO_LOAD_ON_START = ['stable-diffusion', 'vae', 'hypernetwork']
|
||||
MODELS_TO_LOAD_ON_START = ["stable-diffusion", "vae", "hypernetwork"]
|
||||
|
||||
known_models = {}
|
||||
|
||||
|
||||
def init():
|
||||
make_model_folders()
|
||||
getModels() # run this once, to cache the picklescan results
|
||||
|
||||
|
||||
def load_default_models(context: Context):
|
||||
set_vram_optimizations(context)
|
||||
|
||||
@ -41,16 +43,17 @@ def load_default_models(context: Context):
|
||||
try:
|
||||
load_model(context, model_type)
|
||||
except Exception as e:
|
||||
log.error(f'[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]')
|
||||
log.error(f'[red]Error: {e}[/red]')
|
||||
log.error(f'[red]Consider removing the model from the model folder.[red]')
|
||||
log.error(f"[red]Error while loading {model_type} model: {context.model_paths[model_type]}[/red]")
|
||||
log.error(f"[red]Error: {e}[/red]")
|
||||
log.error(f"[red]Consider removing the model from the model folder.[red]")
|
||||
|
||||
|
||||
def unload_all(context: Context):
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
unload_model(context, model_type)
|
||||
|
||||
def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
||||
|
||||
def resolve_model_to_use(model_name: str = None, model_type: str = None):
|
||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||
default_models = DEFAULT_MODELS.get(model_type, [])
|
||||
config = app.getConfig()
|
||||
@ -58,8 +61,8 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
||||
model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR]
|
||||
if not model_name: # When None try user configured model.
|
||||
# config = getConfig()
|
||||
if 'model' in config and model_type in config['model']:
|
||||
model_name = config['model'][model_type]
|
||||
if "model" in config and model_type in config["model"]:
|
||||
model_name = config["model"][model_type]
|
||||
|
||||
if model_name:
|
||||
# Check models directory
|
||||
@ -84,23 +87,30 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
||||
for model_extension in model_extensions:
|
||||
if os.path.exists(default_model_path + model_extension):
|
||||
if model_name is not None:
|
||||
log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
|
||||
log.warn(
|
||||
f"Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}"
|
||||
)
|
||||
return default_model_path + model_extension
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||
model_paths_in_req = {
|
||||
'stable-diffusion': task_data.use_stable_diffusion_model,
|
||||
'vae': task_data.use_vae_model,
|
||||
'hypernetwork': task_data.use_hypernetwork_model,
|
||||
'gfpgan': task_data.use_face_correction,
|
||||
'realesrgan': task_data.use_upscale,
|
||||
"stable-diffusion": task_data.use_stable_diffusion_model,
|
||||
"vae": task_data.use_vae_model,
|
||||
"hypernetwork": task_data.use_hypernetwork_model,
|
||||
"gfpgan": task_data.use_face_correction,
|
||||
"realesrgan": task_data.use_upscale,
|
||||
}
|
||||
models_to_reload = {
|
||||
model_type: path
|
||||
for model_type, path in model_paths_in_req.items()
|
||||
if context.model_paths.get(model_type) != path
|
||||
}
|
||||
models_to_reload = {model_type: path for model_type, path in model_paths_in_req.items() if context.model_paths.get(model_type) != path}
|
||||
|
||||
if set_vram_optimizations(context): # reload SD
|
||||
models_to_reload['stable-diffusion'] = model_paths_in_req['stable-diffusion']
|
||||
models_to_reload["stable-diffusion"] = model_paths_in_req["stable-diffusion"]
|
||||
|
||||
for model_type, model_path_in_req in models_to_reload.items():
|
||||
context.model_paths[model_type] = model_path_in_req
|
||||
@ -108,17 +118,23 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||
action_fn = unload_model if context.model_paths[model_type] is None else load_model
|
||||
action_fn(context, model_type, scan_model=False) # we've scanned them already
|
||||
|
||||
def resolve_model_paths(task_data: TaskData):
|
||||
task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
|
||||
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae')
|
||||
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
|
||||
|
||||
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
|
||||
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'realesrgan')
|
||||
def resolve_model_paths(task_data: TaskData):
|
||||
task_data.use_stable_diffusion_model = resolve_model_to_use(
|
||||
task_data.use_stable_diffusion_model, model_type="stable-diffusion"
|
||||
)
|
||||
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type="vae")
|
||||
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type="hypernetwork")
|
||||
|
||||
if task_data.use_face_correction:
|
||||
task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, "gfpgan")
|
||||
if task_data.use_upscale:
|
||||
task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, "realesrgan")
|
||||
|
||||
|
||||
def set_vram_optimizations(context: Context):
|
||||
config = app.getConfig()
|
||||
vram_usage_level = config.get('vram_usage_level', 'balanced')
|
||||
vram_usage_level = config.get("vram_usage_level", "balanced")
|
||||
|
||||
if vram_usage_level != context.vram_usage_level:
|
||||
context.vram_usage_level = vram_usage_level
|
||||
@ -126,42 +142,51 @@ def set_vram_optimizations(context: Context):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def make_model_folders():
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
||||
|
||||
os.makedirs(model_dir_path, exist_ok=True)
|
||||
|
||||
help_file_name = f'Place your {model_type} model files here.txt'
|
||||
help_file_name = f"Place your {model_type} model files here.txt"
|
||||
help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}'
|
||||
|
||||
with open(os.path.join(model_dir_path, help_file_name), 'w', encoding='utf-8') as f:
|
||||
with open(os.path.join(model_dir_path, help_file_name), "w", encoding="utf-8") as f:
|
||||
f.write(help_file_contents)
|
||||
|
||||
|
||||
def is_malicious_model(file_path):
|
||||
try:
|
||||
scan_result = scan_model(file_path)
|
||||
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
|
||||
log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
||||
log.warn(
|
||||
":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]"
|
||||
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
|
||||
)
|
||||
return True
|
||||
else:
|
||||
log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
||||
log.debug(
|
||||
"Scan %s: [green]%d scanned, %d issue, %d infected.[/green]"
|
||||
% (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
log.error(f'error while scanning: {file_path}, error: {e}')
|
||||
log.error(f"error while scanning: {file_path}, error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def getModels():
|
||||
models = {
|
||||
'active': {
|
||||
'stable-diffusion': 'sd-v1-4',
|
||||
'vae': '',
|
||||
'hypernetwork': '',
|
||||
"active": {
|
||||
"stable-diffusion": "sd-v1-4",
|
||||
"vae": "",
|
||||
"hypernetwork": "",
|
||||
},
|
||||
'options': {
|
||||
'stable-diffusion': ['sd-v1-4'],
|
||||
'vae': [],
|
||||
'hypernetwork': [],
|
||||
"options": {
|
||||
"stable-diffusion": ["sd-v1-4"],
|
||||
"vae": [],
|
||||
"hypernetwork": [],
|
||||
},
|
||||
}
|
||||
|
||||
@ -171,13 +196,16 @@ def getModels():
|
||||
"Raised when picklescan reports a problem with a model"
|
||||
pass
|
||||
|
||||
def scan_directory(directory, suffixes, directoriesFirst:bool=True):
|
||||
def scan_directory(directory, suffixes, directoriesFirst: bool = True):
|
||||
nonlocal models_scanned
|
||||
tree = []
|
||||
for entry in sorted(os.scandir(directory), key = lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower())):
|
||||
for entry in sorted(
|
||||
os.scandir(directory), key=lambda entry: (entry.is_file() == directoriesFirst, entry.name.lower())
|
||||
):
|
||||
if entry.is_file():
|
||||
matching_suffix = list(filter(lambda s: entry.name.endswith(s), suffixes))
|
||||
if len(matching_suffix) == 0: continue
|
||||
if len(matching_suffix) == 0:
|
||||
continue
|
||||
matching_suffix = matching_suffix[0]
|
||||
|
||||
mtime = entry.stat().st_mtime
|
||||
@ -187,12 +215,12 @@ def getModels():
|
||||
if is_malicious_model(entry.path):
|
||||
raise MaliciousModelException(entry.path)
|
||||
known_models[entry.path] = mtime
|
||||
tree.append(entry.name[:-len(matching_suffix)])
|
||||
tree.append(entry.name[: -len(matching_suffix)])
|
||||
elif entry.is_dir():
|
||||
scan=scan_directory(entry.path, suffixes, directoriesFirst=False)
|
||||
scan = scan_directory(entry.path, suffixes, directoriesFirst=False)
|
||||
|
||||
if len(scan) != 0:
|
||||
tree.append( (entry.name, scan ) )
|
||||
tree.append((entry.name, scan))
|
||||
return tree
|
||||
|
||||
def listModels(model_type):
|
||||
@ -204,21 +232,22 @@ def getModels():
|
||||
os.makedirs(models_dir)
|
||||
|
||||
try:
|
||||
models['options'][model_type] = scan_directory(models_dir, model_extensions)
|
||||
models["options"][model_type] = scan_directory(models_dir, model_extensions)
|
||||
except MaliciousModelException as e:
|
||||
models['scan-error'] = e
|
||||
models["scan-error"] = e
|
||||
|
||||
# custom models
|
||||
listModels(model_type='stable-diffusion')
|
||||
listModels(model_type='vae')
|
||||
listModels(model_type='hypernetwork')
|
||||
listModels(model_type='gfpgan')
|
||||
listModels(model_type="stable-diffusion")
|
||||
listModels(model_type="vae")
|
||||
listModels(model_type="hypernetwork")
|
||||
listModels(model_type="gfpgan")
|
||||
|
||||
if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. Nothing infected[/]')
|
||||
if models_scanned > 0:
|
||||
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
|
||||
|
||||
# legacy
|
||||
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
|
||||
custom_weight_path = os.path.join(app.SD_DIR, "custom-model.ckpt")
|
||||
if os.path.exists(custom_weight_path):
|
||||
models['options']['stable-diffusion'].append('custom-model')
|
||||
models["options"]["stable-diffusion"].append("custom-model")
|
||||
|
||||
return models
|
||||
|
@ -13,21 +13,25 @@ from sdkit.filter import apply_filters
|
||||
from sdkit.utils import img_to_buffer, img_to_base64_str, latent_samples_to_images, gc
|
||||
|
||||
context = Context() # thread-local
|
||||
'''
|
||||
"""
|
||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def init(device):
|
||||
'''
|
||||
"""
|
||||
Initializes the fields that will be bound to this runtime's context, and sets the current torch device
|
||||
'''
|
||||
"""
|
||||
context.stop_processing = False
|
||||
context.temp_images = {}
|
||||
context.partial_x_samples = None
|
||||
|
||||
device_manager.device_init(context, device)
|
||||
|
||||
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
|
||||
def make_images(
|
||||
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
|
||||
):
|
||||
context.stop_processing = False
|
||||
print_task_info(req, task_data)
|
||||
|
||||
@ -36,18 +40,24 @@ def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queu
|
||||
res = Response(req, task_data, images=construct_response(images, seeds, task_data, base_seed=req.seed))
|
||||
res = res.json()
|
||||
data_queue.put(json.dumps(res))
|
||||
log.info('Task completed')
|
||||
log.info("Task completed")
|
||||
|
||||
return res
|
||||
|
||||
def print_task_info(req: GenerateImageRequest, task_data: TaskData):
|
||||
req_str = pprint.pformat(get_printable_request(req)).replace("[","\[")
|
||||
task_str = pprint.pformat(task_data.dict()).replace("[","\[")
|
||||
log.info(f'request: {req_str}')
|
||||
log.info(f'task data: {task_str}')
|
||||
|
||||
def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
images, user_stopped = generate_images_internal(req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
||||
def print_task_info(req: GenerateImageRequest, task_data: TaskData):
|
||||
req_str = pprint.pformat(get_printable_request(req)).replace("[", "\[")
|
||||
task_str = pprint.pformat(task_data.dict()).replace("[", "\[")
|
||||
log.info(f"request: {req_str}")
|
||||
log.info(f"task data: {task_str}")
|
||||
|
||||
|
||||
def make_images_internal(
|
||||
req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback
|
||||
):
|
||||
images, user_stopped = generate_images_internal(
|
||||
req, task_data, data_queue, task_temp_images, step_callback, task_data.stream_image_progress
|
||||
)
|
||||
filtered_images = filter_images(task_data, images, user_stopped)
|
||||
|
||||
if task_data.save_to_disk_path is not None:
|
||||
@ -59,13 +69,22 @@ def make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_qu
|
||||
else:
|
||||
return images + filtered_images, seeds + seeds
|
||||
|
||||
def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
|
||||
def generate_images_internal(
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
stream_image_progress: bool,
|
||||
):
|
||||
context.temp_images.clear()
|
||||
|
||||
callback = make_step_callback(req, task_data, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||
|
||||
try:
|
||||
if req.init_image is not None: req.sampler_name = 'ddim'
|
||||
if req.init_image is not None:
|
||||
req.sampler_name = "ddim"
|
||||
|
||||
images = generate_images(context, callback=callback, **req.dict())
|
||||
user_stopped = False
|
||||
@ -75,31 +94,44 @@ def generate_images_internal(req: GenerateImageRequest, task_data: TaskData, dat
|
||||
if context.partial_x_samples is not None:
|
||||
images = latent_samples_to_images(context, context.partial_x_samples)
|
||||
finally:
|
||||
if hasattr(context, 'partial_x_samples') and context.partial_x_samples is not None:
|
||||
if hasattr(context, "partial_x_samples") and context.partial_x_samples is not None:
|
||||
del context.partial_x_samples
|
||||
context.partial_x_samples = None
|
||||
|
||||
return images, user_stopped
|
||||
|
||||
|
||||
def filter_images(task_data: TaskData, images: list, user_stopped):
|
||||
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
|
||||
return images
|
||||
|
||||
filters_to_apply = []
|
||||
if task_data.use_face_correction and 'gfpgan' in task_data.use_face_correction.lower(): filters_to_apply.append('gfpgan')
|
||||
if task_data.use_upscale and 'realesrgan' in task_data.use_upscale.lower(): filters_to_apply.append('realesrgan')
|
||||
if task_data.use_face_correction and "gfpgan" in task_data.use_face_correction.lower():
|
||||
filters_to_apply.append("gfpgan")
|
||||
if task_data.use_upscale and "realesrgan" in task_data.use_upscale.lower():
|
||||
filters_to_apply.append("realesrgan")
|
||||
|
||||
return apply_filters(context, filters_to_apply, images, scale=task_data.upscale_amount)
|
||||
|
||||
|
||||
def construct_response(images: list, seeds: list, task_data: TaskData, base_seed: int):
|
||||
return [
|
||||
ResponseImage(
|
||||
data=img_to_base64_str(img, task_data.output_format, task_data.output_quality),
|
||||
seed=seed,
|
||||
) for img, seed in zip(images, seeds)
|
||||
)
|
||||
for img, seed in zip(images, seeds)
|
||||
]
|
||||
|
||||
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
|
||||
def make_step_callback(
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
stream_image_progress: bool,
|
||||
):
|
||||
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
||||
last_callback_time = -1
|
||||
|
||||
@ -107,11 +139,11 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
|
||||
partial_images = []
|
||||
images = latent_samples_to_images(context, x_samples)
|
||||
for i, img in enumerate(images):
|
||||
buf = img_to_buffer(img, output_format='JPEG')
|
||||
buf = img_to_buffer(img, output_format="JPEG")
|
||||
|
||||
context.temp_images[f"{task_data.request_id}/{i}"] = buf
|
||||
task_temp_images[i] = buf
|
||||
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
|
||||
partial_images.append({"path": f"/image/tmp/{task_data.request_id}/{i}"})
|
||||
del images
|
||||
return partial_images
|
||||
|
||||
@ -125,7 +157,7 @@ def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queu
|
||||
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
|
||||
|
||||
if stream_image_progress and i % 5 == 0:
|
||||
progress['output'] = update_temp_img(x_samples, task_temp_images)
|
||||
progress["output"] = update_temp_img(x_samples, task_temp_images)
|
||||
|
||||
data_queue.put(json.dumps(progress))
|
||||
|
||||
|
@ -16,21 +16,25 @@ from easydiffusion import app, model_manager, task_manager
|
||||
from easydiffusion.types import TaskData, GenerateImageRequest, MergeRequest
|
||||
from easydiffusion.utils import log
|
||||
|
||||
log.info(f'started in {app.SD_DIR}')
|
||||
log.info(f'started at {datetime.datetime.now():%x %X}')
|
||||
log.info(f"started in {app.SD_DIR}")
|
||||
log.info(f"started at {datetime.datetime.now():%x %X}")
|
||||
|
||||
server_api = FastAPI()
|
||||
|
||||
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||
NOCACHE_HEADERS = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||
|
||||
|
||||
class NoCacheStaticFiles(StaticFiles):
|
||||
def is_not_modified(self, response_headers, request_headers) -> bool:
|
||||
if 'content-type' in response_headers and ('javascript' in response_headers['content-type'] or 'css' in response_headers['content-type']):
|
||||
if "content-type" in response_headers and (
|
||||
"javascript" in response_headers["content-type"] or "css" in response_headers["content-type"]
|
||||
):
|
||||
response_headers.update(NOCACHE_HEADERS)
|
||||
return False
|
||||
|
||||
return super().is_not_modified(response_headers, request_headers)
|
||||
|
||||
|
||||
class SetAppConfigRequest(BaseModel):
|
||||
update_branch: str = None
|
||||
render_devices: Union[List[str], List[int], str, int] = None
|
||||
@ -39,130 +43,142 @@ class SetAppConfigRequest(BaseModel):
|
||||
listen_to_network: bool = None
|
||||
listen_port: int = None
|
||||
|
||||
|
||||
def init():
|
||||
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media")
|
||||
server_api.mount("/media", NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, "media")), name="media")
|
||||
|
||||
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
|
||||
server_api.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}")
|
||||
server_api.mount(
|
||||
f"/plugins/{dir_prefix}", NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}"
|
||||
)
|
||||
|
||||
@server_api.post('/app_config')
|
||||
async def set_app_config(req : SetAppConfigRequest):
|
||||
@server_api.post("/app_config")
|
||||
async def set_app_config(req: SetAppConfigRequest):
|
||||
return set_app_config_internal(req)
|
||||
|
||||
@server_api.get('/get/{key:path}')
|
||||
def read_web_data(key:str=None):
|
||||
@server_api.get("/get/{key:path}")
|
||||
def read_web_data(key: str = None):
|
||||
return read_web_data_internal(key)
|
||||
|
||||
@server_api.get('/ping') # Get server and optionally session status.
|
||||
def ping(session_id:str=None):
|
||||
@server_api.get("/ping") # Get server and optionally session status.
|
||||
def ping(session_id: str = None):
|
||||
return ping_internal(session_id)
|
||||
|
||||
@server_api.post('/render')
|
||||
@server_api.post("/render")
|
||||
def render(req: dict):
|
||||
return render_internal(req)
|
||||
|
||||
@server_api.post('/model/merge')
|
||||
@server_api.post("/model/merge")
|
||||
def model_merge(req: dict):
|
||||
print(req)
|
||||
return model_merge_internal(req)
|
||||
|
||||
@server_api.get('/image/stream/{task_id:int}')
|
||||
def stream(task_id:int):
|
||||
@server_api.get("/image/stream/{task_id:int}")
|
||||
def stream(task_id: int):
|
||||
return stream_internal(task_id)
|
||||
|
||||
@server_api.get('/image/stop')
|
||||
@server_api.get("/image/stop")
|
||||
def stop(task: int):
|
||||
return stop_internal(task)
|
||||
|
||||
@server_api.get('/image/tmp/{task_id:int}/{img_id:int}')
|
||||
@server_api.get("/image/tmp/{task_id:int}/{img_id:int}")
|
||||
def get_image(task_id: int, img_id: int):
|
||||
return get_image_internal(task_id, img_id)
|
||||
|
||||
@server_api.get('/')
|
||||
@server_api.get("/")
|
||||
def read_root():
|
||||
return FileResponse(os.path.join(app.SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
|
||||
return FileResponse(os.path.join(app.SD_UI_DIR, "index.html"), headers=NOCACHE_HEADERS)
|
||||
|
||||
@server_api.on_event("shutdown")
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
task_manager.current_state_error = SystemExit('Application shutting down.')
|
||||
task_manager.current_state_error = SystemExit("Application shutting down.")
|
||||
|
||||
|
||||
# API implementations
|
||||
def set_app_config_internal(req : SetAppConfigRequest):
|
||||
def set_app_config_internal(req: SetAppConfigRequest):
|
||||
config = app.getConfig()
|
||||
if req.update_branch is not None:
|
||||
config['update_branch'] = req.update_branch
|
||||
config["update_branch"] = req.update_branch
|
||||
if req.render_devices is not None:
|
||||
update_render_devices_in_config(config, req.render_devices)
|
||||
if req.ui_open_browser_on_start is not None:
|
||||
if 'ui' not in config:
|
||||
config['ui'] = {}
|
||||
config['ui']['open_browser_on_start'] = req.ui_open_browser_on_start
|
||||
if "ui" not in config:
|
||||
config["ui"] = {}
|
||||
config["ui"]["open_browser_on_start"] = req.ui_open_browser_on_start
|
||||
if req.listen_to_network is not None:
|
||||
if 'net' not in config:
|
||||
config['net'] = {}
|
||||
config['net']['listen_to_network'] = bool(req.listen_to_network)
|
||||
if "net" not in config:
|
||||
config["net"] = {}
|
||||
config["net"]["listen_to_network"] = bool(req.listen_to_network)
|
||||
if req.listen_port is not None:
|
||||
if 'net' not in config:
|
||||
config['net'] = {}
|
||||
config['net']['listen_port'] = int(req.listen_port)
|
||||
if "net" not in config:
|
||||
config["net"] = {}
|
||||
config["net"]["listen_port"] = int(req.listen_port)
|
||||
try:
|
||||
app.setConfig(config)
|
||||
|
||||
if req.render_devices:
|
||||
app.update_render_threads()
|
||||
|
||||
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
|
||||
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def update_render_devices_in_config(config, render_devices):
|
||||
if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'):
|
||||
raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}')
|
||||
if render_devices not in ("cpu", "auto") and not render_devices.startswith("cuda:"):
|
||||
raise HTTPException(status_code=400, detail=f"Invalid render device requested: {render_devices}")
|
||||
|
||||
if render_devices.startswith('cuda:'):
|
||||
render_devices = render_devices.split(',')
|
||||
if render_devices.startswith("cuda:"):
|
||||
render_devices = render_devices.split(",")
|
||||
|
||||
config['render_devices'] = render_devices
|
||||
config["render_devices"] = render_devices
|
||||
|
||||
def read_web_data_internal(key:str=None):
|
||||
|
||||
def read_web_data_internal(key: str = None):
|
||||
if not key: # /get without parameters, stable-diffusion easter egg.
|
||||
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
||||
elif key == 'app_config':
|
||||
elif key == "app_config":
|
||||
return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS)
|
||||
elif key == 'system_info':
|
||||
elif key == "system_info":
|
||||
config = app.getConfig()
|
||||
|
||||
output_dir = config.get('force_save_path', os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME))
|
||||
output_dir = config.get("force_save_path", os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME))
|
||||
|
||||
system_info = {
|
||||
'devices': task_manager.get_devices(),
|
||||
'hosts': app.getIPConfig(),
|
||||
'default_output_dir': output_dir,
|
||||
'enforce_output_dir': ('force_save_path' in config),
|
||||
"devices": task_manager.get_devices(),
|
||||
"hosts": app.getIPConfig(),
|
||||
"default_output_dir": output_dir,
|
||||
"enforce_output_dir": ("force_save_path" in config),
|
||||
}
|
||||
system_info['devices']['config'] = config.get('render_devices', "auto")
|
||||
system_info["devices"]["config"] = config.get("render_devices", "auto")
|
||||
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
|
||||
elif key == 'models':
|
||||
elif key == "models":
|
||||
return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS)
|
||||
elif key == 'modifiers': return FileResponse(os.path.join(app.SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
|
||||
elif key == 'ui_plugins': return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS)
|
||||
elif key == "modifiers":
|
||||
return FileResponse(os.path.join(app.SD_UI_DIR, "modifiers.json"), headers=NOCACHE_HEADERS)
|
||||
elif key == "ui_plugins":
|
||||
return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS)
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
|
||||
raise HTTPException(status_code=404, detail=f"Request for unknown {key}") # HTTP404 Not Found
|
||||
|
||||
def ping_internal(session_id:str=None):
|
||||
|
||||
def ping_internal(session_id: str = None):
|
||||
if task_manager.is_alive() <= 0: # Check that render threads are alive.
|
||||
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
raise HTTPException(status_code=500, detail='Render thread is dead.')
|
||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
if task_manager.current_state_error:
|
||||
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
raise HTTPException(status_code=500, detail="Render thread is dead.")
|
||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration):
|
||||
raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
# Alive
|
||||
response = {'status': str(task_manager.current_state)}
|
||||
response = {"status": str(task_manager.current_state)}
|
||||
if session_id:
|
||||
session = task_manager.get_cached_session(session_id, update_ttl=True)
|
||||
response['tasks'] = {id(t): t.status for t in session.tasks}
|
||||
response['devices'] = task_manager.get_devices()
|
||||
response["tasks"] = {id(t): t.status for t in session.tasks}
|
||||
response["devices"] = task_manager.get_devices()
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
|
||||
|
||||
def render_internal(req: dict):
|
||||
try:
|
||||
# separate out the request data into rendering and task-specific data
|
||||
@ -171,80 +187,99 @@ def render_internal(req: dict):
|
||||
|
||||
# Overwrite user specified save path
|
||||
config = app.getConfig()
|
||||
if 'force_save_path' in config:
|
||||
task_data.save_to_disk_path = config['force_save_path']
|
||||
if "force_save_path" in config:
|
||||
task_data.save_to_disk_path = config["force_save_path"]
|
||||
|
||||
render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision
|
||||
render_req.init_image_mask = req.get("mask") # hack: will rename this in the HTTP API in a future revision
|
||||
|
||||
app.save_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model, task_data.vram_usage_level)
|
||||
app.save_to_config(
|
||||
task_data.use_stable_diffusion_model,
|
||||
task_data.use_vae_model,
|
||||
task_data.use_hypernetwork_model,
|
||||
task_data.vram_usage_level,
|
||||
)
|
||||
|
||||
# enqueue the task
|
||||
new_task = task_manager.render(render_req, task_data)
|
||||
response = {
|
||||
'status': str(task_manager.current_state),
|
||||
'queue': len(task_manager.tasks_queue),
|
||||
'stream': f'/image/stream/{id(new_task)}',
|
||||
'task': id(new_task)
|
||||
"status": str(task_manager.current_state),
|
||||
"queue": len(task_manager.tasks_queue),
|
||||
"stream": f"/image/stream/{id(new_task)}",
|
||||
"task": id(new_task),
|
||||
}
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
except ChildProcessError as e: # Render thread is dead
|
||||
raise HTTPException(status_code=500, detail=f'Rendering thread has died.') # HTTP500 Internal Server Error
|
||||
raise HTTPException(status_code=500, detail=f"Rendering thread has died.") # HTTP500 Internal Server Error
|
||||
except ConnectionRefusedError as e: # Unstarted task pending limit reached, deny queueing too many.
|
||||
raise HTTPException(status_code=503, detail=str(e)) # HTTP503 Service Unavailable
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def model_merge_internal(req: dict):
|
||||
try:
|
||||
from sdkit.train import merge_models
|
||||
from easydiffusion.utils.save_utils import filename_regex
|
||||
|
||||
mergeReq: MergeRequest = MergeRequest.parse_obj(req)
|
||||
|
||||
merge_models(model_manager.resolve_model_to_use(mergeReq.model0,'stable-diffusion'),
|
||||
model_manager.resolve_model_to_use(mergeReq.model1,'stable-diffusion'),
|
||||
merge_models(
|
||||
model_manager.resolve_model_to_use(mergeReq.model0, "stable-diffusion"),
|
||||
model_manager.resolve_model_to_use(mergeReq.model1, "stable-diffusion"),
|
||||
mergeReq.ratio,
|
||||
os.path.join(app.MODELS_DIR, 'stable-diffusion', filename_regex.sub('_', mergeReq.out_path)),
|
||||
mergeReq.use_fp16
|
||||
os.path.join(app.MODELS_DIR, "stable-diffusion", filename_regex.sub("_", mergeReq.out_path)),
|
||||
mergeReq.use_fp16,
|
||||
)
|
||||
return JSONResponse({'status':'OK'}, headers=NOCACHE_HEADERS)
|
||||
return JSONResponse({"status": "OK"}, headers=NOCACHE_HEADERS)
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def stream_internal(task_id:int):
|
||||
#TODO Move to WebSockets ??
|
||||
|
||||
def stream_internal(task_id: int):
|
||||
# TODO Move to WebSockets ??
|
||||
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
||||
if not task: raise HTTPException(status_code=404, detail=f'Request {task_id} not found.') # HTTP404 NotFound
|
||||
#if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail=f"Request {task_id} not found.") # HTTP404 NotFound
|
||||
# if (id(task) != task_id): raise HTTPException(status_code=409, detail=f'Wrong task id received. Expected:{id(task)}, Received:{task_id}') # HTTP409 Conflict
|
||||
if task.buffer_queue.empty() and not task.lock.locked():
|
||||
if task.response:
|
||||
#log.info(f'Session {session_id} sending cached response')
|
||||
# log.info(f'Session {session_id} sending cached response')
|
||||
return JSONResponse(task.response, headers=NOCACHE_HEADERS)
|
||||
raise HTTPException(status_code=425, detail='Too Early, task not started yet.') # HTTP425 Too Early
|
||||
#log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
|
||||
return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
|
||||
raise HTTPException(status_code=425, detail="Too Early, task not started yet.") # HTTP425 Too Early
|
||||
# log.info(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
|
||||
return StreamingResponse(task.read_buffer_generator(), media_type="application/json")
|
||||
|
||||
|
||||
def stop_internal(task: int):
|
||||
if not task:
|
||||
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
|
||||
raise HTTPException(status_code=409, detail='Not currently running any tasks.') # HTTP409 Conflict
|
||||
task_manager.current_state_error = StopAsyncIteration('')
|
||||
return {'OK'}
|
||||
if (
|
||||
task_manager.current_state == task_manager.ServerStates.Online
|
||||
or task_manager.current_state == task_manager.ServerStates.Unavailable
|
||||
):
|
||||
raise HTTPException(status_code=409, detail="Not currently running any tasks.") # HTTP409 Conflict
|
||||
task_manager.current_state_error = StopAsyncIteration("")
|
||||
return {"OK"}
|
||||
task_id = task
|
||||
task = task_manager.get_cached_task(task_id, update_ttl=False)
|
||||
if not task: raise HTTPException(status_code=404, detail=f'Task {task_id} was not found.') # HTTP404 Not Found
|
||||
if isinstance(task.error, StopAsyncIteration): raise HTTPException(status_code=409, detail=f'Task {task_id} is already stopped.') # HTTP409 Conflict
|
||||
task.error = StopAsyncIteration(f'Task {task_id} stop requested.')
|
||||
return {'OK'}
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} was not found.") # HTTP404 Not Found
|
||||
if isinstance(task.error, StopAsyncIteration):
|
||||
raise HTTPException(status_code=409, detail=f"Task {task_id} is already stopped.") # HTTP409 Conflict
|
||||
task.error = StopAsyncIteration(f"Task {task_id} stop requested.")
|
||||
return {"OK"}
|
||||
|
||||
|
||||
def get_image_internal(task_id: int, img_id: int):
|
||||
task = task_manager.get_cached_task(task_id, update_ttl=True)
|
||||
if not task: raise HTTPException(status_code=410, detail=f'Task {task_id} could not be found.') # HTTP404 NotFound
|
||||
if not task.temp_images[img_id]: raise HTTPException(status_code=425, detail='Too Early, task data is not available yet.') # HTTP425 Too Early
|
||||
if not task:
|
||||
raise HTTPException(status_code=410, detail=f"Task {task_id} could not be found.") # HTTP404 NotFound
|
||||
if not task.temp_images[img_id]:
|
||||
raise HTTPException(status_code=425, detail="Too Early, task data is not available yet.") # HTTP425 Too Early
|
||||
try:
|
||||
img_data = task.temp_images[img_id]
|
||||
img_data.seek(0)
|
||||
return StreamingResponse(img_data, media_type='image/jpeg')
|
||||
return StreamingResponse(img_data, media_type="image/jpeg")
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
@ -19,71 +19,98 @@ from easydiffusion.utils import log
|
||||
|
||||
from sdkit.utils import gc
|
||||
|
||||
THREAD_NAME_PREFIX = ''
|
||||
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
||||
THREAD_NAME_PREFIX = ""
|
||||
ERR_LOCK_FAILED = " failed to acquire lock within timeout."
|
||||
LOCK_TIMEOUT = 15 # Maximum locking time in seconds before failing a task.
|
||||
# It's better to get an exception than a deadlock... ALWAYS use timeout in critical paths.
|
||||
|
||||
DEVICE_START_TIMEOUT = 60 # seconds - Maximum time to wait for a render device to init.
|
||||
|
||||
|
||||
class SymbolClass(type): # Print nicely formatted Symbol names.
|
||||
def __repr__(self): return self.__qualname__
|
||||
def __str__(self): return self.__name__
|
||||
class Symbol(metaclass=SymbolClass): pass
|
||||
def __repr__(self):
|
||||
return self.__qualname__
|
||||
|
||||
def __str__(self):
|
||||
return self.__name__
|
||||
|
||||
|
||||
class Symbol(metaclass=SymbolClass):
|
||||
pass
|
||||
|
||||
|
||||
class ServerStates:
|
||||
class Init(Symbol): pass
|
||||
class LoadingModel(Symbol): pass
|
||||
class Online(Symbol): pass
|
||||
class Rendering(Symbol): pass
|
||||
class Unavailable(Symbol): pass
|
||||
class Init(Symbol):
|
||||
pass
|
||||
|
||||
class RenderTask(): # Task with output queue and completion lock.
|
||||
class LoadingModel(Symbol):
|
||||
pass
|
||||
|
||||
class Online(Symbol):
|
||||
pass
|
||||
|
||||
class Rendering(Symbol):
|
||||
pass
|
||||
|
||||
class Unavailable(Symbol):
|
||||
pass
|
||||
|
||||
|
||||
class RenderTask: # Task with output queue and completion lock.
|
||||
def __init__(self, req: GenerateImageRequest, task_data: TaskData):
|
||||
task_data.request_id = id(self)
|
||||
self.render_request: GenerateImageRequest = req # Initial Request
|
||||
self.task_data: TaskData = task_data
|
||||
self.response: Any = None # Copy of the last reponse
|
||||
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
||||
self.temp_images:list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
|
||||
self.temp_images: list = [None] * req.num_outputs * (1 if task_data.show_only_filtered_image else 2)
|
||||
self.error: Exception = None
|
||||
self.lock: threading.Lock = threading.Lock() # Locks at task start and unlocks when task is completed
|
||||
self.buffer_queue: queue.Queue = queue.Queue() # Queue of JSON string segments
|
||||
|
||||
async def read_buffer_generator(self):
|
||||
try:
|
||||
while not self.buffer_queue.empty():
|
||||
res = self.buffer_queue.get(block=False)
|
||||
self.buffer_queue.task_done()
|
||||
yield res
|
||||
except queue.Empty as e: yield
|
||||
except queue.Empty as e:
|
||||
yield
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
if self.lock.locked():
|
||||
return 'running'
|
||||
return "running"
|
||||
if isinstance(self.error, StopAsyncIteration):
|
||||
return 'stopped'
|
||||
return "stopped"
|
||||
if self.error:
|
||||
return 'error'
|
||||
return "error"
|
||||
if not self.buffer_queue.empty():
|
||||
return 'buffer'
|
||||
return "buffer"
|
||||
if self.response:
|
||||
return 'completed'
|
||||
return 'pending'
|
||||
return "completed"
|
||||
return "pending"
|
||||
|
||||
@property
|
||||
def is_pending(self):
|
||||
return bool(not self.response and not self.error)
|
||||
|
||||
|
||||
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||
class DataCache():
|
||||
class DataCache:
|
||||
def __init__(self):
|
||||
self._base = dict()
|
||||
self._lock: threading.Lock = threading.Lock()
|
||||
|
||||
def _get_ttl_time(self, ttl: int) -> int:
|
||||
return int(time.time()) + ttl
|
||||
|
||||
def _is_expired(self, timestamp: int) -> bool:
|
||||
return int(time.time()) >= timestamp
|
||||
|
||||
def clean(self) -> None:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clean' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("DataCache.clean" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
# Create a list of expired keys to delete
|
||||
to_delete = []
|
||||
@ -95,20 +122,26 @@ class DataCache():
|
||||
for key in to_delete:
|
||||
(_, val) = self._base[key]
|
||||
if isinstance(val, RenderTask):
|
||||
log.debug(f'RenderTask {key} expired. Data removed.')
|
||||
log.debug(f"RenderTask {key} expired. Data removed.")
|
||||
elif isinstance(val, SessionState):
|
||||
log.debug(f'Session {key} expired. Data removed.')
|
||||
log.debug(f"Session {key} expired. Data removed.")
|
||||
else:
|
||||
log.debug(f'Key {key} expired. Data removed.')
|
||||
log.debug(f"Key {key} expired. Data removed.")
|
||||
del self._base[key]
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def clear(self) -> None:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.clear' + ERR_LOCK_FAILED)
|
||||
try: self._base.clear()
|
||||
finally: self._lock.release()
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("DataCache.clear" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
self._base.clear()
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def delete(self, key: Hashable) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.delete' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("DataCache.delete" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
if key not in self._base:
|
||||
return False
|
||||
@ -116,8 +149,10 @@ class DataCache():
|
||||
return True
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def keep(self, key: Hashable, ttl: int) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.keep' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("DataCache.keep" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
if key in self._base:
|
||||
_, value = self._base.get(key)
|
||||
@ -126,12 +161,12 @@ class DataCache():
|
||||
return False
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def put(self, key: Hashable, value: Any, ttl: int) -> bool:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.put' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("DataCache.put" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
self._base[key] = (
|
||||
self._get_ttl_time(ttl), value
|
||||
)
|
||||
self._base[key] = (self._get_ttl_time(ttl), value)
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
return False
|
||||
@ -139,35 +174,41 @@ class DataCache():
|
||||
return True
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def tryGet(self, key: Hashable) -> Any:
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('DataCache.tryGet' + ERR_LOCK_FAILED)
|
||||
if not self._lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("DataCache.tryGet" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
ttl, value = self._base.get(key, (None, None))
|
||||
if ttl is not None and self._is_expired(ttl):
|
||||
log.debug(f'Session {key} expired. Discarding data.')
|
||||
log.debug(f"Session {key} expired. Discarding data.")
|
||||
del self._base[key]
|
||||
return None
|
||||
return value
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
|
||||
manager_lock = threading.RLock()
|
||||
render_threads = []
|
||||
current_state = ServerStates.Init
|
||||
current_state_error:Exception = None
|
||||
current_state_error: Exception = None
|
||||
tasks_queue = []
|
||||
session_cache = DataCache()
|
||||
task_cache = DataCache()
|
||||
weak_thread_data = weakref.WeakKeyDictionary()
|
||||
idle_event: threading.Event = threading.Event()
|
||||
|
||||
class SessionState():
|
||||
|
||||
class SessionState:
|
||||
def __init__(self, id: str):
|
||||
self._id = id
|
||||
self._tasks_ids = []
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def tasks(self):
|
||||
tasks = []
|
||||
@ -176,6 +217,7 @@ class SessionState():
|
||||
if task:
|
||||
tasks.append(task)
|
||||
return tasks
|
||||
|
||||
def put(self, task, ttl=TASK_TTL):
|
||||
task_id = id(task)
|
||||
self._tasks_ids.append(task_id)
|
||||
@ -185,10 +227,12 @@ class SessionState():
|
||||
self._tasks_ids.pop(0)
|
||||
return True
|
||||
|
||||
|
||||
def thread_get_next_task():
|
||||
from easydiffusion import renderer
|
||||
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
log.warn(f'Render thread on device: {renderer.context.device} failed to acquire manager lock.')
|
||||
log.warn(f"Render thread on device: {renderer.context.device} failed to acquire manager lock.")
|
||||
return None
|
||||
if len(tasks_queue) <= 0:
|
||||
manager_lock.release()
|
||||
@ -202,10 +246,10 @@ def thread_get_next_task():
|
||||
continue # requested device alive, skip current one.
|
||||
else:
|
||||
# Requested device is not active, return error to UI.
|
||||
queued_task.error = Exception(queued_task.render_device + ' is not currently active.')
|
||||
queued_task.error = Exception(queued_task.render_device + " is not currently active.")
|
||||
task = queued_task
|
||||
break
|
||||
if not queued_task.render_device and renderer.context.device == 'cpu' and is_alive() > 1:
|
||||
if not queued_task.render_device and renderer.context.device == "cpu" and is_alive() > 1:
|
||||
# not asking for any specific devices, cpu want to grab task but other render devices are alive.
|
||||
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
|
||||
task = queued_task
|
||||
@ -216,17 +260,19 @@ def thread_get_next_task():
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
|
||||
def thread_render(device):
|
||||
global current_state, current_state_error
|
||||
|
||||
from easydiffusion import renderer, model_manager
|
||||
|
||||
try:
|
||||
renderer.init(device)
|
||||
|
||||
weak_thread_data[threading.current_thread()] = {
|
||||
'device': renderer.context.device,
|
||||
'device_name': renderer.context.device_name,
|
||||
'alive': True
|
||||
"device": renderer.context.device,
|
||||
"device_name": renderer.context.device_name,
|
||||
"alive": True,
|
||||
}
|
||||
|
||||
current_state = ServerStates.LoadingModel
|
||||
@ -235,17 +281,14 @@ def thread_render(device):
|
||||
current_state = ServerStates.Online
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
weak_thread_data[threading.current_thread()] = {
|
||||
'error': e,
|
||||
'alive': False
|
||||
}
|
||||
weak_thread_data[threading.current_thread()] = {"error": e, "alive": False}
|
||||
return
|
||||
|
||||
while True:
|
||||
session_cache.clean()
|
||||
task_cache.clean()
|
||||
if not weak_thread_data[threading.current_thread()]['alive']:
|
||||
log.info(f'Shutting down thread for device {renderer.context.device}')
|
||||
if not weak_thread_data[threading.current_thread()]["alive"]:
|
||||
log.info(f"Shutting down thread for device {renderer.context.device}")
|
||||
model_manager.unload_all(renderer.context)
|
||||
return
|
||||
if isinstance(current_state_error, SystemExit):
|
||||
@ -258,39 +301,47 @@ def thread_render(device):
|
||||
continue
|
||||
if task.error is not None:
|
||||
log.error(task.error)
|
||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||
task.response = {"status": "failed", "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
continue
|
||||
if current_state_error:
|
||||
task.error = current_state_error
|
||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||
task.response = {"status": "failed", "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
continue
|
||||
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}')
|
||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||
log.info(f"Session {task.task_data.session_id} starting task {id(task)} on {renderer.context.device_name}")
|
||||
if not task.lock.acquire(blocking=False):
|
||||
raise Exception("Got locked task from queue.")
|
||||
try:
|
||||
|
||||
def step_callback():
|
||||
global current_state_error
|
||||
|
||||
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
|
||||
if (
|
||||
isinstance(current_state_error, SystemExit)
|
||||
or isinstance(current_state_error, StopAsyncIteration)
|
||||
or isinstance(task.error, StopAsyncIteration)
|
||||
):
|
||||
renderer.context.stop_processing = True
|
||||
if isinstance(current_state_error, StopAsyncIteration):
|
||||
task.error = current_state_error
|
||||
current_state_error = None
|
||||
log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}')
|
||||
log.info(f"Session {task.task_data.session_id} sent cancel signal for task {id(task)}")
|
||||
|
||||
current_state = ServerStates.LoadingModel
|
||||
model_manager.resolve_model_paths(task.task_data)
|
||||
model_manager.reload_models_if_necessary(renderer.context, task.task_data)
|
||||
|
||||
current_state = ServerStates.Rendering
|
||||
task.response = renderer.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
||||
task.response = renderer.make_images(
|
||||
task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback
|
||||
)
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||
except Exception as e:
|
||||
task.error = str(e)
|
||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||
task.response = {"status": "failed", "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
log.error(traceback.format_exc())
|
||||
finally:
|
||||
@ -299,21 +350,25 @@ def thread_render(device):
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||
if isinstance(task.error, StopAsyncIteration):
|
||||
log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!')
|
||||
log.info(f"Session {task.task_data.session_id} task {id(task)} cancelled!")
|
||||
elif task.error is not None:
|
||||
log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
|
||||
log.info(f"Session {task.task_data.session_id} task {id(task)} failed!")
|
||||
else:
|
||||
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}.')
|
||||
log.info(
|
||||
f"Session {task.task_data.session_id} task {id(task)} completed by {renderer.context.device_name}."
|
||||
)
|
||||
current_state = ServerStates.Online
|
||||
|
||||
def get_cached_task(task_id:str, update_ttl:bool=False):
|
||||
|
||||
def get_cached_task(task_id: str, update_ttl: bool = False):
|
||||
# By calling keep before tryGet, wont discard if was expired.
|
||||
if update_ttl and not task_cache.keep(task_id, TASK_TTL):
|
||||
# Failed to keep task, already gone.
|
||||
return None
|
||||
return task_cache.tryGet(task_id)
|
||||
|
||||
def get_cached_session(session_id:str, update_ttl:bool=False):
|
||||
|
||||
def get_cached_session(session_id: str, update_ttl: bool = False):
|
||||
if update_ttl:
|
||||
session_cache.keep(session_id, TASK_TTL)
|
||||
session = session_cache.tryGet(session_id)
|
||||
@ -322,64 +377,68 @@ def get_cached_session(session_id:str, update_ttl:bool=False):
|
||||
session_cache.put(session_id, session, TASK_TTL)
|
||||
return session
|
||||
|
||||
|
||||
def get_devices():
|
||||
devices = {
|
||||
'all': {},
|
||||
'active': {},
|
||||
"all": {},
|
||||
"active": {},
|
||||
}
|
||||
|
||||
def get_device_info(device):
|
||||
if device == 'cpu':
|
||||
return {'name': device_manager.get_processor_name()}
|
||||
if device == "cpu":
|
||||
return {"name": device_manager.get_processor_name()}
|
||||
|
||||
mem_free, mem_total = torch.cuda.mem_get_info(device)
|
||||
mem_free /= float(10**9)
|
||||
mem_total /= float(10**9)
|
||||
|
||||
return {
|
||||
'name': torch.cuda.get_device_name(device),
|
||||
'mem_free': mem_free,
|
||||
'mem_total': mem_total,
|
||||
'max_vram_usage_level': device_manager.get_max_vram_usage_level(device),
|
||||
"name": torch.cuda.get_device_name(device),
|
||||
"mem_free": mem_free,
|
||||
"mem_total": mem_total,
|
||||
"max_vram_usage_level": device_manager.get_max_vram_usage_level(device),
|
||||
}
|
||||
|
||||
# list the compatible devices
|
||||
gpu_count = torch.cuda.device_count()
|
||||
for device in range(gpu_count):
|
||||
device = f'cuda:{device}'
|
||||
device = f"cuda:{device}"
|
||||
if not device_manager.is_device_compatible(device):
|
||||
continue
|
||||
|
||||
devices['all'].update({device: get_device_info(device)})
|
||||
devices["all"].update({device: get_device_info(device)})
|
||||
|
||||
devices['all'].update({'cpu': get_device_info('cpu')})
|
||||
devices["all"].update({"cpu": get_device_info("cpu")})
|
||||
|
||||
# list the activated devices
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('get_devices' + ERR_LOCK_FAILED)
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("get_devices" + ERR_LOCK_FAILED)
|
||||
try:
|
||||
for rthread in render_threads:
|
||||
if not rthread.is_alive():
|
||||
continue
|
||||
weak_data = weak_thread_data.get(rthread)
|
||||
if not weak_data or not 'device' in weak_data or not 'device_name' in weak_data:
|
||||
if not weak_data or not "device" in weak_data or not "device_name" in weak_data:
|
||||
continue
|
||||
device = weak_data['device']
|
||||
devices['active'].update({device: get_device_info(device)})
|
||||
device = weak_data["device"]
|
||||
devices["active"].update({device: get_device_info(device)})
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
return devices
|
||||
|
||||
|
||||
def is_alive(device=None):
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('is_alive' + ERR_LOCK_FAILED)
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("is_alive" + ERR_LOCK_FAILED)
|
||||
nbr_alive = 0
|
||||
try:
|
||||
for rthread in render_threads:
|
||||
if device is not None:
|
||||
weak_data = weak_thread_data.get(rthread)
|
||||
if weak_data is None or not 'device' in weak_data or weak_data['device'] is None:
|
||||
if weak_data is None or not "device" in weak_data or weak_data["device"] is None:
|
||||
continue
|
||||
thread_device = weak_data['device']
|
||||
thread_device = weak_data["device"]
|
||||
if thread_device != device:
|
||||
continue
|
||||
if rthread.is_alive():
|
||||
@ -388,11 +447,13 @@ def is_alive(device=None):
|
||||
finally:
|
||||
manager_lock.release()
|
||||
|
||||
|
||||
def start_render_thread(device):
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('start_render_thread' + ERR_LOCK_FAILED)
|
||||
log.info(f'Start new Rendering Thread on device: {device}')
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("start_render_thread" + ERR_LOCK_FAILED)
|
||||
log.info(f"Start new Rendering Thread on device: {device}")
|
||||
try:
|
||||
rthread = threading.Thread(target=thread_render, kwargs={'device': device})
|
||||
rthread = threading.Thread(target=thread_render, kwargs={"device": device})
|
||||
rthread.daemon = True
|
||||
rthread.name = THREAD_NAME_PREFIX + device
|
||||
rthread.start()
|
||||
@ -400,8 +461,8 @@ def start_render_thread(device):
|
||||
finally:
|
||||
manager_lock.release()
|
||||
timeout = DEVICE_START_TIMEOUT
|
||||
while not rthread.is_alive() or not rthread in weak_thread_data or not 'device' in weak_thread_data[rthread]:
|
||||
if rthread in weak_thread_data and 'error' in weak_thread_data[rthread]:
|
||||
while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]:
|
||||
if rthread in weak_thread_data and "error" in weak_thread_data[rthread]:
|
||||
log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
|
||||
return False
|
||||
if timeout <= 0:
|
||||
@ -410,25 +471,27 @@ def start_render_thread(device):
|
||||
time.sleep(1)
|
||||
return True
|
||||
|
||||
|
||||
def stop_render_thread(device):
|
||||
try:
|
||||
device_manager.validate_device_id(device, log_prefix='stop_render_thread')
|
||||
device_manager.validate_device_id(device, log_prefix="stop_render_thread")
|
||||
except:
|
||||
log.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): raise Exception('stop_render_thread' + ERR_LOCK_FAILED)
|
||||
log.info(f'Stopping Rendering Thread on device: {device}')
|
||||
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
|
||||
raise Exception("stop_render_thread" + ERR_LOCK_FAILED)
|
||||
log.info(f"Stopping Rendering Thread on device: {device}")
|
||||
|
||||
try:
|
||||
thread_to_remove = None
|
||||
for rthread in render_threads:
|
||||
weak_data = weak_thread_data.get(rthread)
|
||||
if weak_data is None or not 'device' in weak_data or weak_data['device'] is None:
|
||||
if weak_data is None or not "device" in weak_data or weak_data["device"] is None:
|
||||
continue
|
||||
thread_device = weak_data['device']
|
||||
thread_device = weak_data["device"]
|
||||
if thread_device == device:
|
||||
weak_data['alive'] = False
|
||||
weak_data["alive"] = False
|
||||
thread_to_remove = rthread
|
||||
break
|
||||
if thread_to_remove is not None:
|
||||
@ -439,44 +502,51 @@ def stop_render_thread(device):
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def update_render_threads(render_devices, active_devices):
|
||||
devices_to_start, devices_to_stop = device_manager.get_device_delta(render_devices, active_devices)
|
||||
log.debug(f'devices_to_start: {devices_to_start}')
|
||||
log.debug(f'devices_to_stop: {devices_to_stop}')
|
||||
log.debug(f"devices_to_start: {devices_to_start}")
|
||||
log.debug(f"devices_to_stop: {devices_to_stop}")
|
||||
|
||||
for device in devices_to_stop:
|
||||
if is_alive(device) <= 0:
|
||||
log.debug(f'{device} is not alive')
|
||||
log.debug(f"{device} is not alive")
|
||||
continue
|
||||
if not stop_render_thread(device):
|
||||
log.warn(f'{device} could not stop render thread')
|
||||
log.warn(f"{device} could not stop render thread")
|
||||
|
||||
for device in devices_to_start:
|
||||
if is_alive(device) >= 1:
|
||||
log.debug(f'{device} already registered.')
|
||||
log.debug(f"{device} already registered.")
|
||||
continue
|
||||
if not start_render_thread(device):
|
||||
log.warn(f'{device} failed to start.')
|
||||
log.warn(f"{device} failed to start.")
|
||||
|
||||
if is_alive() <= 0: # No running devices, probably invalid user config.
|
||||
raise EnvironmentError('ERROR: No active render devices! Please verify the "render_devices" value in config.json')
|
||||
raise EnvironmentError(
|
||||
'ERROR: No active render devices! Please verify the "render_devices" value in config.json'
|
||||
)
|
||||
|
||||
log.debug(f"active devices: {get_devices()['active']}")
|
||||
|
||||
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
global current_state_error
|
||||
current_state_error = SystemExit('Application shutting down.')
|
||||
current_state_error = SystemExit("Application shutting down.")
|
||||
|
||||
|
||||
def render(render_req: GenerateImageRequest, task_data: TaskData):
|
||||
current_thread_count = is_alive()
|
||||
if current_thread_count <= 0: # Render thread is dead
|
||||
raise ChildProcessError('Rendering thread has died.')
|
||||
raise ChildProcessError("Rendering thread has died.")
|
||||
|
||||
# Alive, check if task in cache
|
||||
session = get_cached_session(task_data.session_id, update_ttl=True)
|
||||
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
|
||||
if current_thread_count < len(pending_tasks):
|
||||
raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
|
||||
raise ConnectionRefusedError(
|
||||
f"Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}."
|
||||
)
|
||||
|
||||
new_task = RenderTask(render_req, task_data)
|
||||
if session.put(new_task, TASK_TTL):
|
||||
@ -489,4 +559,4 @@ def render(render_req: GenerateImageRequest, task_data: TaskData):
|
||||
return new_task
|
||||
finally:
|
||||
manager_lock.release()
|
||||
raise RuntimeError('Failed to add task to cache.')
|
||||
raise RuntimeError("Failed to add task to cache.")
|
||||
|
@ -1,6 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
|
||||
|
||||
class GenerateImageRequest(BaseModel):
|
||||
prompt: str = ""
|
||||
negative_prompt: str = ""
|
||||
@ -21,6 +22,7 @@ class GenerateImageRequest(BaseModel):
|
||||
sampler_name: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
||||
hypernetwork_strength: float = 0
|
||||
|
||||
|
||||
class TaskData(BaseModel):
|
||||
request_id: str = None
|
||||
session_id: str = "session"
|
||||
@ -41,6 +43,7 @@ class TaskData(BaseModel):
|
||||
metadata_output_format: str = "txt" # or "json"
|
||||
stream_image_progress: bool = False
|
||||
|
||||
|
||||
class MergeRequest(BaseModel):
|
||||
model0: str = None
|
||||
model1: str = None
|
||||
@ -48,6 +51,7 @@ class MergeRequest(BaseModel):
|
||||
out_path: str = "mix"
|
||||
use_fp16 = True
|
||||
|
||||
|
||||
class Image:
|
||||
data: str # base64
|
||||
seed: int
|
||||
@ -65,6 +69,7 @@ class Image:
|
||||
"path_abs": self.path_abs,
|
||||
}
|
||||
|
||||
|
||||
class Response:
|
||||
render_request: GenerateImageRequest
|
||||
task_data: TaskData
|
||||
@ -80,7 +85,7 @@ class Response:
|
||||
del self.render_request.init_image_mask
|
||||
|
||||
res = {
|
||||
"status": 'succeeded',
|
||||
"status": "succeeded",
|
||||
"render_request": self.render_request.dict(),
|
||||
"task_data": self.task_data.dict(),
|
||||
"output": [],
|
||||
@ -91,5 +96,6 @@ class Response:
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class UserInitiatedStop(Exception):
|
||||
pass
|
||||
|
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
log = logging.getLogger('easydiffusion')
|
||||
log = logging.getLogger("easydiffusion")
|
||||
|
||||
from .save_utils import (
|
||||
save_images_to_disk,
|
||||
|
@ -7,89 +7,126 @@ from easydiffusion.types import TaskData, GenerateImageRequest
|
||||
|
||||
from sdkit.utils import save_images, save_dicts
|
||||
|
||||
filename_regex = re.compile('[^a-zA-Z0-9._-]')
|
||||
filename_regex = re.compile("[^a-zA-Z0-9._-]")
|
||||
|
||||
# keep in sync with `ui/media/js/dnd.js`
|
||||
TASK_TEXT_MAPPING = {
|
||||
'prompt': 'Prompt',
|
||||
'width': 'Width',
|
||||
'height': 'Height',
|
||||
'seed': 'Seed',
|
||||
'num_inference_steps': 'Steps',
|
||||
'guidance_scale': 'Guidance Scale',
|
||||
'prompt_strength': 'Prompt Strength',
|
||||
'use_face_correction': 'Use Face Correction',
|
||||
'use_upscale': 'Use Upscaling',
|
||||
'upscale_amount': 'Upscale By',
|
||||
'sampler_name': 'Sampler',
|
||||
'negative_prompt': 'Negative Prompt',
|
||||
'use_stable_diffusion_model': 'Stable Diffusion model',
|
||||
'use_vae_model': 'VAE model',
|
||||
'use_hypernetwork_model': 'Hypernetwork model',
|
||||
'hypernetwork_strength': 'Hypernetwork Strength'
|
||||
"prompt": "Prompt",
|
||||
"width": "Width",
|
||||
"height": "Height",
|
||||
"seed": "Seed",
|
||||
"num_inference_steps": "Steps",
|
||||
"guidance_scale": "Guidance Scale",
|
||||
"prompt_strength": "Prompt Strength",
|
||||
"use_face_correction": "Use Face Correction",
|
||||
"use_upscale": "Use Upscaling",
|
||||
"upscale_amount": "Upscale By",
|
||||
"sampler_name": "Sampler",
|
||||
"negative_prompt": "Negative Prompt",
|
||||
"use_stable_diffusion_model": "Stable Diffusion model",
|
||||
"use_vae_model": "VAE model",
|
||||
"use_hypernetwork_model": "Hypernetwork model",
|
||||
"hypernetwork_strength": "Hypernetwork Strength",
|
||||
}
|
||||
|
||||
|
||||
def save_images_to_disk(images: list, filtered_images: list, req: GenerateImageRequest, task_data: TaskData):
|
||||
now = time.time()
|
||||
save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
||||
save_dir_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub("_", task_data.session_id))
|
||||
metadata_entries = get_metadata_entries_for_request(req, task_data)
|
||||
make_filename = make_filename_callback(req, now=now)
|
||||
|
||||
if task_data.show_only_filtered_image or filtered_images is images:
|
||||
save_images(filtered_images, save_dir_path, file_name=make_filename, output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||
if task_data.metadata_output_format.lower() in ['json', 'txt', 'embed']:
|
||||
save_dicts(metadata_entries, save_dir_path, file_name=make_filename, output_format=task_data.metadata_output_format, file_format=task_data.output_format)
|
||||
save_images(
|
||||
filtered_images,
|
||||
save_dir_path,
|
||||
file_name=make_filename,
|
||||
output_format=task_data.output_format,
|
||||
output_quality=task_data.output_quality,
|
||||
)
|
||||
if task_data.metadata_output_format.lower() in ["json", "txt", "embed"]:
|
||||
save_dicts(
|
||||
metadata_entries,
|
||||
save_dir_path,
|
||||
file_name=make_filename,
|
||||
output_format=task_data.metadata_output_format,
|
||||
file_format=task_data.output_format,
|
||||
)
|
||||
else:
|
||||
make_filter_filename = make_filename_callback(req, now=now, suffix='filtered')
|
||||
make_filter_filename = make_filename_callback(req, now=now, suffix="filtered")
|
||||
|
||||
save_images(
|
||||
images,
|
||||
save_dir_path,
|
||||
file_name=make_filename,
|
||||
output_format=task_data.output_format,
|
||||
output_quality=task_data.output_quality,
|
||||
)
|
||||
save_images(
|
||||
filtered_images,
|
||||
save_dir_path,
|
||||
file_name=make_filter_filename,
|
||||
output_format=task_data.output_format,
|
||||
output_quality=task_data.output_quality,
|
||||
)
|
||||
if task_data.metadata_output_format.lower() in ["json", "txt", "embed"]:
|
||||
save_dicts(
|
||||
metadata_entries,
|
||||
save_dir_path,
|
||||
file_name=make_filter_filename,
|
||||
output_format=task_data.metadata_output_format,
|
||||
file_format=task_data.output_format,
|
||||
)
|
||||
|
||||
save_images(images, save_dir_path, file_name=make_filename, output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||
save_images(filtered_images, save_dir_path, file_name=make_filter_filename, output_format=task_data.output_format, output_quality=task_data.output_quality)
|
||||
if task_data.metadata_output_format.lower() in ['json', 'txt', 'embed']:
|
||||
save_dicts(metadata_entries, save_dir_path, file_name=make_filter_filename, output_format=task_data.metadata_output_format, file_format=task_data.output_format)
|
||||
|
||||
def get_metadata_entries_for_request(req: GenerateImageRequest, task_data: TaskData):
|
||||
metadata = get_printable_request(req)
|
||||
metadata.update({
|
||||
'use_stable_diffusion_model': task_data.use_stable_diffusion_model,
|
||||
'use_vae_model': task_data.use_vae_model,
|
||||
'use_hypernetwork_model': task_data.use_hypernetwork_model,
|
||||
'use_face_correction': task_data.use_face_correction,
|
||||
'use_upscale': task_data.use_upscale,
|
||||
})
|
||||
if metadata['use_upscale'] is not None:
|
||||
metadata['upscale_amount'] = task_data.upscale_amount
|
||||
if (task_data.use_hypernetwork_model is None):
|
||||
del metadata['hypernetwork_strength']
|
||||
metadata.update(
|
||||
{
|
||||
"use_stable_diffusion_model": task_data.use_stable_diffusion_model,
|
||||
"use_vae_model": task_data.use_vae_model,
|
||||
"use_hypernetwork_model": task_data.use_hypernetwork_model,
|
||||
"use_face_correction": task_data.use_face_correction,
|
||||
"use_upscale": task_data.use_upscale,
|
||||
}
|
||||
)
|
||||
if metadata["use_upscale"] is not None:
|
||||
metadata["upscale_amount"] = task_data.upscale_amount
|
||||
if task_data.use_hypernetwork_model is None:
|
||||
del metadata["hypernetwork_strength"]
|
||||
|
||||
# if text, format it in the text format expected by the UI
|
||||
is_txt_format = (task_data.metadata_output_format.lower() == 'txt')
|
||||
is_txt_format = task_data.metadata_output_format.lower() == "txt"
|
||||
if is_txt_format:
|
||||
metadata = {TASK_TEXT_MAPPING[key]: val for key, val in metadata.items() if key in TASK_TEXT_MAPPING}
|
||||
|
||||
entries = [metadata.copy() for _ in range(req.num_outputs)]
|
||||
for i, entry in enumerate(entries):
|
||||
entry['Seed' if is_txt_format else 'seed'] = req.seed + i
|
||||
entry["Seed" if is_txt_format else "seed"] = req.seed + i
|
||||
|
||||
return entries
|
||||
|
||||
|
||||
def get_printable_request(req: GenerateImageRequest):
|
||||
metadata = req.dict()
|
||||
del metadata['init_image']
|
||||
del metadata['init_image_mask']
|
||||
if (req.init_image is None):
|
||||
del metadata['prompt_strength']
|
||||
del metadata["init_image"]
|
||||
del metadata["init_image_mask"]
|
||||
if req.init_image is None:
|
||||
del metadata["prompt_strength"]
|
||||
return metadata
|
||||
|
||||
|
||||
def make_filename_callback(req: GenerateImageRequest, suffix=None, now=None):
|
||||
if now is None:
|
||||
now = time.time()
|
||||
def make_filename(i):
|
||||
img_id = base64.b64encode(int(now+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
||||
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
||||
|
||||
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
|
||||
def make_filename(i):
|
||||
img_id = base64.b64encode(int(now + i).to_bytes(8, "big")).decode() # Generate unique ID based on time.
|
||||
img_id = img_id.translate({43: None, 47: None, 61: None})[-8:] # Remove + / = and keep last 8 chars.
|
||||
|
||||
prompt_flattened = filename_regex.sub("_", req.prompt)[:50]
|
||||
name = f"{prompt_flattened}_{img_id}"
|
||||
name = name if suffix is None else f'{name}_{suffix}'
|
||||
name = name if suffix is None else f"{name}_{suffix}"
|
||||
return name
|
||||
|
||||
return make_filename
|
||||
|
Loading…
Reference in New Issue
Block a user