mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-04-27 12:58:47 +02:00
Improve responsiveness of UI startup by not waiting for render threads to start up before showing the UI. Errors while starting the render thread will be logged anyway, so there's no need to block the main thread for this
This commit is contained in:
parent
2bd1cceb24
commit
37e8158175
@ -96,6 +96,8 @@ def init():
|
|||||||
# https://pytorch.org/docs/stable/storage.html
|
# https://pytorch.org/docs/stable/storage.html
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
|
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
|
||||||
|
|
||||||
|
|
||||||
|
def init_render_threads():
|
||||||
load_server_plugins()
|
load_server_plugins()
|
||||||
|
|
||||||
update_render_threads()
|
update_render_threads()
|
||||||
@ -279,6 +281,8 @@ def open_browser():
|
|||||||
if ui.get("open_browser_on_start", True):
|
if ui.get("open_browser_on_start", True):
|
||||||
import webbrowser
|
import webbrowser
|
||||||
|
|
||||||
|
log.info("Opening browser..")
|
||||||
|
|
||||||
webbrowser.open(f"http://localhost:{port}")
|
webbrowser.open(f"http://localhost:{port}")
|
||||||
|
|
||||||
Console().print(
|
Console().print(
|
||||||
|
@ -52,7 +52,6 @@ def init():
|
|||||||
make_model_folders()
|
make_model_folders()
|
||||||
migrate_legacy_model_location() # if necessary
|
migrate_legacy_model_location() # if necessary
|
||||||
download_default_models_if_necessary()
|
download_default_models_if_necessary()
|
||||||
getModels() # run this once, to cache the picklescan results
|
|
||||||
|
|
||||||
|
|
||||||
def load_default_models(context: Context):
|
def load_default_models(context: Context):
|
||||||
@ -310,7 +309,7 @@ def is_malicious_model(file_path):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def getModels():
|
def getModels(scan_for_malicious: bool = True):
|
||||||
models = {
|
models = {
|
||||||
"options": {
|
"options": {
|
||||||
"stable-diffusion": ["sd-v1-4"],
|
"stable-diffusion": ["sd-v1-4"],
|
||||||
@ -343,9 +342,10 @@ def getModels():
|
|||||||
mod_time = known_models[entry.path] if entry.path in known_models else -1
|
mod_time = known_models[entry.path] if entry.path in known_models else -1
|
||||||
if mod_time != mtime:
|
if mod_time != mtime:
|
||||||
models_scanned += 1
|
models_scanned += 1
|
||||||
if is_malicious_model(entry.path):
|
if scan_for_malicious and is_malicious_model(entry.path):
|
||||||
raise MaliciousModelException(entry.path)
|
raise MaliciousModelException(entry.path)
|
||||||
known_models[entry.path] = mtime
|
if scan_for_malicious:
|
||||||
|
known_models[entry.path] = mtime
|
||||||
tree.append(entry.name[: -len(matching_suffix)])
|
tree.append(entry.name[: -len(matching_suffix)])
|
||||||
elif entry.is_dir():
|
elif entry.is_dir():
|
||||||
scan = scan_directory(entry.path, suffixes, directoriesFirst=False)
|
scan = scan_directory(entry.path, suffixes, directoriesFirst=False)
|
||||||
@ -365,9 +365,10 @@ def getModels():
|
|||||||
try:
|
try:
|
||||||
models["options"][model_type] = scan_directory(models_dir, model_extensions)
|
models["options"][model_type] = scan_directory(models_dir, model_extensions)
|
||||||
except MaliciousModelException as e:
|
except MaliciousModelException as e:
|
||||||
models["scan-error"] = e
|
models["scan-error"] = str(e)
|
||||||
|
|
||||||
log.info(f"[green]Scanning all model folders for models...[/]")
|
if scan_for_malicious:
|
||||||
|
log.info(f"[green]Scanning all model folders for models...[/]")
|
||||||
# custom models
|
# custom models
|
||||||
listModels(model_type="stable-diffusion")
|
listModels(model_type="stable-diffusion")
|
||||||
listModels(model_type="vae")
|
listModels(model_type="vae")
|
||||||
@ -375,7 +376,7 @@ def getModels():
|
|||||||
listModels(model_type="gfpgan")
|
listModels(model_type="gfpgan")
|
||||||
listModels(model_type="lora")
|
listModels(model_type="lora")
|
||||||
|
|
||||||
if models_scanned > 0:
|
if scan_for_malicious and models_scanned > 0:
|
||||||
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
|
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
@ -86,8 +86,8 @@ def init():
|
|||||||
return set_app_config_internal(req)
|
return set_app_config_internal(req)
|
||||||
|
|
||||||
@server_api.get("/get/{key:path}")
|
@server_api.get("/get/{key:path}")
|
||||||
def read_web_data(key: str = None):
|
def read_web_data(key: str = None, scan_for_malicious: bool = True):
|
||||||
return read_web_data_internal(key)
|
return read_web_data_internal(key, scan_for_malicious=scan_for_malicious)
|
||||||
|
|
||||||
@server_api.get("/ping") # Get server and optionally session status.
|
@server_api.get("/ping") # Get server and optionally session status.
|
||||||
def ping(session_id: str = None):
|
def ping(session_id: str = None):
|
||||||
@ -179,7 +179,7 @@ def update_render_devices_in_config(config, render_devices):
|
|||||||
config["render_devices"] = render_devices
|
config["render_devices"] = render_devices
|
||||||
|
|
||||||
|
|
||||||
def read_web_data_internal(key: str = None):
|
def read_web_data_internal(key: str = None, **kwargs):
|
||||||
if not key: # /get without parameters, stable-diffusion easter egg.
|
if not key: # /get without parameters, stable-diffusion easter egg.
|
||||||
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
||||||
elif key == "app_config":
|
elif key == "app_config":
|
||||||
@ -198,7 +198,8 @@ def read_web_data_internal(key: str = None):
|
|||||||
system_info["devices"]["config"] = config.get("render_devices", "auto")
|
system_info["devices"]["config"] = config.get("render_devices", "auto")
|
||||||
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
|
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
|
||||||
elif key == "models":
|
elif key == "models":
|
||||||
return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS)
|
scan_for_malicious = kwargs.get("scan_for_malicious", True)
|
||||||
|
return JSONResponse(model_manager.getModels(scan_for_malicious), headers=NOCACHE_HEADERS)
|
||||||
elif key == "modifiers":
|
elif key == "modifiers":
|
||||||
return JSONResponse(app.get_image_modifiers(), headers=NOCACHE_HEADERS)
|
return JSONResponse(app.get_image_modifiers(), headers=NOCACHE_HEADERS)
|
||||||
elif key == "ui_plugins":
|
elif key == "ui_plugins":
|
||||||
@ -334,7 +335,8 @@ def get_image_internal(task_id: int, img_id: int):
|
|||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
#---- Cloudflare Tunnel ----
|
|
||||||
|
# ---- Cloudflare Tunnel ----
|
||||||
class CloudflareTunnel:
|
class CloudflareTunnel:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
config = app.getConfig()
|
config = app.getConfig()
|
||||||
@ -357,23 +359,25 @@ class CloudflareTunnel:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
cloudflare = CloudflareTunnel()
|
cloudflare = CloudflareTunnel()
|
||||||
|
|
||||||
|
|
||||||
def start_cloudflare_tunnel_internal(req: dict):
|
def start_cloudflare_tunnel_internal(req: dict):
|
||||||
try:
|
try:
|
||||||
cloudflare.start()
|
cloudflare.start()
|
||||||
log.info(f"- Started cloudflare tunnel. Using address: {cloudflare.address}")
|
log.info(f"- Started cloudflare tunnel. Using address: {cloudflare.address}")
|
||||||
return JSONResponse({"address":cloudflare.address})
|
return JSONResponse({"address": cloudflare.address})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(str(e))
|
log.error(str(e))
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
return HTTPException(status_code=500, detail=str(e))
|
return HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
def stop_cloudflare_tunnel_internal(req: dict):
|
def stop_cloudflare_tunnel_internal(req: dict):
|
||||||
try:
|
try:
|
||||||
cloudflare.stop()
|
cloudflare.stop()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(str(e))
|
log.error(str(e))
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
return HTTPException(status_code=500, detail=str(e))
|
return HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@ -473,16 +473,16 @@ def start_render_thread(device):
|
|||||||
render_threads.append(rthread)
|
render_threads.append(rthread)
|
||||||
finally:
|
finally:
|
||||||
manager_lock.release()
|
manager_lock.release()
|
||||||
timeout = DEVICE_START_TIMEOUT
|
# timeout = DEVICE_START_TIMEOUT
|
||||||
while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]:
|
# while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]:
|
||||||
if rthread in weak_thread_data and "error" in weak_thread_data[rthread]:
|
# if rthread in weak_thread_data and "error" in weak_thread_data[rthread]:
|
||||||
log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
|
# log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
|
||||||
return False
|
# return False
|
||||||
if timeout <= 0:
|
# if timeout <= 0:
|
||||||
return False
|
# return False
|
||||||
timeout -= 1
|
# timeout -= 1
|
||||||
time.sleep(1)
|
# time.sleep(1)
|
||||||
return True
|
# return True
|
||||||
|
|
||||||
|
|
||||||
def stop_render_thread(device):
|
def stop_render_thread(device):
|
||||||
@ -535,12 +535,12 @@ def update_render_threads(render_devices, active_devices):
|
|||||||
if not start_render_thread(device):
|
if not start_render_thread(device):
|
||||||
log.warn(f"{device} failed to start.")
|
log.warn(f"{device} failed to start.")
|
||||||
|
|
||||||
if is_alive() <= 0: # No running devices, probably invalid user config.
|
# if is_alive() <= 0: # No running devices, probably invalid user config.
|
||||||
raise EnvironmentError(
|
# raise EnvironmentError(
|
||||||
'ERROR: No active render devices! Please verify the "render_devices" value in config.json'
|
# 'ERROR: No active render devices! Please verify the "render_devices" value in config.json'
|
||||||
)
|
# )
|
||||||
|
|
||||||
log.debug(f"active devices: {get_devices()['active']}")
|
# log.debug(f"active devices: {get_devices()['active']}")
|
||||||
|
|
||||||
|
|
||||||
def shutdown_event(): # Signal render thread to close on shutdown
|
def shutdown_event(): # Signal render thread to close on shutdown
|
||||||
|
@ -624,7 +624,7 @@
|
|||||||
<script>
|
<script>
|
||||||
async function init() {
|
async function init() {
|
||||||
await initSettings()
|
await initSettings()
|
||||||
await getModels()
|
await getModels(false)
|
||||||
await getAppConfig()
|
await getAppConfig()
|
||||||
await loadUIPlugins()
|
await loadUIPlugins()
|
||||||
await loadModifiers()
|
await loadModifiers()
|
||||||
@ -640,6 +640,9 @@ async function init() {
|
|||||||
})
|
})
|
||||||
splashScreen()
|
splashScreen()
|
||||||
|
|
||||||
|
// load models again, but scan for malicious this time
|
||||||
|
await getModels(true)
|
||||||
|
|
||||||
// playSound()
|
// playSound()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from easydiffusion import model_manager, app, server
|
from easydiffusion import model_manager, app, server
|
||||||
from easydiffusion.server import server_api # required for uvicorn
|
from easydiffusion.server import server_api # required for uvicorn
|
||||||
|
|
||||||
# Init the app
|
# Init the app
|
||||||
model_manager.init()
|
model_manager.init()
|
||||||
@ -8,3 +8,5 @@ server.init()
|
|||||||
|
|
||||||
# start the browser ui
|
# start the browser ui
|
||||||
app.open_browser()
|
app.open_browser()
|
||||||
|
|
||||||
|
app.init_render_threads()
|
||||||
|
@ -1121,13 +1121,13 @@
|
|||||||
return systemInfo.hosts
|
return systemInfo.hosts
|
||||||
}
|
}
|
||||||
|
|
||||||
async function getModels() {
|
async function getModels(scanForMalicious = true) {
|
||||||
let models = {
|
let models = {
|
||||||
"stable-diffusion": [],
|
"stable-diffusion": [],
|
||||||
vae: [],
|
vae: [],
|
||||||
}
|
}
|
||||||
try {
|
try {
|
||||||
const res = await fetch("/get/models")
|
const res = await fetch("/get/models?scan_for_malicious=" + scanForMalicious)
|
||||||
if (!res.ok) {
|
if (!res.ok) {
|
||||||
console.error("Invalid response fetching models", res.statusText)
|
console.error("Invalid response fetching models", res.statusText)
|
||||||
return models
|
return models
|
||||||
|
@ -627,9 +627,9 @@ class ModelDropdown {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* (RE)LOAD THE MODELS */
|
/* (RE)LOAD THE MODELS */
|
||||||
async function getModels() {
|
async function getModels(scanForMalicious = true) {
|
||||||
try {
|
try {
|
||||||
modelsCache = await SD.getModels()
|
modelsCache = await SD.getModels(scanForMalicious)
|
||||||
modelsOptions = modelsCache["options"]
|
modelsOptions = modelsCache["options"]
|
||||||
if ("scan-error" in modelsCache) {
|
if ("scan-error" in modelsCache) {
|
||||||
// let previewPane = document.getElementById('tab-content-wrapper')
|
// let previewPane = document.getElementById('tab-content-wrapper')
|
||||||
@ -667,4 +667,4 @@ async function getModels() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// reload models button
|
// reload models button
|
||||||
document.querySelector("#reload-models").addEventListener("click", getModels)
|
document.querySelector("#reload-models").addEventListener("click", () => getModels())
|
||||||
|
Loading…
Reference in New Issue
Block a user