Recursive scanning for models

This commit is contained in:
JeLuF 2023-01-07 19:04:15 +01:00
parent 2d9853f1f4
commit 9532928998
2 changed files with 42 additions and 28 deletions

View File

@ -190,6 +190,30 @@ def getModels():
} }
models_scanned = 0 models_scanned = 0
class MaliciousModelException(Exception):
"Raised when picklescan reports a problem with a model"
pass
def scan_directory(directory, suffixes):
nonlocal models_scanned
tree = []
for entry in os.scandir(directory):
if entry.is_file() and True in [entry.name.endswith(s) for s in suffixes]:
mtime = entry.stat().st_mtime
mod_time = known_models[entry.path] if entry.path in known_models else -1
if mod_time != mtime:
models_scanned += 1
if is_malicious_model(entry.path):
raise MaliciousModelException(entry.path)
known_models[entry.path] = mtime
tree.append(entry.name.rsplit('.',1)[0])
elif entry.is_dir():
scan=scan_directory(entry.path, suffixes)
if len(scan) != 0:
tree.append( (entry.name, scan ) )
return tree
def listModels(model_type): def listModels(model_type):
nonlocal models_scanned nonlocal models_scanned
@ -198,26 +222,10 @@ def getModels():
if not os.path.exists(models_dir): if not os.path.exists(models_dir):
os.makedirs(models_dir) os.makedirs(models_dir)
for file in os.listdir(models_dir): try:
for model_extension in model_extensions: models['options'][model_type] = scan_directory(models_dir, model_extensions)
if not file.endswith(model_extension): except MaliciousModelException as e:
continue models['scan-error'] = e
model_path = os.path.join(models_dir, file)
mtime = os.path.getmtime(model_path)
mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime:
models_scanned += 1
if is_malicious_model(model_path):
models['scan-error'] = file
return
known_models[model_path] = mtime
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates
models['options'][model_type].sort()
# custom models # custom models
listModels(model_type='stable-diffusion') listModels(model_type='stable-diffusion')

View File

@ -1305,6 +1305,7 @@ async function getModels() {
function createModelOptions(modelField, selectedModel) { function createModelOptions(modelField, selectedModel) {
return function(modelName) { return function(modelName) {
if (typeof(modelName) == 'string') {
const modelOption = document.createElement('option') const modelOption = document.createElement('option')
modelOption.value = modelName modelOption.value = modelName
modelOption.innerText = modelName !== '' ? modelName : 'None' modelOption.innerText = modelName !== '' ? modelName : 'None'
@ -1312,8 +1313,13 @@ async function getModels() {
if (modelName === selectedModel) { if (modelName === selectedModel) {
modelOption.selected = true modelOption.selected = true
} }
modelField.appendChild(modelOption) modelField.appendChild(modelOption)
} else {
const modelGroup = document.createElement('optgroup')
modelGroup.label = modelName[0]
modelName[1].forEach( createModelOptions(modelGroup, selectedModel) )
modelField.appendChild(modelGroup)
}
} }
} }