Support custom VAE files; Use vae-ft-mse-840000-ema-pruned as the default VAE, which can be overridden by putting a .vae.pt file inside models/stable-diffusion with the same name as the ckpt model file. The UI / System Settings allows setting the default VAE model to use

This commit is contained in:
cmdr2
2022-10-28 20:06:44 +05:30
parent 79a7cd2938
commit a8c16e39b8
8 changed files with 196 additions and 56 deletions

View File

@ -30,6 +30,9 @@ APP_CONFIG_DEFAULT_MODELS = [
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
APP_CONFIG_DEFAULT_VAE = [
'vae-ft-mse-840000-ema-pruned', # Default fallback.
]
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
@ -129,37 +132,57 @@ def setConfig(config):
except Exception as e:
print(traceback.format_exc())
def resolve_model_to_use(model_name:str=None):
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extension:str, default_models=[]):
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
if not model_name: # When None try user configured model.
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion']
if 'model' in config and model_type in config['model']:
model_name = config['model'][model_type]
if model_name:
# Check models directory
models_dir_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
if os.path.exists(models_dir_path + '.ckpt'):
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
if os.path.exists(models_dir_path + model_extension):
return models_dir_path
if os.path.exists(model_name + '.ckpt'):
if os.path.exists(model_name + model_extension):
# Direct Path to file
model_name = os.path.abspath(model_name)
return model_name
# Default locations
if model_name in APP_CONFIG_DEFAULT_MODELS:
if model_name in default_models:
default_model_path = os.path.join(SD_DIR, model_name)
if os.path.exists(default_model_path + '.ckpt'):
if os.path.exists(default_model_path + model_extension):
return default_model_path
# Can't find requested model, check the default paths.
for default_model in APP_CONFIG_DEFAULT_MODELS:
default_model_path = os.path.join(SD_DIR, default_model)
if os.path.exists(default_model_path + '.ckpt'):
if model_name is not None:
print(f'Could not find the configured custom model {model_name}.ckpt. Using the default one: {default_model_path}.ckpt')
return default_model_path
for default_model in default_models:
for model_dir in model_dirs:
default_model_path = os.path.join(model_dir, default_model)
if os.path.exists(default_model_path + model_extension):
if model_name is not None:
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path
raise Exception('No valid models found.')
def resolve_ckpt_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extension='.ckpt', default_models=APP_CONFIG_DEFAULT_MODELS)
def resolve_vae_to_use(ckpt_model_path:str=None):
if ckpt_model_path is not None:
if os.path.exists(ckpt_model_path + '.vae.pt'):
return ckpt_model_path
ckpt_model_name = os.path.basename(ckpt_model_path)
model_dirs = [os.path.join(MODELS_DIR, 'stable-diffusion'), SD_DIR]
for model_dir in model_dirs:
default_model_path = os.path.join(model_dir, ckpt_model_name)
if os.path.exists(default_model_path + '.vae.pt'):
return default_model_path
return resolve_model_to_use(model_name=None, model_type='vae', model_dir='stable-diffusion', model_extension='.vae.pt', default_models=APP_CONFIG_DEFAULT_VAE)
class SetAppConfigRequest(BaseModel):
update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
model_vae: str = None
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
@ -180,6 +203,10 @@ async def setAppConfig(req : SetAppConfigRequest):
render_devices.append('GPU:' + req.render_devices)
if len(render_devices) > 0:
config['render_devices'] = render_devices
if req.model_vae:
if 'model' not in config:
config['model'] = {}
config['model']['vae'] = req.model_vae
try:
setConfig(config)
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
@ -191,28 +218,28 @@ def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
'vae': '',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
'vae': [],
},
}
# custom models
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
for file in os.listdir(sd_models_dir):
if file.endswith('.ckpt'):
model_name = os.path.splitext(file)[0]
models['options']['stable-diffusion'].append(model_name)
for model_type, model_extension in [('stable-diffusion', '.ckpt'), ('vae', '.vae.pt')]:
for file in os.listdir(sd_models_dir):
if file.endswith(model_extension):
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['active']['stable-diffusion'] = 'custom-model'
models['options']['stable-diffusion'].append('custom-model')
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
models['active']['vae'] = os.path.basename(task_manager.default_vae_to_load)
return models
@ -283,7 +310,8 @@ def render(req : task_manager.ImageRequest):
raise HTTPException(status_code=412, detail=f'GFPGANer only works GPU:0, use CUDA_VISIBLE_DEVICES if GFPGANer is needed on a specific GPU.') # HTTP412 Precondition Failed
try:
save_model_to_config(req.use_stable_diffusion_model)
req.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model)
req.use_vae_model = resolve_vae_to_use(ckpt_model_path=req.use_stable_diffusion_model)
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),
@ -361,7 +389,8 @@ logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
config = getConfig()
# Start the task_manager
task_manager.default_model_to_load = resolve_model_to_use()
task_manager.default_model_to_load = resolve_ckpt_to_use()
task_manager.default_vae_to_load = resolve_vae_to_use(ckpt_model_path=task_manager.default_model_to_load)
if 'render_devices' in config: # Start a new thread for each device.
if isinstance(config['render_devices'], str):
config['render_devices'] = config['render_devices'].split(',')