mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-08-09 07:45:01 +02:00
First draft for Multi-GPU support
This commit is contained in:
324
ui/server.py
324
ui/server.py
@ -15,14 +15,24 @@ MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
|
||||
|
||||
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
|
||||
TASK_TTL = 15 * 60 # Discard last session's task timeout
|
||||
APP_CONFIG_DEFAULTS = {
|
||||
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
|
||||
'render_devices': ['auto'], # ['cuda'] or ['CPU', 'GPU:0', 'GPU:1', ...] or ['cpu']
|
||||
'update_branch': 'main',
|
||||
}
|
||||
APP_CONFIG_DEFAULT_MODELS = [
|
||||
# needed to support the legacy installations
|
||||
'custom-model', # Check if user has a custom model, use it first.
|
||||
'sd-v1-4', # Default fallback.
|
||||
]
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
import queue, threading, time
|
||||
from typing import Any, Generator, Hashable, Optional, Union
|
||||
#import queue, threading, time
|
||||
from typing import Any, Generator, Hashable, List, Optional, Union
|
||||
|
||||
from sd_internal import Request, Response, task_manager
|
||||
|
||||
@ -37,52 +47,173 @@ ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
|
||||
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||
app.mount('/media', StaticFiles(directory=os.path.join(SD_UI_DIR, 'media/')), name="media")
|
||||
|
||||
config_cached = None
|
||||
config_last_mod_time = 0
|
||||
def getConfig(default_val=APP_CONFIG_DEFAULTS):
|
||||
global config_cached, config_last_mod_time
|
||||
try:
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
if not os.path.exists(config_json_path):
|
||||
return default_val
|
||||
if config_last_mod_time > 0 and config_cached is not None:
|
||||
# Don't read if file was not modified
|
||||
mtime = os.path.getmtime(config_json_path)
|
||||
if mtime <= config_last_mod_time:
|
||||
return config_cached
|
||||
with open(config_json_path, 'r') as f:
|
||||
config_cached = json.load(f)
|
||||
config_last_mod_time = os.path.getmtime(config_json_path)
|
||||
return config_cached
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
print(traceback.format_exc())
|
||||
return default_val
|
||||
|
||||
def setConfig(config):
|
||||
try: # config.json
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
with open(config_json_path, 'w') as f:
|
||||
return json.dump(config, f)
|
||||
except:
|
||||
print(traceback.format_exc())
|
||||
|
||||
if 'render_devices' in config:
|
||||
gpu_devices = filter(lambda dev: dev.startswith('GPU:'), config['render_devices'])
|
||||
else:
|
||||
gpu_devices = []
|
||||
|
||||
try: # config.bat
|
||||
config_bat = [
|
||||
f"@set update_branch={config['update_branch']}"
|
||||
]
|
||||
if len(gpu_devices) > 0:
|
||||
config_sh.append(f"@set CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}")
|
||||
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
|
||||
with open(config_bat_path, 'w') as f:
|
||||
f.write(f.write('\r\n'.join(config_bat)))
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
|
||||
try: # config.sh
|
||||
config_sh = [
|
||||
'#!/bin/bash'
|
||||
f"export update_branch={config['update_branch']}"
|
||||
]
|
||||
if len(gpu_devices) > 0:
|
||||
config_sh.append(f"CUDA_VISIBLE_DEVICES={','.join(gpu_devices)}")
|
||||
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
|
||||
with open(config_sh_path, 'w') as f:
|
||||
f.write('\n'.join(config_sh))
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
|
||||
def resolve_model_to_use(model_name:str=None):
|
||||
if not model_name: # When None try user configured model.
|
||||
config = getConfig()
|
||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
||||
model_name = config['model']['stable-diffusion']
|
||||
if model_name:
|
||||
if os.path.exists(model_name + '.ckpt'):
|
||||
# Direct Path to file
|
||||
return model_name
|
||||
# Check models directory
|
||||
models_dir_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
||||
if os.path.exists(models_dir_path + '.ckpt'):
|
||||
return models_dir_path
|
||||
# Default locations
|
||||
if model_name in APP_CONFIG_DEFAULT_MODELS:
|
||||
default_model_path = os.path.join(SD_DIR, model_name)
|
||||
if os.path.exists(default_model_path + '.ckpt'):
|
||||
return default_model_path
|
||||
# Can't find requested model, check the default paths.
|
||||
for default_model in APP_CONFIG_DEFAULT_MODELS:
|
||||
default_model_path = os.path.join(SD_DIR, default_model + '.ckpt')
|
||||
if os.path.exists(default_model_path):
|
||||
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', default_model_path + '.ckpt')
|
||||
return default_model_path
|
||||
raise Exception('No valid models found.')
|
||||
|
||||
class SetAppConfigRequest(BaseModel):
|
||||
update_branch: str = "main"
|
||||
update_branch: str = None
|
||||
render_devices: Union[List[str], List[int], str, int] = None
|
||||
|
||||
# needs to support the legacy installations
|
||||
def get_initial_model_to_load():
|
||||
@app.post('/app_config')
|
||||
async def setAppConfig(req : SetAppConfigRequest):
|
||||
config = getConfig()
|
||||
if req.update_branch:
|
||||
config['update_branch'] = req.update_branch
|
||||
if req.render_devices and hasattr(req.render_devices, "__len__"): # strings, array of strings or numbers.
|
||||
render_devices = []
|
||||
if isinstance(req.render_devices, str):
|
||||
req.render_devices = req.render_devices.split(',')
|
||||
if isinstance(req.render_devices, list):
|
||||
for gpu in req.render_devices:
|
||||
if isinstance(req.render_devices, int):
|
||||
render_devices.append('GPU:' + gpu)
|
||||
else:
|
||||
render_devices.append(gpu)
|
||||
if isinstance(req.render_devices, int):
|
||||
render_devices.append('GPU:' + req.render_devices)
|
||||
if len(render_devices) > 0:
|
||||
config['render_devices'] = render_devices
|
||||
try:
|
||||
setConfig(config)
|
||||
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def getModels():
|
||||
models = {
|
||||
'active': {
|
||||
'stable-diffusion': 'sd-v1-4',
|
||||
},
|
||||
'options': {
|
||||
'stable-diffusion': ['sd-v1-4'],
|
||||
},
|
||||
}
|
||||
|
||||
# custom models
|
||||
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
|
||||
for file in os.listdir(sd_models_dir):
|
||||
if file.endswith('.ckpt'):
|
||||
model_name = os.path.splitext(file)[0]
|
||||
models['options']['stable-diffusion'].append(model_name)
|
||||
|
||||
# legacy
|
||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
||||
ckpt_to_use = "sd-v1-4" if not os.path.exists(custom_weight_path) else "custom-model"
|
||||
|
||||
ckpt_to_use = os.path.join(SD_DIR, ckpt_to_use)
|
||||
if os.path.exists(custom_weight_path):
|
||||
models['active']['stable-diffusion'] = 'custom-model'
|
||||
models['options']['stable-diffusion'].append('custom-model')
|
||||
|
||||
config = getConfig()
|
||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
||||
model_name = config['model']['stable-diffusion']
|
||||
model_path = resolve_model_to_use(model_name)
|
||||
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
|
||||
|
||||
if os.path.exists(model_path + '.ckpt'):
|
||||
ckpt_to_use = model_path
|
||||
else:
|
||||
print('Could not find the configured custom model at:', model_path + '.ckpt', '. Using the default one:', ckpt_to_use + '.ckpt')
|
||||
return ckpt_to_use
|
||||
return models
|
||||
|
||||
def resolve_model_to_use(model_name):
|
||||
if model_name in ('sd-v1-4', 'custom-model'):
|
||||
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
||||
|
||||
legacy_model_path = os.path.join(SD_DIR, model_name)
|
||||
if not os.path.exists(model_path + '.ckpt') and os.path.exists(legacy_model_path + '.ckpt'):
|
||||
model_path = legacy_model_path
|
||||
@app.get('/get/{key:path}')
|
||||
def read_web_data(key:str=None):
|
||||
if not key: # /get without parameters, stable-diffusion easter egg.
|
||||
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
||||
elif key == 'app_config':
|
||||
config = getConfig(default_val=None)
|
||||
if config is None:
|
||||
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
|
||||
return JSONResponse(config, headers=NOCACHE_HEADERS)
|
||||
elif key == 'models':
|
||||
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
|
||||
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
|
||||
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
|
||||
else:
|
||||
model_path = os.path.join(MODELS_DIR, 'stable-diffusion', model_name)
|
||||
return model_path
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
task_manager.current_state_error = SystemExit('Application shutting down.')
|
||||
|
||||
@app.get('/')
|
||||
def read_root():
|
||||
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
|
||||
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
|
||||
|
||||
@app.get('/ping') # Get server and optionally session status.
|
||||
def ping(session_id:str=None):
|
||||
if not task_manager.render_thread.is_alive(): # Render thread is dead.
|
||||
if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(current_state_error))
|
||||
if task_manager.is_alive() <= 0: # Check that render threads are alive.
|
||||
if task_manager.current_state_error: return HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
return HTTPException(status_code=500, detail='Render thread is dead.')
|
||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(current_state_error))
|
||||
if task_manager.current_state_error and not isinstance(task_manager.current_state_error, StopAsyncIteration): return HTTPException(status_code=500, detail=str(task_manager.current_state_error))
|
||||
# Alive
|
||||
response = {'status': str(task_manager.current_state)}
|
||||
if session_id:
|
||||
@ -119,7 +250,7 @@ def render(req : task_manager.ImageRequest):
|
||||
new_task = task_manager.render(req)
|
||||
response = {
|
||||
'status': str(task_manager.current_state),
|
||||
'queue': task_manager.tasks_queue.qsize(),
|
||||
'queue': len(task_manager.tasks_queue),
|
||||
'stream': f'/image/stream/{req.session_id}/{id(new_task)}',
|
||||
'task': id(new_task)
|
||||
}
|
||||
@ -172,100 +303,13 @@ def get_image(session_id, img_id):
|
||||
except KeyError as e:
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.post('/app_config')
|
||||
async def setAppConfig(req : SetAppConfigRequest):
|
||||
try:
|
||||
config = {
|
||||
'update_branch': req.update_branch
|
||||
}
|
||||
@app.get('/')
|
||||
def read_root():
|
||||
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
|
||||
|
||||
config_json_str = json.dumps(config)
|
||||
config_bat_str = f'@set update_branch={req.update_branch}'
|
||||
config_sh_str = f'export update_branch={req.update_branch}'
|
||||
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
|
||||
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
|
||||
|
||||
with open(config_json_path, 'w') as f:
|
||||
f.write(config_json_str)
|
||||
|
||||
with open(config_bat_path, 'w') as f:
|
||||
f.write(config_bat_str)
|
||||
|
||||
with open(config_sh_path, 'w') as f:
|
||||
f.write(config_sh_str)
|
||||
|
||||
return {'OK'}
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
return HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def getConfig(default_val={}):
|
||||
try:
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
if not os.path.exists(config_json_path):
|
||||
return default_val
|
||||
with open(config_json_path, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
print(traceback.format_exc())
|
||||
return default_val
|
||||
|
||||
def setConfig(config):
|
||||
try:
|
||||
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
|
||||
with open(config_json_path, 'w') as f:
|
||||
return json.dump(config, f)
|
||||
except:
|
||||
print(str(e))
|
||||
print(traceback.format_exc())
|
||||
|
||||
def getModels():
|
||||
models = {
|
||||
'active': {
|
||||
'stable-diffusion': 'sd-v1-4',
|
||||
},
|
||||
'options': {
|
||||
'stable-diffusion': ['sd-v1-4'],
|
||||
},
|
||||
}
|
||||
|
||||
# custom models
|
||||
sd_models_dir = os.path.join(MODELS_DIR, 'stable-diffusion')
|
||||
for file in os.listdir(sd_models_dir):
|
||||
if file.endswith('.ckpt'):
|
||||
model_name = os.path.splitext(file)[0]
|
||||
models['options']['stable-diffusion'].append(model_name)
|
||||
|
||||
# legacy
|
||||
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
|
||||
if os.path.exists(custom_weight_path):
|
||||
models['active']['stable-diffusion'] = 'custom-model'
|
||||
models['options']['stable-diffusion'].append('custom-model')
|
||||
|
||||
config = getConfig()
|
||||
if 'model' in config and 'stable-diffusion' in config['model']:
|
||||
models['active']['stable-diffusion'] = config['model']['stable-diffusion']
|
||||
|
||||
return models
|
||||
|
||||
@app.get('/get/{key:path}')
|
||||
def read_web_data(key:str=None):
|
||||
if not key: # /get without parameters, stable-diffusion easter egg.
|
||||
return HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
|
||||
elif key == 'app_config':
|
||||
config = getConfig(default_val=None)
|
||||
if config is None:
|
||||
return HTTPException(status_code=500, detail="Config file is missing or unreadable")
|
||||
return JSONResponse(config, headers=NOCACHE_HEADERS)
|
||||
elif key == 'models':
|
||||
return JSONResponse(getModels(), headers=NOCACHE_HEADERS)
|
||||
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
|
||||
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS)
|
||||
else:
|
||||
return HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_event(): # Signal render thread to close on shutdown
|
||||
task_manager.current_state_error = SystemExit('Application shutting down.')
|
||||
|
||||
# don't log certain requests
|
||||
class LogSuppressFilter(logging.Filter):
|
||||
@ -277,8 +321,26 @@ class LogSuppressFilter(logging.Filter):
|
||||
return True
|
||||
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
|
||||
|
||||
task_manager.default_model_to_load = get_initial_model_to_load()
|
||||
task_manager.start_render_thread()
|
||||
config = getConfig()
|
||||
# Start the task_manager
|
||||
task_manager.default_model_to_load = resolve_model_to_use()
|
||||
if 'render_devices' in config: # Start a new thread for each device.
|
||||
if isinstance(config['render_devices'], str):
|
||||
config['render_devices'] = config['render_devices'].split(',')
|
||||
if not isinstance(config['render_devices'], list):
|
||||
raise Exception('Invalid render_devices value in config.')
|
||||
for device in config['render_devices']:
|
||||
task_manager.start_render_thread(device)
|
||||
|
||||
allow_cpu = False
|
||||
if task_manager.is_alive() <= 0: # No running devices, apply defaults.
|
||||
# Select best device GPU device using free memory if more than one device.
|
||||
task_manager.start_render_thread('auto')
|
||||
allow_cpu = True
|
||||
|
||||
# Allow CPU to be used for renders if not already enabled in current config.
|
||||
if task_manager.is_alive('cpu') <= 0 and allow_cpu:
|
||||
task_manager.start_render_thread('cpu')
|
||||
|
||||
# start the browser ui
|
||||
import webbrowser; webbrowser.open('http://localhost:9000')
|
Reference in New Issue
Block a user