mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-22 16:23:28 +01:00
Work-in-progress model loading
This commit is contained in:
parent
5782966d63
commit
bad89160cc
@ -18,10 +18,6 @@ USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'
|
||||
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
|
||||
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
|
||||
|
||||
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
|
||||
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
|
||||
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
|
||||
|
||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||
APP_CONFIG_DEFAULTS = {
|
||||
|
@ -4,6 +4,10 @@ from sd_internal import app
|
||||
import picklescan.scanner
|
||||
import rich
|
||||
|
||||
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
|
||||
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
|
||||
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
|
||||
|
||||
default_model_to_load = None
|
||||
default_vae_to_load = None
|
||||
default_hypernetwork_to_load = None
|
||||
@ -59,22 +63,17 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex
|
||||
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
|
||||
return default_model_path + model_extension
|
||||
|
||||
raise Exception('No valid models found.')
|
||||
print(f'No valid models found for model_name: {model_name}')
|
||||
return None
|
||||
|
||||
def resolve_ckpt_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS)
|
||||
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS)
|
||||
|
||||
def resolve_vae_to_use(model_name:str=None):
|
||||
try:
|
||||
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=app.VAE_MODEL_EXTENSIONS, default_models=[])
|
||||
except:
|
||||
return None
|
||||
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[])
|
||||
|
||||
def resolve_hypernetwork_to_use(model_name:str=None):
|
||||
try:
|
||||
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
|
||||
except:
|
||||
return None
|
||||
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
|
||||
|
||||
def is_malicious_model(file_path):
|
||||
try:
|
||||
@ -129,9 +128,9 @@ def getModels():
|
||||
models['options'][model_type].sort()
|
||||
|
||||
# custom models
|
||||
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS)
|
||||
listModels(models_dirname='vae', model_type='vae', model_extensions=app.VAE_MODEL_EXTENSIONS)
|
||||
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS)
|
||||
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
|
||||
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
|
||||
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS)
|
||||
|
||||
# legacy
|
||||
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
|
||||
|
@ -1,7 +1,8 @@
|
||||
import threading
|
||||
import queue
|
||||
|
||||
from sd_internal import device_manager, Request, Response, Image as ResponseImage
|
||||
from sd_internal import device_manager, model_manager
|
||||
from sd_internal import Request, Response, Image as ResponseImage
|
||||
|
||||
from modules import model_loader, image_generator, image_utils
|
||||
|
||||
@ -18,7 +19,7 @@ def init(device):
|
||||
thread_data.temp_images = {}
|
||||
|
||||
thread_data.models = {}
|
||||
thread_data.loaded_model_paths = {}
|
||||
thread_data.model_paths = {}
|
||||
|
||||
thread_data.device = None
|
||||
thread_data.device_name = None
|
||||
@ -27,14 +28,34 @@ def init(device):
|
||||
|
||||
device_manager.device_init(thread_data, device)
|
||||
|
||||
reload_models()
|
||||
load_default_models()
|
||||
|
||||
def destroy():
|
||||
model_loader.unload_sd_model(thread_data)
|
||||
model_loader.unload_gfpgan_model(thread_data)
|
||||
model_loader.unload_realesrgan_model(thread_data)
|
||||
|
||||
def reload_models(req: Request=None):
|
||||
def load_default_models():
|
||||
thread_data.model_paths['stable-diffusion'] = model_manager.default_model_to_load
|
||||
thread_data.model_paths['vae'] = model_manager.default_vae_to_load
|
||||
|
||||
model_loader.load_sd_model(thread_data)
|
||||
|
||||
def reload_models_if_necessary(req: Request=None):
|
||||
needs_model_reload = False
|
||||
if 'stable-diffusion' not in thread_data.models or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
|
||||
thread_data.ckpt_file = req.use_stable_diffusion_model
|
||||
thread_data.vae_file = req.use_vae_model
|
||||
needs_model_reload = True
|
||||
|
||||
if thread_data.device != 'cpu':
|
||||
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
|
||||
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
|
||||
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
||||
needs_model_reload = True
|
||||
|
||||
return needs_model_reload
|
||||
|
||||
if is_hypernetwork_reload_necessary(task.request):
|
||||
current_state = ServerStates.LoadingModel
|
||||
runtime.reload_hypernetwork()
|
||||
|
@ -310,9 +310,6 @@ def thread_render(device):
|
||||
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
|
||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||
try:
|
||||
current_state = ServerStates.LoadingModel
|
||||
runtime2.reload_models(task.request)
|
||||
|
||||
def step_callback():
|
||||
global current_state_error
|
||||
|
||||
@ -323,6 +320,9 @@ def thread_render(device):
|
||||
current_state_error = None
|
||||
print(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
||||
|
||||
current_state = ServerStates.LoadingModel
|
||||
runtime2.reload_models_if_necessary(task.request)
|
||||
|
||||
current_state = ServerStates.Rendering
|
||||
task.response = runtime2.make_image(task.request, task.buffer_queue, task.temp_images, step_callback)
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
|
Loading…
Reference in New Issue
Block a user