Work-in-progress: refactored the end-to-end codebase. Missing: hypernetworks, turbo config, and SD 2. Not tested yet

This commit is contained in:
cmdr2 2022-12-08 21:39:09 +05:30
parent bad89160cc
commit f4a6910ab4
11 changed files with 239 additions and 128 deletions

View File

@ -199,12 +199,7 @@ call WHERE uvicorn > .tmp
if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion"
if not exist "..\models\vae" mkdir "..\models\vae"
if not exist "..\models\hypernetwork" mkdir "..\models\hypernetwork"
echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
echo. > "..\models\vae\Put your VAE files here.txt"
echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt"
@if exist "sd-v1-4.ckpt" (
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (

View File

@ -159,12 +159,7 @@ fi
mkdir -p "../models/stable-diffusion"
mkdir -p "../models/vae"
mkdir -p "../models/hypernetwork"
echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt"
echo "" > "../models/vae/Put your VAE files here.txt"
echo "" > "../models/hypernetwork/Put your hypernetwork files here.txt"
if [ -f "sd-v1-4.ckpt" ]; then
model_size=`find "sd-v1-4.ckpt" -printf "%s"`

View File

@ -409,7 +409,6 @@
async function init() {
await initSettings()
await getModels()
await getDiskPath()
await getAppConfig()
await loadUIPlugins()
await loadModifiers()

View File

@ -327,20 +327,10 @@ autoPickGPUsField.addEventListener('click', function() {
gpuSettingEntry.style.display = (this.checked ? 'none' : '')
})
async function getDiskPath() {
try {
var diskPath = getSetting("diskPath")
if (diskPath == '' || diskPath == undefined || diskPath == "undefined") {
let res = await fetch('/get/output_dir')
if (res.status === 200) {
res = await res.json()
res = res.output_dir
setSetting("diskPath", res)
}
}
} catch (e) {
console.log('error fetching output dir path', e)
async function setDiskPath(defaultDiskPath) {
var diskPath = getSetting("diskPath")
if (diskPath == '' || diskPath == undefined || diskPath == "undefined") {
setSetting("diskPath", defaultDiskPath)
}
}
@ -415,6 +405,7 @@ async function getSystemInfo() {
setDeviceInfo(devices)
setHostInfo(res['hosts'])
setDiskPath(res['default_output_dir'])
} catch (e) {
console.log('error fetching devices', e)
}

View File

@ -105,6 +105,10 @@ class Response:
request: Request
images: list
def __init__(self, request: Request, images: list):
self.request = request
self.images = images
def json(self):
res = {
"status": 'succeeded',
@ -116,3 +120,6 @@ class Response:
res["output"].append(image.json())
return res
class UserInitiatedStop(Exception):
pass

View File

@ -28,11 +28,6 @@ APP_CONFIG_DEFAULTS = {
'open_browser_on_start': True,
},
}
DEFAULT_MODELS = [
# needed to support the legacy installations
'custom-model', # Check if user has a custom model, use it first.
'sd-v1-4', # Default fallback.
]
def init():
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)

View File

@ -101,10 +101,8 @@ def device_init(context, device):
context.device = device
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
device_name = context.device_name.lower()
force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
if force_full_precision:
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', context.device_name)
if needs_to_force_full_precision(context.device_name):
print(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
# Apply force_full_precision now before models are loaded.
context.precision = 'full'
@ -113,6 +111,10 @@ def device_init(context, device):
return
def needs_to_force_full_precision(context):
device_name = context.device_name.lower()
return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
def validate_device_id(device, log_prefix=''):
def is_valid():
if not isinstance(device, str):

View File

@ -1,32 +1,39 @@
import os
from sd_internal import app
from sd_internal import app, device_manager
from sd_internal import Request
import picklescan.scanner
import rich
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
default_model_to_load = None
default_vae_to_load = None
default_hypernetwork_to_load = None
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
MODEL_EXTENSIONS = {
'stable-diffusion': ['.ckpt', '.safetensors'],
'vae': ['.vae.pt', '.ckpt'],
'hypernetwork': ['.pt'],
'gfpgan': ['.pth'],
'realesrgan': ['.pth'],
}
DEFAULT_MODELS = {
'stable-diffusion': [ # needed to support the legacy installations
'custom-model', # only one custom model file was supported initially, creatively named 'custom-model'
'sd-v1-4', # Default fallback.
],
'gfpgan': ['GFPGANv1.3'],
'realesrgan': ['RealESRGAN_x4plus'],
}
known_models = {}
def init():
global default_model_to_load, default_vae_to_load, default_hypernetwork_to_load
default_model_to_load = resolve_ckpt_to_use()
default_vae_to_load = resolve_vae_to_use()
default_hypernetwork_to_load = resolve_hypernetwork_to_use()
make_model_folders()
getModels() # run this once, to cache the picklescan results
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
def resolve_model_to_use(model_name:str, model_type:str):
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
default_models = DEFAULT_MODELS.get(model_type, [])
config = app.getConfig()
model_dirs = [os.path.join(app.MODELS_DIR, model_dir), app.SD_DIR]
model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR]
if not model_name: # When None try user configured model.
# config = getConfig()
if 'model' in config and model_type in config['model']:
@ -39,7 +46,7 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex
model_name = 'sd-v1-4'
# Check models directory
models_dir_path = os.path.join(app.MODELS_DIR, model_dir, model_name)
models_dir_path = os.path.join(app.MODELS_DIR, model_type, model_name)
for model_extension in model_extensions:
if os.path.exists(models_dir_path + model_extension):
return models_dir_path + model_extension
@ -66,14 +73,32 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex
print(f'No valid models found for model_name: {model_name}')
return None
def resolve_ckpt_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS)
def resolve_sd_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='stable-diffusion')
def resolve_vae_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[])
def resolve_vae_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='vae')
def resolve_hypernetwork_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
def resolve_hypernetwork_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='hypernetwork')
def resolve_gfpgan_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='gfpgan')
def resolve_realesrgan_model_to_use(model_name:str=None):
return resolve_model_to_use(model_name, model_type='realesrgan')
def make_model_folders():
for model_type in KNOWN_MODEL_TYPES:
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
os.makedirs(model_dir_path, exist_ok=True)
help_file_name = f'Place your {model_type} model files here.txt'
help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}'
with open(os.path.join(model_dir_path, help_file_name)) as f:
f.write(help_file_contents)
def is_malicious_model(file_path):
try:
@ -102,8 +127,9 @@ def getModels():
},
}
def listModels(models_dirname, model_type, model_extensions):
models_dir = os.path.join(app.MODELS_DIR, models_dirname)
def listModels(model_type):
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
models_dir = os.path.join(app.MODELS_DIR, model_type)
if not os.path.exists(models_dir):
os.makedirs(models_dir)
@ -128,9 +154,9 @@ def getModels():
models['options'][model_type].sort()
# custom models
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS)
listModels(model_type='stable-diffusion')
listModels(model_type='vae')
listModels(model_type='hypernetwork')
# legacy
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
@ -138,3 +164,19 @@ def getModels():
models['options']['stable-diffusion'].append('custom-model')
return models
def is_sd_model_reload_necessary(thread_data, req: Request):
needs_model_reload = False
if 'stable-diffusion' not in thread_data.models or \
thread_data.model_paths['stable-diffusion'] != req.use_stable_diffusion_model or \
thread_data.model_paths['vae'] != req.use_vae_model:
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and req.use_full_precision) or \
(thread_data.precision == 'full' and not req.use_full_precision and not device_manager.needs_to_force_full_precision(thread_data)):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
return needs_model_reload

View File

@ -1,16 +1,23 @@
import threading
import queue
import time
import json
import os
import base64
import re
from sd_internal import device_manager, model_manager
from sd_internal import Request, Response, Image as ResponseImage
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
from modules import model_loader, image_generator, image_utils
from modules import model_loader, image_generator, image_utils, image_filters
thread_data = threading.local()
'''
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
'''
filename_regex = re.compile('[^a-zA-Z0-9]')
def init(device):
'''
Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device
@ -28,89 +35,167 @@ def init(device):
device_manager.device_init(thread_data, device)
load_default_models()
init_and_load_default_models()
def destroy():
model_loader.unload_sd_model(thread_data)
model_loader.unload_gfpgan_model(thread_data)
model_loader.unload_realesrgan_model(thread_data)
def load_default_models():
thread_data.model_paths['stable-diffusion'] = model_manager.default_model_to_load
thread_data.model_paths['vae'] = model_manager.default_vae_to_load
def init_and_load_default_models():
# init default model paths
thread_data.model_paths['stable-diffusion'] = model_manager.resolve_sd_model_to_use()
thread_data.model_paths['vae'] = model_manager.resolve_vae_model_to_use()
thread_data.model_paths['hypernetwork'] = model_manager.resolve_hypernetwork_model_to_use()
thread_data.model_paths['gfpgan'] = model_manager.resolve_gfpgan_model_to_use()
thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use()
# load mandatory models
model_loader.load_sd_model(thread_data)
def reload_models_if_necessary(req: Request=None):
needs_model_reload = False
if 'stable-diffusion' not in thread_data.models or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
thread_data.ckpt_file = req.use_stable_diffusion_model
thread_data.vae_file = req.use_vae_model
needs_model_reload = True
def reload_models_if_necessary(req: Request):
if model_manager.is_sd_model_reload_necessary(thread_data, req):
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
thread_data.model_paths['vae'] = req.use_vae_model
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
model_loader.load_sd_model(thread_data)
return needs_model_reload
# if is_hypernetwork_reload_necessary(task.request):
# current_state = ServerStates.LoadingModel
# runtime.reload_hypernetwork()
if is_hypernetwork_reload_necessary(task.request):
current_state = ServerStates.LoadingModel
runtime.reload_hypernetwork()
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback)
images = apply_filters(req, images, user_stopped)
if is_model_reload_necessary(task.request):
current_state = ServerStates.LoadingModel
runtime.reload_model()
save_images(req, images)
def load_models():
if ckpt_file_path == None:
ckpt_file_path = default_model_to_load
if vae_file_path == None:
vae_file_path = default_vae_to_load
if hypernetwork_file_path == None:
hypernetwork_file_path = default_hypernetwork_to_load
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
return
current_state = ServerStates.LoadingModel
try:
from sd_internal import runtime2
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
runtime.thread_data.ckpt_file = ckpt_file_path
runtime.thread_data.vae_file = vae_file_path
runtime.load_model_ckpt()
runtime.load_hypernetwork()
current_model_path = ckpt_file_path
current_vae_path = vae_file_path
current_hypernetwork_path = hypernetwork_file_path
current_state_error = None
current_state = ServerStates.Online
except Exception as e:
current_model_path = None
current_vae_path = None
current_state_error = e
current_state = ServerStates.Unavailable
print(traceback.format_exc())
return Response(req, images=construct_response(req, images))
def generate_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
thread_data.temp_images.clear()
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback)
def make_image(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
try:
images = image_generator.make_image(context=thread_data, args=get_mk_img_args(req))
user_stopped = False
except UserInitiatedStop:
pass
images = []
user_stopped = True
if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None:
return images
for i in range(req.num_outputs):
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
del thread_data.partial_x_samples
finally:
model_loader.gc(thread_data)
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
return images, user_stopped
def apply_filters(req: Request, images: list, user_stopped):
if user_stopped or (req.use_face_correction is None and req.use_upscale is None):
return images
filters = []
if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_gfpgan_model_to_use(req.use_face_correction)))
if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_realesrgan_model_to_use(req.use_upscale)))
filtered_images = []
for img, seed, _ in images:
for filter_fn, filter_model_path in filters:
img = filter_fn(thread_data, img, filter_model_path)
filtered_images.append((img, seed, True))
if not req.show_only_filtered_image:
filtered_images = images + filtered_images
return filtered_images
def save_images(req: Request, images: list):
if req.save_to_disk_path is None:
return
def get_image_id(i):
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
return img_id
def get_image_basepath(i):
session_out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id))
os.makedirs(session_out_path, exist_ok=True)
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
return os.path.join(session_out_path, f"{prompt_flattened}_{get_image_id(i)}")
for i, img_data in enumerate(images):
img, seed, filtered = img_data
img_path = get_image_basepath(i)
if not filtered or req.show_only_filtered_image:
img_metadata_path = img_path + '.txt'
metadata = req.json()
metadata['seed'] = seed
with open(img_metadata_path, 'w', encoding='utf-8') as f:
f.write(metadata)
img_path += '_filtered' if filtered else ''
img_path += '.' + req.output_format
img.save(img_path, quality=req.output_quality)
def construct_response(req: Request, images: list):
return [
ResponseImage(
data=image_utils.img_to_base64_str(img, req.output_format, req.output_quality),
seed=seed
) for img, seed, _ in images
]
def get_mk_img_args(req: Request):
args = req.json()
if req.init_image is not None:
args['init_image'] = image_utils.base64_str_to_img(req.init_image)
if req.mask is not None:
args['mask'] = image_utils.base64_str_to_img(req.mask)
args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
args['mask'] = image_utils.base64_str_to_img(req.mask) if req.mask is not None else None
return args
def on_image_step(x_samples, i):
pass
def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
last_callback_time = -1
image_generator.on_image_step = on_image_step
def update_temp_img(req, x_samples, task_temp_images: list):
partial_images = []
for i in range(req.num_outputs):
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
buf = image_utils.img_to_buffer(img, output_format='JPEG')
del img
thread_data.temp_images[f'{req.request_id}/{i}'] = buf
task_temp_images[i] = buf
partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'})
return partial_images
def on_image_step(x_samples, i):
nonlocal last_callback_time
thread_data.partial_x_samples = x_samples
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
last_callback_time = time.time()
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
if req.stream_image_progress and i % 5 == 0:
progress['output'] = update_temp_img(req, x_samples, task_temp_images)
data_queue.put(json.dumps(progress))
step_callback()
if thread_data.stop_processing:
raise UserInitiatedStop("User requested that we stop processing")
return on_image_step

View File

@ -324,7 +324,7 @@ def thread_render(device):
runtime2.reload_models_if_necessary(task.request)
current_state = ServerStates.Rendering
task.response = runtime2.make_image(task.request, task.buffer_queue, task.temp_images, step_callback)
task.response = runtime2.make_images(task.request, task.buffer_queue, task.temp_images, step_callback)
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL)

View File

@ -137,9 +137,9 @@ def ping(session_id:str=None):
def render(req : task_manager.ImageRequest):
try:
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
req.use_stable_diffusion_model = model_manager.resolve_ckpt_to_use(req.use_stable_diffusion_model)
req.use_vae_model = model_manager.resolve_vae_to_use(req.use_vae_model)
req.use_hypernetwork_model = model_manager.resolve_hypernetwork_to_use(req.use_hypernetwork_model)
req.use_stable_diffusion_model = model_manager.resolve_sd_model_to_use(req.use_stable_diffusion_model)
req.use_vae_model = model_manager.resolve_vae_model_to_use(req.use_vae_model)
req.use_hypernetwork_model = model_manager.resolve_hypernetwork_model_to_use(req.use_hypernetwork_model)
new_task = task_manager.render(req)
response = {