mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-01-21 13:50:12 +01:00
96 lines
3.0 KiB
Python
96 lines
3.0 KiB
Python
import threading
|
|
import queue
|
|
|
|
from sd_internal import device_manager, Request, Response, Image as ResponseImage
|
|
|
|
from modules import model_loader, image_generator, image_utils
|
|
|
|
thread_data = threading.local()
|
|
'''
|
|
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
|
'''
|
|
|
|
def init(device):
|
|
'''
|
|
Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device
|
|
'''
|
|
thread_data.stop_processing = False
|
|
thread_data.temp_images = {}
|
|
|
|
thread_data.models = {}
|
|
thread_data.loaded_model_paths = {}
|
|
|
|
thread_data.device = None
|
|
thread_data.device_name = None
|
|
thread_data.precision = 'autocast'
|
|
thread_data.vram_optimizations = ('TURBO', 'MOVE_MODELS')
|
|
|
|
device_manager.device_init(thread_data, device)
|
|
|
|
reload_models()
|
|
|
|
def destroy():
|
|
model_loader.unload_sd_model(thread_data)
|
|
model_loader.unload_gfpgan_model(thread_data)
|
|
model_loader.unload_realesrgan_model(thread_data)
|
|
|
|
def reload_models(req: Request=None):
|
|
if is_hypernetwork_reload_necessary(task.request):
|
|
current_state = ServerStates.LoadingModel
|
|
runtime.reload_hypernetwork()
|
|
|
|
if is_model_reload_necessary(task.request):
|
|
current_state = ServerStates.LoadingModel
|
|
runtime.reload_model()
|
|
|
|
def load_models():
|
|
if ckpt_file_path == None:
|
|
ckpt_file_path = default_model_to_load
|
|
if vae_file_path == None:
|
|
vae_file_path = default_vae_to_load
|
|
if hypernetwork_file_path == None:
|
|
hypernetwork_file_path = default_hypernetwork_to_load
|
|
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
|
|
return
|
|
current_state = ServerStates.LoadingModel
|
|
try:
|
|
from sd_internal import runtime2
|
|
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
|
|
runtime.thread_data.ckpt_file = ckpt_file_path
|
|
runtime.thread_data.vae_file = vae_file_path
|
|
runtime.load_model_ckpt()
|
|
runtime.load_hypernetwork()
|
|
current_model_path = ckpt_file_path
|
|
current_vae_path = vae_file_path
|
|
current_hypernetwork_path = hypernetwork_file_path
|
|
current_state_error = None
|
|
current_state = ServerStates.Online
|
|
except Exception as e:
|
|
current_model_path = None
|
|
current_vae_path = None
|
|
current_state_error = e
|
|
current_state = ServerStates.Unavailable
|
|
print(traceback.format_exc())
|
|
|
|
def make_image(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
|
try:
|
|
images = image_generator.make_image(context=thread_data, args=get_mk_img_args(req))
|
|
except UserInitiatedStop:
|
|
pass
|
|
|
|
def get_mk_img_args(req: Request):
|
|
args = req.json()
|
|
|
|
if req.init_image is not None:
|
|
args['init_image'] = image_utils.base64_str_to_img(req.init_image)
|
|
|
|
if req.mask is not None:
|
|
args['mask'] = image_utils.base64_str_to_img(req.mask)
|
|
|
|
return args
|
|
|
|
def on_image_step(x_samples, i):
|
|
pass
|
|
|
|
image_generator.on_image_step = on_image_step
|