mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-02-23 22:01:25 +01:00
Improve responsiveness of UI startup by not waiting for render threads to start up before showing the UI. Errors while starting the render thread will be logged anyway, so there's no need to block the main thread for this
This commit is contained in:
parent
2bd1cceb24
commit
37e8158175
@ -96,6 +96,8 @@ def init():
|
||||
# https://pytorch.org/docs/stable/storage.html
|
||||
warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")
|
||||
|
||||
|
||||
def init_render_threads():
|
||||
load_server_plugins()
|
||||
|
||||
update_render_threads()
|
||||
@ -279,6 +281,8 @@ def open_browser():
|
||||
if ui.get("open_browser_on_start", True):
|
||||
import webbrowser
|
||||
|
||||
log.info("Opening browser..")
|
||||
|
||||
webbrowser.open(f"http://localhost:{port}")
|
||||
|
||||
Console().print(
|
||||
|
@ -52,7 +52,6 @@ def init():
|
||||
make_model_folders()
|
||||
migrate_legacy_model_location() # if necessary
|
||||
download_default_models_if_necessary()
|
||||
getModels() # run this once, to cache the picklescan results
|
||||
|
||||
|
||||
def load_default_models(context: Context):
|
||||
@ -310,7 +309,7 @@ def is_malicious_model(file_path):
|
||||
return False
|
||||
|
||||
|
||||
def getModels():
|
||||
def getModels(scan_for_malicious: bool = True):
|
||||
models = {
|
||||
"options": {
|
||||
"stable-diffusion": ["sd-v1-4"],
|
||||
@ -343,9 +342,10 @@ def getModels():
|
||||
mod_time = known_models[entry.path] if entry.path in known_models else -1
|
||||
if mod_time != mtime:
|
||||
models_scanned += 1
|
||||
if is_malicious_model(entry.path):
|
||||
if scan_for_malicious and is_malicious_model(entry.path):
|
||||
raise MaliciousModelException(entry.path)
|
||||
known_models[entry.path] = mtime
|
||||
if scan_for_malicious:
|
||||
known_models[entry.path] = mtime
|
||||
tree.append(entry.name[: -len(matching_suffix)])
|
||||
elif entry.is_dir():
|
||||
scan = scan_directory(entry.path, suffixes, directoriesFirst=False)
|
||||
@ -365,9 +365,10 @@ def getModels():
|
||||
try:
|
||||
models["options"][model_type] = scan_directory(models_dir, model_extensions)
|
||||
except MaliciousModelException as e:
|
||||
models["scan-error"] = e
|
||||
models["scan-error"] = str(e)
|
||||
|
||||
log.info(f"[green]Scanning all model folders for models...[/]")
|
||||
if scan_for_malicious:
|
||||
log.info(f"[green]Scanning all model folders for models...[/]")
|
||||
# custom models
|
||||
listModels(model_type="stable-diffusion")
|
||||
listModels(model_type="vae")
|
||||
@ -375,7 +376,7 @@ def getModels():
|
||||
listModels(model_type="gfpgan")
|
||||
listModels(model_type="lora")
|
||||
|
||||
if models_scanned > 0:
|
||||
if scan_for_malicious and models_scanned > 0:
|
||||
log.info(f"[green]Scanned {models_scanned} models. Nothing infected[/]")
|
||||
|
||||
return models
|
||||
|
@ -86,8 +86,8 @@ def init():
|
||||
return set_app_config_internal(req)
|
||||
|
||||
@server_api.get("/get/{key:path}")
|
||||
def read_web_data(key: str = None):
|
||||
return read_web_data_internal(key)
|
||||
def read_web_data(key: str = None, scan_for_malicious: bool = True):
|
||||
return read_web_data_internal(key, scan_for_malicious=scan_for_malicious)
|
||||
|
||||
@server_api.get("/ping") # Get server and optionally session status.
|
||||
def ping(session_id: str = None):
|
||||
@ -179,7 +179,7 @@ def update_render_devices_in_config(config, render_devices):
|
||||
config["render_devices"] = render_devices
|
||||
|
||||
|
||||
def read_web_data_internal(key: str = None):
|
||||
def read_web_data_internal(key: str = None, **kwargs):
|
||||
if not key: # /get without parameters, stable-diffusion easter egg.
|
||||
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
||||
elif key == "app_config":
|
||||
@ -198,7 +198,8 @@ def read_web_data_internal(key: str = None):
|
||||
system_info["devices"]["config"] = config.get("render_devices", "auto")
|
||||
return JSONResponse(system_info, headers=NOCACHE_HEADERS)
|
||||
elif key == "models":
|
||||
return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS)
|
||||
scan_for_malicious = kwargs.get("scan_for_malicious", True)
|
||||
return JSONResponse(model_manager.getModels(scan_for_malicious), headers=NOCACHE_HEADERS)
|
||||
elif key == "modifiers":
|
||||
return JSONResponse(app.get_image_modifiers(), headers=NOCACHE_HEADERS)
|
||||
elif key == "ui_plugins":
|
||||
@ -334,7 +335,8 @@ def get_image_internal(task_id: int, img_id: int):
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
#---- Cloudflare Tunnel ----
|
||||
|
||||
# ---- Cloudflare Tunnel ----
|
||||
class CloudflareTunnel:
|
||||
def __init__(self):
|
||||
config = app.getConfig()
|
||||
@ -357,23 +359,25 @@ class CloudflareTunnel:
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
cloudflare = CloudflareTunnel()
|
||||
|
||||
|
||||
def start_cloudflare_tunnel_internal(req: dict):
|
||||
try:
|
||||
cloudflare.start()
|
||||
log.info(f"- Started cloudflare tunnel. Using address: {cloudflare.address}")
|
||||
return JSONResponse({"address":cloudflare.address})
|
||||
except Exception as e:
|
||||
log.error(str(e))
|
||||
log.error(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
try:
|
||||
cloudflare.start()
|
||||
log.info(f"- Started cloudflare tunnel. Using address: {cloudflare.address}")
|
||||
return JSONResponse({"address": cloudflare.address})
|
||||
except Exception as e:
|
||||
log.error(str(e))
|
||||
log.error(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def stop_cloudflare_tunnel_internal(req: dict):
|
||||
try:
|
||||
cloudflare.stop()
|
||||
except Exception as e:
|
||||
log.error(str(e))
|
||||
log.error(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
try:
|
||||
cloudflare.stop()
|
||||
except Exception as e:
|
||||
log.error(str(e))
|
||||
log.error(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
@ -473,16 +473,16 @@ def start_render_thread(device):
|
||||
render_threads.append(rthread)
|
||||
finally:
|
||||
manager_lock.release()
|
||||
timeout = DEVICE_START_TIMEOUT
|
||||
while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]:
|
||||
if rthread in weak_thread_data and "error" in weak_thread_data[rthread]:
|
||||
log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
|
||||
return False
|
||||
if timeout <= 0:
|
||||
return False
|
||||
timeout -= 1
|
||||
time.sleep(1)
|
||||
return True
|
||||
# timeout = DEVICE_START_TIMEOUT
|
||||
# while not rthread.is_alive() or not rthread in weak_thread_data or not "device" in weak_thread_data[rthread]:
|
||||
# if rthread in weak_thread_data and "error" in weak_thread_data[rthread]:
|
||||
# log.error(f"{rthread}, {device}, error: {weak_thread_data[rthread]['error']}")
|
||||
# return False
|
||||
# if timeout <= 0:
|
||||
# return False
|
||||
# timeout -= 1
|
||||
# time.sleep(1)
|
||||
# return True
|
||||
|
||||
|
||||
def stop_render_thread(device):
|
||||
@ -535,12 +535,12 @@ def update_render_threads(render_devices, active_devices):
|
||||
if not start_render_thread(device):
|
||||
log.warn(f"{device} failed to start.")
|
||||
|
||||
if is_alive() <= 0: # No running devices, probably invalid user config.
|
||||
raise EnvironmentError(
|
||||
'ERROR: No active render devices! Please verify the "render_devices" value in config.json'
|
||||
)
|
||||
# if is_alive() <= 0: # No running devices, probably invalid user config.
|
||||
# raise EnvironmentError(
|
||||
# 'ERROR: No active render devices! Please verify the "render_devices" value in config.json'
|
||||
# )
|
||||
|
||||
log.debug(f"active devices: {get_devices()['active']}")
|
||||
# log.debug(f"active devices: {get_devices()['active']}")
|
||||
|
||||
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
|
@ -624,7 +624,7 @@
|
||||
<script>
|
||||
async function init() {
|
||||
await initSettings()
|
||||
await getModels()
|
||||
await getModels(false)
|
||||
await getAppConfig()
|
||||
await loadUIPlugins()
|
||||
await loadModifiers()
|
||||
@ -640,6 +640,9 @@ async function init() {
|
||||
})
|
||||
splashScreen()
|
||||
|
||||
// load models again, but scan for malicious this time
|
||||
await getModels(true)
|
||||
|
||||
// playSound()
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from easydiffusion import model_manager, app, server
|
||||
from easydiffusion.server import server_api # required for uvicorn
|
||||
from easydiffusion.server import server_api # required for uvicorn
|
||||
|
||||
# Init the app
|
||||
model_manager.init()
|
||||
@ -8,3 +8,5 @@ server.init()
|
||||
|
||||
# start the browser ui
|
||||
app.open_browser()
|
||||
|
||||
app.init_render_threads()
|
||||
|
@ -1121,13 +1121,13 @@
|
||||
return systemInfo.hosts
|
||||
}
|
||||
|
||||
async function getModels() {
|
||||
async function getModels(scanForMalicious = true) {
|
||||
let models = {
|
||||
"stable-diffusion": [],
|
||||
vae: [],
|
||||
}
|
||||
try {
|
||||
const res = await fetch("/get/models")
|
||||
const res = await fetch("/get/models?scan_for_malicious=" + scanForMalicious)
|
||||
if (!res.ok) {
|
||||
console.error("Invalid response fetching models", res.statusText)
|
||||
return models
|
||||
|
@ -627,9 +627,9 @@ class ModelDropdown {
|
||||
}
|
||||
|
||||
/* (RE)LOAD THE MODELS */
|
||||
async function getModels() {
|
||||
async function getModels(scanForMalicious = true) {
|
||||
try {
|
||||
modelsCache = await SD.getModels()
|
||||
modelsCache = await SD.getModels(scanForMalicious)
|
||||
modelsOptions = modelsCache["options"]
|
||||
if ("scan-error" in modelsCache) {
|
||||
// let previewPane = document.getElementById('tab-content-wrapper')
|
||||
@ -667,4 +667,4 @@ async function getModels() {
|
||||
}
|
||||
|
||||
// reload models button
|
||||
document.querySelector("#reload-models").addEventListener("click", getModels)
|
||||
document.querySelector("#reload-models").addEventListener("click", () => getModels())
|
||||
|
Loading…
Reference in New Issue
Block a user