mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-01-26 08:08:58 +01:00
Support server-side plugins. Currently supports overriding the get_cond_and_uncond function
This commit is contained in:
parent
4754743c84
commit
e73e820237
@ -27,15 +27,21 @@ logging.basicConfig(
|
||||
SD_DIR = os.getcwd()
|
||||
|
||||
SD_UI_DIR = os.getenv("SD_UI_PATH", None)
|
||||
sys.path.append(os.path.dirname(SD_UI_DIR))
|
||||
|
||||
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "..", "scripts"))
|
||||
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "models"))
|
||||
|
||||
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins", "ui"))
|
||||
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins", "ui"))
|
||||
USER_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, "..", "plugins"))
|
||||
CORE_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, "plugins"))
|
||||
|
||||
USER_UI_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "ui")
|
||||
CORE_UI_PLUGINS_DIR = os.path.join(CORE_PLUGINS_DIR, "ui")
|
||||
USER_SERVER_PLUGINS_DIR = os.path.join(USER_PLUGINS_DIR, "server")
|
||||
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, "core"), (USER_UI_PLUGINS_DIR, "user"))
|
||||
|
||||
sys.path.append(os.path.dirname(SD_UI_DIR))
|
||||
sys.path.append(USER_SERVER_PLUGINS_DIR)
|
||||
|
||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||
PRESERVE_CONFIG_VARS = ["FORCE_FULL_PRECISION"]
|
||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||
@ -51,6 +57,9 @@ APP_CONFIG_DEFAULTS = {
|
||||
|
||||
def init():
|
||||
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
||||
os.makedirs(USER_SERVER_PLUGINS_DIR, exist_ok=True)
|
||||
|
||||
load_server_plugins()
|
||||
|
||||
update_render_threads()
|
||||
|
||||
@ -171,6 +180,41 @@ def getUIPlugins():
|
||||
return plugins
|
||||
|
||||
|
||||
def load_server_plugins():
|
||||
if not os.path.exists(USER_SERVER_PLUGINS_DIR):
|
||||
return
|
||||
|
||||
import importlib
|
||||
|
||||
def load_plugin(file):
|
||||
mod_path = file.replace(".py", "")
|
||||
return importlib.import_module(mod_path)
|
||||
|
||||
def apply_plugin(file, plugin):
|
||||
if hasattr(plugin, "get_cond_and_uncond"):
|
||||
import sdkit.generate.image_generator
|
||||
|
||||
sdkit.generate.image_generator.get_cond_and_uncond = plugin.get_cond_and_uncond
|
||||
log.info(f"Overridden get_cond_and_uncond with the one in the server plugin: {file}")
|
||||
|
||||
for file in os.listdir(USER_SERVER_PLUGINS_DIR):
|
||||
file_path = os.path.join(USER_SERVER_PLUGINS_DIR, file)
|
||||
if (not os.path.isdir(file_path) and not file_path.endswith("_plugin.py")) or (
|
||||
os.path.isdir(file_path) and not file_path.endswith("_plugin")
|
||||
):
|
||||
continue
|
||||
|
||||
try:
|
||||
log.info(f"Loading server plugin: {file}")
|
||||
mod = load_plugin(file)
|
||||
|
||||
log.info(f"Applying server plugin: {file}")
|
||||
apply_plugin(file, mod)
|
||||
except:
|
||||
log.warn(f"Error while loading a server plugin")
|
||||
log.warn(traceback.format_exc())
|
||||
|
||||
|
||||
def getIPConfig():
|
||||
try:
|
||||
ips = socket.gethostbyname_ex(socket.gethostname())
|
||||
|
Loading…
Reference in New Issue
Block a user