mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-22 16:23:28 +01:00
Use the simplified model loading API in diffusion-kit; Catch and report exceptions while generating images
This commit is contained in:
parent
27c6113287
commit
16410d90b8
@ -5,6 +5,7 @@ import json
|
||||
import os
|
||||
import base64
|
||||
import re
|
||||
import traceback
|
||||
|
||||
from sd_internal import device_manager, model_manager
|
||||
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
|
||||
@ -38,10 +39,10 @@ def init(device):
|
||||
init_and_load_default_models()
|
||||
|
||||
def destroy():
|
||||
model_loader.unload_sd_model(thread_data)
|
||||
model_loader.unload_gfpgan_model(thread_data)
|
||||
model_loader.unload_realesrgan_model(thread_data)
|
||||
model_loader.unload_hypernetwork_model(thread_data)
|
||||
model_loader.unload_model(thread_data, 'stable-diffusion')
|
||||
model_loader.unload_model(thread_data, 'gfpgan')
|
||||
model_loader.unload_model(thread_data, 'realesrgan')
|
||||
model_loader.unload_model(thread_data, 'hypernetwork')
|
||||
|
||||
def init_and_load_default_models():
|
||||
# init default model paths
|
||||
@ -52,24 +53,36 @@ def init_and_load_default_models():
|
||||
thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use()
|
||||
|
||||
# load mandatory models
|
||||
model_loader.load_sd_model(thread_data)
|
||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||
|
||||
def reload_models_if_necessary(req: Request):
|
||||
if model_manager.is_sd_model_reload_necessary(thread_data, req):
|
||||
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
|
||||
thread_data.model_paths['vae'] = req.use_vae_model
|
||||
|
||||
model_loader.load_sd_model(thread_data)
|
||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||
|
||||
if thread_data.model_paths.get('hypernetwork') != req.use_hypernetwork_model:
|
||||
thread_data.model_paths['hypernetwork'] = req.use_hypernetwork_model
|
||||
|
||||
if thread_data.model_paths['hypernetwork'] is not None:
|
||||
model_loader.load_hypernetwork_model(thread_data)
|
||||
model_loader.load_model(thread_data, 'hypernetwork')
|
||||
else:
|
||||
model_loader.unload_hypernetwork_model(thread_data)
|
||||
model_loader.unload_model(thread_data, 'hypernetwork')
|
||||
|
||||
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
try:
|
||||
return _make_images_internal(req, data_queue, task_temp_images, step_callback)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
|
||||
data_queue.put(json.dumps({
|
||||
"status": 'failed',
|
||||
"detail": str(e)
|
||||
}))
|
||||
raise e
|
||||
|
||||
def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback)
|
||||
images = apply_filters(req, images, user_stopped)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user