forked from extern/easydiffusion
Work-in-progress: refactored the end-to-end codebase. Missing: hypernetworks, turbo config, and SD 2. Not tested yet
This commit is contained in:
parent
bad89160cc
commit
f4a6910ab4
@ -199,12 +199,7 @@ call WHERE uvicorn > .tmp
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion"
|
|
||||||
if not exist "..\models\vae" mkdir "..\models\vae"
|
if not exist "..\models\vae" mkdir "..\models\vae"
|
||||||
if not exist "..\models\hypernetwork" mkdir "..\models\hypernetwork"
|
|
||||||
echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
|
|
||||||
echo. > "..\models\vae\Put your VAE files here.txt"
|
|
||||||
echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt"
|
|
||||||
|
|
||||||
@if exist "sd-v1-4.ckpt" (
|
@if exist "sd-v1-4.ckpt" (
|
||||||
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (
|
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (
|
||||||
|
@ -159,12 +159,7 @@ fi
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
mkdir -p "../models/stable-diffusion"
|
|
||||||
mkdir -p "../models/vae"
|
mkdir -p "../models/vae"
|
||||||
mkdir -p "../models/hypernetwork"
|
|
||||||
echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt"
|
|
||||||
echo "" > "../models/vae/Put your VAE files here.txt"
|
|
||||||
echo "" > "../models/hypernetwork/Put your hypernetwork files here.txt"
|
|
||||||
|
|
||||||
if [ -f "sd-v1-4.ckpt" ]; then
|
if [ -f "sd-v1-4.ckpt" ]; then
|
||||||
model_size=`find "sd-v1-4.ckpt" -printf "%s"`
|
model_size=`find "sd-v1-4.ckpt" -printf "%s"`
|
||||||
|
@ -409,7 +409,6 @@
|
|||||||
async function init() {
|
async function init() {
|
||||||
await initSettings()
|
await initSettings()
|
||||||
await getModels()
|
await getModels()
|
||||||
await getDiskPath()
|
|
||||||
await getAppConfig()
|
await getAppConfig()
|
||||||
await loadUIPlugins()
|
await loadUIPlugins()
|
||||||
await loadModifiers()
|
await loadModifiers()
|
||||||
|
@ -327,20 +327,10 @@ autoPickGPUsField.addEventListener('click', function() {
|
|||||||
gpuSettingEntry.style.display = (this.checked ? 'none' : '')
|
gpuSettingEntry.style.display = (this.checked ? 'none' : '')
|
||||||
})
|
})
|
||||||
|
|
||||||
async function getDiskPath() {
|
async function setDiskPath(defaultDiskPath) {
|
||||||
try {
|
var diskPath = getSetting("diskPath")
|
||||||
var diskPath = getSetting("diskPath")
|
if (diskPath == '' || diskPath == undefined || diskPath == "undefined") {
|
||||||
if (diskPath == '' || diskPath == undefined || diskPath == "undefined") {
|
setSetting("diskPath", defaultDiskPath)
|
||||||
let res = await fetch('/get/output_dir')
|
|
||||||
if (res.status === 200) {
|
|
||||||
res = await res.json()
|
|
||||||
res = res.output_dir
|
|
||||||
|
|
||||||
setSetting("diskPath", res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.log('error fetching output dir path', e)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -415,6 +405,7 @@ async function getSystemInfo() {
|
|||||||
|
|
||||||
setDeviceInfo(devices)
|
setDeviceInfo(devices)
|
||||||
setHostInfo(res['hosts'])
|
setHostInfo(res['hosts'])
|
||||||
|
setDiskPath(res['default_output_dir'])
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.log('error fetching devices', e)
|
console.log('error fetching devices', e)
|
||||||
}
|
}
|
||||||
|
@ -105,6 +105,10 @@ class Response:
|
|||||||
request: Request
|
request: Request
|
||||||
images: list
|
images: list
|
||||||
|
|
||||||
|
def __init__(self, request: Request, images: list):
|
||||||
|
self.request = request
|
||||||
|
self.images = images
|
||||||
|
|
||||||
def json(self):
|
def json(self):
|
||||||
res = {
|
res = {
|
||||||
"status": 'succeeded',
|
"status": 'succeeded',
|
||||||
@ -116,3 +120,6 @@ class Response:
|
|||||||
res["output"].append(image.json())
|
res["output"].append(image.json())
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
class UserInitiatedStop(Exception):
|
||||||
|
pass
|
||||||
|
@ -28,11 +28,6 @@ APP_CONFIG_DEFAULTS = {
|
|||||||
'open_browser_on_start': True,
|
'open_browser_on_start': True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
DEFAULT_MODELS = [
|
|
||||||
# needed to support the legacy installations
|
|
||||||
'custom-model', # Check if user has a custom model, use it first.
|
|
||||||
'sd-v1-4', # Default fallback.
|
|
||||||
]
|
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
||||||
|
@ -101,10 +101,8 @@ def device_init(context, device):
|
|||||||
context.device = device
|
context.device = device
|
||||||
|
|
||||||
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
|
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
|
||||||
device_name = context.device_name.lower()
|
if needs_to_force_full_precision(context.device_name):
|
||||||
force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
|
print(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
|
||||||
if force_full_precision:
|
|
||||||
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', context.device_name)
|
|
||||||
# Apply force_full_precision now before models are loaded.
|
# Apply force_full_precision now before models are loaded.
|
||||||
context.precision = 'full'
|
context.precision = 'full'
|
||||||
|
|
||||||
@ -113,6 +111,10 @@ def device_init(context, device):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def needs_to_force_full_precision(context):
|
||||||
|
device_name = context.device_name.lower()
|
||||||
|
return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
|
||||||
|
|
||||||
def validate_device_id(device, log_prefix=''):
|
def validate_device_id(device, log_prefix=''):
|
||||||
def is_valid():
|
def is_valid():
|
||||||
if not isinstance(device, str):
|
if not isinstance(device, str):
|
||||||
|
@ -1,32 +1,39 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from sd_internal import app
|
from sd_internal import app, device_manager
|
||||||
|
from sd_internal import Request
|
||||||
import picklescan.scanner
|
import picklescan.scanner
|
||||||
import rich
|
import rich
|
||||||
|
|
||||||
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
|
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
||||||
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
|
MODEL_EXTENSIONS = {
|
||||||
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
|
'stable-diffusion': ['.ckpt', '.safetensors'],
|
||||||
|
'vae': ['.vae.pt', '.ckpt'],
|
||||||
default_model_to_load = None
|
'hypernetwork': ['.pt'],
|
||||||
default_vae_to_load = None
|
'gfpgan': ['.pth'],
|
||||||
default_hypernetwork_to_load = None
|
'realesrgan': ['.pth'],
|
||||||
|
}
|
||||||
|
DEFAULT_MODELS = {
|
||||||
|
'stable-diffusion': [ # needed to support the legacy installations
|
||||||
|
'custom-model', # only one custom model file was supported initially, creatively named 'custom-model'
|
||||||
|
'sd-v1-4', # Default fallback.
|
||||||
|
],
|
||||||
|
'gfpgan': ['GFPGANv1.3'],
|
||||||
|
'realesrgan': ['RealESRGAN_x4plus'],
|
||||||
|
}
|
||||||
|
|
||||||
known_models = {}
|
known_models = {}
|
||||||
|
|
||||||
def init():
|
def init():
|
||||||
global default_model_to_load, default_vae_to_load, default_hypernetwork_to_load
|
make_model_folders()
|
||||||
|
|
||||||
default_model_to_load = resolve_ckpt_to_use()
|
|
||||||
default_vae_to_load = resolve_vae_to_use()
|
|
||||||
default_hypernetwork_to_load = resolve_hypernetwork_to_use()
|
|
||||||
|
|
||||||
getModels() # run this once, to cache the picklescan results
|
getModels() # run this once, to cache the picklescan results
|
||||||
|
|
||||||
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
|
def resolve_model_to_use(model_name:str, model_type:str):
|
||||||
|
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
|
default_models = DEFAULT_MODELS.get(model_type, [])
|
||||||
config = app.getConfig()
|
config = app.getConfig()
|
||||||
|
|
||||||
model_dirs = [os.path.join(app.MODELS_DIR, model_dir), app.SD_DIR]
|
model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR]
|
||||||
if not model_name: # When None try user configured model.
|
if not model_name: # When None try user configured model.
|
||||||
# config = getConfig()
|
# config = getConfig()
|
||||||
if 'model' in config and model_type in config['model']:
|
if 'model' in config and model_type in config['model']:
|
||||||
@ -39,7 +46,7 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex
|
|||||||
model_name = 'sd-v1-4'
|
model_name = 'sd-v1-4'
|
||||||
|
|
||||||
# Check models directory
|
# Check models directory
|
||||||
models_dir_path = os.path.join(app.MODELS_DIR, model_dir, model_name)
|
models_dir_path = os.path.join(app.MODELS_DIR, model_type, model_name)
|
||||||
for model_extension in model_extensions:
|
for model_extension in model_extensions:
|
||||||
if os.path.exists(models_dir_path + model_extension):
|
if os.path.exists(models_dir_path + model_extension):
|
||||||
return models_dir_path + model_extension
|
return models_dir_path + model_extension
|
||||||
@ -66,14 +73,32 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex
|
|||||||
print(f'No valid models found for model_name: {model_name}')
|
print(f'No valid models found for model_name: {model_name}')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def resolve_ckpt_to_use(model_name:str=None):
|
def resolve_sd_model_to_use(model_name:str=None):
|
||||||
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS)
|
return resolve_model_to_use(model_name, model_type='stable-diffusion')
|
||||||
|
|
||||||
def resolve_vae_to_use(model_name:str=None):
|
def resolve_vae_model_to_use(model_name:str=None):
|
||||||
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[])
|
return resolve_model_to_use(model_name, model_type='vae')
|
||||||
|
|
||||||
def resolve_hypernetwork_to_use(model_name:str=None):
|
def resolve_hypernetwork_model_to_use(model_name:str=None):
|
||||||
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
|
return resolve_model_to_use(model_name, model_type='hypernetwork')
|
||||||
|
|
||||||
|
def resolve_gfpgan_model_to_use(model_name:str=None):
|
||||||
|
return resolve_model_to_use(model_name, model_type='gfpgan')
|
||||||
|
|
||||||
|
def resolve_realesrgan_model_to_use(model_name:str=None):
|
||||||
|
return resolve_model_to_use(model_name, model_type='realesrgan')
|
||||||
|
|
||||||
|
def make_model_folders():
|
||||||
|
for model_type in KNOWN_MODEL_TYPES:
|
||||||
|
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
||||||
|
|
||||||
|
os.makedirs(model_dir_path, exist_ok=True)
|
||||||
|
|
||||||
|
help_file_name = f'Place your {model_type} model files here.txt'
|
||||||
|
help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}'
|
||||||
|
|
||||||
|
with open(os.path.join(model_dir_path, help_file_name)) as f:
|
||||||
|
f.write(help_file_contents)
|
||||||
|
|
||||||
def is_malicious_model(file_path):
|
def is_malicious_model(file_path):
|
||||||
try:
|
try:
|
||||||
@ -102,8 +127,9 @@ def getModels():
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
def listModels(models_dirname, model_type, model_extensions):
|
def listModels(model_type):
|
||||||
models_dir = os.path.join(app.MODELS_DIR, models_dirname)
|
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||||
|
models_dir = os.path.join(app.MODELS_DIR, model_type)
|
||||||
if not os.path.exists(models_dir):
|
if not os.path.exists(models_dir):
|
||||||
os.makedirs(models_dir)
|
os.makedirs(models_dir)
|
||||||
|
|
||||||
@ -128,9 +154,9 @@ def getModels():
|
|||||||
models['options'][model_type].sort()
|
models['options'][model_type].sort()
|
||||||
|
|
||||||
# custom models
|
# custom models
|
||||||
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
|
listModels(model_type='stable-diffusion')
|
||||||
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
|
listModels(model_type='vae')
|
||||||
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS)
|
listModels(model_type='hypernetwork')
|
||||||
|
|
||||||
# legacy
|
# legacy
|
||||||
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
|
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
|
||||||
@ -138,3 +164,19 @@ def getModels():
|
|||||||
models['options']['stable-diffusion'].append('custom-model')
|
models['options']['stable-diffusion'].append('custom-model')
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
def is_sd_model_reload_necessary(thread_data, req: Request):
|
||||||
|
needs_model_reload = False
|
||||||
|
if 'stable-diffusion' not in thread_data.models or \
|
||||||
|
thread_data.model_paths['stable-diffusion'] != req.use_stable_diffusion_model or \
|
||||||
|
thread_data.model_paths['vae'] != req.use_vae_model:
|
||||||
|
|
||||||
|
needs_model_reload = True
|
||||||
|
|
||||||
|
if thread_data.device != 'cpu':
|
||||||
|
if (thread_data.precision == 'autocast' and req.use_full_precision) or \
|
||||||
|
(thread_data.precision == 'full' and not req.use_full_precision and not device_manager.needs_to_force_full_precision(thread_data)):
|
||||||
|
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
||||||
|
needs_model_reload = True
|
||||||
|
|
||||||
|
return needs_model_reload
|
||||||
|
@ -1,16 +1,23 @@
|
|||||||
import threading
|
import threading
|
||||||
import queue
|
import queue
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
import re
|
||||||
|
|
||||||
from sd_internal import device_manager, model_manager
|
from sd_internal import device_manager, model_manager
|
||||||
from sd_internal import Request, Response, Image as ResponseImage
|
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
|
||||||
|
|
||||||
from modules import model_loader, image_generator, image_utils
|
from modules import model_loader, image_generator, image_utils, image_filters
|
||||||
|
|
||||||
thread_data = threading.local()
|
thread_data = threading.local()
|
||||||
'''
|
'''
|
||||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
filename_regex = re.compile('[^a-zA-Z0-9]')
|
||||||
|
|
||||||
def init(device):
|
def init(device):
|
||||||
'''
|
'''
|
||||||
Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device
|
Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device
|
||||||
@ -28,89 +35,167 @@ def init(device):
|
|||||||
|
|
||||||
device_manager.device_init(thread_data, device)
|
device_manager.device_init(thread_data, device)
|
||||||
|
|
||||||
load_default_models()
|
init_and_load_default_models()
|
||||||
|
|
||||||
def destroy():
|
def destroy():
|
||||||
model_loader.unload_sd_model(thread_data)
|
model_loader.unload_sd_model(thread_data)
|
||||||
model_loader.unload_gfpgan_model(thread_data)
|
model_loader.unload_gfpgan_model(thread_data)
|
||||||
model_loader.unload_realesrgan_model(thread_data)
|
model_loader.unload_realesrgan_model(thread_data)
|
||||||
|
|
||||||
def load_default_models():
|
def init_and_load_default_models():
|
||||||
thread_data.model_paths['stable-diffusion'] = model_manager.default_model_to_load
|
# init default model paths
|
||||||
thread_data.model_paths['vae'] = model_manager.default_vae_to_load
|
thread_data.model_paths['stable-diffusion'] = model_manager.resolve_sd_model_to_use()
|
||||||
|
thread_data.model_paths['vae'] = model_manager.resolve_vae_model_to_use()
|
||||||
|
thread_data.model_paths['hypernetwork'] = model_manager.resolve_hypernetwork_model_to_use()
|
||||||
|
thread_data.model_paths['gfpgan'] = model_manager.resolve_gfpgan_model_to_use()
|
||||||
|
thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use()
|
||||||
|
|
||||||
|
# load mandatory models
|
||||||
model_loader.load_sd_model(thread_data)
|
model_loader.load_sd_model(thread_data)
|
||||||
|
|
||||||
def reload_models_if_necessary(req: Request=None):
|
def reload_models_if_necessary(req: Request):
|
||||||
needs_model_reload = False
|
if model_manager.is_sd_model_reload_necessary(thread_data, req):
|
||||||
if 'stable-diffusion' not in thread_data.models or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
|
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
|
||||||
thread_data.ckpt_file = req.use_stable_diffusion_model
|
thread_data.model_paths['vae'] = req.use_vae_model
|
||||||
thread_data.vae_file = req.use_vae_model
|
|
||||||
needs_model_reload = True
|
|
||||||
|
|
||||||
if thread_data.device != 'cpu':
|
model_loader.load_sd_model(thread_data)
|
||||||
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
|
|
||||||
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
|
|
||||||
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
|
||||||
needs_model_reload = True
|
|
||||||
|
|
||||||
return needs_model_reload
|
# if is_hypernetwork_reload_necessary(task.request):
|
||||||
|
# current_state = ServerStates.LoadingModel
|
||||||
|
# runtime.reload_hypernetwork()
|
||||||
|
|
||||||
if is_hypernetwork_reload_necessary(task.request):
|
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
current_state = ServerStates.LoadingModel
|
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback)
|
||||||
runtime.reload_hypernetwork()
|
images = apply_filters(req, images, user_stopped)
|
||||||
|
|
||||||
if is_model_reload_necessary(task.request):
|
save_images(req, images)
|
||||||
current_state = ServerStates.LoadingModel
|
|
||||||
runtime.reload_model()
|
|
||||||
|
|
||||||
def load_models():
|
return Response(req, images=construct_response(req, images))
|
||||||
if ckpt_file_path == None:
|
|
||||||
ckpt_file_path = default_model_to_load
|
def generate_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
if vae_file_path == None:
|
thread_data.temp_images.clear()
|
||||||
vae_file_path = default_vae_to_load
|
|
||||||
if hypernetwork_file_path == None:
|
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback)
|
||||||
hypernetwork_file_path = default_hypernetwork_to_load
|
|
||||||
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
|
|
||||||
return
|
|
||||||
current_state = ServerStates.LoadingModel
|
|
||||||
try:
|
|
||||||
from sd_internal import runtime2
|
|
||||||
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
|
|
||||||
runtime.thread_data.ckpt_file = ckpt_file_path
|
|
||||||
runtime.thread_data.vae_file = vae_file_path
|
|
||||||
runtime.load_model_ckpt()
|
|
||||||
runtime.load_hypernetwork()
|
|
||||||
current_model_path = ckpt_file_path
|
|
||||||
current_vae_path = vae_file_path
|
|
||||||
current_hypernetwork_path = hypernetwork_file_path
|
|
||||||
current_state_error = None
|
|
||||||
current_state = ServerStates.Online
|
|
||||||
except Exception as e:
|
|
||||||
current_model_path = None
|
|
||||||
current_vae_path = None
|
|
||||||
current_state_error = e
|
|
||||||
current_state = ServerStates.Unavailable
|
|
||||||
print(traceback.format_exc())
|
|
||||||
|
|
||||||
def make_image(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
|
||||||
try:
|
try:
|
||||||
images = image_generator.make_image(context=thread_data, args=get_mk_img_args(req))
|
images = image_generator.make_image(context=thread_data, args=get_mk_img_args(req))
|
||||||
|
user_stopped = False
|
||||||
except UserInitiatedStop:
|
except UserInitiatedStop:
|
||||||
pass
|
images = []
|
||||||
|
user_stopped = True
|
||||||
|
if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None:
|
||||||
|
return images
|
||||||
|
for i in range(req.num_outputs):
|
||||||
|
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
|
||||||
|
|
||||||
|
del thread_data.partial_x_samples
|
||||||
|
finally:
|
||||||
|
model_loader.gc(thread_data)
|
||||||
|
|
||||||
|
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
|
||||||
|
|
||||||
|
return images, user_stopped
|
||||||
|
|
||||||
|
def apply_filters(req: Request, images: list, user_stopped):
|
||||||
|
if user_stopped or (req.use_face_correction is None and req.use_upscale is None):
|
||||||
|
return images
|
||||||
|
|
||||||
|
filters = []
|
||||||
|
if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_gfpgan_model_to_use(req.use_face_correction)))
|
||||||
|
if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_realesrgan_model_to_use(req.use_upscale)))
|
||||||
|
|
||||||
|
filtered_images = []
|
||||||
|
for img, seed, _ in images:
|
||||||
|
for filter_fn, filter_model_path in filters:
|
||||||
|
img = filter_fn(thread_data, img, filter_model_path)
|
||||||
|
|
||||||
|
filtered_images.append((img, seed, True))
|
||||||
|
|
||||||
|
if not req.show_only_filtered_image:
|
||||||
|
filtered_images = images + filtered_images
|
||||||
|
|
||||||
|
return filtered_images
|
||||||
|
|
||||||
|
def save_images(req: Request, images: list):
|
||||||
|
if req.save_to_disk_path is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
def get_image_id(i):
|
||||||
|
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
||||||
|
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
||||||
|
return img_id
|
||||||
|
|
||||||
|
def get_image_basepath(i):
|
||||||
|
session_out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id))
|
||||||
|
os.makedirs(session_out_path, exist_ok=True)
|
||||||
|
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
|
||||||
|
return os.path.join(session_out_path, f"{prompt_flattened}_{get_image_id(i)}")
|
||||||
|
|
||||||
|
for i, img_data in enumerate(images):
|
||||||
|
img, seed, filtered = img_data
|
||||||
|
img_path = get_image_basepath(i)
|
||||||
|
|
||||||
|
if not filtered or req.show_only_filtered_image:
|
||||||
|
img_metadata_path = img_path + '.txt'
|
||||||
|
metadata = req.json()
|
||||||
|
metadata['seed'] = seed
|
||||||
|
with open(img_metadata_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(metadata)
|
||||||
|
|
||||||
|
img_path += '_filtered' if filtered else ''
|
||||||
|
img_path += '.' + req.output_format
|
||||||
|
img.save(img_path, quality=req.output_quality)
|
||||||
|
|
||||||
|
def construct_response(req: Request, images: list):
|
||||||
|
return [
|
||||||
|
ResponseImage(
|
||||||
|
data=image_utils.img_to_base64_str(img, req.output_format, req.output_quality),
|
||||||
|
seed=seed
|
||||||
|
) for img, seed, _ in images
|
||||||
|
]
|
||||||
|
|
||||||
def get_mk_img_args(req: Request):
|
def get_mk_img_args(req: Request):
|
||||||
args = req.json()
|
args = req.json()
|
||||||
|
|
||||||
if req.init_image is not None:
|
args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
|
||||||
args['init_image'] = image_utils.base64_str_to_img(req.init_image)
|
args['mask'] = image_utils.base64_str_to_img(req.mask) if req.mask is not None else None
|
||||||
|
|
||||||
if req.mask is not None:
|
|
||||||
args['mask'] = image_utils.base64_str_to_img(req.mask)
|
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def on_image_step(x_samples, i):
|
def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
pass
|
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
||||||
|
last_callback_time = -1
|
||||||
|
|
||||||
image_generator.on_image_step = on_image_step
|
def update_temp_img(req, x_samples, task_temp_images: list):
|
||||||
|
partial_images = []
|
||||||
|
for i in range(req.num_outputs):
|
||||||
|
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
|
||||||
|
buf = image_utils.img_to_buffer(img, output_format='JPEG')
|
||||||
|
|
||||||
|
del img
|
||||||
|
|
||||||
|
thread_data.temp_images[f'{req.request_id}/{i}'] = buf
|
||||||
|
task_temp_images[i] = buf
|
||||||
|
partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'})
|
||||||
|
return partial_images
|
||||||
|
|
||||||
|
def on_image_step(x_samples, i):
|
||||||
|
nonlocal last_callback_time
|
||||||
|
|
||||||
|
thread_data.partial_x_samples = x_samples
|
||||||
|
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
|
||||||
|
last_callback_time = time.time()
|
||||||
|
|
||||||
|
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
|
||||||
|
|
||||||
|
if req.stream_image_progress and i % 5 == 0:
|
||||||
|
progress['output'] = update_temp_img(req, x_samples, task_temp_images)
|
||||||
|
|
||||||
|
data_queue.put(json.dumps(progress))
|
||||||
|
|
||||||
|
step_callback()
|
||||||
|
|
||||||
|
if thread_data.stop_processing:
|
||||||
|
raise UserInitiatedStop("User requested that we stop processing")
|
||||||
|
|
||||||
|
return on_image_step
|
||||||
|
@ -324,7 +324,7 @@ def thread_render(device):
|
|||||||
runtime2.reload_models_if_necessary(task.request)
|
runtime2.reload_models_if_necessary(task.request)
|
||||||
|
|
||||||
current_state = ServerStates.Rendering
|
current_state = ServerStates.Rendering
|
||||||
task.response = runtime2.make_image(task.request, task.buffer_queue, task.temp_images, step_callback)
|
task.response = runtime2.make_images(task.request, task.buffer_queue, task.temp_images, step_callback)
|
||||||
# Before looping back to the generator, mark cache as still alive.
|
# Before looping back to the generator, mark cache as still alive.
|
||||||
task_cache.keep(id(task), TASK_TTL)
|
task_cache.keep(id(task), TASK_TTL)
|
||||||
session_cache.keep(task.request.session_id, TASK_TTL)
|
session_cache.keep(task.request.session_id, TASK_TTL)
|
||||||
|
@ -137,9 +137,9 @@ def ping(session_id:str=None):
|
|||||||
def render(req : task_manager.ImageRequest):
|
def render(req : task_manager.ImageRequest):
|
||||||
try:
|
try:
|
||||||
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
|
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
|
||||||
req.use_stable_diffusion_model = model_manager.resolve_ckpt_to_use(req.use_stable_diffusion_model)
|
req.use_stable_diffusion_model = model_manager.resolve_sd_model_to_use(req.use_stable_diffusion_model)
|
||||||
req.use_vae_model = model_manager.resolve_vae_to_use(req.use_vae_model)
|
req.use_vae_model = model_manager.resolve_vae_model_to_use(req.use_vae_model)
|
||||||
req.use_hypernetwork_model = model_manager.resolve_hypernetwork_to_use(req.use_hypernetwork_model)
|
req.use_hypernetwork_model = model_manager.resolve_hypernetwork_model_to_use(req.use_hypernetwork_model)
|
||||||
|
|
||||||
new_task = task_manager.render(req)
|
new_task = task_manager.render(req)
|
||||||
response = {
|
response = {
|
||||||
|
Loading…
Reference in New Issue
Block a user