forked from extern/easydiffusion
142 lines
6.0 KiB
Python
142 lines
6.0 KiB
Python
import os
|
|
|
|
from sd_internal import app
|
|
import picklescan.scanner
|
|
import rich
|
|
|
|
default_model_to_load = None
|
|
default_vae_to_load = None
|
|
default_hypernetwork_to_load = None
|
|
|
|
known_models = {}
|
|
|
|
def init():
|
|
global default_model_to_load, default_vae_to_load, default_hypernetwork_to_load
|
|
|
|
default_model_to_load = resolve_ckpt_to_use()
|
|
default_vae_to_load = resolve_vae_to_use()
|
|
default_hypernetwork_to_load = resolve_hypernetwork_to_use()
|
|
|
|
getModels() # run this once, to cache the picklescan results
|
|
|
|
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
|
|
config = app.getConfig()
|
|
|
|
model_dirs = [os.path.join(app.MODELS_DIR, model_dir), app.SD_DIR]
|
|
if not model_name: # When None try user configured model.
|
|
# config = getConfig()
|
|
if 'model' in config and model_type in config['model']:
|
|
model_name = config['model'][model_type]
|
|
|
|
if model_name:
|
|
is_sd2 = config.get('test_sd2', False)
|
|
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
|
|
print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
|
|
model_name = 'sd-v1-4'
|
|
|
|
# Check models directory
|
|
models_dir_path = os.path.join(app.MODELS_DIR, model_dir, model_name)
|
|
for model_extension in model_extensions:
|
|
if os.path.exists(models_dir_path + model_extension):
|
|
return models_dir_path + model_extension
|
|
if os.path.exists(model_name + model_extension):
|
|
return os.path.abspath(model_name + model_extension)
|
|
|
|
# Default locations
|
|
if model_name in default_models:
|
|
default_model_path = os.path.join(app.SD_DIR, model_name)
|
|
for model_extension in model_extensions:
|
|
if os.path.exists(default_model_path + model_extension):
|
|
return default_model_path + model_extension
|
|
|
|
# Can't find requested model, check the default paths.
|
|
for default_model in default_models:
|
|
for model_dir in model_dirs:
|
|
default_model_path = os.path.join(model_dir, default_model)
|
|
for model_extension in model_extensions:
|
|
if os.path.exists(default_model_path + model_extension):
|
|
if model_name is not None:
|
|
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
|
|
return default_model_path + model_extension
|
|
|
|
raise Exception('No valid models found.')
|
|
|
|
def resolve_ckpt_to_use(model_name:str=None):
|
|
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS)
|
|
|
|
def resolve_vae_to_use(model_name:str=None):
|
|
try:
|
|
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=app.VAE_MODEL_EXTENSIONS, default_models=[])
|
|
except:
|
|
return None
|
|
|
|
def resolve_hypernetwork_to_use(model_name:str=None):
|
|
try:
|
|
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
|
|
except:
|
|
return None
|
|
|
|
def is_malicious_model(file_path):
|
|
try:
|
|
scan_result = picklescan.scanner.scan_file_path(file_path)
|
|
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
|
|
rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
|
return True
|
|
else:
|
|
rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
|
|
return False
|
|
except Exception as e:
|
|
print('error while scanning', file_path, 'error:', e)
|
|
return False
|
|
|
|
def getModels():
|
|
models = {
|
|
'active': {
|
|
'stable-diffusion': 'sd-v1-4',
|
|
'vae': '',
|
|
'hypernetwork': '',
|
|
},
|
|
'options': {
|
|
'stable-diffusion': ['sd-v1-4'],
|
|
'vae': [],
|
|
'hypernetwork': [],
|
|
},
|
|
}
|
|
|
|
def listModels(models_dirname, model_type, model_extensions):
|
|
models_dir = os.path.join(app.MODELS_DIR, models_dirname)
|
|
if not os.path.exists(models_dir):
|
|
os.makedirs(models_dir)
|
|
|
|
for file in os.listdir(models_dir):
|
|
for model_extension in model_extensions:
|
|
if not file.endswith(model_extension):
|
|
continue
|
|
|
|
model_path = os.path.join(models_dir, file)
|
|
mtime = os.path.getmtime(model_path)
|
|
mod_time = known_models[model_path] if model_path in known_models else -1
|
|
if mod_time != mtime:
|
|
if is_malicious_model(model_path):
|
|
models['scan-error'] = file
|
|
return
|
|
known_models[model_path] = mtime
|
|
|
|
model_name = file[:-len(model_extension)]
|
|
models['options'][model_type].append(model_name)
|
|
|
|
models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates
|
|
models['options'][model_type].sort()
|
|
|
|
# custom models
|
|
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS)
|
|
listModels(models_dirname='vae', model_type='vae', model_extensions=app.VAE_MODEL_EXTENSIONS)
|
|
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS)
|
|
|
|
# legacy
|
|
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
|
|
if os.path.exists(custom_weight_path):
|
|
models['options']['stable-diffusion'].append('custom-model')
|
|
|
|
return models
|