mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-11-22 00:03:20 +01:00
Work-in-progress: refactored the end-to-end codebase. Missing: hypernetworks, turbo config, and SD 2. Not tested yet
This commit is contained in:
parent
bad89160cc
commit
f4a6910ab4
@ -199,12 +199,7 @@ call WHERE uvicorn > .tmp
|
||||
|
||||
|
||||
|
||||
if not exist "..\models\stable-diffusion" mkdir "..\models\stable-diffusion"
|
||||
if not exist "..\models\vae" mkdir "..\models\vae"
|
||||
if not exist "..\models\hypernetwork" mkdir "..\models\hypernetwork"
|
||||
echo. > "..\models\stable-diffusion\Put your custom ckpt files here.txt"
|
||||
echo. > "..\models\vae\Put your VAE files here.txt"
|
||||
echo. > "..\models\hypernetwork\Put your hypernetwork files here.txt"
|
||||
|
||||
@if exist "sd-v1-4.ckpt" (
|
||||
for %%I in ("sd-v1-4.ckpt") do if "%%~zI" EQU "4265380512" (
|
||||
|
@ -159,12 +159,7 @@ fi
|
||||
|
||||
|
||||
|
||||
mkdir -p "../models/stable-diffusion"
|
||||
mkdir -p "../models/vae"
|
||||
mkdir -p "../models/hypernetwork"
|
||||
echo "" > "../models/stable-diffusion/Put your custom ckpt files here.txt"
|
||||
echo "" > "../models/vae/Put your VAE files here.txt"
|
||||
echo "" > "../models/hypernetwork/Put your hypernetwork files here.txt"
|
||||
|
||||
if [ -f "sd-v1-4.ckpt" ]; then
|
||||
model_size=`find "sd-v1-4.ckpt" -printf "%s"`
|
||||
|
@ -409,7 +409,6 @@
|
||||
async function init() {
|
||||
await initSettings()
|
||||
await getModels()
|
||||
await getDiskPath()
|
||||
await getAppConfig()
|
||||
await loadUIPlugins()
|
||||
await loadModifiers()
|
||||
|
@ -327,20 +327,10 @@ autoPickGPUsField.addEventListener('click', function() {
|
||||
gpuSettingEntry.style.display = (this.checked ? 'none' : '')
|
||||
})
|
||||
|
||||
async function getDiskPath() {
|
||||
try {
|
||||
var diskPath = getSetting("diskPath")
|
||||
if (diskPath == '' || diskPath == undefined || diskPath == "undefined") {
|
||||
let res = await fetch('/get/output_dir')
|
||||
if (res.status === 200) {
|
||||
res = await res.json()
|
||||
res = res.output_dir
|
||||
|
||||
setSetting("diskPath", res)
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.log('error fetching output dir path', e)
|
||||
async function setDiskPath(defaultDiskPath) {
|
||||
var diskPath = getSetting("diskPath")
|
||||
if (diskPath == '' || diskPath == undefined || diskPath == "undefined") {
|
||||
setSetting("diskPath", defaultDiskPath)
|
||||
}
|
||||
}
|
||||
|
||||
@ -415,6 +405,7 @@ async function getSystemInfo() {
|
||||
|
||||
setDeviceInfo(devices)
|
||||
setHostInfo(res['hosts'])
|
||||
setDiskPath(res['default_output_dir'])
|
||||
} catch (e) {
|
||||
console.log('error fetching devices', e)
|
||||
}
|
||||
|
@ -105,6 +105,10 @@ class Response:
|
||||
request: Request
|
||||
images: list
|
||||
|
||||
def __init__(self, request: Request, images: list):
|
||||
self.request = request
|
||||
self.images = images
|
||||
|
||||
def json(self):
|
||||
res = {
|
||||
"status": 'succeeded',
|
||||
@ -116,3 +120,6 @@ class Response:
|
||||
res["output"].append(image.json())
|
||||
|
||||
return res
|
||||
|
||||
class UserInitiatedStop(Exception):
|
||||
pass
|
||||
|
@ -28,11 +28,6 @@ APP_CONFIG_DEFAULTS = {
|
||||
'open_browser_on_start': True,
|
||||
},
|
||||
}
|
||||
DEFAULT_MODELS = [
|
||||
# needed to support the legacy installations
|
||||
'custom-model', # Check if user has a custom model, use it first.
|
||||
'sd-v1-4', # Default fallback.
|
||||
]
|
||||
|
||||
def init():
|
||||
os.makedirs(USER_UI_PLUGINS_DIR, exist_ok=True)
|
||||
|
@ -101,10 +101,8 @@ def device_init(context, device):
|
||||
context.device = device
|
||||
|
||||
# Force full precision on 1660 and 1650 NVIDIA cards to avoid creating green images
|
||||
device_name = context.device_name.lower()
|
||||
force_full_precision = (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
|
||||
if force_full_precision:
|
||||
print('forcing full precision on NVIDIA 16xx cards, to avoid green images. GPU detected: ', context.device_name)
|
||||
if needs_to_force_full_precision(context.device_name):
|
||||
print(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
|
||||
# Apply force_full_precision now before models are loaded.
|
||||
context.precision = 'full'
|
||||
|
||||
@ -113,6 +111,10 @@ def device_init(context, device):
|
||||
|
||||
return
|
||||
|
||||
def needs_to_force_full_precision(context):
|
||||
device_name = context.device_name.lower()
|
||||
return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
|
||||
|
||||
def validate_device_id(device, log_prefix=''):
|
||||
def is_valid():
|
||||
if not isinstance(device, str):
|
||||
|
@ -1,32 +1,39 @@
|
||||
import os
|
||||
|
||||
from sd_internal import app
|
||||
from sd_internal import app, device_manager
|
||||
from sd_internal import Request
|
||||
import picklescan.scanner
|
||||
import rich
|
||||
|
||||
STABLE_DIFFUSION_MODEL_EXTENSIONS = ['.ckpt', '.safetensors']
|
||||
VAE_MODEL_EXTENSIONS = ['.vae.pt', '.ckpt']
|
||||
HYPERNETWORK_MODEL_EXTENSIONS = ['.pt']
|
||||
|
||||
default_model_to_load = None
|
||||
default_vae_to_load = None
|
||||
default_hypernetwork_to_load = None
|
||||
KNOWN_MODEL_TYPES = ['stable-diffusion', 'vae', 'hypernetwork', 'gfpgan', 'realesrgan']
|
||||
MODEL_EXTENSIONS = {
|
||||
'stable-diffusion': ['.ckpt', '.safetensors'],
|
||||
'vae': ['.vae.pt', '.ckpt'],
|
||||
'hypernetwork': ['.pt'],
|
||||
'gfpgan': ['.pth'],
|
||||
'realesrgan': ['.pth'],
|
||||
}
|
||||
DEFAULT_MODELS = {
|
||||
'stable-diffusion': [ # needed to support the legacy installations
|
||||
'custom-model', # only one custom model file was supported initially, creatively named 'custom-model'
|
||||
'sd-v1-4', # Default fallback.
|
||||
],
|
||||
'gfpgan': ['GFPGANv1.3'],
|
||||
'realesrgan': ['RealESRGAN_x4plus'],
|
||||
}
|
||||
|
||||
known_models = {}
|
||||
|
||||
def init():
|
||||
global default_model_to_load, default_vae_to_load, default_hypernetwork_to_load
|
||||
|
||||
default_model_to_load = resolve_ckpt_to_use()
|
||||
default_vae_to_load = resolve_vae_to_use()
|
||||
default_hypernetwork_to_load = resolve_hypernetwork_to_use()
|
||||
|
||||
make_model_folders()
|
||||
getModels() # run this once, to cache the picklescan results
|
||||
|
||||
def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_extensions:list, default_models=[]):
|
||||
def resolve_model_to_use(model_name:str, model_type:str):
|
||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||
default_models = DEFAULT_MODELS.get(model_type, [])
|
||||
config = app.getConfig()
|
||||
|
||||
model_dirs = [os.path.join(app.MODELS_DIR, model_dir), app.SD_DIR]
|
||||
model_dirs = [os.path.join(app.MODELS_DIR, model_type), app.SD_DIR]
|
||||
if not model_name: # When None try user configured model.
|
||||
# config = getConfig()
|
||||
if 'model' in config and model_type in config['model']:
|
||||
@ -39,7 +46,7 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex
|
||||
model_name = 'sd-v1-4'
|
||||
|
||||
# Check models directory
|
||||
models_dir_path = os.path.join(app.MODELS_DIR, model_dir, model_name)
|
||||
models_dir_path = os.path.join(app.MODELS_DIR, model_type, model_name)
|
||||
for model_extension in model_extensions:
|
||||
if os.path.exists(models_dir_path + model_extension):
|
||||
return models_dir_path + model_extension
|
||||
@ -66,14 +73,32 @@ def resolve_model_to_use(model_name:str, model_type:str, model_dir:str, model_ex
|
||||
print(f'No valid models found for model_name: {model_name}')
|
||||
return None
|
||||
|
||||
def resolve_ckpt_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='stable-diffusion', model_dir='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS, default_models=app.APP_CONFIG_DEFAULT_MODELS)
|
||||
def resolve_sd_model_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='stable-diffusion')
|
||||
|
||||
def resolve_vae_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='vae', model_dir='vae', model_extensions=VAE_MODEL_EXTENSIONS, default_models=[])
|
||||
def resolve_vae_model_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='vae')
|
||||
|
||||
def resolve_hypernetwork_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='hypernetwork', model_dir='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS, default_models=[])
|
||||
def resolve_hypernetwork_model_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='hypernetwork')
|
||||
|
||||
def resolve_gfpgan_model_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='gfpgan')
|
||||
|
||||
def resolve_realesrgan_model_to_use(model_name:str=None):
|
||||
return resolve_model_to_use(model_name, model_type='realesrgan')
|
||||
|
||||
def make_model_folders():
|
||||
for model_type in KNOWN_MODEL_TYPES:
|
||||
model_dir_path = os.path.join(app.MODELS_DIR, model_type)
|
||||
|
||||
os.makedirs(model_dir_path, exist_ok=True)
|
||||
|
||||
help_file_name = f'Place your {model_type} model files here.txt'
|
||||
help_file_contents = f'Supported extensions: {" or ".join(MODEL_EXTENSIONS.get(model_type))}'
|
||||
|
||||
with open(os.path.join(model_dir_path, help_file_name)) as f:
|
||||
f.write(help_file_contents)
|
||||
|
||||
def is_malicious_model(file_path):
|
||||
try:
|
||||
@ -102,8 +127,9 @@ def getModels():
|
||||
},
|
||||
}
|
||||
|
||||
def listModels(models_dirname, model_type, model_extensions):
|
||||
models_dir = os.path.join(app.MODELS_DIR, models_dirname)
|
||||
def listModels(model_type):
|
||||
model_extensions = MODEL_EXTENSIONS.get(model_type, [])
|
||||
models_dir = os.path.join(app.MODELS_DIR, model_type)
|
||||
if not os.path.exists(models_dir):
|
||||
os.makedirs(models_dir)
|
||||
|
||||
@ -128,9 +154,9 @@ def getModels():
|
||||
models['options'][model_type].sort()
|
||||
|
||||
# custom models
|
||||
listModels(models_dirname='stable-diffusion', model_type='stable-diffusion', model_extensions=STABLE_DIFFUSION_MODEL_EXTENSIONS)
|
||||
listModels(models_dirname='vae', model_type='vae', model_extensions=VAE_MODEL_EXTENSIONS)
|
||||
listModels(models_dirname='hypernetwork', model_type='hypernetwork', model_extensions=HYPERNETWORK_MODEL_EXTENSIONS)
|
||||
listModels(model_type='stable-diffusion')
|
||||
listModels(model_type='vae')
|
||||
listModels(model_type='hypernetwork')
|
||||
|
||||
# legacy
|
||||
custom_weight_path = os.path.join(app.SD_DIR, 'custom-model.ckpt')
|
||||
@ -138,3 +164,19 @@ def getModels():
|
||||
models['options']['stable-diffusion'].append('custom-model')
|
||||
|
||||
return models
|
||||
|
||||
def is_sd_model_reload_necessary(thread_data, req: Request):
|
||||
needs_model_reload = False
|
||||
if 'stable-diffusion' not in thread_data.models or \
|
||||
thread_data.model_paths['stable-diffusion'] != req.use_stable_diffusion_model or \
|
||||
thread_data.model_paths['vae'] != req.use_vae_model:
|
||||
|
||||
needs_model_reload = True
|
||||
|
||||
if thread_data.device != 'cpu':
|
||||
if (thread_data.precision == 'autocast' and req.use_full_precision) or \
|
||||
(thread_data.precision == 'full' and not req.use_full_precision and not device_manager.needs_to_force_full_precision(thread_data)):
|
||||
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
||||
needs_model_reload = True
|
||||
|
||||
return needs_model_reload
|
||||
|
@ -1,16 +1,23 @@
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
import re
|
||||
|
||||
from sd_internal import device_manager, model_manager
|
||||
from sd_internal import Request, Response, Image as ResponseImage
|
||||
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
|
||||
|
||||
from modules import model_loader, image_generator, image_utils
|
||||
from modules import model_loader, image_generator, image_utils, image_filters
|
||||
|
||||
thread_data = threading.local()
|
||||
'''
|
||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||
'''
|
||||
|
||||
filename_regex = re.compile('[^a-zA-Z0-9]')
|
||||
|
||||
def init(device):
|
||||
'''
|
||||
Initializes the fields that will be bound to this runtime's thread_data, and sets the current torch device
|
||||
@ -28,89 +35,167 @@ def init(device):
|
||||
|
||||
device_manager.device_init(thread_data, device)
|
||||
|
||||
load_default_models()
|
||||
init_and_load_default_models()
|
||||
|
||||
def destroy():
|
||||
model_loader.unload_sd_model(thread_data)
|
||||
model_loader.unload_gfpgan_model(thread_data)
|
||||
model_loader.unload_realesrgan_model(thread_data)
|
||||
|
||||
def load_default_models():
|
||||
thread_data.model_paths['stable-diffusion'] = model_manager.default_model_to_load
|
||||
thread_data.model_paths['vae'] = model_manager.default_vae_to_load
|
||||
def init_and_load_default_models():
|
||||
# init default model paths
|
||||
thread_data.model_paths['stable-diffusion'] = model_manager.resolve_sd_model_to_use()
|
||||
thread_data.model_paths['vae'] = model_manager.resolve_vae_model_to_use()
|
||||
thread_data.model_paths['hypernetwork'] = model_manager.resolve_hypernetwork_model_to_use()
|
||||
thread_data.model_paths['gfpgan'] = model_manager.resolve_gfpgan_model_to_use()
|
||||
thread_data.model_paths['realesrgan'] = model_manager.resolve_realesrgan_model_to_use()
|
||||
|
||||
# load mandatory models
|
||||
model_loader.load_sd_model(thread_data)
|
||||
|
||||
def reload_models_if_necessary(req: Request=None):
|
||||
needs_model_reload = False
|
||||
if 'stable-diffusion' not in thread_data.models or thread_data.ckpt_file != req.use_stable_diffusion_model or thread_data.vae_file != req.use_vae_model:
|
||||
thread_data.ckpt_file = req.use_stable_diffusion_model
|
||||
thread_data.vae_file = req.use_vae_model
|
||||
needs_model_reload = True
|
||||
def reload_models_if_necessary(req: Request):
|
||||
if model_manager.is_sd_model_reload_necessary(thread_data, req):
|
||||
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
|
||||
thread_data.model_paths['vae'] = req.use_vae_model
|
||||
|
||||
if thread_data.device != 'cpu':
|
||||
if (thread_data.precision == 'autocast' and (req.use_full_precision or not thread_data.model_is_half)) or \
|
||||
(thread_data.precision == 'full' and not req.use_full_precision and not thread_data.force_full_precision):
|
||||
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
||||
needs_model_reload = True
|
||||
model_loader.load_sd_model(thread_data)
|
||||
|
||||
return needs_model_reload
|
||||
# if is_hypernetwork_reload_necessary(task.request):
|
||||
# current_state = ServerStates.LoadingModel
|
||||
# runtime.reload_hypernetwork()
|
||||
|
||||
if is_hypernetwork_reload_necessary(task.request):
|
||||
current_state = ServerStates.LoadingModel
|
||||
runtime.reload_hypernetwork()
|
||||
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback)
|
||||
images = apply_filters(req, images, user_stopped)
|
||||
|
||||
if is_model_reload_necessary(task.request):
|
||||
current_state = ServerStates.LoadingModel
|
||||
runtime.reload_model()
|
||||
save_images(req, images)
|
||||
|
||||
def load_models():
|
||||
if ckpt_file_path == None:
|
||||
ckpt_file_path = default_model_to_load
|
||||
if vae_file_path == None:
|
||||
vae_file_path = default_vae_to_load
|
||||
if hypernetwork_file_path == None:
|
||||
hypernetwork_file_path = default_hypernetwork_to_load
|
||||
if ckpt_file_path == current_model_path and vae_file_path == current_vae_path:
|
||||
return
|
||||
current_state = ServerStates.LoadingModel
|
||||
try:
|
||||
from sd_internal import runtime2
|
||||
runtime.thread_data.hypernetwork_file = hypernetwork_file_path
|
||||
runtime.thread_data.ckpt_file = ckpt_file_path
|
||||
runtime.thread_data.vae_file = vae_file_path
|
||||
runtime.load_model_ckpt()
|
||||
runtime.load_hypernetwork()
|
||||
current_model_path = ckpt_file_path
|
||||
current_vae_path = vae_file_path
|
||||
current_hypernetwork_path = hypernetwork_file_path
|
||||
current_state_error = None
|
||||
current_state = ServerStates.Online
|
||||
except Exception as e:
|
||||
current_model_path = None
|
||||
current_vae_path = None
|
||||
current_state_error = e
|
||||
current_state = ServerStates.Unavailable
|
||||
print(traceback.format_exc())
|
||||
return Response(req, images=construct_response(req, images))
|
||||
|
||||
def generate_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
thread_data.temp_images.clear()
|
||||
|
||||
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback)
|
||||
|
||||
def make_image(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
try:
|
||||
images = image_generator.make_image(context=thread_data, args=get_mk_img_args(req))
|
||||
user_stopped = False
|
||||
except UserInitiatedStop:
|
||||
pass
|
||||
images = []
|
||||
user_stopped = True
|
||||
if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None:
|
||||
return images
|
||||
for i in range(req.num_outputs):
|
||||
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
|
||||
|
||||
del thread_data.partial_x_samples
|
||||
finally:
|
||||
model_loader.gc(thread_data)
|
||||
|
||||
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
|
||||
|
||||
return images, user_stopped
|
||||
|
||||
def apply_filters(req: Request, images: list, user_stopped):
|
||||
if user_stopped or (req.use_face_correction is None and req.use_upscale is None):
|
||||
return images
|
||||
|
||||
filters = []
|
||||
if req.use_face_correction.startswith('GFPGAN'): filters.append((image_filters.apply_gfpgan, model_manager.resolve_gfpgan_model_to_use(req.use_face_correction)))
|
||||
if req.use_upscale.startswith('RealESRGAN'): filters.append((image_filters.apply_realesrgan, model_manager.resolve_realesrgan_model_to_use(req.use_upscale)))
|
||||
|
||||
filtered_images = []
|
||||
for img, seed, _ in images:
|
||||
for filter_fn, filter_model_path in filters:
|
||||
img = filter_fn(thread_data, img, filter_model_path)
|
||||
|
||||
filtered_images.append((img, seed, True))
|
||||
|
||||
if not req.show_only_filtered_image:
|
||||
filtered_images = images + filtered_images
|
||||
|
||||
return filtered_images
|
||||
|
||||
def save_images(req: Request, images: list):
|
||||
if req.save_to_disk_path is None:
|
||||
return
|
||||
|
||||
def get_image_id(i):
|
||||
img_id = base64.b64encode(int(time.time()+i).to_bytes(8, 'big')).decode() # Generate unique ID based on time.
|
||||
img_id = img_id.translate({43:None, 47:None, 61:None})[-8:] # Remove + / = and keep last 8 chars.
|
||||
return img_id
|
||||
|
||||
def get_image_basepath(i):
|
||||
session_out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id))
|
||||
os.makedirs(session_out_path, exist_ok=True)
|
||||
prompt_flattened = filename_regex.sub('_', req.prompt)[:50]
|
||||
return os.path.join(session_out_path, f"{prompt_flattened}_{get_image_id(i)}")
|
||||
|
||||
for i, img_data in enumerate(images):
|
||||
img, seed, filtered = img_data
|
||||
img_path = get_image_basepath(i)
|
||||
|
||||
if not filtered or req.show_only_filtered_image:
|
||||
img_metadata_path = img_path + '.txt'
|
||||
metadata = req.json()
|
||||
metadata['seed'] = seed
|
||||
with open(img_metadata_path, 'w', encoding='utf-8') as f:
|
||||
f.write(metadata)
|
||||
|
||||
img_path += '_filtered' if filtered else ''
|
||||
img_path += '.' + req.output_format
|
||||
img.save(img_path, quality=req.output_quality)
|
||||
|
||||
def construct_response(req: Request, images: list):
|
||||
return [
|
||||
ResponseImage(
|
||||
data=image_utils.img_to_base64_str(img, req.output_format, req.output_quality),
|
||||
seed=seed
|
||||
) for img, seed, _ in images
|
||||
]
|
||||
|
||||
def get_mk_img_args(req: Request):
|
||||
args = req.json()
|
||||
|
||||
if req.init_image is not None:
|
||||
args['init_image'] = image_utils.base64_str_to_img(req.init_image)
|
||||
|
||||
if req.mask is not None:
|
||||
args['mask'] = image_utils.base64_str_to_img(req.mask)
|
||||
args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
|
||||
args['mask'] = image_utils.base64_str_to_img(req.mask) if req.mask is not None else None
|
||||
|
||||
return args
|
||||
|
||||
def on_image_step(x_samples, i):
|
||||
pass
|
||||
def make_step_callback(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
||||
last_callback_time = -1
|
||||
|
||||
image_generator.on_image_step = on_image_step
|
||||
def update_temp_img(req, x_samples, task_temp_images: list):
|
||||
partial_images = []
|
||||
for i in range(req.num_outputs):
|
||||
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
|
||||
buf = image_utils.img_to_buffer(img, output_format='JPEG')
|
||||
|
||||
del img
|
||||
|
||||
thread_data.temp_images[f'{req.request_id}/{i}'] = buf
|
||||
task_temp_images[i] = buf
|
||||
partial_images.append({'path': f'/image/tmp/{req.request_id}/{i}'})
|
||||
return partial_images
|
||||
|
||||
def on_image_step(x_samples, i):
|
||||
nonlocal last_callback_time
|
||||
|
||||
thread_data.partial_x_samples = x_samples
|
||||
step_time = time.time() - last_callback_time if last_callback_time != -1 else -1
|
||||
last_callback_time = time.time()
|
||||
|
||||
progress = {"step": i, "step_time": step_time, "total_steps": n_steps}
|
||||
|
||||
if req.stream_image_progress and i % 5 == 0:
|
||||
progress['output'] = update_temp_img(req, x_samples, task_temp_images)
|
||||
|
||||
data_queue.put(json.dumps(progress))
|
||||
|
||||
step_callback()
|
||||
|
||||
if thread_data.stop_processing:
|
||||
raise UserInitiatedStop("User requested that we stop processing")
|
||||
|
||||
return on_image_step
|
||||
|
@ -324,7 +324,7 @@ def thread_render(device):
|
||||
runtime2.reload_models_if_necessary(task.request)
|
||||
|
||||
current_state = ServerStates.Rendering
|
||||
task.response = runtime2.make_image(task.request, task.buffer_queue, task.temp_images, step_callback)
|
||||
task.response = runtime2.make_images(task.request, task.buffer_queue, task.temp_images, step_callback)
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.request.session_id, TASK_TTL)
|
||||
|
@ -137,9 +137,9 @@ def ping(session_id:str=None):
|
||||
def render(req : task_manager.ImageRequest):
|
||||
try:
|
||||
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
|
||||
req.use_stable_diffusion_model = model_manager.resolve_ckpt_to_use(req.use_stable_diffusion_model)
|
||||
req.use_vae_model = model_manager.resolve_vae_to_use(req.use_vae_model)
|
||||
req.use_hypernetwork_model = model_manager.resolve_hypernetwork_to_use(req.use_hypernetwork_model)
|
||||
req.use_stable_diffusion_model = model_manager.resolve_sd_model_to_use(req.use_stable_diffusion_model)
|
||||
req.use_vae_model = model_manager.resolve_vae_model_to_use(req.use_vae_model)
|
||||
req.use_hypernetwork_model = model_manager.resolve_hypernetwork_model_to_use(req.use_hypernetwork_model)
|
||||
|
||||
new_task = task_manager.render(req)
|
||||
response = {
|
||||
|
Loading…
Reference in New Issue
Block a user