diff --git a/ui/easydiffusion/app.py b/ui/easydiffusion/app.py index d556dd6f..d245bf39 100644 --- a/ui/easydiffusion/app.py +++ b/ui/easydiffusion/app.py @@ -27,15 +27,21 @@ logging.basicConfig( SD_DIR = os.getcwd() SD_UI_DIR = os.getenv("SD_UI_PATH", None) -sys.path.append(os.path.dirname(SD_UI_DIR)) CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts")) MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models")) -USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins", "ui")) -CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins", "ui")) +USER_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins")) +CORE_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins")) + +USER_UI_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "ui") +CORE_UI_PLUGINS_DIR = os.path.join(CORE_PLUGINS_DIR, "ui") +USER_SERVER_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "server") UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, "core"), (USER_UI_PLUGINS_DIR, "user")) +sys.path.append(os.path.dirname(SD_UI_DIR)) +sys.path.append(USER_SERVER_PLUGINS_DIR) + OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder PRESERVE_CONFIG_VARS = ["FORCE_FULL_PRECISION"] TASK_TTL = 15 * 60 # Discard last session's task timeout @@ -51,6 +57,9 @@ APP_CONFIG_DEFAULTS = { def init(): os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) + os.makedirs(USER_SERVER_PLUGINS_DIR, exist_ok=True) + + load_server_plugins() update_render_threads() @@ -171,6 +180,41 @@ def getUIPlugins(): return plugins +def load_server_plugins(): + if not os.path.exists(USER_SERVER_PLUGINS_DIR): + return + + import importlib + + def load_plugin(file): + mod_path = file.replace(".py", "") + return importlib.import_module(mod_path) + + def apply_plugin(file, plugin): + if hasattr(plugin, "get_cond_and_uncond"): + import sdkit.generate.image_generator + + sdkit.generate.image_generator.get_cond_and_uncond = plugin.get_cond_and_uncond + log.info(f"Overridden get_cond_and_uncond with the one in the server plugin: {file}") + + for file in os.listdir(USER_SERVER_PLUGINS_DIR): + file_path = os.path.join(USER_SERVER_PLUGINS_DIR, file) + if (not os.path.isdir(file_path) and not file_path.endswith("_plugin.py")) or ( + os.path.isdir(file_path) and not file_path.endswith("_plugin") + ): + continue + + try: + log.info(f"Loading server plugin: {file}") + mod = load_plugin(file) + + log.info(f"Applying server plugin: {file}") + apply_plugin(file, mod) + except: + log.warn(f"Error while loading a server plugin") + log.warn(traceback.format_exc()) + + def getIPConfig(): try: ips = socket.gethostbyname_ex(socket.gethostname())