mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-05-06 17:24:27 +02:00
Use the simplified model loading API in diffusion-kit; Catch and report exceptions while generating images
This commit is contained in:
parent
27c6113287
commit
16410d90b8
@ -5,6 +5,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
|
import traceback
|
||||||
|
|
||||||
from sd_internal import device_manager, model_manager
|
from sd_internal import device_manager, model_manager
|
||||||
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
|
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
|
||||||
@ -38,10 +39,10 @@ def init(device):
|
|||||||
init_and_load_default_models()
|
init_and_load_default_models()
|
||||||
|
|
||||||
def destroy():
|
def destroy():
|
||||||
model_loader.unload_sd_model(thread_data)
|
model_loader.unload_model(thread_data, 'stable-diffusion')
|
||||||
model_loader.unload_gfpgan_model(thread_data)
|
model_loader.unload_model(thread_data, 'gfpgan')
|
||||||
model_loader.unload_realesrgan_model(thread_data)
|
model_loader.unload_model(thread_data, 'realesrgan')
|
||||||
model_loader.unload_hypernetwork_model(thread_data)
|
model_loader.unload_model(thread_data, 'hypernetwork')
|
||||||
|
|
||||||
def init_and_load_default_models():
|
def init_and_load_default_models():
|
||||||
# init default model paths
|
# init default model paths
|
||||||
@ -52,24 +53,36 @@ def init_and_load_default_models():
|
|||||||
thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use()
|
thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use()
|
||||||
|
|
||||||
# load mandatory models
|
# load mandatory models
|
||||||
model_loader.load_sd_model(thread_data)
|
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||||
|
|
||||||
def reload_models_if_necessary(req: Request):
|
def reload_models_if_necessary(req: Request):
|
||||||
if model_manager.is_sd_model_reload_necessary(thread_data, req):
|
if model_manager.is_sd_model_reload_necessary(thread_data, req):
|
||||||
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
|
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
|
||||||
thread_data.model_paths['vae'] = req.use_vae_model
|
thread_data.model_paths['vae'] = req.use_vae_model
|
||||||
|
|
||||||
model_loader.load_sd_model(thread_data)
|
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||||
|
|
||||||
if thread_data.model_paths.get('hypernetwork') != req.use_hypernetwork_model:
|
if thread_data.model_paths.get('hypernetwork') != req.use_hypernetwork_model:
|
||||||
thread_data.model_paths['hypernetwork'] = req.use_hypernetwork_model
|
thread_data.model_paths['hypernetwork'] = req.use_hypernetwork_model
|
||||||
|
|
||||||
if thread_data.model_paths['hypernetwork'] is not None:
|
if thread_data.model_paths['hypernetwork'] is not None:
|
||||||
model_loader.load_hypernetwork_model(thread_data)
|
model_loader.load_model(thread_data, 'hypernetwork')
|
||||||
else:
|
else:
|
||||||
model_loader.unload_hypernetwork_model(thread_data)
|
model_loader.unload_model(thread_data, 'hypernetwork')
|
||||||
|
|
||||||
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
|
try:
|
||||||
|
return _make_images_internal(req, data_queue, task_temp_images, step_callback)
|
||||||
|
except Exception as e:
|
||||||
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
data_queue.put(json.dumps({
|
||||||
|
"status": 'failed',
|
||||||
|
"detail": str(e)
|
||||||
|
}))
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback)
|
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback)
|
||||||
images = apply_filters(req, images, user_stopped)
|
images = apply_filters(req, images, user_stopped)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user