Work-in-progress model loading

This commit is contained in:
cmdr2 2022-12-08 13:50:46 +05:30
parent 5782966d63
commit bad89160cc
4 changed files with 40 additions and 24 deletions

View File

@ -18,10 +18,6 @@ USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = {

View File

@ -4,6 +4,10 @@ from sd_internal import app
import picklescan.scanner
import rich
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
default_model_to_load = None
default_vae_to_load = None
default_hypernetwork_to_load = None
@ -59,22 +63,17 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path + model_extension
raise Exception('No valid models found.')
print(f'No valid models found for model_name: {model_name}')
return None
def resolve_ckpt_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS)
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS)
def resolve_vae_to_use(model_name:str=None):
try:
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=app.VAE_MODEL_EXTENSIONS, default_models=[])
except:
return None
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[])
def resolve_hypernetwork_to_use(model_name:str=None):
try:
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
except:
return None
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
def is_malicious_model(file_path):
try:
@ -129,9 +128,9 @@ def getModels():
models['options'][model_type].sort()
# custom models
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS)
listModels(models_dirname='vae', model_type='vae', model_extensions=app.VAE_MODEL_EXTENSIONS)
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS)
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS)
# legacy
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')

View File

@ -1,7 +1,8 @@
import threading
import queue
from sd_internal import device_manager, Request, Response, Image as ResponseImage
from sd_internal import device_manager, model_manager
from sd_internal import Request, Response, Image as ResponseImage
from modules import model_loader, image_generator, image_utils
@ -18,7 +19,7 @@ def init(device):
thread_data.temp_images = {}
thread_data.models = {}
thread_data.loaded_model_paths = {}
thread_data.model_paths = {}
thread_data.device = None
thread_data.device_name = None
@ -27,14 +28,34 @@ def init(device):
device_manager.device_init(thread_data, device)
reload_models()
load_default_models()
def destroy():
model_loader.unload_sd_model(thread_data)
model_loader.unload_gfpgan_model(thread_data)
model_loader.unload_realesrgan_model(thread_data)
def reload_models(req: Request=None):
def load_default_models():
thread_data.model_paths['stable-diffusion'] = model_manager.default_model_to_load
thread_data.model_paths['vae'] = model_manager.default_vae_to_load
model_loader.load_sd_model(thread_data)
def reload_models_if_necessary(req: Request=None):
needs_model_reload = False
if 'stable-diffusion' not in thread_data.models or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
return needs_model_reload
if is_hypernetwork_reload_necessary(task.request):
current_state = ServerStates.LoadingModel
runtime.reload_hypernetwork()

View File

@ -310,9 +310,6 @@ def thread_render(device):
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try:
current_state = ServerStates.LoadingModel
runtime2.reload_models(task.request)
def step_callback():
global current_state_error
@ -323,6 +320,9 @@ def thread_render(device):
current_state_error = None
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
current_state = ServerStates.LoadingModel
runtime2.reload_models_if_necessary(task.request)
current_state = ServerStates.Rendering
task.response = runtime2.make_image(task.request, task.buffer_queue, task.temp_images, step_callback)
# Before looping back to the generator, mark cache as still alive.