easydiffusion/ui/sd_internal/model_manager.py

208 lines
8.3 KiB
Python
Raw Normal View History

import os
2022-12-09 17:00:18 +01:00
import logging
import picklescan.scanner
import rich
from sd_internal import app, TaskData
from diffusionkit import model_loader
from diffusionkit.types import Context
2022-12-09 17:00:18 +01:00
log = logging.getLogger()
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
MODEL_EXTENSIONS = {
'stable-diffusion': ['.ckpt', '.safetensors'],
'vae': ['.vae.pt', '.ckpt'],
'hypernetwork': ['.pt'],
'gfpgan': ['.pth'],
'realesrgan': ['.pth'],
}
DEFAULT_MODELS = {
'stable-diffusion': [ # needed to support the legacy installations
'custom-model', # only one custom model file was supported initially, creatively named 'custom-model'
'sd-v1-4', # Default fallback.
],
'gfpgan': ['GFPGANv1.3'],
'realesrgan': ['RealESRGAN_x4plus'],
}
known_models = {}
def init():
make_model_folders()
getModels() # run this once, to cache the picklescan results
2022-12-11 15:28:12 +01:00
def load_default_models(context: Context):
# init default model paths
for model_type in KNOWN_MODEL_TYPES:
context.model_paths[model_type] = resolve_model_to_use(model_type=model_type)
2022-12-11 16:12:31 +01:00
# disable TURBO initially (this should be read from the config eventually)
context.vram_optimizations -= {'TURBO'}
2022-12-11 15:28:12 +01:00
# load mandatory models
model_loader.load_model(context, 'stable-diffusion')
model_loader.load_model(context, 'vae')
model_loader.load_model(context, 'hypernetwork')
2022-12-11 15:28:12 +01:00
def unload_all(context: Context):
for model_type in KNOWN_MODEL_TYPES:
model_loader.unload_model(context, model_type)
def resolve_model_to_use(model_name:str=None, model_type:str=None):
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
default_models = DEFAULT_MODELS.get(model_type, [])
config = app.getConfig()
model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR]
if not model_name: # When None try user configured model.
# config = getConfig()
if 'model' in config and model_type in config['model']:
model_name = config['model'][model_type]
if model_name:
is_sd2 = config.get('test_sd2', False)
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
2022-12-09 17:00:18 +01:00
log.error('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
model_name = 'sd-v1-4'
# Check models directory
models_dir_path = os.path.join(app.MODELS_DIR, model_type, model_name)
for model_extension in model_extensions:
if os.path.exists(models_dir_path + model_extension):
return models_dir_path + model_extension
if os.path.exists(model_name + model_extension):
return os.path.abspath(model_name + model_extension)
# Default locations
if model_name in default_models:
default_model_path = os.path.join(app.SD_DIR, model_name)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
return default_model_path + model_extension
# Can't find requested model, check the default paths.
for default_model in default_models:
for model_dir in model_dirs:
default_model_path = os.path.join(model_dir, default_model)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
if model_name is not None:
2022-12-09 17:00:18 +01:00
log.warn(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path + model_extension
2022-12-08 09:20:46 +01:00
return None
def reload_models_if_necessary(context: Context, task_data: TaskData):
model_paths_in_req = (
('stable-diffusion', task_data.use_stable_diffusion_model),
('vae', task_data.use_vae_model),
('hypernetwork', task_data.use_hypernetwork_model),
('gfpgan', task_data.use_face_correction),
('realesrgan', task_data.use_upscale),
)
for model_type, model_path_in_req in model_paths_in_req:
if context.model_paths.get(model_type) != model_path_in_req:
context.model_paths[model_type] = model_path_in_req
action_fn = model_loader.unload_model if context.model_paths[model_type] is None else model_loader.load_model
action_fn(context, model_type)
def resolve_model_paths(task_data: TaskData):
task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae')
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan')
2022-12-11 16:12:31 +01:00
def set_vram_optimizations(context: Context, task_data: TaskData):
if task_data.turbo:
2022-12-12 10:48:56 +01:00
context.vram_optimizations.add('TURBO')
2022-12-11 16:12:31 +01:00
else:
2022-12-12 10:48:56 +01:00
context.vram_optimizations.remove('TURBO')
2022-12-11 16:12:31 +01:00
def make_model_folders():
for model_type in KNOWN_MODEL_TYPES:
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
os.makedirs(model_dir_path, exist_ok=True)
help_file_name = f'Place your {model_type} model files here.txt'
help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}'
2022-12-09 10:57:40 +01:00
with open(os.path.join(model_dir_path, help_file_name), 'w', encoding='utf-8') as f:
f.write(help_file_contents)
def is_malicious_model(file_path):
try:
scan_result = picklescan.scanner.scan_file_path(file_path)
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
2022-12-09 17:00:18 +01:00
log.warn(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return True
else:
2022-12-09 17:00:18 +01:00
log.debug("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return False
except Exception as e:
2022-12-09 17:00:18 +01:00
log.error(f'error while scanning: {file_path}, error: {e}')
return False
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
'vae': '',
'hypernetwork': '',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
'vae': [],
'hypernetwork': [],
},
}
2022-12-09 17:00:18 +01:00
models_scanned = 0
def listModels(model_type):
2022-12-09 17:00:18 +01:00
nonlocal models_scanned
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
models_dir = os.path.join(app.MODELS_DIR, model_type)
if not os.path.exists(models_dir):
os.makedirs(models_dir)
for file in os.listdir(models_dir):
for model_extension in model_extensions:
if not file.endswith(model_extension):
continue
model_path = os.path.join(models_dir, file)
mtime = os.path.getmtime(model_path)
mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime:
2022-12-09 17:00:18 +01:00
models_scanned += 1
if is_malicious_model(model_path):
models['scan-error'] = file
return
known_models[model_path] = mtime
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates
models['options'][model_type].sort()
# custom models
listModels(model_type='stable-diffusion')
listModels(model_type='vae')
listModels(model_type='hypernetwork')
if models_scanned > 0: log.info(f'[green]Scanned {models_scanned} models. Nothing infected[/]')
2022-12-09 17:00:18 +01:00
# legacy
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['options']['stable-diffusion'].append('custom-model')
return models