diff --git a/ui/server.py b/ui/server.py index ab910bfa..94b7fce8 100644 --- a/ui/server.py +++ b/ui/server.py @@ -23,6 +23,9 @@ USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui' CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui')) UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user')) +STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt'] +VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] + OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder TASK_TTL = 15 * 60 # Discard last session's task timeout APP_CONFIG_DEFAULTS = { @@ -162,11 +165,11 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex raise Exception('No valid models found.') def resolve_ckpt_to_use(model_name:str=None): - return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=['.ckpt'], default_models=APP_CONFIG_DEFAULT_MODELS) + return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=APP_CONFIG_DEFAULT_MODELS) def resolve_vae_to_use(model_name:str=None): try: - return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=['.vae.pt', '.ckpt'], default_models=[]) + return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[]) except: return None @@ -198,6 +201,20 @@ async def setAppConfig(req : SetAppConfigRequest): print(traceback.format_exc()) raise HTTPException(status_code=500, detail=str(e)) +def is_malicious_model(file_path): + try: + scan_result = picklescan.scanner.scan_file_path(file_path) + if scan_result.issues_count > 0 or scan_result.infected_files > 0: + rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + return True + else: + rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files)) + return False + except Exception as e: + print('error while scanning', file_path, 'error:', e) + + return False + def getModels(): models = { 'active': { @@ -216,28 +233,23 @@ def getModels(): os.makedirs(models_dir) for file in os.listdir(models_dir): - try: - scan_result = picklescan.scanner.scan_file_path( os.path.join(models_dir, file)) - if ( scan_result.issues_count >0 or scan_result.infected_files >0): - rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % ( file, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files) ) - models['scan-error'] = file - return models - else: - rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % ( file, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files ) ) - except Exception as e: - print('error while scanning', os.path.join(models_dir, file), 'error:', e) - for model_extension in model_extensions: - if file.endswith(model_extension): - model_name = file[:-len(model_extension)] - models['options'][model_type].append(model_name) + if not file.endswith(model_extension): + continue + + if is_malicious_model(os.path.join(models_dir, file)): + models['scan-error'] = file + return + + model_name = file[:-len(model_extension)] + models['options'][model_type].append(model_name) models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates models['options'][model_type].sort() # custom models - listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=['.ckpt']) - listModels(models_dirname='vae', model_type='vae', model_extensions=['.vae.pt', '.ckpt']) + listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS) + listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS) # legacy custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')