Improve responsiveness of UI startup by not waiting for render threads to start up before showing the UI. Errors while starting the render thread will be logged anyway, so there's no need to block the main thread for this

This commit is contained in:
cmdr2 2023-07-08 23:28:47 +05:30
parent 2bd1cceb24
commit 37e8158175
8 changed files with 63 additions and 49 deletions

View File

@ -96,6 +96,8 @@ def init():
# https://pytorch.org/docs/stable/storage.html # https://pytorch.org/docs/stable/storage.html
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated") warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
def init_render_threads():
load_server_plugins() load_server_plugins()
update_render_threads() update_render_threads()
@ -279,6 +281,8 @@ def open_browser():
if ui.get("open_browser_on_start", True): if ui.get("open_browser_on_start", True):
import webbrowser import webbrowser
log.info("Opening browser..")
webbrowser.open(f"http://localhost:{port}") webbrowser.open(f"http://localhost:{port}")
Console().print( Console().print(

View File

@ -52,7 +52,6 @@ def init():
make_model_folders() make_model_folders()
migrate_legacy_model_location() # if necessary migrate_legacy_model_location() # if necessary
download_default_models_if_necessary() download_default_models_if_necessary()
getModels() # run this once, to cache the picklescan results
def load_default_models(context: Context): def load_default_models(context: Context):
@ -310,7 +309,7 @@ def is_malicious_model(file_path):
return False return False
def getModels(): def getModels(scan_for_malicious: bool = True):
models = { models = {
"options": { "options": {
"stable-diffusion": ["sd-v1-4"], "stable-diffusion": ["sd-v1-4"],
@ -343,9 +342,10 @@ def getModels():
mod_time = known_models[entry.path] if entry.path in known_models else -1 mod_time = known_models[entry.path] if entry.path in known_models else -1
if mod_time != mtime: if mod_time != mtime:
models_scanned += 1 models_scanned += 1
if is_malicious_model(entry.path): if scan_for_malicious and is_malicious_model(entry.path):
raise MaliciousModelException(entry.path) raise MaliciousModelException(entry.path)
known_models[entry.path] = mtime if scan_for_malicious:
known_models[entry.path] = mtime
tree.append(entry.name[: -len(matching_suffix)]) tree.append(entry.name[: -len(matching_suffix)])
elif entry.is_dir(): elif entry.is_dir():
scan = scan_directory(entry.path, suffixes, directoriesFirst=False) scan = scan_directory(entry.path, suffixes, directoriesFirst=False)
@ -365,9 +365,10 @@ def getModels():
try: try:
models["options"][model_type] = scan_directory(models_dir, model_extensions) models["options"][model_type] = scan_directory(models_dir, model_extensions)
except MaliciousModelException as e: except MaliciousModelException as e:
models["scan-error"] = e models["scan-error"] = str(e)
log.info(f"[green]Scanning all model folders for models...[/]") if scan_for_malicious:
log.info(f"[green]Scanning all model folders for models...[/]")
# custom models # custom models
listModels(model_type="stable-diffusion") listModels(model_type="stable-diffusion")
listModels(model_type="vae") listModels(model_type="vae")
@ -375,7 +376,7 @@ def getModels():
listModels(model_type="gfpgan") listModels(model_type="gfpgan")
listModels(model_type="lora") listModels(model_type="lora")
if models_scanned > 0: if scan_for_malicious and models_scanned > 0:
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]") log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
return models return models

View File

@ -86,8 +86,8 @@ def init():
return set_app_config_internal(req) return set_app_config_internal(req)
@server_api.get("/get/{key:path}") @server_api.get("/get/{key:path}")
def read_web_data(key: str = None): def read_web_data(key: str = None, scan_for_malicious: bool = True):
return read_web_data_internal(key) return read_web_data_internal(key, scan_for_malicious=scan_for_malicious)
@server_api.get("/ping") # Get server and optionally session status. @server_api.get("/ping") # Get server and optionally session status.
def ping(session_id: str = None): def ping(session_id: str = None):
@ -179,7 +179,7 @@ def update_render_devices_in_config(config, render_devices):
config["render_devices"] = render_devices config["render_devices"] = render_devices
def read_web_data_internal(key: str = None): def read_web_data_internal(key: str = None, **kwargs):
if not key: # /get without parameters, stable-diffusion easter egg. if not key: # /get without parameters, stable-diffusion easter egg.
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == "app_config": elif key == "app_config":
@ -198,7 +198,8 @@ def read_web_data_internal(key: str = None):
system_info["devices"]["config"] = config.get("render_devices", "auto") system_info["devices"]["config"] = config.get("render_devices", "auto")
return JSONResponse(system_info, headers=NOCACHE_HEADERS) return JSONResponse(system_info, headers=NOCACHE_HEADERS)
elif key == "models": elif key == "models":
return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS) scan_for_malicious = kwargs.get("scan_for_malicious", True)
return JSONResponse(model_manager.getModels(scan_for_malicious), headers=NOCACHE_HEADERS)
elif key == "modifiers": elif key == "modifiers":
return JSONResponse(app.get_image_modifiers(), headers=NOCACHE_HEADERS) return JSONResponse(app.get_image_modifiers(), headers=NOCACHE_HEADERS)
elif key == "ui_plugins": elif key == "ui_plugins":
@ -334,7 +335,8 @@ def get_image_internal(task_id: int, img_id: int):
except KeyError as e: except KeyError as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
#---- Cloudflare Tunnel ----
# ---- Cloudflare Tunnel ----
class CloudflareTunnel: class CloudflareTunnel:
def __init__(self): def __init__(self):
config = app.getConfig() config = app.getConfig()
@ -357,23 +359,25 @@ class CloudflareTunnel:
else: else:
return None return None
cloudflare = CloudflareTunnel() cloudflare = CloudflareTunnel()
def start_cloudflare_tunnel_internal(req: dict): def start_cloudflare_tunnel_internal(req: dict):
try: try:
cloudflare.start() cloudflare.start()
log.info(f"- Started cloudflare tunnel. Using address: {cloudflare.address}") log.info(f"- Started cloudflare tunnel. Using address: {cloudflare.address}")
return JSONResponse({"address":cloudflare.address}) return JSONResponse({"address": cloudflare.address})
except Exception as e: except Exception as e:
log.error(str(e)) log.error(str(e))
log.error(traceback.format_exc()) log.error(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e)) return HTTPException(status_code=500, detail=str(e))
def stop_cloudflare_tunnel_internal(req: dict): def stop_cloudflare_tunnel_internal(req: dict):
try: try:
cloudflare.stop() cloudflare.stop()
except Exception as e: except Exception as e:
log.error(str(e)) log.error(str(e))
log.error(traceback.format_exc()) log.error(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e)) return HTTPException(status_code=500, detail=str(e))

View File

@ -473,16 +473,16 @@ def start_render_thread(device):
render_threads.append(rthread) render_threads.append(rthread)
finally: finally:
manager_lock.release() manager_lock.release()
timeout = DEVICE_START_TIMEOUT # timeout = DEVICE_START_TIMEOUT
while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]: # while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]:
if rthread in weak_thread_data and "error" in weak_thread_data[rthread]: # if rthread in weak_thread_data and "error" in weak_thread_data[rthread]:
log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}") # log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
return False # return False
if timeout <= 0: # if timeout <= 0:
return False # return False
timeout -= 1 # timeout -= 1
time.sleep(1) # time.sleep(1)
return True # return True
def stop_render_thread(device): def stop_render_thread(device):
@ -535,12 +535,12 @@ def update_render_threads(render_devices, active_devices):
if not start_render_thread(device): if not start_render_thread(device):
log.warn(f"{device} failed to start.") log.warn(f"{device} failed to start.")
if is_alive() <= 0: # No running devices, probably invalid user config. # if is_alive() <= 0: # No running devices, probably invalid user config.
raise EnvironmentError( # raise EnvironmentError(
'ERROR: No active render devices! Please verify the "render_devices" value in config.json' # 'ERROR: No active render devices! Please verify the "render_devices" value in config.json'
) # )
log.debug(f"active devices: {get_devices()['active']}") # log.debug(f"active devices: {get_devices()['active']}")
def shutdown_event(): # Signal render thread to close on shutdown def shutdown_event(): # Signal render thread to close on shutdown

View File

@ -624,7 +624,7 @@
<script> <script>
async function init() { async function init() {
await initSettings() await initSettings()
await getModels() await getModels(false)
await getAppConfig() await getAppConfig()
await loadUIPlugins() await loadUIPlugins()
await loadModifiers() await loadModifiers()
@ -640,6 +640,9 @@ async function init() {
}) })
splashScreen() splashScreen()
// load models again, but scan for malicious this time
await getModels(true)
// playSound() // playSound()
} }

View File

@ -1,5 +1,5 @@
from easydiffusion import model_manager, app, server from easydiffusion import model_manager, app, server
from easydiffusion.server import server_api # required for uvicorn from easydiffusion.server import server_api # required for uvicorn
# Init the app # Init the app
model_manager.init() model_manager.init()
@ -8,3 +8,5 @@ server.init()
# start the browser ui # start the browser ui
app.open_browser() app.open_browser()
app.init_render_threads()

View File

@ -1121,13 +1121,13 @@
return systemInfo.hosts return systemInfo.hosts
} }
async function getModels() { async function getModels(scanForMalicious = true) {
let models = { let models = {
"stable-diffusion": [], "stable-diffusion": [],
vae: [], vae: [],
} }
try { try {
const res = await fetch("/get/models") const res = await fetch("/get/models?scan_for_malicious=" + scanForMalicious)
if (!res.ok) { if (!res.ok) {
console.error("Invalid response fetching models", res.statusText) console.error("Invalid response fetching models", res.statusText)
return models return models

View File

@ -627,9 +627,9 @@ class ModelDropdown {
} }
/* (RE)LOAD THE MODELS */ /* (RE)LOAD THE MODELS */
async function getModels() { async function getModels(scanForMalicious = true) {
try { try {
modelsCache = await SD.getModels() modelsCache = await SD.getModels(scanForMalicious)
modelsOptions = modelsCache["options"] modelsOptions = modelsCache["options"]
if ("scan-error" in modelsCache) { if ("scan-error" in modelsCache) {
// let previewPane = document.getElementById('tab-content-wrapper') // let previewPane = document.getElementById('tab-content-wrapper')
@ -667,4 +667,4 @@ async function getModels() {
} }
// reload models button // reload models button
document.querySelector("#reload-models").addEventListener("click", getModels) document.querySelector("#reload-models").addEventListener("click", () => getModels())