Move away the remaining model-related code to the model_manager

This commit is contained in:
cmdr2 2022-12-11 20:13:44 +05:30
parent 97919c7e87
commit 096556d8c9
3 changed files with 32 additions and 33 deletions

View File

@ -3,7 +3,7 @@ import logging
import picklescan.scanner
import rich
from sd_internal import app
from sd_internal import app, TaskData
from modules import model_loader
from modules.types import Context
@ -88,6 +88,34 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
return None
def reload_models_if_necessary(context: Context, task_data: TaskData):
model_paths_in_req = (
('hypernetwork', task_data.use_hypernetwork_model),
('gfpgan', task_data.use_face_correction),
('realesrgan', task_data.use_upscale),
)
if context.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or context.model_paths.get('vae') != task_data.use_vae_model:
context.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model
context.model_paths['vae'] = task_data.use_vae_model
model_loader.load_model(context, 'stable-diffusion')
for model_type, model_path_in_req in model_paths_in_req:
if context.model_paths.get(model_type) != model_path_in_req:
context.model_paths[model_type] = model_path_in_req
action_fn = model_loader.unload_model if context.model_paths[model_type] is None else model_loader.load_model
action_fn(context, model_type)
def resolve_model_paths(task_data: TaskData):
task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae')
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan')
def make_model_folders():
for model_type in KNOWN_MODEL_TYPES:
model_dir_path = os.path.join(app.MODELS_DIR, model_type)

View File

@ -1,4 +1,3 @@
import threading
import queue
import time
import json
@ -8,7 +7,7 @@ import re
import traceback
import logging
from sd_internal import device_manager, model_manager
from sd_internal import device_manager
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
from modules import model_loader, image_generator, image_utils, filters as image_filters
@ -33,26 +32,6 @@ def init(device):
device_manager.device_init(thread_data, device)
def reload_models_if_necessary(task_data: TaskData):
model_paths_in_req = (
('hypernetwork', task_data.use_hypernetwork_model),
('gfpgan', task_data.use_face_correction),
('realesrgan', task_data.use_upscale),
)
if thread_data.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or thread_data.model_paths.get('vae') != task_data.use_vae_model:
thread_data.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model
thread_data.model_paths['vae'] = task_data.use_vae_model
model_loader.load_model(thread_data, 'stable-diffusion')
for model_type, model_path_in_req in model_paths_in_req:
if thread_data.model_paths.get(model_type) != model_path_in_req:
thread_data.model_paths[model_type] = model_path_in_req
action_fn = model_loader.unload_model if thread_data.model_paths[model_type] is None else model_loader.load_model
action_fn(thread_data, model_type)
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
try:
return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
@ -80,14 +59,6 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q
return res
def resolve_model_paths(task_data: TaskData):
task_data.use_stable_diffusion_model = model_manager.resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
task_data.use_vae_model = model_manager.resolve_model_to_use(task_data.use_vae_model, model_type='vae')
task_data.use_hypernetwork_model = model_manager.resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
if task_data.use_face_correction: task_data.use_face_correction = model_manager.resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan')
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
log.info(req.to_metadata())
thread_data.temp_images.clear()

View File

@ -277,10 +277,10 @@ def thread_render(device):
log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}')
current_state = ServerStates.LoadingModel
runtime2.reload_models_if_necessary(task.task_data)
model_manager.resolve_model_paths(task.task_data)
model_manager.reload_models_if_necessary(runtime2.thread_data, task.task_data)
current_state = ServerStates.Rendering
runtime2.resolve_model_paths(task.task_data)
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL)