Make custom VAE an Image Setting, rather than a System Setting; Don't load a VAE into memory by default

This commit is contained in:
cmdr2
2022-11-08 16:54:15 +05:30
parent 67cca3bc00
commit 9bc7521de0
8 changed files with 102 additions and 99 deletions

View File

@ -30,9 +30,6 @@ APP_CONFIG_DEFAULT_MODELS = [
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
APP_CONFIG_DEFAULT_VAE = [
'vae-ft-mse-840000-ema-pruned', # Default fallback.
]
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
@ -142,7 +139,7 @@ def setConfig(config):
except Exception as e:
print(traceback.format_exc())
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extension:str, default_models=[]):
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
if not model_name: # When None try user configured model.
config = getConfig()
@ -151,43 +148,38 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex
if model_name:
# Check models directory
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
if os.path.exists(models_dir_path + model_extension):
return models_dir_path
if os.path.exists(model_name + model_extension):
# Direct Path to file
model_name = os.path.abspath(model_name)
return model_name
for model_extension in model_extensions:
if os.path.exists(models_dir_path + model_extension):
return models_dir_path
if os.path.exists(model_name + model_extension):
# Direct Path to file
model_name = os.path.abspath(model_name)
return model_name
# Default locations
if model_name in default_models:
default_model_path = os.path.join(SD_DIR, model_name)
if os.path.exists(default_model_path + model_extension):
return default_model_path
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
return default_model_path
# Can't find requested model, check the default paths.
for default_model in default_models:
for model_dir in model_dirs:
default_model_path = os.path.join(model_dir, default_model)
if os.path.exists(default_model_path + model_extension):
if model_name is not None:
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
if model_name is not None:
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path
raise Exception('No valid models found.')
def resolve_ckpt_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extension='.ckpt', default_models=APP_CONFIG_DEFAULT_MODELS)
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=['.ckpt'], default_models=APP_CONFIG_DEFAULT_MODELS)
def resolve_vae_to_use(ckpt_model_path:str=None):
if ckpt_model_path is not None:
if os.path.exists(ckpt_model_path + '.vae.pt'):
return ckpt_model_path
ckpt_model_name = os.path.basename(ckpt_model_path)
model_dirs = [os.path.join(MODELS_DIR, 'stable-diffusion'), SD_DIR]
for model_dir in model_dirs:
default_model_path = os.path.join(model_dir, ckpt_model_name)
if os.path.exists(default_model_path + '.vae.pt'):
return default_model_path
return resolve_model_to_use(model_name=None, model_type='vae', model_dir='stable-diffusion', model_extension='.vae.pt', default_models=APP_CONFIG_DEFAULT_VAE)
def resolve_vae_to_use(model_name:str=None):
try:
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=['.vae.pt', '.ckpt'], default_models=[])
except:
return None
class SetAppConfigRequest(BaseModel):
update_branch: str = None
@ -213,10 +205,6 @@ async def setAppConfig(req : SetAppConfigRequest):
render_devices.append('GPU:' + req.render_devices)
if len(render_devices) > 0:
config['render_devices'] = render_devices
if req.model_vae:
if 'model' not in config:
config['model'] = {}
config['model']['vae'] = req.model_vae
try:
setConfig(config)
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
@ -236,21 +224,25 @@ def getModels():
},
}
def listModels(models_dirname, model_type, model_extensions):
models_dir = os.path.join(MODELS_DIR, models_dirname)
for file in os.listdir(models_dir):
for model_extension in model_extensions:
if file.endswith(model_extension):
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates
# custom models
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
for model_type, model_extension in [('stable-diffusion', '.ckpt'), ('vae', '.vae.pt')]:
for file in os.listdir(sd_models_dir):
if file.endswith(model_extension):
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=['.ckpt'])
listModels(models_dirname='vae', model_type='vae', model_extensions=['.vae.pt', '.ckpt'])
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['options']['stable-diffusion'].append('custom-model')
models['active']['vae'] = os.path.basename(task_manager.default_vae_to_load)
return models
def getUIPlugins():
@ -307,12 +299,17 @@ def ping(session_id:str=None):
response['session'] = 'pending'
return JSONResponse(response, headers=NOCACHE_HEADERS)
def save_model_to_config(model_name):
def save_model_to_config(ckpt_model_name, vae_model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = model_name
config['model']['stable-diffusion'] = ckpt_model_name
config['model']['vae'] = vae_model_name
if vae_model_name is None or vae_model_name == "":
del config['model']['vae']
setConfig(config)
@app.post('/render')
@ -325,9 +322,9 @@ def render(req : task_manager.ImageRequest):
if req.use_face_correction and task_manager.is_alive(0) <= 0: #TODO Remove when GFPGANer is fixed upstream.
raise HTTPException(status_code=412, detail=f'GFPGANer only works GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.') # HTTP412 Precondition Failed
try:
save_model_to_config(req.use_stable_diffusion_model)
save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model)
req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model)
req.use_vae_model = resolve_vae_to_use(ckpt_model_path=req.use_stable_diffusion_model)
req.use_vae_model = resolve_vae_to_use(req.use_vae_model)
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),
@ -406,7 +403,7 @@ config = getConfig()
# Start the task_manager
task_manager.default_model_to_load = resolve_ckpt_to_use()
task_manager.default_vae_to_load = resolve_vae_to_use(ckpt_model_path=task_manager.default_model_to_load)
task_manager.default_vae_to_load = resolve_vae_to_use()
if 'render_devices' in config: # Start a new thread for each device.
if isinstance(config['render_devices'], str):
config['render_devices'] = config['render_devices'].split(',')