forked from extern/easydiffusion
Simplify the API for resolving model paths; Code cleanup
This commit is contained in:
parent
b40fb3a422
commit
8820814002
@ -28,7 +28,7 @@ def init():
|
|||||||
make_model_folders()
|
make_model_folders()
|
||||||
getModels() # run this once, to cache the picklescan results
|
getModels() # run this once, to cache the picklescan results
|
||||||
|
|
||||||
def resolve_model_to_use(model_name:str, model_type:str):
|
def resolve_model_to_use(model_name:str=None, model_type:str=None):
|
||||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
default_models = DEFAULT_MODELS.get(model_type, [])
|
default_models = DEFAULT_MODELS.get(model_type, [])
|
||||||
config = app.getConfig()
|
config = app.getConfig()
|
||||||
@ -72,21 +72,6 @@ def resolve_model_to_use(model_name:str, model_type:str):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def resolve_sd_model_to_use(model_name:str=None):
|
|
||||||
return resolve_model_to_use(model_name, model_type='stable-diffusion')
|
|
||||||
|
|
||||||
def resolve_vae_model_to_use(model_name:str=None):
|
|
||||||
return resolve_model_to_use(model_name, model_type='vae')
|
|
||||||
|
|
||||||
def resolve_hypernetwork_model_to_use(model_name:str=None):
|
|
||||||
return resolve_model_to_use(model_name, model_type='hypernetwork')
|
|
||||||
|
|
||||||
def resolve_gfpgan_model_to_use(model_name:str=None):
|
|
||||||
return resolve_model_to_use(model_name, model_type='gfpgan')
|
|
||||||
|
|
||||||
def resolve_realesrgan_model_to_use(model_name:str=None):
|
|
||||||
return resolve_model_to_use(model_name, model_type='realesrgan')
|
|
||||||
|
|
||||||
def make_model_folders():
|
def make_model_folders():
|
||||||
for model_type in KNOWN_MODEL_TYPES:
|
for model_type in KNOWN_MODEL_TYPES:
|
||||||
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
||||||
|
@ -39,18 +39,13 @@ def init(device):
|
|||||||
init_and_load_default_models()
|
init_and_load_default_models()
|
||||||
|
|
||||||
def destroy():
|
def destroy():
|
||||||
model_loader.unload_model(thread_data, 'stable-diffusion')
|
for model_type in ('stable-diffusion', 'hypernetwork', 'gfpgan', 'realesrgan'):
|
||||||
model_loader.unload_model(thread_data, 'gfpgan')
|
model_loader.unload_model(thread_data, model_type)
|
||||||
model_loader.unload_model(thread_data, 'realesrgan')
|
|
||||||
model_loader.unload_model(thread_data, 'hypernetwork')
|
|
||||||
|
|
||||||
def init_and_load_default_models():
|
def init_and_load_default_models():
|
||||||
# init default model paths
|
# init default model paths
|
||||||
thread_data.model_paths['stable-diffusion'] = model_manager.resolve_sd_model_to_use()
|
for model_type in ('stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan'):
|
||||||
thread_data.model_paths['vae'] = model_manager.resolve_vae_model_to_use()
|
thread_data.model_paths[model_type] = model_manager.resolve_model_to_use(model_type=model_type)
|
||||||
thread_data.model_paths['hypernetwork'] = model_manager.resolve_hypernetwork_model_to_use()
|
|
||||||
thread_data.model_paths['gfpgan'] = model_manager.resolve_gfpgan_model_to_use()
|
|
||||||
thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use()
|
|
||||||
|
|
||||||
# load mandatory models
|
# load mandatory models
|
||||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||||
@ -119,8 +114,8 @@ def apply_filters(req: Request, images: list, user_stopped):
|
|||||||
return images
|
return images
|
||||||
|
|
||||||
filters = []
|
filters = []
|
||||||
if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_gfpgan_model_to_use(req.use_face_correction)))
|
if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_model_to_use(req.use_face_correction, model_type='gfpgan')))
|
||||||
if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_realesrgan_model_to_use(req.use_upscale)))
|
if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_model_to_use(req.use_upscale, model_type='realesrgan')))
|
||||||
|
|
||||||
filtered_images = []
|
filtered_images = []
|
||||||
for img, seed, _ in images:
|
for img, seed, _ in images:
|
||||||
|
@ -11,10 +11,10 @@ TASK_TTL = 15 * 60 # seconds, Discard last session's task timeout
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import queue, threading, time, weakref
|
import queue, threading, time, weakref
|
||||||
from typing import Any, Generator, Hashable, Optional, Union
|
from typing import Any, Hashable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sd_internal import Request, Response, runtime, device_manager
|
from sd_internal import Request, device_manager
|
||||||
|
|
||||||
THREAD_NAME_PREFIX = 'Runtime-Render/'
|
THREAD_NAME_PREFIX = 'Runtime-Render/'
|
||||||
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
ERR_LOCK_FAILED = ' failed to acquire lock within timeout.'
|
||||||
|
@ -136,9 +136,9 @@ def ping(session_id:str=None):
|
|||||||
def render(req : task_manager.ImageRequest):
|
def render(req : task_manager.ImageRequest):
|
||||||
try:
|
try:
|
||||||
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
|
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
|
||||||
req.use_stable_diffusion_model = model_manager.resolve_sd_model_to_use(req.use_stable_diffusion_model)
|
req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion')
|
||||||
req.use_vae_model = model_manager.resolve_vae_model_to_use(req.use_vae_model)
|
req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae')
|
||||||
req.use_hypernetwork_model = model_manager.resolve_hypernetwork_model_to_use(req.use_hypernetwork_model)
|
req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork')
|
||||||
|
|
||||||
new_task = task_manager.render(req)
|
new_task = task_manager.render(req)
|
||||||
response = {
|
response = {
|
||||||
|
Loading…
Reference in New Issue
Block a user