forked from extern/easydiffusion
Scan only the model files (check by extension), minor refactoring of the
scanning code
This commit is contained in:
parent
87c6a54634
commit
f7af259576
44
ui/server.py
44
ui/server.py
@ -23,6 +23,9 @@ USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'
|
|||||||
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
|
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
|
||||||
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
|
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
|
||||||
|
|
||||||
|
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt']
|
||||||
|
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
|
||||||
|
|
||||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||||
APP_CONFIG_DEFAULTS = {
|
APP_CONFIG_DEFAULTS = {
|
||||||
@ -162,11 +165,11 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex
|
|||||||
raise Exception('No valid models found.')
|
raise Exception('No valid models found.')
|
||||||
|
|
||||||
def resolve_ckpt_to_use(model_name:str=None):
|
def resolve_ckpt_to_use(model_name:str=None):
|
||||||
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=['.ckpt'], default_models=APP_CONFIG_DEFAULT_MODELS)
|
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=APP_CONFIG_DEFAULT_MODELS)
|
||||||
|
|
||||||
def resolve_vae_to_use(model_name:str=None):
|
def resolve_vae_to_use(model_name:str=None):
|
||||||
try:
|
try:
|
||||||
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=['.vae.pt', '.ckpt'], default_models=[])
|
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[])
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -198,6 +201,20 @@ async def setAppConfig(req : SetAppConfigRequest):
|
|||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
def is_malicious_model(file_path):
|
||||||
|
try:
|
||||||
|
scan_result = picklescan.scanner.scan_file_path(file_path)
|
||||||
|
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
|
||||||
|
rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print('error while scanning', file_path, 'error:', e)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def getModels():
|
def getModels():
|
||||||
models = {
|
models = {
|
||||||
'active': {
|
'active': {
|
||||||
@ -216,19 +233,14 @@ def getModels():
|
|||||||
os.makedirs(models_dir)
|
os.makedirs(models_dir)
|
||||||
|
|
||||||
for file in os.listdir(models_dir):
|
for file in os.listdir(models_dir):
|
||||||
try:
|
|
||||||
scan_result = picklescan.scanner.scan_file_path( os.path.join(models_dir, file))
|
|
||||||
if ( scan_result.issues_count >0 or scan_result.infected_files >0):
|
|
||||||
rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % ( file, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files) )
|
|
||||||
models['scan-error'] = file
|
|
||||||
return models
|
|
||||||
else:
|
|
||||||
rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % ( file, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files ) )
|
|
||||||
except Exception as e:
|
|
||||||
print('error while scanning', os.path.join(models_dir, file), 'error:', e)
|
|
||||||
|
|
||||||
for model_extension in model_extensions:
|
for model_extension in model_extensions:
|
||||||
if file.endswith(model_extension):
|
if not file.endswith(model_extension):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_malicious_model(os.path.join(models_dir, file)):
|
||||||
|
models['scan-error'] = file
|
||||||
|
return
|
||||||
|
|
||||||
model_name = file[:-len(model_extension)]
|
model_name = file[:-len(model_extension)]
|
||||||
models['options'][model_type].append(model_name)
|
models['options'][model_type].append(model_name)
|
||||||
|
|
||||||
@ -236,8 +248,8 @@ def getModels():
|
|||||||
models['options'][model_type].sort()
|
models['options'][model_type].sort()
|
||||||
|
|
||||||
# custom models
|
# custom models
|
||||||
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=['.ckpt'])
|
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
|
||||||
listModels(models_dirname='vae', model_type='vae', model_extensions=['.vae.pt', '.ckpt'])
|
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
|
||||||
|
|
||||||
# legacy
|
# legacy
|
||||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
||||||
|
Loading…
Reference in New Issue
Block a user