From 2de96d4dc9767f6514a9576a8bdf492f11e43474 Mon Sep 17 00:00:00 2001 From: Marc-Andre Ferland Date: Tue, 22 Nov 2022 11:24:36 -0500 Subject: [PATCH] Scan model once as start, then only if changed. --- ui/server.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/ui/server.py b/ui/server.py index 61635f18..d03d3113 100644 --- a/ui/server.py +++ b/ui/server.py @@ -7,7 +7,6 @@ import traceback import sys import os -import picklescan.scanner import rich SD_DIR = os.getcwd() @@ -221,6 +220,7 @@ async def setAppConfig(req : SetAppConfigRequest): def is_malicious_model(file_path): try: + import picklescan.scanner scan_result = picklescan.scanner.scan_file_path(file_path) if scan_result.issues_count > 0 or scan_result.infected_files > 0: rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) @@ -230,9 +230,9 @@ def is_malicious_model(file_path): return False except Exception as e: print('error while scanning', file_path, 'error:', e) - return False +known_models = {} def getModels(): models = { 'active': { @@ -255,9 +255,14 @@ def getModels(): if not file.endswith(model_extension): continue - if is_malicious_model(os.path.join(models_dir, file)): - models['scan-error'] = file - return + model_path = os.path.join(models_dir, file) + mtime = os.path.getmtime(model_path) + mod_time = known_models[model_path] if model_path in known_models else -1 + if mod_time != mtime: + if is_malicious_model(model_path): + models['scan-error'] = file + return + known_models[model_path] = mtime model_name = file[:-len(model_extension)] models['options'][model_type].append(model_name) @@ -435,6 +440,9 @@ class LogSuppressFilter(logging.Filter): return True logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) +# Check models and prepare cache for UI open +getModels() + # Start the task_manager task_manager.default_model_to_load = resolve_ckpt_to_use() task_manager.default_vae_to_load = resolve_vae_to_use()