Use the simplified model loading API in diffusion-kit; Catch and report exceptions while generating images

This commit is contained in:
cmdr2 2022-12-09 15:21:49 +05:30
parent 27c6113287
commit 16410d90b8

View File

@ -5,6 +5,7 @@ import json
import os
import base64
import re
import traceback
from sd_internal import device_manager, model_manager
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
@ -38,10 +39,10 @@ def init(device):
init_and_load_default_models()
def destroy():
model_loader.unload_sd_model(thread_data)
model_loader.unload_gfpgan_model(thread_data)
model_loader.unload_realesrgan_model(thread_data)
model_loader.unload_hypernetwork_model(thread_data)
model_loader.unload_model(thread_data, 'stable-diffusion')
model_loader.unload_model(thread_data, 'gfpgan')
model_loader.unload_model(thread_data, 'realesrgan')
model_loader.unload_model(thread_data, 'hypernetwork')
def init_and_load_default_models():
# init default model paths
@ -52,24 +53,36 @@ def init_and_load_default_models():
thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use()
# load mandatory models
model_loader.load_sd_model(thread_data)
model_loader.load_model(thread_data, 'stable-diffusion')
def reload_models_if_necessary(req: Request):
if model_manager.is_sd_model_reload_necessary(thread_data, req):
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
thread_data.model_paths['vae'] = req.use_vae_model
model_loader.load_sd_model(thread_data)
model_loader.load_model(thread_data, 'stable-diffusion')
if thread_data.model_paths.get('hypernetwork') != req.use_hypernetwork_model:
thread_data.model_paths['hypernetwork'] = req.use_hypernetwork_model
if thread_data.model_paths['hypernetwork'] is not None:
model_loader.load_hypernetwork_model(thread_data)
model_loader.load_model(thread_data, 'hypernetwork')
else:
model_loader.unload_hypernetwork_model(thread_data)
model_loader.unload_model(thread_data, 'hypernetwork')
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
try:
return _make_images_internal(req, data_queue, task_temp_images, step_callback)
except Exception as e:
print(traceback.format_exc())
data_queue.put(json.dumps({
"status": 'failed',
"detail": str(e)
}))
raise e
def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback)
images = apply_filters(req, images, user_stopped)