mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-01-27 00:29:38 +01:00
Use the sdkit model scan; Disable scan-per-load since we scan them before allowing them to be invoked
This commit is contained in:
parent
d8b79d8b5c
commit
d8543d1358
@ -1,12 +1,11 @@
|
||||
import os
|
||||
import picklescan.scanner
|
||||
|
||||
from easydiffusion import app, device_manager
|
||||
from easydiffusion.types import TaskData
|
||||
from easydiffusion.utils import log
|
||||
|
||||
from sdkit import Context
|
||||
from sdkit.models import load_model, unload_model, get_known_model_info
|
||||
from sdkit.models import load_model, unload_model, get_known_model_info, scan_model
|
||||
from sdkit.utils import hash_file_quick
|
||||
|
||||
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
||||
@ -45,7 +44,7 @@ def load_default_models(context: Context):
|
||||
set_vram_optimizations(context)
|
||||
|
||||
# load mandatory models
|
||||
load_model(context, 'stable-diffusion')
|
||||
load_model(context, 'stable-diffusion', scan_model=False) # we've scanned them already
|
||||
load_model(context, 'vae')
|
||||
load_model(context, 'hypernetwork')
|
||||
|
||||
@ -113,7 +112,7 @@ def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||
context.model_paths[model_type] = model_path_in_req
|
||||
|
||||
action_fn = unload_model if context.model_paths[model_type] is None else load_model
|
||||
action_fn(context, model_type)
|
||||
action_fn(context, model_type, scan_model=False) # we've scanned them already
|
||||
|
||||
def resolve_model_paths(task_data: TaskData):
|
||||
task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
|
||||
@ -157,7 +156,7 @@ def make_model_folders():
|
||||
|
||||
def is_malicious_model(file_path):
|
||||
try:
|
||||
scan_result = picklescan.scanner.scan_file_path(file_path)
|
||||
scan_result = scan_model(file_path)
|
||||
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
|
||||
log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
||||
return True
|
||||
|
Loading…
Reference in New Issue
Block a user