Work-in-progress refactor of the backend, to move most of the logic to diffusion-kit and keeping this as a UI around that engine. Does not work yet.

This commit is contained in:
cmdr2 2022-12-07 22:15:35 +05:30
parent bfdf487d52
commit fb6a7e04f5
7 changed files with 484 additions and 417 deletions

View File

@ -40,7 +40,6 @@ class Request:
"num_outputs": self.num_outputs, "num_outputs": self.num_outputs,
"num_inference_steps": self.num_inference_steps, "num_inference_steps": self.num_inference_steps,
"guidance_scale": self.guidance_scale, "guidance_scale": self.guidance_scale,
"hypernetwork_strengtgh": self.guidance_scale,
"width": self.width, "width": self.width,
"height": self.height, "height": self.height,
"seed": self.seed, "seed": self.seed,

156
ui/sd_internal/app.py Normal file
View File

@ -0,0 +1,156 @@
import os
import socket
import sys
import json
import traceback
from sd_internal import task_manager
SD_DIR = os.getcwd()
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
sys.path.append(os.path.dirname(SD_UI_DIR))
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'))
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = {
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
'update_branch': 'main',
'ui': {
'open_browser_on_start': True,
},
}
DEFAULT_MODELS = [
# needed to support the legacy installations
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
def init():
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
update_render_threads()
def getConfig(default_val=APP_CONFIG_DEFAULTS):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return default_val
with open(config_json_path, 'r', encoding='utf-8') as f:
config = json.load(f)
if 'net' not in config:
config['net'] = {}
if os.getenv('SD_UI_BIND_PORT') is not None:
config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT'))
if os.getenv('SD_UI_BIND_IP') is not None:
config['net']['listen_to_network'] = (os.getenv('SD_UI_BIND_IP') == '0.0.0.0')
return config
except Exception as e:
print(str(e))
print(traceback.format_exc())
return default_val
def setConfig(config):
try: # config.json
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w', encoding='utf-8') as f:
json.dump(config, f)
except:
print(traceback.format_exc())
try: # config.bat
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
config_bat = []
if 'update_branch' in config:
config_bat.append(f"@set update_branch={config['update_branch']}")
config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}")
if len(config_bat) > 0:
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat))
except:
print(traceback.format_exc())
try: # config.sh
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
config_sh = ['#!/bin/bash']
if 'update_branch' in config:
config_sh.append(f"export update_branch={config['update_branch']}")
config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"")
if len(config_sh) > 1:
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh))
except:
print(traceback.format_exc())
def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name):
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = ckpt_model_name
config['model']['vae'] = vae_model_name
config['model']['hypernetwork'] = hypernetwork_model_name
if vae_model_name is None or vae_model_name == "":
del config['model']['vae']
if hypernetwork_model_name is None or hypernetwork_model_name == "":
del config['model']['hypernetwork']
setConfig(config)
def update_render_threads():
config = getConfig()
render_devices = config.get('render_devices', 'auto')
active_devices = task_manager.get_devices()['active'].keys()
print('requesting for render_devices', render_devices)
task_manager.update_render_threads(render_devices, active_devices)
def getUIPlugins():
plugins = []
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
for file in os.listdir(plugins_dir):
if file.endswith('.plugin.js'):
plugins.append(f'/plugins/{dir_prefix}/{file}')
return plugins
def getIPConfig():
ips = socket.gethostbyname_ex(socket.gethostname())
ips[2].append(ips[0])
return ips[2]
def open_browser():
config = getConfig()
ui = config.get('ui', {})
net = config.get('net', {'listen_port':9000})
port = net.get('listen_port', 9000)
if ui.get('open_browser_on_start', True):
import webbrowser; webbrowser.open(f"http://localhost:{port}")

View File

@ -82,7 +82,7 @@ def auto_pick_devices(currently_active_devices):
devices = list(map(lambda x: x['device'], devices)) devices = list(map(lambda x: x['device'], devices))
return devices return devices
def device_init(thread_data, device): def device_init(context, device):
''' '''
This function assumes the 'device' has already been verified to be compatible. This function assumes the 'device' has already been verified to be compatible.
`get_device_delta()` has already filtered out incompatible devices. `get_device_delta()` has already filtered out incompatible devices.
@ -91,21 +91,22 @@ def device_init(thread_data, device):
validate_device_id(device, log_prefix='device_init') validate_device_id(device, log_prefix='device_init')
if device == 'cpu': if device == 'cpu':
thread_data.device = 'cpu' context.device = 'cpu'
thread_data.device_name = get_processor_name() context.device_name = get_processor_name()
print('Render device CPU available as', thread_data.device_name) context.precision = 'full'
print('Render device CPU available as', context.device_name)
return return
thread_data.device_name = torch.cuda.get_device_name(device) context.device_name = torch.cuda.get_device_name(device)
thread_data.device = device context.device = device
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
device_name = thread_data.device_name.lower() device_name = context.device_name.lower()
thread_data.force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
if thread_data.force_full_precision: if force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', thread_data.device_name) print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', context.device_name)
# Apply force_full_precision now before models are loaded. # Apply force_full_precision now before models are loaded.
thread_data.precision = 'full' context.precision = 'full'
print(f'Setting {device} as active') print(f'Setting {device} as active')
torch.cuda.device(device) torch.cuda.device(device)

View File

@ -0,0 +1,141 @@
import os
from sd_internal import app
import picklescan.scanner
import rich
default_model_to_load = None
default_vae_to_load = None
default_hypernetwork_to_load = None
known_models = {}
def init():
global default_model_to_load, default_vae_to_load, default_hypernetwork_to_load
default_model_to_load = resolve_ckpt_to_use()
default_vae_to_load = resolve_vae_to_use()
default_hypernetwork_to_load = resolve_hypernetwork_to_use()
getModels() # run this once, to cache the picklescan results
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
config = app.getConfig()
model_dirs = [os.path.join(app.MODELS_DIR, model_dir), app.SD_DIR]
if not model_name: # When None try user configured model.
# config = getConfig()
if 'model' in config and model_type in config['model']:
model_name = config['model'][model_type]
if model_name:
is_sd2 = config.get('test_sd2', False)
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
model_name = 'sd-v1-4'
# Check models directory
models_dir_path = os.path.join(app.MODELS_DIR, model_dir, model_name)
for model_extension in model_extensions:
if os.path.exists(models_dir_path + model_extension):
return models_dir_path + model_extension
if os.path.exists(model_name + model_extension):
return os.path.abspath(model_name + model_extension)
# Default locations
if model_name in default_models:
default_model_path = os.path.join(app.SD_DIR, model_name)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
return default_model_path + model_extension
# Can't find requested model, check the default paths.
for default_model in default_models:
for model_dir in model_dirs:
default_model_path = os.path.join(model_dir, default_model)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
if model_name is not None:
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path + model_extension
raise Exception('No valid models found.')
def resolve_ckpt_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS)
def resolve_vae_to_use(model_name:str=None):
try:
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=app.VAE_MODEL_EXTENSIONS, default_models=[])
except:
return None
def resolve_hypernetwork_to_use(model_name:str=None):
try:
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
except:
return None
def is_malicious_model(file_path):
try:
scan_result = picklescan.scanner.scan_file_path(file_path)
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return True
else:
rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return False
except Exception as e:
print('error while scanning', file_path, 'error:', e)
return False
def getModels():
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
'vae': '',
'hypernetwork': '',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
'vae': [],
'hypernetwork': [],
},
}
def listModels(models_dirname, model_type, model_extensions):
models_dir = os.path.join(app.MODELS_DIR, models_dirname)
if not os.path.exists(models_dir):
os.makedirs(models_dir)
for file in os.listdir(models_dir):
for model_extension in model_extensions:
if not file.endswith(model_extension):
continue
model_path = os.path.join(models_dir, file)
mtime = os.path.getmtime(model_path)
mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime:
if is_malicious_model(model_path):
models['scan-error'] = file
return
known_models[model_path] = mtime
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates
models['options'][model_type].sort()
# custom models
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=app.STABLE_DIFFUSION_MODEL_EXTENSIONS)
listModels(models_dirname='vae', model_type='vae', model_extensions=app.VAE_MODEL_EXTENSIONS)
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=app.HYPERNETWORK_MODEL_EXTENSIONS)
# legacy
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['options']['stable-diffusion'].append('custom-model')
return models

View File

@ -0,0 +1,95 @@
import threading
import queue
from sd_internal import device_manager, Request, Response, Image as ResponseImage
from modules import model_loader, image_generator, image_utils
thread_data = threading.local()
'''
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
'''
def init(device):
'''
Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device
'''
thread_data.stop_processing = False
thread_data.temp_images = {}
thread_data.models = {}
thread_data.loaded_model_paths = {}
thread_data.device = None
thread_data.device_name = None
thread_data.precision = 'autocast'
thread_data.vram_optimizations = ('TURBO', 'MOVE_MODELS')
device_manager.device_init(thread_data, device)
reload_models()
def destroy():
model_loader.unload_sd_model(thread_data)
model_loader.unload_gfpgan_model(thread_data)
model_loader.unload_realesrgan_model(thread_data)
def reload_models(req: Request=None):
if is_hypernetwork_reload_necessary(task.request):
current_state = ServerStates.LoadingModel
runtime.reload_hypernetwork()
if is_model_reload_necessary(task.request):
current_state = ServerStates.LoadingModel
runtime.reload_model()
def load_models():
if ckpt_file_path == None:
ckpt_file_path = default_model_to_load
if vae_file_path == None:
vae_file_path = default_vae_to_load
if hypernetwork_file_path == None:
hypernetwork_file_path = default_hypernetwork_to_load
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
return
current_state = ServerStates.LoadingModel
try:
from sd_internal import runtime2
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
runtime.thread_data.ckpt_file = ckpt_file_path
runtime.thread_data.vae_file = vae_file_path
runtime.load_model_ckpt()
runtime.load_hypernetwork()
current_model_path = ckpt_file_path
current_vae_path = vae_file_path
current_hypernetwork_path = hypernetwork_file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
current_model_path = None
current_vae_path = None
current_state_error = e
current_state = ServerStates.Unavailable
print(traceback.format_exc())
def make_image(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
try:
images = image_generator.make_image(context=thread_data, args=get_mk_img_args(req))
except UserInitiatedStop:
pass
def get_mk_img_args(req: Request):
args = req.json()
if req.init_image is not None:
args['init_image'] = image_utils.base64_str_to_img(req.init_image)
if req.mask is not None:
args['mask'] = image_utils.base64_str_to_img(req.mask)
return args
def on_image_step(x_samples, i):
pass
image_generator.on_image_step = on_image_step

View File

@ -177,50 +177,14 @@ manager_lock = threading.RLock()
render_threads = [] render_threads = []
current_state = ServerStates.Init current_state = ServerStates.Init
current_state_error:Exception = None current_state_error:Exception = None
current_model_path = None
current_vae_path = None
current_hypernetwork_path = None
tasks_queue = [] tasks_queue = []
task_cache = TaskCache() task_cache = TaskCache()
default_model_to_load = None
default_vae_to_load = None
default_hypernetwork_to_load = None
weak_thread_data = weakref.WeakKeyDictionary() weak_thread_data = weakref.WeakKeyDictionary()
def preload_model(ckpt_file_path=None, vae_file_path=None, hypernetwork_file_path=None):
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path
if ckpt_file_path == None:
ckpt_file_path = default_model_to_load
if vae_file_path == None:
vae_file_path = default_vae_to_load
if hypernetwork_file_path == None:
hypernetwork_file_path = default_hypernetwork_to_load
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
return
current_state = ServerStates.LoadingModel
try:
from . import runtime
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
runtime.thread_data.ckpt_file = ckpt_file_path
runtime.thread_data.vae_file = vae_file_path
runtime.load_model_ckpt()
runtime.load_hypernetwork()
current_model_path = ckpt_file_path
current_vae_path = vae_file_path
current_hypernetwork_path = hypernetwork_file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
current_model_path = None
current_vae_path = None
current_state_error = e
current_state = ServerStates.Unavailable
print(traceback.format_exc())
def thread_get_next_task(): def thread_get_next_task():
from . import runtime from sd_internal import runtime2
if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT): if not manager_lock.acquire(blocking=True, timeout=LOCK_TIMEOUT):
print('Render thread on device', runtime.thread_data.device, 'failed to acquire manager lock.') print('Render thread on device', runtime2.thread_data.device, 'failed to acquire manager lock.')
return None return None
if len(tasks_queue) <= 0: if len(tasks_queue) <= 0:
manager_lock.release() manager_lock.release()
@ -228,7 +192,7 @@ def thread_get_next_task():
task = None task = None
try: # Select a render task. try: # Select a render task.
for queued_task in tasks_queue: for queued_task in tasks_queue:
if queued_task.render_device and runtime.thread_data.device != queued_task.render_device: if queued_task.render_device and runtime2.thread_data.device != queued_task.render_device:
# Is asking for a specific render device. # Is asking for a specific render device.
if is_alive(queued_task.render_device) > 0: if is_alive(queued_task.render_device) > 0:
continue # requested device alive, skip current one. continue # requested device alive, skip current one.
@ -237,7 +201,7 @@ def thread_get_next_task():
queued_task.error = Exception(queued_task.render_device + ' is not currently active.') queued_task.error = Exception(queued_task.render_device + ' is not currently active.')
task = queued_task task = queued_task
break break
if not queued_task.render_device and runtime.thread_data.device == 'cpu' and is_alive() > 1: if not queued_task.render_device and runtime2.thread_data.device == 'cpu' and is_alive() > 1:
# not asking for any specific devices, cpu want to grab task but other render devices are alive. # not asking for any specific devices, cpu want to grab task but other render devices are alive.
continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it. continue # Skip Tasks, don't run on CPU unless there is nothing else or user asked for it.
task = queued_task task = queued_task
@ -249,30 +213,31 @@ def thread_get_next_task():
manager_lock.release() manager_lock.release()
def thread_render(device): def thread_render(device):
global current_state, current_state_error, current_model_path, current_vae_path, current_hypernetwork_path global current_state, current_state_error
from . import runtime
from sd_internal import runtime2
try: try:
runtime.thread_init(device) runtime2.init(device)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
weak_thread_data[threading.current_thread()] = { weak_thread_data[threading.current_thread()] = {
'error': e 'error': e
} }
return return
weak_thread_data[threading.current_thread()] = { weak_thread_data[threading.current_thread()] = {
'device': runtime.thread_data.device, 'device': runtime2.thread_data.device,
'device_name': runtime.thread_data.device_name, 'device_name': runtime2.thread_data.device_name,
'alive': True 'alive': True
} }
if runtime.thread_data.device != 'cpu' or is_alive() == 1:
preload_model() current_state = ServerStates.Online
current_state = ServerStates.Online
while True: while True:
task_cache.clean() task_cache.clean()
if not weak_thread_data[threading.current_thread()]['alive']: if not weak_thread_data[threading.current_thread()]['alive']:
print(f'Shutting down thread for device {runtime.thread_data.device}') print(f'Shutting down thread for device {runtime2.thread_data.device}')
runtime.unload_models() runtime2.destroy()
runtime.unload_filters()
return return
if isinstance(current_state_error, SystemExit): if isinstance(current_state_error, SystemExit):
current_state = ServerStates.Unavailable current_state = ServerStates.Unavailable
@ -291,24 +256,17 @@ def thread_render(device):
task.response = {"status": 'failed', "detail": str(task.error)} task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
continue continue
print(f'Session {task.request.session_id} starting task {id(task)} on {runtime.thread_data.device_name}') print(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try: try:
if runtime.is_hypernetwork_reload_necessary(task.request): current_state = ServerStates.LoadingModel
runtime.reload_hypernetwork() runtime2.reload_models(task.request)
current_hypernetwork_path = task.request.use_hypernetwork_model
if runtime.is_model_reload_necessary(task.request):
current_state = ServerStates.LoadingModel
runtime.reload_model()
current_model_path = task.request.use_stable_diffusion_model
current_vae_path = task.request.use_vae_model
def step_callback(): def step_callback():
global current_state_error global current_state_error
if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration): if isinstance(current_state_error, SystemExit) or isinstance(current_state_error, StopAsyncIteration) or isinstance(task.error, StopAsyncIteration):
runtime.thread_data.stop_processing = True runtime2.thread_data.stop_processing = True
if isinstance(current_state_error, StopAsyncIteration): if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error task.error = current_state_error
current_state_error = None current_state_error = None
@ -317,7 +275,7 @@ def thread_render(device):
task_cache.keep(task.request.session_id, TASK_TTL) task_cache.keep(task.request.session_id, TASK_TTL)
current_state = ServerStates.Rendering current_state = ServerStates.Rendering
task.response = runtime.mk_img(task.request, task.buffer_queue, task.temp_images, step_callback) task.response = runtime2.make_image(task.request, task.buffer_queue, task.temp_images, step_callback)
except Exception as e: except Exception as e:
task.error = e task.error = e
print(traceback.format_exc()) print(traceback.format_exc())
@ -331,7 +289,7 @@ def thread_render(device):
elif task.error is not None: elif task.error is not None:
print(f'Session {task.request.session_id} task {id(task)} failed!') print(f'Session {task.request.session_id} task {id(task)} failed!')
else: else:
print(f'Session {task.request.session_id} task {id(task)} completed by {runtime.thread_data.device_name}.') print(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
current_state = ServerStates.Online current_state = ServerStates.Online
def get_cached_task(session_id:str, update_ttl:bool=False): def get_cached_task(session_id:str, update_ttl:bool=False):
@ -493,8 +451,7 @@ def render(req : ImageRequest):
if task and not task.response and not task.error and not task.lock.locked(): if task and not task.response and not task.error and not task.lock.locked():
# Unstarted task pending, deny queueing more than one. # Unstarted task pending, deny queueing more than one.
raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.') raise ConnectionRefusedError(f'Session {req.session_id} has an already pending task.')
#
from . import runtime
r = Request() r = Request()
r.session_id = req.session_id r.session_id = req.session_id
r.prompt = req.prompt r.prompt = req.prompt

View File

@ -2,64 +2,21 @@
Notes: Notes:
async endpoints always run on the main thread. Without they run on the thread pool. async endpoints always run on the main thread. Without they run on the thread pool.
""" """
import json
import traceback
import sys
import os import os
import socket import traceback
import picklescan.scanner import logging
import rich from typing import List, Union
SD_DIR = os.getcwd()
print('started in ', SD_DIR)
SD_UI_DIR = os.getenv('SD_UI_PATH', None)
sys.path.append(os.path.dirname(SD_UI_DIR))
CONFIG_DIR = os.path.abspath(os.path.join(SD_UI_DIR, '..', 'scripts'))
MODELS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'models'))
USER_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_DIR, '..', 'plugins', 'ui'))
CORE_UI_PLUGINS_DIR = os.path.abspath(os.path.join(SD_UI_DIR, 'plugins', 'ui'))
UI_PLUGINS_SOURCES = ((CORE_UI_PLUGINS_DIR, 'core'), (USER_UI_PLUGINS_DIR, 'user'))
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
OUTPUT_DIRNAME = "Stable Diffusion UI" # in the user's home folder
TASK_TTL = 15 * 60 # Discard last session's task timeout
APP_CONFIG_DEFAULTS = {
# auto: selects the cuda device with the most free memory, cuda: use the currently active cuda device.
'render_devices': 'auto', # valid entries: 'auto', 'cpu' or 'cuda:N' (where N is a GPU index)
'update_branch': 'main',
'ui': {
'open_browser_on_start': True,
},
}
APP_CONFIG_DEFAULT_MODELS = [
# needed to support the legacy installations
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from starlette.responses import FileResponse, JSONResponse, StreamingResponse from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
import logging
#import queue, threading, time
from typing import Any, Generator, Hashable, List, Optional, Union
from sd_internal import Request, Response, task_manager from sd_internal import app, model_manager, task_manager
app = FastAPI() print('started in ', app.SD_DIR)
modifiers_cache = None server_api = FastAPI()
outpath = os.path.join(os.path.expanduser("~"), OUTPUT_DIRNAME)
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
# don't show access log entries for URLs that start with the given prefix # don't show access log entries for URLs that start with the given prefix
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails'] ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
@ -74,132 +31,6 @@ class NoCacheStaticFiles(StaticFiles):
return super().is_not_modified(response_headers, request_headers) return super().is_not_modified(response_headers, request_headers)
app.mount('/media', NoCacheStaticFiles(directory=os.path.join(SD_UI_DIR, 'media')), name="media")
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
app.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}")
def getConfig(default_val=APP_CONFIG_DEFAULTS):
try:
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
if not os.path.exists(config_json_path):
return default_val
with open(config_json_path, 'r', encoding='utf-8') as f:
config = json.load(f)
if 'net' not in config:
config['net'] = {}
if os.getenv('SD_UI_BIND_PORT') is not None:
config['net']['listen_port'] = int(os.getenv('SD_UI_BIND_PORT'))
if os.getenv('SD_UI_BIND_IP') is not None:
config['net']['listen_to_network'] = ( os.getenv('SD_UI_BIND_IP') == '0.0.0.0' )
return config
except Exception as e:
print(str(e))
print(traceback.format_exc())
return default_val
def setConfig(config):
print( json.dumps(config) )
try: # config.json
config_json_path = os.path.join(CONFIG_DIR, 'config.json')
with open(config_json_path, 'w', encoding='utf-8') as f:
json.dump(config, f)
except:
print(traceback.format_exc())
try: # config.bat
config_bat_path = os.path.join(CONFIG_DIR, 'config.bat')
config_bat = []
if 'update_branch' in config:
config_bat.append(f"@set update_branch={config['update_branch']}")
config_bat.append(f"@set SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_bat.append(f"@set SD_UI_BIND_IP={bind_ip}")
config_bat.append(f"@set test_sd2={'Y' if config.get('test_sd2', False) else 'N'}")
if len(config_bat) > 0:
with open(config_bat_path, 'w', encoding='utf-8') as f:
f.write('\r\n'.join(config_bat))
except:
print(traceback.format_exc())
try: # config.sh
config_sh_path = os.path.join(CONFIG_DIR, 'config.sh')
config_sh = ['#!/bin/bash']
if 'update_branch' in config:
config_sh.append(f"export update_branch={config['update_branch']}")
config_sh.append(f"export SD_UI_BIND_PORT={config['net']['listen_port']}")
bind_ip = '0.0.0.0' if config['net']['listen_to_network'] else '127.0.0.1'
config_sh.append(f"export SD_UI_BIND_IP={bind_ip}")
config_sh.append(f"export test_sd2=\"{'Y' if config.get('test_sd2', False) else 'N'}\"")
if len(config_sh) > 1:
with open(config_sh_path, 'w', encoding='utf-8') as f:
f.write('\n'.join(config_sh))
except:
print(traceback.format_exc())
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
config = getConfig()
model_dirs = [os.path.join(MODELS_DIR, model_dir), SD_DIR]
if not model_name: # When None try user configured model.
# config = getConfig()
if 'model' in config and model_type in config['model']:
model_name = config['model'][model_type]
if model_name:
is_sd2 = config.get('test_sd2', False)
if model_name.startswith('sd2_') and not is_sd2: # temp hack, until SD2 is unified with 1.4
print('ERROR: Cannot use SD 2.0 models with SD 1.0 code. Using the sd-v1-4 model instead!')
model_name = 'sd-v1-4'
# Check models directory
models_dir_path = os.path.join(MODELS_DIR, model_dir, model_name)
for model_extension in model_extensions:
if os.path.exists(models_dir_path + model_extension):
return models_dir_path
if os.path.exists(model_name + model_extension):
# Direct Path to file
model_name = os.path.abspath(model_name)
return model_name
# Default locations
if model_name in default_models:
default_model_path = os.path.join(SD_DIR, model_name)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
return default_model_path
# Can't find requested model, check the default paths.
for default_model in default_models:
for model_dir in model_dirs:
default_model_path = os.path.join(model_dir, default_model)
for model_extension in model_extensions:
if os.path.exists(default_model_path + model_extension):
if model_name is not None:
print(f'Could not find the configured custom model {model_name}{model_extension}. Using the default one: {default_model_path}{model_extension}')
return default_model_path
raise Exception('No valid models found.')
def resolve_ckpt_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=APP_CONFIG_DEFAULT_MODELS)
def resolve_vae_to_use(model_name:str=None):
try:
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[])
except:
return None
def resolve_hypernetwork_to_use(model_name:str=None):
try:
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
except:
return None
class SetAppConfigRequest(BaseModel): class SetAppConfigRequest(BaseModel):
update_branch: str = None update_branch: str = None
render_devices: Union[List[str], List[int], str, int] = None render_devices: Union[List[str], List[int], str, int] = None
@ -209,9 +40,25 @@ class SetAppConfigRequest(BaseModel):
listen_port: int = None listen_port: int = None
test_sd2: bool = None test_sd2: bool = None
@app.post('/app_config') class LogSuppressFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
path = record.getMessage()
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
if path.find(prefix) != -1:
return False
return True
# don't log certain requests
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media")
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
app.mount(f'/plugins/{dir_prefix}', NoCacheStaticFiles(directory=plugins_dir), name=f"plugins-{dir_prefix}")
@server_api.post('/app_config')
async def setAppConfig(req : SetAppConfigRequest): async def setAppConfig(req : SetAppConfigRequest):
config = getConfig() config = app.getConfig()
if req.update_branch is not None: if req.update_branch is not None:
config['update_branch'] = req.update_branch config['update_branch'] = req.update_branch
if req.render_devices is not None: if req.render_devices is not None:
@ -231,121 +78,48 @@ async def setAppConfig(req : SetAppConfigRequest):
if req.test_sd2 is not None: if req.test_sd2 is not None:
config['test_sd2'] = req.test_sd2 config['test_sd2'] = req.test_sd2
try: try:
setConfig(config) app.setConfig(config)
if req.render_devices: if req.render_devices:
update_render_threads() app.update_render_threads()
return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS) return JSONResponse({'status': 'OK'}, headers=NOCACHE_HEADERS)
except Exception as e: except Exception as e:
print(traceback.format_exc()) print(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
def is_malicious_model(file_path): def update_render_devices_in_config(config, render_devices):
try: if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'):
scan_result = picklescan.scanner.scan_file_path(file_path) raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}')
if scan_result.issues_count > 0 or scan_result.infected_files > 0:
rich.print(":warning: [bold red]Scan %s: %d scanned, %d issue, %d infected.[/bold red]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return True
else:
rich.print("Scan %s: [green]%d scanned, %d issue, %d infected.[/green]" % (file_path, scan_result.scanned_files, scan_result.issues_count, scan_result.infected_files))
return False
except Exception as e:
print('error while scanning', file_path, 'error:', e)
return False
known_models = {} if render_devices.startswith('cuda:'):
def getModels(): render_devices = render_devices.split(',')
models = {
'active': {
'stable-diffusion': 'sd-v1-4',
'vae': '',
'hypernetwork': '',
},
'options': {
'stable-diffusion': ['sd-v1-4'],
'vae': [],
'hypernetwork': [],
},
}
def listModels(models_dirname, model_type, model_extensions): config['render_devices'] = render_devices
models_dir = os.path.join(MODELS_DIR, models_dirname)
if not os.path.exists(models_dir):
os.makedirs(models_dir)
for file in os.listdir(models_dir): @server_api.get('/get/{key:path}')
for model_extension in model_extensions:
if not file.endswith(model_extension):
continue
model_path = os.path.join(models_dir, file)
mtime = os.path.getmtime(model_path)
mod_time = known_models[model_path] if model_path in known_models else -1
if mod_time != mtime:
if is_malicious_model(model_path):
models['scan-error'] = file
return
known_models[model_path] = mtime
model_name = file[:-len(model_extension)]
models['options'][model_type].append(model_name)
models['options'][model_type] = [*set(models['options'][model_type])] # remove duplicates
models['options'][model_type].sort()
# custom models
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS)
# legacy
custom_weight_path = os.path.join(SD_DIR, 'custom-model.ckpt')
if os.path.exists(custom_weight_path):
models['options']['stable-diffusion'].append('custom-model')
return models
def getUIPlugins():
plugins = []
for plugins_dir, dir_prefix in UI_PLUGINS_SOURCES:
for file in os.listdir(plugins_dir):
if file.endswith('.plugin.js'):
plugins.append(f'/plugins/{dir_prefix}/{file}')
return plugins
def getIPConfig():
ips = socket.gethostbyname_ex(socket.gethostname())
ips[2].append(ips[0])
return ips[2]
@app.get('/get/{key:path}')
def read_web_data(key:str=None): def read_web_data(key:str=None):
if not key: # /get without parameters, stable-diffusion easter egg. if not key: # /get without parameters, stable-diffusion easter egg.
raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot raise HTTPException(status_code=418, detail="StableDiffusion is drawing a teapot!") # HTTP418 I'm a teapot
elif key == 'app_config': elif key == 'app_config':
config = getConfig(default_val=None) return JSONResponse(app.getConfig(), headers=NOCACHE_HEADERS)
if config is None:
config = APP_CONFIG_DEFAULTS
return JSONResponse(config, headers=NOCACHE_HEADERS)
elif key == 'system_info': elif key == 'system_info':
config = getConfig() config = app.getConfig()
system_info = { system_info = {
'devices': task_manager.get_devices(), 'devices': task_manager.get_devices(),
'hosts': getIPConfig(), 'hosts': app.getIPConfig(),
'default_output_dir': os.path.join(os.path.expanduser("~"), app.OUTPUT_DIRNAME),
} }
system_info['devices']['config'] = config.get('render_devices', "auto") system_info['devices']['config'] = config.get('render_devices', "auto")
return JSONResponse(system_info, headers=NOCACHE_HEADERS) return JSONResponse(system_info, headers=NOCACHE_HEADERS)
elif key == 'models': elif key == 'models':
return JSONResponse(getModels(), headers=NOCACHE_HEADERS) return JSONResponse(model_manager.getModels(), headers=NOCACHE_HEADERS)
elif key == 'modifiers': return FileResponse(os.path.join(SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS) elif key == 'modifiers': return FileResponse(os.path.join(app.SD_UI_DIR, 'modifiers.json'), headers=NOCACHE_HEADERS)
elif key == 'output_dir': return JSONResponse({ 'output_dir': outpath }, headers=NOCACHE_HEADERS) elif key == 'ui_plugins': return JSONResponse(app.getUIPlugins(), headers=NOCACHE_HEADERS)
elif key == 'ui_plugins': return JSONResponse(getUIPlugins(), headers=NOCACHE_HEADERS)
else: else:
raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found raise HTTPException(status_code=404, detail=f'Request for unknown {key}') # HTTP404 Not Found
@app.get('/ping') # Get server and optionally session status. @server_api.get('/ping') # Get server and optionally session status.
def ping(session_id:str=None): def ping(session_id:str=None):
if task_manager.is_alive() <= 0: # Check that render threads are alive. if task_manager.is_alive() <= 0: # Check that render threads are alive.
if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error)) if task_manager.current_state_error: raise HTTPException(status_code=500, detail=str(task_manager.current_state_error))
@ -372,38 +146,14 @@ def ping(session_id:str=None):
response['devices'] = task_manager.get_devices() response['devices'] = task_manager.get_devices()
return JSONResponse(response, headers=NOCACHE_HEADERS) return JSONResponse(response, headers=NOCACHE_HEADERS)
def save_model_to_config(ckpt_model_name, vae_model_name, hypernetwork_model_name): @server_api.post('/render')
config = getConfig()
if 'model' not in config:
config['model'] = {}
config['model']['stable-diffusion'] = ckpt_model_name
config['model']['vae'] = vae_model_name
config['model']['hypernetwork'] = hypernetwork_model_name
if vae_model_name is None or vae_model_name == "":
del config['model']['vae']
if hypernetwork_model_name is None or hypernetwork_model_name == "":
del config['model']['hypernetwork']
setConfig(config)
def update_render_devices_in_config(config, render_devices):
if render_devices not in ('cpu', 'auto') and not render_devices.startswith('cuda:'):
raise HTTPException(status_code=400, detail=f'Invalid render device requested: {render_devices}')
if render_devices.startswith('cuda:'):
render_devices = render_devices.split(',')
config['render_devices'] = render_devices
@app.post('/render')
def render(req : task_manager.ImageRequest): def render(req : task_manager.ImageRequest):
try: try:
save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
req.use_stable_diffusion_model = resolve_ckpt_to_use(req.use_stable_diffusion_model) req.use_stable_diffusion_model = model_manager.resolve_ckpt_to_use(req.use_stable_diffusion_model)
req.use_vae_model = resolve_vae_to_use(req.use_vae_model) req.use_vae_model = model_manager.resolve_vae_to_use(req.use_vae_model)
req.use_hypernetwork_model = resolve_hypernetwork_to_use(req.use_hypernetwork_model) req.use_hypernetwork_model = model_manager.resolve_hypernetwork_to_use(req.use_hypernetwork_model)
new_task = task_manager.render(req) new_task = task_manager.render(req)
response = { response = {
'status': str(task_manager.current_state), 'status': str(task_manager.current_state),
@ -419,7 +169,7 @@ def render(req : task_manager.ImageRequest):
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get('/image/stream/{session_id:str}/{task_id:int}') @server_api.get('/image/stream/{session_id:str}/{task_id:int}')
def stream(session_id:str, task_id:int): def stream(session_id:str, task_id:int):
#TODO Move to WebSockets ?? #TODO Move to WebSockets ??
task = task_manager.get_cached_task(session_id, update_ttl=True) task = task_manager.get_cached_task(session_id, update_ttl=True)
@ -433,7 +183,7 @@ def stream(session_id:str, task_id:int):
#print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}') #print(f'Session {session_id} opened live render stream {id(task.buffer_queue)}')
return StreamingResponse(task.read_buffer_generator(), media_type='application/json') return StreamingResponse(task.read_buffer_generator(), media_type='application/json')
@app.get('/image/stop') @server_api.get('/image/stop')
def stop(session_id:str=None): def stop(session_id:str=None):
if not session_id: if not session_id:
if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable: if task_manager.current_state == task_manager.ServerStates.Online or task_manager.current_state == task_manager.ServerStates.Unavailable:
@ -446,7 +196,7 @@ def stop(session_id:str=None):
task.error = StopAsyncIteration('') task.error = StopAsyncIteration('')
return {'OK'} return {'OK'}
@app.get('/image/tmp/{session_id}/{img_id:int}') @server_api.get('/image/tmp/{session_id}/{img_id:int}')
def get_image(session_id, img_id): def get_image(session_id, img_id):
task = task_manager.get_cached_task(session_id, update_ttl=True) task = task_manager.get_cached_task(session_id, update_ttl=True)
if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone if not task: raise HTTPException(status_code=410, detail=f'Session {session_id} has not submitted a task.') # HTTP410 Gone
@ -458,49 +208,17 @@ def get_image(session_id, img_id):
except KeyError as e: except KeyError as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@app.get('/') @server_api.get('/')
def read_root(): def read_root():
return FileResponse(os.path.join(SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS) return FileResponse(os.path.join(app.SD_UI_DIR, 'index.html'), headers=NOCACHE_HEADERS)
@app.on_event("shutdown") @server_api.on_event("shutdown")
def shutdown_event(): # Signal render thread to close on shutdown def shutdown_event(): # Signal render thread to close on shutdown
task_manager.current_state_error = SystemExit('Application shutting down.') task_manager.current_state_error = SystemExit('Application shutting down.')
# don't log certain requests # Init the app
class LogSuppressFilter(logging.Filter): model_manager.init()
def filter(self, record: logging.LogRecord) -> bool: app.init()
path = record.getMessage()
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
if path.find(prefix) != -1:
return False
return True
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
# Check models and prepare cache for UI open
getModels()
# Start the task_manager
task_manager.default_model_to_load = resolve_ckpt_to_use()
task_manager.default_vae_to_load = resolve_vae_to_use()
task_manager.default_hypernetwork_to_load = resolve_hypernetwork_to_use()
def update_render_threads():
config = getConfig()
render_devices = config.get('render_devices', 'auto')
active_devices = task_manager.get_devices()['active'].keys()
print('requesting for render_devices', render_devices)
task_manager.update_render_threads(render_devices, active_devices)
update_render_threads()
# start the browser ui # start the browser ui
def open_browser(): app.open_browser()
config = getConfig()
ui = config.get('ui', {})
net = config.get('net', {'listen_port':9000})
port = net.get('listen_port', 9000)
if ui.get('open_browser_on_start', True):
import webbrowser; webbrowser.open(f"http://localhost:{port}")
open_browser()