forked from extern/easydiffusion
Move away the remaining model-related code to the model_manager
This commit is contained in:
parent
97919c7e87
commit
096556d8c9
@ -3,7 +3,7 @@ import logging
|
|||||||
import picklescan.scanner
|
import picklescan.scanner
|
||||||
import rich
|
import rich
|
||||||
|
|
||||||
from sd_internal import app
|
from sd_internal import app, TaskData
|
||||||
from modules import model_loader
|
from modules import model_loader
|
||||||
from modules.types import Context
|
from modules.types import Context
|
||||||
|
|
||||||
@ -88,6 +88,34 @@ def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def reload_models_if_necessary(context: Context, task_data: TaskData):
|
||||||
|
model_paths_in_req = (
|
||||||
|
('hypernetwork', task_data.use_hypernetwork_model),
|
||||||
|
('gfpgan', task_data.use_face_correction),
|
||||||
|
('realesrgan', task_data.use_upscale),
|
||||||
|
)
|
||||||
|
|
||||||
|
if context.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or context.model_paths.get('vae') != task_data.use_vae_model:
|
||||||
|
context.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model
|
||||||
|
context.model_paths['vae'] = task_data.use_vae_model
|
||||||
|
|
||||||
|
model_loader.load_model(context, 'stable-diffusion')
|
||||||
|
|
||||||
|
for model_type, model_path_in_req in model_paths_in_req:
|
||||||
|
if context.model_paths.get(model_type) != model_path_in_req:
|
||||||
|
context.model_paths[model_type] = model_path_in_req
|
||||||
|
|
||||||
|
action_fn = model_loader.unload_model if context.model_paths[model_type] is None else model_loader.load_model
|
||||||
|
action_fn(context, model_type)
|
||||||
|
|
||||||
|
def resolve_model_paths(task_data: TaskData):
|
||||||
|
task_data.use_stable_diffusion_model = resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
|
||||||
|
task_data.use_vae_model = resolve_model_to_use(task_data.use_vae_model, model_type='vae')
|
||||||
|
task_data.use_hypernetwork_model = resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
|
||||||
|
|
||||||
|
if task_data.use_face_correction: task_data.use_face_correction = resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
|
||||||
|
if task_data.use_upscale: task_data.use_upscale = resolve_model_to_use(task_data.use_upscale, 'gfpgan')
|
||||||
|
|
||||||
def make_model_folders():
|
def make_model_folders():
|
||||||
for model_type in KNOWN_MODEL_TYPES:
|
for model_type in KNOWN_MODEL_TYPES:
|
||||||
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import threading
|
|
||||||
import queue
|
import queue
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
@ -8,7 +7,7 @@ import re
|
|||||||
import traceback
|
import traceback
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from sd_internal import device_manager, model_manager
|
from sd_internal import device_manager
|
||||||
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
||||||
|
|
||||||
from modules import model_loader, image_generator, image_utils, filters as image_filters
|
from modules import model_loader, image_generator, image_utils, filters as image_filters
|
||||||
@ -33,26 +32,6 @@ def init(device):
|
|||||||
|
|
||||||
device_manager.device_init(thread_data, device)
|
device_manager.device_init(thread_data, device)
|
||||||
|
|
||||||
def reload_models_if_necessary(task_data: TaskData):
|
|
||||||
model_paths_in_req = (
|
|
||||||
('hypernetwork', task_data.use_hypernetwork_model),
|
|
||||||
('gfpgan', task_data.use_face_correction),
|
|
||||||
('realesrgan', task_data.use_upscale),
|
|
||||||
)
|
|
||||||
|
|
||||||
if thread_data.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or thread_data.model_paths.get('vae') != task_data.use_vae_model:
|
|
||||||
thread_data.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model
|
|
||||||
thread_data.model_paths['vae'] = task_data.use_vae_model
|
|
||||||
|
|
||||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
|
||||||
|
|
||||||
for model_type, model_path_in_req in model_paths_in_req:
|
|
||||||
if thread_data.model_paths.get(model_type) != model_path_in_req:
|
|
||||||
thread_data.model_paths[model_type] = model_path_in_req
|
|
||||||
|
|
||||||
action_fn = model_loader.unload_model if thread_data.model_paths[model_type] is None else model_loader.load_model
|
|
||||||
action_fn(thread_data, model_type)
|
|
||||||
|
|
||||||
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
try:
|
try:
|
||||||
return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||||
@ -80,14 +59,6 @@ def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_q
|
|||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def resolve_model_paths(task_data: TaskData):
|
|
||||||
task_data.use_stable_diffusion_model = model_manager.resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
|
|
||||||
task_data.use_vae_model = model_manager.resolve_model_to_use(task_data.use_vae_model, model_type='vae')
|
|
||||||
task_data.use_hypernetwork_model = model_manager.resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
|
|
||||||
|
|
||||||
if task_data.use_face_correction: task_data.use_face_correction = model_manager.resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
|
|
||||||
if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan')
|
|
||||||
|
|
||||||
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||||
log.info(req.to_metadata())
|
log.info(req.to_metadata())
|
||||||
thread_data.temp_images.clear()
|
thread_data.temp_images.clear()
|
||||||
|
@ -277,10 +277,10 @@ def thread_render(device):
|
|||||||
log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}')
|
log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}')
|
||||||
|
|
||||||
current_state = ServerStates.LoadingModel
|
current_state = ServerStates.LoadingModel
|
||||||
runtime2.reload_models_if_necessary(task.task_data)
|
model_manager.resolve_model_paths(task.task_data)
|
||||||
|
model_manager.reload_models_if_necessary(runtime2.thread_data, task.task_data)
|
||||||
|
|
||||||
current_state = ServerStates.Rendering
|
current_state = ServerStates.Rendering
|
||||||
runtime2.resolve_model_paths(task.task_data)
|
|
||||||
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
||||||
# Before looping back to the generator, mark cache as still alive.
|
# Before looping back to the generator, mark cache as still alive.
|
||||||
task_cache.keep(id(task), TASK_TTL)
|
task_cache.keep(id(task), TASK_TTL)
|
||||||
|
Loading…
Reference in New Issue
Block a user