mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-02-18 03:11:10 +01:00
Support server-side plugins. Currently supports overriding the get_cond_and_uncond function
This commit is contained in:
parent
4754743c84
commit
e73e820237
@ -27,15 +27,21 @@ logging.basicConfig(
|
|||||||
SD_DIR = os.getcwd()
|
SD_DIR = os.getcwd()
|
||||||
|
|
||||||
SD_UI_DIR = os.getenv("SD_UI_PATH", None)
|
SD_UI_DIR = os.getenv("SD_UI_PATH", None)
|
||||||
sys.path.append(os.path.dirname(SD_UI_DIR))
|
|
||||||
|
|
||||||
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts"))
|
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts"))
|
||||||
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models"))
|
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models"))
|
||||||
|
|
||||||
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins", "ui"))
|
USER_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins"))
|
||||||
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins", "ui"))
|
CORE_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins"))
|
||||||
|
|
||||||
|
USER_UI_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "ui")
|
||||||
|
CORE_UI_PLUGINS_DIR = os.path.join(CORE_PLUGINS_DIR, "ui")
|
||||||
|
USER_SERVER_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "server")
|
||||||
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, "core"), (USER_UI_PLUGINS_DIR, "user"))
|
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, "core"), (USER_UI_PLUGINS_DIR, "user"))
|
||||||
|
|
||||||
|
sys.path.append(os.path.dirname(SD_UI_DIR))
|
||||||
|
sys.path.append(USER_SERVER_PLUGINS_DIR)
|
||||||
|
|
||||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||||
PRESERVE_CONFIG_VARS = ["FORCE_FULL_PRECISION"]
|
PRESERVE_CONFIG_VARS = ["FORCE_FULL_PRECISION"]
|
||||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||||
@ -51,6 +57,9 @@ APP_CONFIG_DEFAULTS = {
|
|||||||
|
|
||||||
def init():
|
def init():
|
||||||
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
||||||
|
os.makedirs(USER_SERVER_PLUGINS_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
load_server_plugins()
|
||||||
|
|
||||||
update_render_threads()
|
update_render_threads()
|
||||||
|
|
||||||
@ -171,6 +180,41 @@ def getUIPlugins():
|
|||||||
return plugins
|
return plugins
|
||||||
|
|
||||||
|
|
||||||
|
def load_server_plugins():
|
||||||
|
if not os.path.exists(USER_SERVER_PLUGINS_DIR):
|
||||||
|
return
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
def load_plugin(file):
|
||||||
|
mod_path = file.replace(".py", "")
|
||||||
|
return importlib.import_module(mod_path)
|
||||||
|
|
||||||
|
def apply_plugin(file, plugin):
|
||||||
|
if hasattr(plugin, "get_cond_and_uncond"):
|
||||||
|
import sdkit.generate.image_generator
|
||||||
|
|
||||||
|
sdkit.generate.image_generator.get_cond_and_uncond = plugin.get_cond_and_uncond
|
||||||
|
log.info(f"Overridden get_cond_and_uncond with the one in the server plugin: {file}")
|
||||||
|
|
||||||
|
for file in os.listdir(USER_SERVER_PLUGINS_DIR):
|
||||||
|
file_path = os.path.join(USER_SERVER_PLUGINS_DIR, file)
|
||||||
|
if (not os.path.isdir(file_path) and not file_path.endswith("_plugin.py")) or (
|
||||||
|
os.path.isdir(file_path) and not file_path.endswith("_plugin")
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
log.info(f"Loading server plugin: {file}")
|
||||||
|
mod = load_plugin(file)
|
||||||
|
|
||||||
|
log.info(f"Applying server plugin: {file}")
|
||||||
|
apply_plugin(file, mod)
|
||||||
|
except:
|
||||||
|
log.warn(f"Error while loading a server plugin")
|
||||||
|
log.warn(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
def getIPConfig():
|
def getIPConfig():
|
||||||
try:
|
try:
|
||||||
ips = socket.gethostbyname_ex(socket.gethostname())
|
ips = socket.gethostbyname_ex(socket.gethostname())
|
||||||
|
Loading…
Reference in New Issue
Block a user