First draft for Multi-GPU support

This commit is contained in:
Marc-Andre Ferland
2022-10-16 21:41:39 -04:00
parent 2edc06c662
commit 7c72608e1c
3 changed files with 584 additions and 376 deletions

View File

@ -15,14 +15,24 @@ MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = {
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
'render_devices': ['auto'], # ['cuda'] or ['CPU', 'GPU:0', 'GPU:1', ...] or ['cpu']
'update_branch': 'main',
}
APP_CONFIG_DEFAULT_MODELS = [
# needed to support the legacy installations
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
import logging
import queue, threading, time
from typing import Any, Generator, Hashable, Optional, Union
#import queue, threading, time
from typing import Any, Generator, Hashable, List, Optional, Union
from sd_internal import Request, Response, task_manager
@ -37,52 +47,173 @@ ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
config_cached = None
config_last_mod_time = 0
def getConfig(default_val=APP_CONFIG_DEFAULTS):
global config_cached, config_last_mod_time
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return default_val
if config_last_mod_time > 0 and config_cached is not None:
# Don't read if file was not modified
mtime = os.path.getmtime(config_json_path)
if mtime <= config_last_mod_time:
return config_cached
with open(config_json_path, 'r') as f:
config_cached = json.load(f)
config_last_mod_time = os.path.getmtime(config_json_path)
return config_cached
except Exception as e:
print(str(e))
print(traceback.format_exc())
return default_val
def setConfig(config):
try: # config.json
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w') as f:
return json.dump(config, f)
except:
print(traceback.format_exc())
if 'render_devices' in config:
gpu_devices = filter(lambda dev: dev.startswith('GPU:'), config['render_devices'])
else:
gpu_devices = []
try: # config.bat
config_bat = [
f"@set update_branch={config['update_branch']}"
]
if len(gpu_devices) > 0:
config_sh.append(f"@set CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}")
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
with open(config_bat_path, 'w') as f:
f.write(f.write('\r\n'.join(config_bat)))
except Exception as e:
print(traceback.format_exc())
try: # config.sh
config_sh = [
'#!/bin/bash'
f"export update_branch={config['update_branch']}"
]
if len(gpu_devices) > 0:
config_sh.append(f"CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}")
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
with open(config_sh_path, 'w') as f:
f.write('\n'.join(config_sh))
except Exception as e:
print(traceback.format_exc())
def resolve_model_to_use(model_name:str=None):
if not model_name: # When None try user configured model.
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion']
if model_name:
if os.path.exists(model_name + '.ckpt'):
# Direct Path to file
return model_name
# Check models directory
models_dir_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
if os.path.exists(models_dir_path + '.ckpt'):
return models_dir_path
# Default locations
if model_name in APP_CONFIG_DEFAULT_MODELS:
default_model_path = os.path.join(SD_DIR, model_name)
if os.path.exists(default_model_path + '.ckpt'):
return default_model_path
# Can't find requested model, check the default paths.
for default_model in APP_CONFIG_DEFAULT_MODELS:
default_model_path = os.path.join(SD_DIR, default_model + '.ckpt')
if os.path.exists(default_model_path):
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', default_model_path + '.ckpt')
return default_model_path
raise Exception('No valid models found.')
class SetAppConfigRequest(BaseModel):
update_branch: str = "main"
update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None
# needs to support the legacy installations
def get_initial_model_to_load():
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
config = getConfig()
if req.update_branch:
config['update_branch'] = req.update_branch
if req.render_devices and hasattr(req.render_devices, "__len__"): # strings, array of strings or numbers.
render_devices = []
if isinstance(req.render_devices, str):
req.render_devices = req.render_devices.split(',')
if isinstance(req.render_devices, list):
for gpu in req.render_devices:
if isinstance(req.render_devices, int):
render_devices.append('GPU:' + gpu)
else:
render_devices.append(gpu)
if isinstance(req.render_devices, int):
render_devices.append('GPU:' + req.render_devices)
if len(render_devices) > 0:
config['render_devices'] = render_devices
try:
setConfig(config)
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
},
}
# custom models
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
for file in os.listdir(sd_models_dir):
if file.endswith('.ckpt'):
model_name = os.path.splitext(file)[0]
models['options']['stable-diffusion'].append(model_name)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use)
if os.path.exists(custom_weight_path):
models['active']['stable-diffusion'] = 'custom-model'
models['options']['stable-diffusion'].append('custom-model')
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
model_name = config['model']['stable-diffusion']
model_path = resolve_model_to_use(model_name)
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
if os.path.exists(model_path + '.ckpt'):
ckpt_to_use = model_path
else:
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
return ckpt_to_use
return models
def resolve_model_to_use(model_name):
if model_name in ('sd-v1-4', 'custom-model'):
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
legacy_model_path = os.path.join(SD_DIR, model_name)
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'):
model_path = legacy_model_path
@app.get('/get/{key:path}')
def read_web_data(key:str=None):
if not key: # /get without parameters, stable-diffusion easter egg.
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == 'app_config':
config = getConfig(default_val=None)
if config is None:
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
return JSONResponse(config, headers=NOCACHE_HEADERS)
elif key == 'models':
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
else:
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
return model_path
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
@app.get('/')
def read_root():
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
@app.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None):
if not task_manager.render_thread.is_alive(): # Render thread is dead.
if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(current_state_error))
if task_manager.is_alive() <= 0: # Check that render threads are alive.
if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(task_manager.current_state_error))
return HTTPException(status_code=500, detail='Render thread is dead.')
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error))
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(task_manager.current_state_error))
# Alive
response = {'status': str(task_manager.current_state)}
if session_id:
@ -119,7 +250,7 @@ def render(req : task_manager.ImageRequest):
new_task = task_manager.render(req)
response = {
'status': str(task_manager.current_state),
'queue': task_manager.tasks_queue.qsize(),
'queue': len(task_manager.tasks_queue),
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
'task': id(new_task)
}
@ -172,100 +303,13 @@ def get_image(session_id, img_id):
except KeyError as e:
return HTTPException(status_code=500, detail=str(e))
@app.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest):
try:
config = {
'update_branch': req.update_branch
}
@app.get('/')
def read_root():
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
config_json_str = json.dumps(config)
config_bat_str = f'@set update_branch={req.update_branch}'
config_sh_str = f'export update_branch={req.update_branch}'
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
with open(config_json_path, 'w') as f:
f.write(config_json_str)
with open(config_bat_path, 'w') as f:
f.write(config_bat_str)
with open(config_sh_path, 'w') as f:
f.write(config_sh_str)
return {'OK'}
except Exception as e:
print(traceback.format_exc())
return HTTPException(status_code=500, detail=str(e))
def getConfig(default_val={}):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return default_val
with open(config_json_path, 'r') as f:
return json.load(f)
except Exception as e:
print(str(e))
print(traceback.format_exc())
return default_val
def setConfig(config):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w') as f:
return json.dump(config, f)
except:
print(str(e))
print(traceback.format_exc())
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
},
}
# custom models
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
for file in os.listdir(sd_models_dir):
if file.endswith('.ckpt'):
model_name = os.path.splitext(file)[0]
models['options']['stable-diffusion'].append(model_name)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['active']['stable-diffusion'] = 'custom-model'
models['options']['stable-diffusion'].append('custom-model')
config = getConfig()
if 'model' in config and 'stable-diffusion' in config['model']:
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
return models
@app.get('/get/{key:path}')
def read_web_data(key:str=None):
if not key: # /get without parameters, stable-diffusion easter egg.
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == 'app_config':
config = getConfig(default_val=None)
if config is None:
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
return JSONResponse(config, headers=NOCACHE_HEADERS)
elif key == 'models':
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
else:
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
@app.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.')
# don't log certain requests
class LogSuppressFilter(logging.Filter):
@ -277,8 +321,26 @@ class LogSuppressFilter(logging.Filter):
return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
task_manager.default_model_to_load = get_initial_model_to_load()
task_manager.start_render_thread()
config = getConfig()
# Start the task_manager
task_manager.default_model_to_load = resolve_model_to_use()
if 'render_devices' in config: # Start a new thread for each device.
if isinstance(config['render_devices'], str):
config['render_devices'] = config['render_devices'].split(',')
if not isinstance(config['render_devices'], list):
raise Exception('Invalid render_devices value in config.')
for device in config['render_devices']:
task_manager.start_render_thread(device)
allow_cpu = False
if task_manager.is_alive() <= 0: # No running devices, apply defaults.
# Select best device GPU device using free memory if more than one device.
task_manager.start_render_thread('auto')
allow_cpu = True
# Allow CPU to be used for renders if not already enabled in current config.
if task_manager.is_alive('cpu') <= 0 and allow_cpu:
task_manager.start_render_thread('cpu')
# start the browser ui
import webbrowser; webbrowser.open('http://localhost:9000')