Merge pull request #530 from madrang/list-models

Scan model once as start, then only if changed.
This commit is contained in:
cmdr2 2022-11-30 13:37:27 +05:30 committed by GitHub
commit efd9a22bb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -7,7 +7,6 @@ import traceback
import sys import sys
import os import os
import picklescan.scanner
import rich import rich
SD_DIR = os.getcwd() SD_DIR = os.getcwd()
@ -235,6 +234,7 @@ async def setAppConfig(req : SetAppConfigRequest):
def is_malicious_model(file_path): def is_malicious_model(file_path):
try: try:
import picklescan.scanner
scan_result = picklescan.scanner.scan_file_path(file_path) scan_result = picklescan.scanner.scan_file_path(file_path)
if scan_result.issues_count > 0 or scan_result.infected_files > 0: if scan_result.issues_count > 0 or scan_result.infected_files > 0:
rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
@ -244,9 +244,9 @@ def is_malicious_model(file_path):
return False return False
except Exception as e: except Exception as e:
print('error while scanning', file_path, 'error:', e) print('error while scanning', file_path, 'error:', e)
return False return False
known_models = {}
def getModels(): def getModels():
models = { models = {
'active': { 'active': {
@ -269,9 +269,14 @@ def getModels():
if not file.endswith(model_extension): if not file.endswith(model_extension):
continue continue
if is_malicious_model(os.path.join(models_dir, file)): model_path = os.path.join(models_dir, file)
models['scan-error'] = file mtime = os.path.getmtime(model_path)
return mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime:
if is_malicious_model(model_path):
models['scan-error'] = file
return
known_models[model_path] = mtime
model_name = file[:-len(model_extension)] model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name) models['options'][model_type].append(model_name)
@ -449,6 +454,9 @@ class LogSuppressFilter(logging.Filter):
return True return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter()) logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
# Check models and prepare cache for UI open
getModels()
# Start the task_manager # Start the task_manager
task_manager.default_model_to_load = resolve_ckpt_to_use() task_manager.default_model_to_load = resolve_ckpt_to_use()
task_manager.default_vae_to_load = resolve_vae_to_use() task_manager.default_vae_to_load = resolve_vae_to_use()