mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-20 18:08:00 +02:00
Simplify the code for VAE loading, and make it faster to load VAEs (because we don't reload the entire SD model each time a VAE changes); Record the error and end the thread if the SD model fails to load during startup
This commit is contained in:
parent
a483bd0800
commit
6cd0b530c5
@ -42,6 +42,8 @@ def load_default_models(context: Context):
|
|||||||
|
|
||||||
# load mandatory models
|
# load mandatory models
|
||||||
model_loader.load_model(context, 'stable-diffusion')
|
model_loader.load_model(context, 'stable-diffusion')
|
||||||
|
model_loader.load_model(context, 'vae')
|
||||||
|
model_loader.load_model(context, 'hypernetwork')
|
||||||
|
|
||||||
def unload_all(context: Context):
|
def unload_all(context: Context):
|
||||||
for model_type in KNOWN_MODEL_TYPES:
|
for model_type in KNOWN_MODEL_TYPES:
|
||||||
@ -93,17 +95,13 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
|||||||
|
|
||||||
def reload_models_if_necessary(context: Context, task_data: TaskData):
|
def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||||
model_paths_in_req = (
|
model_paths_in_req = (
|
||||||
|
('stable-diffusion', task_data.use_stable_diffusion_model),
|
||||||
|
('vae', task_data.use_vae_model),
|
||||||
('hypernetwork', task_data.use_hypernetwork_model),
|
('hypernetwork', task_data.use_hypernetwork_model),
|
||||||
('gfpgan', task_data.use_face_correction),
|
('gfpgan', task_data.use_face_correction),
|
||||||
('realesrgan', task_data.use_upscale),
|
('realesrgan', task_data.use_upscale),
|
||||||
)
|
)
|
||||||
|
|
||||||
if context.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or context.model_paths.get('vae') != task_data.use_vae_model:
|
|
||||||
context.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model
|
|
||||||
context.model_paths['vae'] = task_data.use_vae_model
|
|
||||||
|
|
||||||
model_loader.load_model(context, 'stable-diffusion')
|
|
||||||
|
|
||||||
for model_type, model_path_in_req in model_paths_in_req:
|
for model_type, model_path_in_req in model_paths_in_req:
|
||||||
if context.model_paths.get(model_type) != model_path_in_req:
|
if context.model_paths.get(model_type) != model_path_in_req:
|
||||||
context.model_paths[model_type] = model_path_in_req
|
context.model_paths[model_type] = model_path_in_req
|
||||||
|
@ -222,24 +222,25 @@ def thread_render(device):
|
|||||||
from sd_internal import renderer, model_manager
|
from sd_internal import renderer, model_manager
|
||||||
try:
|
try:
|
||||||
renderer.init(device)
|
renderer.init(device)
|
||||||
|
|
||||||
|
weak_thread_data[threading.current_thread()] = {
|
||||||
|
'device': renderer.context.device,
|
||||||
|
'device_name': renderer.context.device_name,
|
||||||
|
'alive': True
|
||||||
|
}
|
||||||
|
|
||||||
|
current_state = ServerStates.LoadingModel
|
||||||
|
model_manager.load_default_models(renderer.context)
|
||||||
|
|
||||||
|
current_state = ServerStates.Online
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
weak_thread_data[threading.current_thread()] = {
|
weak_thread_data[threading.current_thread()] = {
|
||||||
'error': e
|
'error': e,
|
||||||
|
'alive': False
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
weak_thread_data[threading.current_thread()] = {
|
|
||||||
'device': renderer.context.device,
|
|
||||||
'device_name': renderer.context.device_name,
|
|
||||||
'alive': True
|
|
||||||
}
|
|
||||||
|
|
||||||
current_state = ServerStates.LoadingModel
|
|
||||||
model_manager.load_default_models(renderer.context)
|
|
||||||
|
|
||||||
current_state = ServerStates.Online
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
session_cache.clean()
|
session_cache.clean()
|
||||||
task_cache.clean()
|
task_cache.clean()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user