mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-02-16 18:32:25 +01:00
Move away the remaining model-related code to the model_manager
This commit is contained in:
parent
97919c7e87
commit
096556d8c9
@ -3,7 +3,7 @@ import logging
|
||||
import picklescan.scanner
|
||||
import rich
|
||||
|
||||
from sd_internal import app
|
||||
from sd_internal import app, TaskData
|
||||
from modules import model_loader
|
||||
from modules.types import Context
|
||||
|
||||
@ -88,6 +88,34 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
||||
|
||||
return None
|
||||
|
||||
def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||
model_paths_in_req = (
|
||||
('hypernetwork', task_data.use_hypernetwork_model),
|
||||
('gfpgan', task_data.use_face_correction),
|
||||
('realesrgan', task_data.use_upscale),
|
||||
)
|
||||
|
||||
if context.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or context.model_paths.get('vae') != task_data.use_vae_model:
|
||||
context.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model
|
||||
context.model_paths['vae'] = task_data.use_vae_model
|
||||
|
||||
model_loader.load_model(context, 'stable-diffusion')
|
||||
|
||||
for model_type, model_path_in_req in model_paths_in_req:
|
||||
if context.model_paths.get(model_type) != model_path_in_req:
|
||||
context.model_paths[model_type] = model_path_in_req
|
||||
|
||||
action_fn = model_loader.unload_model if context.model_paths[model_type] is None else model_loader.load_model
|
||||
action_fn(context, model_type)
|
||||
|
||||
def resolve_model_paths(task_data: TaskData):
|
||||
task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
|
||||
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae')
|
||||
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
|
||||
|
||||
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
|
||||
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan')
|
||||
|
||||
def make_model_folders():
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
||||
|
@ -1,4 +1,3 @@
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
import json
|
||||
@ -8,7 +7,7 @@ import re
|
||||
import traceback
|
||||
import logging
|
||||
|
||||
from sd_internal import device_manager, model_manager
|
||||
from sd_internal import device_manager
|
||||
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
||||
|
||||
from modules import model_loader, image_generator, image_utils, filters as image_filters
|
||||
@ -33,26 +32,6 @@ def init(device):
|
||||
|
||||
device_manager.device_init(thread_data, device)
|
||||
|
||||
def reload_models_if_necessary(task_data: TaskData):
|
||||
model_paths_in_req = (
|
||||
('hypernetwork', task_data.use_hypernetwork_model),
|
||||
('gfpgan', task_data.use_face_correction),
|
||||
('realesrgan', task_data.use_upscale),
|
||||
)
|
||||
|
||||
if thread_data.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or thread_data.model_paths.get('vae') != task_data.use_vae_model:
|
||||
thread_data.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model
|
||||
thread_data.model_paths['vae'] = task_data.use_vae_model
|
||||
|
||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||
|
||||
for model_type, model_path_in_req in model_paths_in_req:
|
||||
if thread_data.model_paths.get(model_type) != model_path_in_req:
|
||||
thread_data.model_paths[model_type] = model_path_in_req
|
||||
|
||||
action_fn = model_loader.unload_model if thread_data.model_paths[model_type] is None else model_loader.load_model
|
||||
action_fn(thread_data, model_type)
|
||||
|
||||
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
try:
|
||||
return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||
@ -80,14 +59,6 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q
|
||||
|
||||
return res
|
||||
|
||||
def resolve_model_paths(task_data: TaskData):
|
||||
task_data.use_stable_diffusion_model = model_manager.resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
|
||||
task_data.use_vae_model = model_manager.resolve_model_to_use(task_data.use_vae_model, model_type='vae')
|
||||
task_data.use_hypernetwork_model = model_manager.resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
|
||||
|
||||
if task_data.use_face_correction: task_data.use_face_correction = model_manager.resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
|
||||
if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan')
|
||||
|
||||
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
log.info(req.to_metadata())
|
||||
thread_data.temp_images.clear()
|
||||
|
@ -277,10 +277,10 @@ def thread_render(device):
|
||||
log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}')
|
||||
|
||||
current_state = ServerStates.LoadingModel
|
||||
runtime2.reload_models_if_necessary(task.task_data)
|
||||
model_manager.resolve_model_paths(task.task_data)
|
||||
model_manager.reload_models_if_necessary(runtime2.thread_data, task.task_data)
|
||||
|
||||
current_state = ServerStates.Rendering
|
||||
runtime2.resolve_model_paths(task.task_data)
|
||||
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
|
Loading…
Reference in New Issue
Block a user