Improve responsiveness of UI startup by not waiting for render threads to start up before showing the UI. Errors while starting the render thread will be logged anyway, so there's no need to block the main thread for this

This commit is contained in:
cmdr2 2023-07-08 23:28:47 +05:30
parent 2bd1cceb24
commit 37e8158175
8 changed files with 63 additions and 49 deletions

View File

@ -96,6 +96,8 @@ def init():
# https://pytorch.org/docs/stable/storage.html
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
def init_render_threads():
load_server_plugins()
update_render_threads()
@ -279,6 +281,8 @@ def open_browser():
if ui.get("open_browser_on_start", True):
import webbrowser
log.info("Opening browser..")
webbrowser.open(f"http://localhost:{port}")
Console().print(

View File

@ -52,7 +52,6 @@ def init():
make_model_folders()
migrate_legacy_model_location() # if necessary
download_default_models_if_necessary()
getModels() # run this once, to cache the picklescan results
def load_default_models(context: Context):
@ -310,7 +309,7 @@ def is_malicious_model(file_path):
return False
def getModels():
def getModels(scan_for_malicious: bool = True):
models = {
"options": {
"stable-diffusion": ["sd-v1-4"],
@ -343,9 +342,10 @@ def getModels():
mod_time = known_models[entry.path] if entry.path in known_models else -1
if mod_time != mtime:
models_scanned += 1
if is_malicious_model(entry.path):
if scan_for_malicious and is_malicious_model(entry.path):
raise MaliciousModelException(entry.path)
known_models[entry.path] = mtime
if scan_for_malicious:
known_models[entry.path] = mtime
tree.append(entry.name[: -len(matching_suffix)])
elif entry.is_dir():
scan = scan_directory(entry.path, suffixes, directoriesFirst=False)
@ -365,9 +365,10 @@ def getModels():
try:
models["options"][model_type] = scan_directory(models_dir, model_extensions)
except MaliciousModelException as e:
models["scan-error"] = e
models["scan-error"] = str(e)
log.info(f"[green]Scanning all model folders for models...[/]")
if scan_for_malicious:
log.info(f"[green]Scanning all model folders for models...[/]")
# custom models
listModels(model_type="stable-diffusion")
listModels(model_type="vae")
@ -375,7 +376,7 @@ def getModels():
listModels(model_type="gfpgan")
listModels(model_type="lora")
if models_scanned > 0:
if scan_for_malicious and models_scanned > 0:
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
return models

View File

@ -86,8 +86,8 @@ def init():
return set_app_config_internal(req)
@server_api.get("/get/{key:path}")
def read_web_data(key: str = None):
return read_web_data_internal(key)
def read_web_data(key: str = None, scan_for_malicious: bool = True):
return read_web_data_internal(key, scan_for_malicious=scan_for_malicious)
@server_api.get("/ping") # Get server and optionally session status.
def ping(session_id: str = None):
@ -179,7 +179,7 @@ def update_render_devices_in_config(config, render_devices):
config["render_devices"] = render_devices
def read_web_data_internal(key: str = None):
def read_web_data_internal(key: str = None, **kwargs):
if not key: # /get without parameters, stable-diffusion easter egg.
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == "app_config":
@ -198,7 +198,8 @@ def read_web_data_internal(key: str = None):
system_info["devices"]["config"] = config.get("render_devices", "auto")
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
elif key == "models":
return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS)
scan_for_malicious = kwargs.get("scan_for_malicious", True)
return JSONResponse(model_manager.getModels(scan_for_malicious), headers=NOCACHE_HEADERS)
elif key == "modifiers":
return JSONResponse(app.get_image_modifiers(), headers=NOCACHE_HEADERS)
elif key == "ui_plugins":
@ -334,7 +335,8 @@ def get_image_internal(task_id: int, img_id: int):
except KeyError as e:
raise HTTPException(status_code=500, detail=str(e))
#---- Cloudflare Tunnel ----
# ---- Cloudflare Tunnel ----
class CloudflareTunnel:
def __init__(self):
config = app.getConfig()
@ -357,23 +359,25 @@ class CloudflareTunnel:
else:
return None
cloudflare = CloudflareTunnel()
def start_cloudflare_tunnel_internal(req: dict):
try:
cloudflare.start()
log.info(f"- Started cloudflare tunnel. Using address: {cloudflare.address}")
return JSONResponse({"address":cloudflare.address})
except Exception as e:
log.error(str(e))
log.error(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
try:
cloudflare.start()
log.info(f"- Started cloudflare tunnel. Using address: {cloudflare.address}")
return JSONResponse({"address": cloudflare.address})
except Exception as e:
log.error(str(e))
log.error(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
def stop_cloudflare_tunnel_internal(req: dict):
try:
cloudflare.stop()
except Exception as e:
log.error(str(e))
log.error(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
try:
cloudflare.stop()
except Exception as e:
log.error(str(e))
log.error(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))

View File

@ -473,16 +473,16 @@ def start_render_thread(device):
render_threads.append(rthread)
finally:
manager_lock.release()
timeout = DEVICE_START_TIMEOUT
while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]:
if rthread in weak_thread_data and "error" in weak_thread_data[rthread]:
log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
return False
if timeout <= 0:
return False
timeout -= 1
time.sleep(1)
return True
# timeout = DEVICE_START_TIMEOUT
# while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]:
# if rthread in weak_thread_data and "error" in weak_thread_data[rthread]:
# log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
# return False
# if timeout <= 0:
# return False
# timeout -= 1
# time.sleep(1)
# return True
def stop_render_thread(device):
@ -535,12 +535,12 @@ def update_render_threads(render_devices, active_devices):
if not start_render_thread(device):
log.warn(f"{device} failed to start.")
if is_alive() <= 0: # No running devices, probably invalid user config.
raise EnvironmentError(
'ERROR: No active render devices! Please verify the "render_devices" value in config.json'
)
# if is_alive() <= 0: # No running devices, probably invalid user config.
# raise EnvironmentError(
# 'ERROR: No active render devices! Please verify the "render_devices" value in config.json'
# )
log.debug(f"active devices: {get_devices()['active']}")
# log.debug(f"active devices: {get_devices()['active']}")
def shutdown_event(): # Signal render thread to close on shutdown

View File

@ -624,7 +624,7 @@
<script>
async function init() {
await initSettings()
await getModels()
await getModels(false)
await getAppConfig()
await loadUIPlugins()
await loadModifiers()
@ -640,6 +640,9 @@ async function init() {
})
splashScreen()
// load models again, but scan for malicious this time
await getModels(true)
// playSound()
}

View File

@ -1,5 +1,5 @@
from easydiffusion import model_manager, app, server
from easydiffusion.server import server_api # required for uvicorn
from easydiffusion.server import server_api # required for uvicorn
# Init the app
model_manager.init()
@ -8,3 +8,5 @@ server.init()
# start the browser ui
app.open_browser()
app.init_render_threads()

View File

@ -1121,13 +1121,13 @@
return systemInfo.hosts
}
async function getModels() {
async function getModels(scanForMalicious = true) {
let models = {
"stable-diffusion": [],
vae: [],
}
try {
const res = await fetch("/get/models")
const res = await fetch("/get/models?scan_for_malicious=" + scanForMalicious)
if (!res.ok) {
console.error("Invalid response fetching models", res.statusText)
return models

View File

@ -627,9 +627,9 @@ class ModelDropdown {
}
/* (RE)LOAD THE MODELS */
async function getModels() {
async function getModels(scanForMalicious = true) {
try {
modelsCache = await SD.getModels()
modelsCache = await SD.getModels(scanForMalicious)
modelsOptions = modelsCache["options"]
if ("scan-error" in modelsCache) {
// let previewPane = document.getElementById('tab-content-wrapper')
@ -667,4 +667,4 @@ async function getModels() {
}
// reload models button
document.querySelector("#reload-models").addEventListener("click", getModels)
document.querySelector("#reload-models").addEventListener("click", () => getModels())