Support an arbitrary number of custom models, placed in the models/stable-diffusion folder. Shows an option in the UI to select which model to use

This commit is contained in:
cmdr2
2022-10-06 14:28:02 +05:30
parent 703f987825
commit 201a053025
7 changed files with 165 additions and 15 deletions

View File

@ -4,13 +4,14 @@ import traceback
import sys
import os
SCRIPT_DIR = os.getcwd()
print('started in ', SCRIPT_DIR)
SD_DIR = os.getcwd()
print('started in ', SD_DIR)
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
sys.path.append(os.path.dirname(SD_UI_DIR))
CONFIG_DIR = os.path.join(SD_UI_DIR, '..', 'scripts')
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
@ -57,6 +58,7 @@ class ImageRequest(BaseModel):
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
show_only_filtered_image: bool = False
stream_progress_updates: bool = False
@ -85,9 +87,7 @@ async def ping():
from sd_internal import runtime
custom_weight_path = os.path.join(SCRIPT_DIR, 'custom-model.ckpt')
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
runtime.load_model_ckpt(ckpt_to_use=ckpt_to_use)
runtime.load_model_ckpt(ckpt_to_use=get_initial_model_to_load())
model_loaded = True
model_is_loading = False
@ -97,6 +97,46 @@ async def ping():
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
# needs to support the legacy installations
def get_initial_model_to_load():
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use)
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion']
model_path = resolve_model_to_use(model_name)
if os.path.exists(model_path + '.ckpt'):
ckpt_to_use = model_path
else:
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
return ckpt_to_use
def resolve_model_to_use(model_name):
if model_name in ('sd-v1-4', 'custom-model'):
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
legacy_model_path = os.path.join(SD_DIR, model_name)
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'):
model_path = legacy_model_path
else:
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
return model_path
def save_model_to_config(model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = model_name
setConfig(config)
@app.post('/image')
def image(req : ImageRequest):
from sd_internal import runtime
@ -127,6 +167,10 @@ def image(req : ImageRequest):
r.stream_progress_updates = True # the underlying implementation only supports streaming
r.stream_image_progress = req.stream_image_progress
r.use_stable_diffusion_model = resolve_model_to_use(req.use_stable_diffusion_model)
save_model_to_config(req.use_stable_diffusion_model)
try:
if not req.stream_progress_updates:
r.stream_image_progress = False
@ -205,13 +249,62 @@ def getAppConfig():
return HTTPException(status_code=500, detail="No config file")
with open(config_json_path, 'r') as f:
config_json_str = f.read()
config = json.loads(config_json_str)
return config
return json.load(f)
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
def getConfig():
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return {}
with open(config_json_path, 'r') as f:
return json.load(f)
except Exception as e:
return {}
def setConfig(config):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w') as f:
return json.dump(config, f)
except:
print(traceback.format_exc())
@app.get('/models')
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
},
}
# custom models
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
for file in os.listdir(sd_models_dir):
if file.endswith('.ckpt'):
model_name = os.path.splitext(file)[0]
models['options']['stable-diffusion'].append(model_name)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['active']['stable-diffusion'] = 'custom-model'
models['options']['stable-diffusion'].append('custom-model')
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
return models
@app.get('/modifiers.json')
def read_modifiers():
headers = {"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}