Simplify the API for resolving model paths; Code cleanup

This commit is contained in:
cmdr2 2022-12-09 15:45:36 +05:30
parent b40fb3a422
commit 8820814002
4 changed files with 12 additions and 32 deletions

View File

@ -28,7 +28,7 @@ def init():
make_model_folders()
getModels() # run this once, to cache the picklescan results
def resolve_model_to_use(model_name:str, model_type:str):
def resolve_model_to_use(model_name:str=None, model_type:str=None):
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
default_models = DEFAULT_MODELS.get(model_type, [])
config = app.getConfig()
@ -72,21 +72,6 @@ def resolve_model_to_use(model_name:str, model_type:str):
return None
def resolve_sd_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='stable-diffusion')
def resolve_vae_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='vae')
def resolve_hypernetwork_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='hypernetwork')
def resolve_gfpgan_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='gfpgan')
def resolve_realesrgan_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='realesrgan')
def make_model_folders():
for model_type in KNOWN_MODEL_TYPES:
model_dir_path = os.path.join(app.MODELS_DIR, model_type)

View File

@ -39,18 +39,13 @@ def init(device):
init_and_load_default_models()
def destroy():
model_loader.unload_model(thread_data, 'stable-diffusion')
model_loader.unload_model(thread_data, 'gfpgan')
model_loader.unload_model(thread_data, 'realesrgan')
model_loader.unload_model(thread_data, 'hypernetwork')
for model_type in ('stable-diffusion', 'hypernetwork', 'gfpgan', 'realesrgan'):
model_loader.unload_model(thread_data, model_type)
def init_and_load_default_models():
# init default model paths
thread_data.model_paths['stable-diffusion'] = model_manager.resolve_sd_model_to_use()
thread_data.model_paths['vae'] = model_manager.resolve_vae_model_to_use()
thread_data.model_paths['hypernetwork'] = model_manager.resolve_hypernetwork_model_to_use()
thread_data.model_paths['gfpgan'] = model_manager.resolve_gfpgan_model_to_use()
thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use()
for model_type in ('stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'):
thread_data.model_paths[model_type] = model_manager.resolve_model_to_use(model_type=model_type)
# load mandatory models
model_loader.load_model(thread_data, 'stable-diffusion')
@ -119,8 +114,8 @@ def apply_filters(req: Request, images: list, user_stopped):
return images
filters = []
if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_gfpgan_model_to_use(req.use_face_correction)))
if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_realesrgan_model_to_use(req.use_upscale)))
if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(req.use_face_correction, model_type='gfpgan')))
if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(req.use_upscale, model_type='realesrgan')))
filtered_images = []
for img, seed, _ in images:

View File

@ -11,10 +11,10 @@ TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
import torch
import queue, threading, time, weakref
from typing import Any, Generator, Hashable, Optional, Union
from typing import Any, Hashable
from pydantic import BaseModel
from sd_internal import Request, Response, runtime, device_manager
from sd_internal import Request, device_manager
THREAD_NAME_PREFIX = 'Runtime-Render/'
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'

View File

@ -136,9 +136,9 @@ def ping(session_id:str=None):
def render(req : task_manager.ImageRequest):
try:
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
req.use_stable_diffusion_model = model_manager.resolve_sd_model_to_use(req.use_stable_diffusion_model)
req.use_vae_model = model_manager.resolve_vae_model_to_use(req.use_vae_model)
req.use_hypernetwork_model = model_manager.resolve_hypernetwork_model_to_use(req.use_hypernetwork_model)
req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion')
req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae')
req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork')
new_task = task_manager.render(req)
response = {