Work-in-progress: refactored the end-to-end codebase. Missing: hypernetworks, turbo config, and SD 2. Not tested yet

This commit is contained in:
cmdr2 2022-12-08 21:39:09 +05:30
parent bad89160cc
commit f4a6910ab4
11 changed files with 239 additions and 128 deletions

View File

@ -199,12 +199,7 @@ call WHERE uvicorn > .tmp
if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion"
if not exist "..\models\vae" mkdir "..\models\vae" if not exist "..\models\vae" mkdir "..\models\vae"
if not exist "..\models\hypernetwork" mkdir "..\models\hypernetwork"
echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
echo. > "..\models\vae\Put your VAE files here.txt"
echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt"
@if exist "sd-v1-4.ckpt" ( @if exist "sd-v1-4.ckpt" (
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" ( for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (

View File

@ -159,12 +159,7 @@ fi
mkdir -p "../models/stable-diffusion"
mkdir -p "../models/vae" mkdir -p "../models/vae"
mkdir -p "../models/hypernetwork"
echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt"
echo "" > "../models/vae/Put your VAE files here.txt"
echo "" > "../models/hypernetwork/Put your hypernetwork files here.txt"
if [ -f "sd-v1-4.ckpt" ]; then if [ -f "sd-v1-4.ckpt" ]; then
model_size=`find "sd-v1-4.ckpt" -printf "%s"` model_size=`find "sd-v1-4.ckpt" -printf "%s"`

View File

@ -409,7 +409,6 @@
async function init() { async function init() {
await initSettings() await initSettings()
await getModels() await getModels()
await getDiskPath()
await getAppConfig() await getAppConfig()
await loadUIPlugins() await loadUIPlugins()
await loadModifiers() await loadModifiers()

View File

@ -327,20 +327,10 @@ autoPickGPUsField.addEventListener('click', function() {
gpuSettingEntry.style.display = (this.checked ? 'none' : '') gpuSettingEntry.style.display = (this.checked ? 'none' : '')
}) })
async function getDiskPath() { async function setDiskPath(defaultDiskPath) {
try { var diskPath = getSetting("diskPath")
var diskPath = getSetting("diskPath") if (diskPath == '' || diskPath == undefined || diskPath == "undefined") {
if (diskPath == '' || diskPath == undefined || diskPath == "undefined") { setSetting("diskPath", defaultDiskPath)
let res = await fetch('/get/output_dir')
if (res.status === 200) {
res = await res.json()
res = res.output_dir
setSetting("diskPath", res)
}
}
} catch (e) {
console.log('error fetching output dir path', e)
} }
} }
@ -415,6 +405,7 @@ async function getSystemInfo() {
setDeviceInfo(devices) setDeviceInfo(devices)
setHostInfo(res['hosts']) setHostInfo(res['hosts'])
setDiskPath(res['default_output_dir'])
} catch (e) { } catch (e) {
console.log('error fetching devices', e) console.log('error fetching devices', e)
} }

View File

@ -105,6 +105,10 @@ class Response:
request: Request request: Request
images: list images: list
def __init__(self, request: Request, images: list):
self.request = request
self.images = images
def json(self): def json(self):
res = { res = {
"status": 'succeeded', "status": 'succeeded',
@ -116,3 +120,6 @@ class Response:
res["output"].append(image.json()) res["output"].append(image.json())
return res return res
class UserInitiatedStop(Exception):
pass

View File

@ -28,11 +28,6 @@ APP_CONFIG_DEFAULTS = {
'open_browser_on_start': True, 'open_browser_on_start': True,
}, },
} }
DEFAULT_MODELS = [
# needed to support the legacy installations
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
def init(): def init():
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True) os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)

View File

@ -101,10 +101,8 @@ def device_init(context, device):
context.device = device context.device = device
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images # Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
device_name = context.device_name.lower() if needs_to_force_full_precision(context.device_name):
force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) print(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
if force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', context.device_name)
# Apply force_full_precision now before models are loaded. # Apply force_full_precision now before models are loaded.
context.precision = 'full' context.precision = 'full'
@ -113,6 +111,10 @@ def device_init(context, device):
return return
def needs_to_force_full_precision(context):
device_name = context.device_name.lower()
return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
def validate_device_id(device, log_prefix=''): def validate_device_id(device, log_prefix=''):
def is_valid(): def is_valid():
if not isinstance(device, str): if not isinstance(device, str):

View File

@ -1,32 +1,39 @@
import os import os
from sd_internal import app from sd_internal import app, device_manager
from sd_internal import Request
import picklescan.scanner import picklescan.scanner
import rich import rich
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors'] KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt'] MODEL_EXTENSIONS = {
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt'] 'stable-diffusion': ['.ckpt', '.safetensors'],
'vae': ['.vae.pt', '.ckpt'],
default_model_to_load = None 'hypernetwork': ['.pt'],
default_vae_to_load = None 'gfpgan': ['.pth'],
default_hypernetwork_to_load = None 'realesrgan': ['.pth'],
}
DEFAULT_MODELS = {
'stable-diffusion': [ # needed to support the legacy installations
'custom-model', # only one custom model file was supported initially, creatively named 'custom-model'
'sd-v1-4', # Default fallback.
],
'gfpgan': ['GFPGANv1.3'],
'realesrgan': ['RealESRGAN_x4plus'],
}
known_models = {} known_models = {}
def init(): def init():
global default_model_to_load, default_vae_to_load, default_hypernetwork_to_load make_model_folders()
default_model_to_load = resolve_ckpt_to_use()
default_vae_to_load = resolve_vae_to_use()
default_hypernetwork_to_load = resolve_hypernetwork_to_use()
getModels() # run this once, to cache the picklescan results getModels() # run this once, to cache the picklescan results
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]): def resolve_model_to_use(model_name:str, model_type:str):
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
default_models = DEFAULT_MODELS.get(model_type, [])
config = app.getConfig() config = app.getConfig()
model_dirs = [os.path.join(app.MODELS_DIR, model_dir), app.SD_DIR] model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR]
if not model_name: # When None try user configured model. if not model_name: # When None try user configured model.
# config = getConfig() # config = getConfig()
if 'model' in config and model_type in config['model']: if 'model' in config and model_type in config['model']:
@ -39,7 +46,7 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex
model_name = 'sd-v1-4' model_name = 'sd-v1-4'
# Check models directory # Check models directory
models_dir_path = os.path.join(app.MODELS_DIR, model_dir, model_name) models_dir_path = os.path.join(app.MODELS_DIR, model_type, model_name)
for model_extension in model_extensions: for model_extension in model_extensions:
if os.path.exists(models_dir_path + model_extension): if os.path.exists(models_dir_path + model_extension):
return models_dir_path + model_extension return models_dir_path + model_extension
@ -66,14 +73,32 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex
print(f'No valid models found for model_name: {model_name}') print(f'No valid models found for model_name: {model_name}')
return None return None
def resolve_ckpt_to_use(model_name:str=None): def resolve_sd_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS) return resolve_model_to_use(model_name, model_type='stable-diffusion')
def resolve_vae_to_use(model_name:str=None): def resolve_vae_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[]) return resolve_model_to_use(model_name, model_type='vae')
def resolve_hypernetwork_to_use(model_name:str=None): def resolve_hypernetwork_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[]) return resolve_model_to_use(model_name, model_type='hypernetwork')
def resolve_gfpgan_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='gfpgan')
def resolve_realesrgan_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='realesrgan')
def make_model_folders():
for model_type in KNOWN_MODEL_TYPES:
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
os.makedirs(model_dir_path, exist_ok=True)
help_file_name = f'Place your {model_type} model files here.txt'
help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}'
with open(os.path.join(model_dir_path, help_file_name)) as f:
f.write(help_file_contents)
def is_malicious_model(file_path): def is_malicious_model(file_path):
try: try:
@ -102,8 +127,9 @@ def getModels():
}, },
} }
def listModels(models_dirname, model_type, model_extensions): def listModels(model_type):
models_dir = os.path.join(app.MODELS_DIR, models_dirname) model_extensions = MODEL_EXTENSIONS.get(model_type, [])
models_dir = os.path.join(app.MODELS_DIR, model_type)
if not os.path.exists(models_dir): if not os.path.exists(models_dir):
os.makedirs(models_dir) os.makedirs(models_dir)
@ -128,9 +154,9 @@ def getModels():
models['options'][model_type].sort() models['options'][model_type].sort()
# custom models # custom models
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS) listModels(model_type='stable-diffusion')
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS) listModels(model_type='vae')
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS) listModels(model_type='hypernetwork')
# legacy # legacy
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt') custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
@ -138,3 +164,19 @@ def getModels():
models['options']['stable-diffusion'].append('custom-model') models['options']['stable-diffusion'].append('custom-model')
return models return models
def is_sd_model_reload_necessary(thread_data, req: Request):
needs_model_reload = False
if 'stable-diffusion' not in thread_data.models or \
thread_data.model_paths['stable-diffusion'] != req.use_stable_diffusion_model or \
thread_data.model_paths['vae'] != req.use_vae_model:
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and req.use_full_precision) or \
(thread_data.precision == 'full' and not req.use_full_precision and not device_manager.needs_to_force_full_precision(thread_data)):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
return needs_model_reload

View File

@ -1,16 +1,23 @@
import threading import threading
import queue import queue
import time
import json
import os
import base64
import re
from sd_internal import device_manager, model_manager from sd_internal import device_manager, model_manager
from sd_internal import Request, Response, Image as ResponseImage from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
from modules import model_loader, image_generator, image_utils from modules import model_loader, image_generator, image_utils, image_filters
thread_data = threading.local() thread_data = threading.local()
''' '''
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
''' '''
filename_regex = re.compile('[^a-zA-Z0-9]')
def init(device): def init(device):
''' '''
Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device
@ -28,89 +35,167 @@ def init(device):
device_manager.device_init(thread_data, device) device_manager.device_init(thread_data, device)
load_default_models() init_and_load_default_models()
def destroy(): def destroy():
model_loader.unload_sd_model(thread_data) model_loader.unload_sd_model(thread_data)
model_loader.unload_gfpgan_model(thread_data) model_loader.unload_gfpgan_model(thread_data)
model_loader.unload_realesrgan_model(thread_data) model_loader.unload_realesrgan_model(thread_data)
def load_default_models(): def init_and_load_default_models():
thread_data.model_paths['stable-diffusion'] = model_manager.default_model_to_load # init default model paths
thread_data.model_paths['vae'] = model_manager.default_vae_to_load thread_data.model_paths['stable-diffusion'] = model_manager.resolve_sd_model_to_use()
thread_data.model_paths['vae'] = model_manager.resolve_vae_model_to_use()
thread_data.model_paths['hypernetwork'] = model_manager.resolve_hypernetwork_model_to_use()
thread_data.model_paths['gfpgan'] = model_manager.resolve_gfpgan_model_to_use()
thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use()
# load mandatory models
model_loader.load_sd_model(thread_data) model_loader.load_sd_model(thread_data)
def reload_models_if_necessary(req: Request=None): def reload_models_if_necessary(req: Request):
needs_model_reload = False if model_manager.is_sd_model_reload_necessary(thread_data, req):
if 'stable-diffusion' not in thread_data.models or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model: thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
thread_data.ckpt_file = req.use_stable_diffusion_model thread_data.model_paths['vae'] = req.use_vae_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
if thread_data.device != 'cpu': model_loader.load_sd_model(thread_data)
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
return needs_model_reload # if is_hypernetwork_reload_necessary(task.request):
# current_state = ServerStates.LoadingModel
# runtime.reload_hypernetwork()
if is_hypernetwork_reload_necessary(task.request): def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
current_state = ServerStates.LoadingModel images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback)
runtime.reload_hypernetwork() images = apply_filters(req, images, user_stopped)
if is_model_reload_necessary(task.request): save_images(req, images)
current_state = ServerStates.LoadingModel
runtime.reload_model()
def load_models(): return Response(req, images=construct_response(req, images))
if ckpt_file_path == None:
ckpt_file_path = default_model_to_load def generate_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
if vae_file_path == None: thread_data.temp_images.clear()
vae_file_path = default_vae_to_load
if hypernetwork_file_path == None: image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback)
hypernetwork_file_path = default_hypernetwork_to_load
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
return
current_state = ServerStates.LoadingModel
try:
from sd_internal import runtime2
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
runtime.thread_data.ckpt_file = ckpt_file_path
runtime.thread_data.vae_file = vae_file_path
runtime.load_model_ckpt()
runtime.load_hypernetwork()
current_model_path = ckpt_file_path
current_vae_path = vae_file_path
current_hypernetwork_path = hypernetwork_file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
current_model_path = None
current_vae_path = None
current_state_error = e
current_state = ServerStates.Unavailable
print(traceback.format_exc())
def make_image(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
try: try:
images = image_generator.make_image(context=thread_data, args=get_mk_img_args(req)) images = image_generator.make_image(context=thread_data, args=get_mk_img_args(req))
user_stopped = False
except UserInitiatedStop: except UserInitiatedStop:
pass images = []
user_stopped = True
if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None:
return images
for i in range(req.num_outputs):
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
del thread_data.partial_x_samples
finally:
model_loader.gc(thread_data)
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
return images, user_stopped
def apply_filters(req: Request, images: list, user_stopped):
if user_stopped or (req.use_face_correction is None and req.use_upscale is None):
return images
filters = []
if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_gfpgan_model_to_use(req.use_face_correction)))
if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_realesrgan_model_to_use(req.use_upscale)))
filtered_images = []
for img, seed, _ in images:
for filter_fn, filter_model_path in filters:
img = filter_fn(thread_data, img, filter_model_path)
filtered_images.append((img, seed, True))
if not req.show_only_filtered_image:
filtered_images = images + filtered_images
return filtered_images
def save_images(req: Request, images: list):
if req.save_to_disk_path is None:
return
def get_image_id(i):
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
return img_id
def get_image_basepath(i):
session_out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id))
os.makedirs(session_out_path, exist_ok=True)
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
return os.path.join(session_out_path, f"{prompt_flattened}_{get_image_id(i)}")
for i, img_data in enumerate(images):
img, seed, filtered = img_data
img_path = get_image_basepath(i)
if not filtered or req.show_only_filtered_image:
img_metadata_path = img_path + '.txt'
metadata = req.json()
metadata['seed'] = seed
with open(img_metadata_path, 'w', encoding='utf-8') as f:
f.write(metadata)
img_path += '_filtered' if filtered else ''
img_path += '.' + req.output_format
img.save(img_path, quality=req.output_quality)
def construct_response(req: Request, images: list):
return [
ResponseImage(
data=image_utils.img_to_base64_str(img, req.output_format, req.output_quality),
seed=seed
) for img, seed, _ in images
]
def get_mk_img_args(req: Request): def get_mk_img_args(req: Request):
args = req.json() args = req.json()
if req.init_image is not None: args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
args['init_image'] = image_utils.base64_str_to_img(req.init_image) args['mask'] = image_utils.base64_str_to_img(req.mask) if req.mask is not None else None
if req.mask is not None:
args['mask'] = image_utils.base64_str_to_img(req.mask)
return args return args
def on_image_step(x_samples, i): def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
pass n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
last_callback_time = -1
image_generator.on_image_step = on_image_step def update_temp_img(req, x_samples, task_temp_images: list):
partial_images = []
for i in range(req.num_outputs):
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
buf = image_utils.img_to_buffer(img, output_format='JPEG')
del img
thread_data.temp_images[f'{req.request_id}/{i}'] = buf
task_temp_images[i] = buf
partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'})
return partial_images
def on_image_step(x_samples, i):
nonlocal last_callback_time
thread_data.partial_x_samples = x_samples
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
if req.stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(req, x_samples, task_temp_images)
data_queue.put(json.dumps(progress))
step_callback()
if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return on_image_step

View File

@ -324,7 +324,7 @@ def thread_render(device):
runtime2.reload_models_if_necessary(task.request) runtime2.reload_models_if_necessary(task.request)
current_state = ServerStates.Rendering current_state = ServerStates.Rendering
task.response = runtime2.make_image(task.request, task.buffer_queue, task.temp_images, step_callback) task.response = runtime2.make_images(task.request, task.buffer_queue, task.temp_images, step_callback)
# Before looping back to the generator, mark cache as still alive. # Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL) task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL) session_cache.keep(task.request.session_id, TASK_TTL)

View File

@ -137,9 +137,9 @@ def ping(session_id:str=None):
def render(req : task_manager.ImageRequest): def render(req : task_manager.ImageRequest):
try: try:
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
req.use_stable_diffusion_model = model_manager.resolve_ckpt_to_use(req.use_stable_diffusion_model) req.use_stable_diffusion_model = model_manager.resolve_sd_model_to_use(req.use_stable_diffusion_model)
req.use_vae_model = model_manager.resolve_vae_to_use(req.use_vae_model) req.use_vae_model = model_manager.resolve_vae_model_to_use(req.use_vae_model)
req.use_hypernetwork_model = model_manager.resolve_hypernetwork_to_use(req.use_hypernetwork_model) req.use_hypernetwork_model = model_manager.resolve_hypernetwork_model_to_use(req.use_hypernetwork_model)
new_task = task_manager.render(req) new_task = task_manager.render(req)
response = { response = {