diff --git a/ui/easydiffusion/model_manager.py b/ui/easydiffusion/model_manager.py index fefcc916..22eb7201 100644 --- a/ui/easydiffusion/model_manager.py +++ b/ui/easydiffusion/model_manager.py @@ -190,6 +190,30 @@ def getModels(): } models_scanned = 0 + + class MaliciousModelException(Exception): + "Raised when picklescan reports a problem with a model" + pass + + def scan_directory(directory, suffixes): + nonlocal models_scanned + tree = [] + for entry in os.scandir(directory): + if entry.is_file() and True in [entry.name.endswith(s) for s in suffixes]: + mtime = entry.stat().st_mtime + mod_time = known_models[entry.path] if entry.path in known_models else -1 + if mod_time != mtime: + models_scanned += 1 + if is_malicious_model(entry.path): + raise MaliciousModelException(entry.path) + known_models[entry.path] = mtime + tree.append(entry.name.rsplit('.',1)[0]) + elif entry.is_dir(): + scan=scan_directory(entry.path, suffixes) + if len(scan) != 0: + tree.append( (entry.name, scan ) ) + return tree + def listModels(model_type): nonlocal models_scanned @@ -198,26 +222,10 @@ def getModels(): if not os.path.exists(models_dir): os.makedirs(models_dir) - for file in os.listdir(models_dir): - for model_extension in model_extensions: - if not file.endswith(model_extension): - continue - - model_path = os.path.join(models_dir, file) - mtime = os.path.getmtime(model_path) - mod_time = known_models[model_path] if model_path in known_models else -1 - if mod_time != mtime: - models_scanned += 1 - if is_malicious_model(model_path): - models['scan-error'] = file - return - known_models[model_path] = mtime - - model_name = file[:-len(model_extension)] - models['options'][model_type].append(model_name) - - models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates - models['options'][model_type].sort() + try: + models['options'][model_type] = scan_directory(models_dir, model_extensions) + except MaliciousModelException as e: + models['scan-error'] = e # custom models listModels(model_type='stable-diffusion') diff --git a/ui/media/js/main.js b/ui/media/js/main.js index 7d70233e..10e5797d 100644 --- a/ui/media/js/main.js +++ b/ui/media/js/main.js @@ -1305,15 +1305,21 @@ async function getModels() { function createModelOptions(modelField, selectedModel) { return function(modelName) { - const modelOption = document.createElement('option') - modelOption.value = modelName - modelOption.innerText = modelName !== '' ? modelName : 'None' + if (typeof(modelName) == 'string') { + const modelOption = document.createElement('option') + modelOption.value = modelName + modelOption.innerText = modelName !== '' ? modelName : 'None' - if (modelName === selectedModel) { - modelOption.selected = true - } - - modelField.appendChild(modelOption) + if (modelName === selectedModel) { + modelOption.selected = true + } + modelField.appendChild(modelOption) + } else { + const modelGroup = document.createElement('optgroup') + modelGroup.label = modelName[0] + modelName[1].forEach( createModelOptions(modelGroup, selectedModel) ) + modelField.appendChild(modelGroup) + } } }