Merge branch 'beta' into serverip

This commit is contained in:
cmdr2
2022-11-30 14:00:12 +05:30
committed by GitHub
19 changed files with 815 additions and 376 deletions

View File

@ -117,6 +117,8 @@ def setConfig(config):
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}")
if len(config_bat) > 0:
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat))
@ -134,6 +136,8 @@ def setConfig(config):
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"")
if len(config_sh) > 1:
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh))
@ -141,12 +145,19 @@ def setConfig(config):
print(traceback.format_exc())
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
config = getConfig()
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
if not model_name: # When None try user configured model.
config = getConfig()
# config = getConfig()
if 'model' in config and model_type in config['model']:
model_name = config['model'][model_type]
if model_name:
is_sd2 = config.get('test_sd2', False)
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
model_name = 'sd-v1-4'
# Check models directory
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
for model_extension in model_extensions:
@ -189,6 +200,7 @@ class SetAppConfigRequest(BaseModel):
ui_open_browser_on_start: bool = None
listen_to_network: bool = None
listen_port: int = None
test_sd2: bool = None
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
@ -209,6 +221,8 @@ async def setAppConfig(req : SetAppConfigRequest):
if 'net' not in config:
config['net'] = {}
config['net']['listen_port'] = int(req.listen_port)
if req.test_sd2 is not None:
config['test_sd2'] = req.test_sd2
try:
setConfig(config)
@ -231,9 +245,9 @@ def is_malicious_model(file_path):
return False
except Exception as e:
print('error while scanning', file_path, 'error:', e)
return False
known_models = {}
def getModels():
models = {
'active': {
@ -256,9 +270,14 @@ def getModels():
if not file.endswith(model_extension):
continue
if is_malicious_model(os.path.join(models_dir, file)):
models['scan-error'] = file
return
model_path = os.path.join(models_dir, file)
mtime = os.path.getmtime(model_path)
mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime:
if is_malicious_model(model_path):
models['scan-error'] = file
return
known_models[model_path] = mtime
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
@ -440,6 +459,9 @@ class LogSuppressFilter(logging.Filter):
return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
# Check models and prepare cache for UI open
getModels()
# Start the task_manager
task_manager.default_model_to_load = resolve_ckpt_to_use()
task_manager.default_vae_to_load = resolve_vae_to_use()