forked from extern/easydiffusion
Simplify the API for resolving model paths; Code cleanup
This commit is contained in:
parent
b40fb3a422
commit
8820814002
@ -28,7 +28,7 @@ def init():
|
||||
make_model_folders()
|
||||
getModels() # run this once, to cache the picklescan results
|
||||
|
||||
def resolve_model_to_use(model_name:str, model_type:str):
|
||||
def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||
default_models = DEFAULT_MODELS.get(model_type, [])
|
||||
config = app.getConfig()
|
||||
@ -72,21 +72,6 @@ def resolve_model_to_use(model_name:str, model_type:str):
|
||||
|
||||
return None
|
||||
|
||||
def resolve_sd_model_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='stable-diffusion')
|
||||
|
||||
def resolve_vae_model_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='vae')
|
||||
|
||||
def resolve_hypernetwork_model_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='hypernetwork')
|
||||
|
||||
def resolve_gfpgan_model_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='gfpgan')
|
||||
|
||||
def resolve_realesrgan_model_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='realesrgan')
|
||||
|
||||
def make_model_folders():
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
||||
|
@ -39,18 +39,13 @@ def init(device):
|
||||
init_and_load_default_models()
|
||||
|
||||
def destroy():
|
||||
model_loader.unload_model(thread_data, 'stable-diffusion')
|
||||
model_loader.unload_model(thread_data, 'gfpgan')
|
||||
model_loader.unload_model(thread_data, 'realesrgan')
|
||||
model_loader.unload_model(thread_data, 'hypernetwork')
|
||||
for model_type in ('stable-diffusion', 'hypernetwork', 'gfpgan', 'realesrgan'):
|
||||
model_loader.unload_model(thread_data, model_type)
|
||||
|
||||
def init_and_load_default_models():
|
||||
# init default model paths
|
||||
thread_data.model_paths['stable-diffusion'] = model_manager.resolve_sd_model_to_use()
|
||||
thread_data.model_paths['vae'] = model_manager.resolve_vae_model_to_use()
|
||||
thread_data.model_paths['hypernetwork'] = model_manager.resolve_hypernetwork_model_to_use()
|
||||
thread_data.model_paths['gfpgan'] = model_manager.resolve_gfpgan_model_to_use()
|
||||
thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use()
|
||||
for model_type in ('stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'):
|
||||
thread_data.model_paths[model_type] = model_manager.resolve_model_to_use(model_type=model_type)
|
||||
|
||||
# load mandatory models
|
||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||
@ -119,8 +114,8 @@ def apply_filters(req: Request, images: list, user_stopped):
|
||||
return images
|
||||
|
||||
filters = []
|
||||
if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_gfpgan_model_to_use(req.use_face_correction)))
|
||||
if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_realesrgan_model_to_use(req.use_upscale)))
|
||||
if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(req.use_face_correction, model_type='gfpgan')))
|
||||
if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(req.use_upscale, model_type='realesrgan')))
|
||||
|
||||
filtered_images = []
|
||||
for img, seed, _ in images:
|
||||
|
@ -11,10 +11,10 @@ TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
||||
|
||||
import torch
|
||||
import queue, threading, time, weakref
|
||||
from typing import Any, Generator, Hashable, Optional, Union
|
||||
from typing import Any, Hashable
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sd_internal import Request, Response, runtime, device_manager
|
||||
from sd_internal import Request, device_manager
|
||||
|
||||
THREAD_NAME_PREFIX = 'Runtime-Render/'
|
||||
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
||||
|
@ -136,9 +136,9 @@ def ping(session_id:str=None):
|
||||
def render(req : task_manager.ImageRequest):
|
||||
try:
|
||||
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
|
||||
req.use_stable_diffusion_model = model_manager.resolve_sd_model_to_use(req.use_stable_diffusion_model)
|
||||
req.use_vae_model = model_manager.resolve_vae_model_to_use(req.use_vae_model)
|
||||
req.use_hypernetwork_model = model_manager.resolve_hypernetwork_model_to_use(req.use_hypernetwork_model)
|
||||
req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion')
|
||||
req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae')
|
||||
req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork')
|
||||
|
||||
new_task = task_manager.render(req)
|
||||
response = {
|
||||
|
Loading…
Reference in New Issue
Block a user