forked from extern/easydiffusion
Recursive scanning for models
This commit is contained in:
parent
2d9853f1f4
commit
9532928998
@ -190,6 +190,30 @@ def getModels():
|
||||
}
|
||||
|
||||
models_scanned = 0
|
||||
|
||||
class MaliciousModelException(Exception):
|
||||
"Raised when picklescan reports a problem with a model"
|
||||
pass
|
||||
|
||||
def scan_directory(directory, suffixes):
|
||||
nonlocal models_scanned
|
||||
tree = []
|
||||
for entry in os.scandir(directory):
|
||||
if entry.is_file() and True in [entry.name.endswith(s) for s in suffixes]:
|
||||
mtime = entry.stat().st_mtime
|
||||
mod_time = known_models[entry.path] if entry.path in known_models else -1
|
||||
if mod_time != mtime:
|
||||
models_scanned += 1
|
||||
if is_malicious_model(entry.path):
|
||||
raise MaliciousModelException(entry.path)
|
||||
known_models[entry.path] = mtime
|
||||
tree.append(entry.name.rsplit('.',1)[0])
|
||||
elif entry.is_dir():
|
||||
scan=scan_directory(entry.path, suffixes)
|
||||
if len(scan) != 0:
|
||||
tree.append( (entry.name, scan ) )
|
||||
return tree
|
||||
|
||||
def listModels(model_type):
|
||||
nonlocal models_scanned
|
||||
|
||||
@ -198,26 +222,10 @@ def getModels():
|
||||
if not os.path.exists(models_dir):
|
||||
os.makedirs(models_dir)
|
||||
|
||||
for file in os.listdir(models_dir):
|
||||
for model_extension in model_extensions:
|
||||
if not file.endswith(model_extension):
|
||||
continue
|
||||
|
||||
model_path = os.path.join(models_dir, file)
|
||||
mtime = os.path.getmtime(model_path)
|
||||
mod_time = known_models[model_path] if model_path in known_models else -1
|
||||
if mod_time != mtime:
|
||||
models_scanned += 1
|
||||
if is_malicious_model(model_path):
|
||||
models['scan-error'] = file
|
||||
return
|
||||
known_models[model_path] = mtime
|
||||
|
||||
model_name = file[:-len(model_extension)]
|
||||
models['options'][model_type].append(model_name)
|
||||
|
||||
models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates
|
||||
models['options'][model_type].sort()
|
||||
try:
|
||||
models['options'][model_type] = scan_directory(models_dir, model_extensions)
|
||||
except MaliciousModelException as e:
|
||||
models['scan-error'] = e
|
||||
|
||||
# custom models
|
||||
listModels(model_type='stable-diffusion')
|
||||
|
@ -1305,15 +1305,21 @@ async function getModels() {
|
||||
|
||||
function createModelOptions(modelField, selectedModel) {
|
||||
return function(modelName) {
|
||||
const modelOption = document.createElement('option')
|
||||
modelOption.value = modelName
|
||||
modelOption.innerText = modelName !== '' ? modelName : 'None'
|
||||
if (typeof(modelName) == 'string') {
|
||||
const modelOption = document.createElement('option')
|
||||
modelOption.value = modelName
|
||||
modelOption.innerText = modelName !== '' ? modelName : 'None'
|
||||
|
||||
if (modelName === selectedModel) {
|
||||
modelOption.selected = true
|
||||
}
|
||||
|
||||
modelField.appendChild(modelOption)
|
||||
if (modelName === selectedModel) {
|
||||
modelOption.selected = true
|
||||
}
|
||||
modelField.appendChild(modelOption)
|
||||
} else {
|
||||
const modelGroup = document.createElement('optgroup')
|
||||
modelGroup.label = modelName[0]
|
||||
modelName[1].forEach( createModelOptions(modelGroup, selectedModel) )
|
||||
modelField.appendChild(modelGroup)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user