Get rid of the ugly copying around (and maintaining) of multiple request-related fields. Split into two objects: task-related fields, and render-related fields. Also remove the ability for request-defined full-precision. Full-precision can now be forced by using a USE_FULL_PRECISION environment variable

This commit is contained in:
cmdr2 2022-12-11 18:16:29 +05:30
parent d03eed3859
commit 6ce6dc3ff6
11 changed files with 115 additions and 305 deletions

View File

@ -37,7 +37,6 @@ const SETTINGS_IDS_LIST = [
"diskPath", "diskPath",
"sound_toggle", "sound_toggle",
"turbo", "turbo",
"use_full_precision",
"confirm_dangerous_actions", "confirm_dangerous_actions",
"auto_save_settings", "auto_save_settings",
"apply_color_correction" "apply_color_correction"
@ -278,7 +277,6 @@ function tryLoadOldSettings() {
"soundEnabled": "sound_toggle", "soundEnabled": "sound_toggle",
"saveToDisk": "save_to_disk", "saveToDisk": "save_to_disk",
"useCPU": "use_cpu", "useCPU": "use_cpu",
"useFullPrecision": "use_full_precision",
"useTurboMode": "turbo", "useTurboMode": "turbo",
"diskPath": "diskPath", "diskPath": "diskPath",
"useFaceCorrection": "use_face_correction", "useFaceCorrection": "use_face_correction",

View File

@ -217,13 +217,6 @@ const TASK_MAPPING = {
readUI: () => turboField.checked, readUI: () => turboField.checked,
parse: (val) => Boolean(val) parse: (val) => Boolean(val)
}, },
use_full_precision: { name: 'Use Full Precision',
setUI: (use_full_precision) => {
useFullPrecisionField.checked = use_full_precision
},
readUI: () => useFullPrecisionField.checked,
parse: (val) => Boolean(val)
},
stream_image_progress: { name: 'Stream Image Progress', stream_image_progress: { name: 'Stream Image Progress',
setUI: (stream_image_progress) => { setUI: (stream_image_progress) => {
@ -453,7 +446,6 @@ document.addEventListener("dragover", dragOverHandler)
const TASK_REQ_NO_EXPORT = [ const TASK_REQ_NO_EXPORT = [
"use_cpu", "use_cpu",
"turbo", "turbo",
"use_full_precision",
"save_to_disk_path" "save_to_disk_path"
] ]
const resetSettings = document.getElementById('reset-image-settings') const resetSettings = document.getElementById('reset-image-settings')

View File

@ -728,7 +728,6 @@
"stream_image_progress": 'boolean', "stream_image_progress": 'boolean',
"show_only_filtered_image": 'boolean', "show_only_filtered_image": 'boolean',
"turbo": 'boolean', "turbo": 'boolean',
"use_full_precision": 'boolean',
"output_format": 'string', "output_format": 'string',
"output_quality": 'number', "output_quality": 'number',
} }
@ -744,7 +743,6 @@
"stream_image_progress": true, "stream_image_progress": true,
"show_only_filtered_image": true, "show_only_filtered_image": true,
"turbo": false, "turbo": false,
"use_full_precision": false,
"output_format": "png", "output_format": "png",
"output_quality": 75, "output_quality": 75,
} }

View File

@ -850,7 +850,6 @@ function getCurrentUserRequest() {
// allow_nsfw: allowNSFWField.checked, // allow_nsfw: allowNSFWField.checked,
turbo: turboField.checked, turbo: turboField.checked,
//render_device: undefined, // Set device affinity. Prefer this device, but wont activate. //render_device: undefined, // Set device affinity. Prefer this device, but wont activate.
use_full_precision: useFullPrecisionField.checked,
use_stable_diffusion_model: stableDiffusionModelField.value, use_stable_diffusion_model: stableDiffusionModelField.value,
use_vae_model: vaeModelField.value, use_vae_model: vaeModelField.value,
stream_progress_updates: true, stream_progress_updates: true,

View File

@ -105,14 +105,6 @@ var PARAMETERS = [
note: "to process in parallel", note: "to process in parallel",
default: false, default: false,
}, },
{
id: "use_full_precision",
type: ParameterType.checkbox,
label: "Use Full Precision",
note: "for GPU-only. warning: this will consume more VRAM",
icon: "fa-crosshairs",
default: false,
},
{ {
id: "auto_save_settings", id: "auto_save_settings",
type: ParameterType.checkbox, type: ParameterType.checkbox,
@ -214,7 +206,6 @@ let turboField = document.querySelector('#turbo')
let useCPUField = document.querySelector('#use_cpu') let useCPUField = document.querySelector('#use_cpu')
let autoPickGPUsField = document.querySelector('#auto_pick_gpus') let autoPickGPUsField = document.querySelector('#auto_pick_gpus')
let useGPUsField = document.querySelector('#use_gpus') let useGPUsField = document.querySelector('#use_gpus')
let useFullPrecisionField = document.querySelector('#use_full_precision')
let saveToDiskField = document.querySelector('#save_to_disk') let saveToDiskField = document.querySelector('#save_to_disk')
let diskPathField = document.querySelector('#diskPath') let diskPathField = document.querySelector('#diskPath')
let listenToNetworkField = document.querySelector("#listen_to_network") let listenToNetworkField = document.querySelector("#listen_to_network")

View File

@ -1,93 +1,22 @@
import json from pydantic import BaseModel
class Request: from modules.types import GenerateImageRequest
class TaskData(BaseModel):
request_id: str = None request_id: str = None
session_id: str = "session" session_id: str = "session"
prompt: str = ""
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
apply_color_correction = False
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = 42
prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False
precision: str = "autocast" # or "full"
save_to_disk_path: str = None save_to_disk_path: str = None
turbo: bool = True turbo: bool = True
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3" use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B" use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4" use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None use_vae_model: str = None
use_hypernetwork_model: str = None use_hypernetwork_model: str = None
hypernetwork_strength: float = 1
show_only_filtered_image: bool = False show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png" output_format: str = "jpeg" # or "png"
output_quality: int = 75 output_quality: int = 75
stream_progress_updates: bool = False
stream_image_progress: bool = False stream_image_progress: bool = False
def json(self):
return {
"request_id": self.request_id,
"session_id": self.session_id,
"prompt": self.prompt,
"negative_prompt": self.negative_prompt,
"num_outputs": self.num_outputs,
"num_inference_steps": self.num_inference_steps,
"guidance_scale": self.guidance_scale,
"width": self.width,
"height": self.height,
"seed": self.seed,
"prompt_strength": self.prompt_strength,
"sampler": self.sampler,
"apply_color_correction": self.apply_color_correction,
"use_face_correction": self.use_face_correction,
"use_upscale": self.use_upscale,
"use_stable_diffusion_model": self.use_stable_diffusion_model,
"use_vae_model": self.use_vae_model,
"use_hypernetwork_model": self.use_hypernetwork_model,
"hypernetwork_strength": self.hypernetwork_strength,
"output_format": self.output_format,
"output_quality": self.output_quality,
}
def __str__(self):
return f'''
session_id: {self.session_id}
prompt: {self.prompt}
negative_prompt: {self.negative_prompt}
seed: {self.seed}
num_inference_steps: {self.num_inference_steps}
sampler: {self.sampler}
guidance_scale: {self.guidance_scale}
w: {self.width}
h: {self.height}
precision: {self.precision}
save_to_disk_path: {self.save_to_disk_path}
turbo: {self.turbo}
use_full_precision: {self.use_full_precision}
apply_color_correction: {self.apply_color_correction}
use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale}
use_stable_diffusion_model: {self.use_stable_diffusion_model}
use_vae_model: {self.use_vae_model}
use_hypernetwork_model: {self.use_hypernetwork_model}
hypernetwork_strength: {self.hypernetwork_strength}
show_only_filtered_image: {self.show_only_filtered_image}
output_format: {self.output_format}
output_quality: {self.output_quality}
stream_progress_updates: {self.stream_progress_updates}
stream_image_progress: {self.stream_image_progress}'''
class Image: class Image:
data: str # base64 data: str # base64
seed: int seed: int
@ -106,17 +35,20 @@ class Image:
} }
class Response: class Response:
request: Request render_request: GenerateImageRequest
task_data: TaskData
images: list images: list
def __init__(self, request: Request, images: list): def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list):
self.request = request self.render_request = render_request
self.task_data = task_data
self.images = images self.images = images
def json(self): def json(self):
res = { res = {
"status": 'succeeded', "status": 'succeeded',
"request": self.request.json(), "render_request": self.render_request.dict(),
"task_data": self.task_data.dict(),
"output": [], "output": [],
} }

View File

@ -6,6 +6,13 @@ import logging
log = logging.getLogger() log = logging.getLogger()
'''
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
Otherwise the models will load at half-precision (i.e. float16).
Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs).
'''
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
mem_free_threshold = 0 mem_free_threshold = 0
@ -96,7 +103,7 @@ def device_init(context, device):
if device == 'cpu': if device == 'cpu':
context.device = 'cpu' context.device = 'cpu'
context.device_name = get_processor_name() context.device_name = get_processor_name()
context.precision = 'full' context.half_precision = False
log.debug(f'Render device CPU available as {context.device_name}') log.debug(f'Render device CPU available as {context.device_name}')
return return
@ -107,7 +114,7 @@ def device_init(context, device):
if needs_to_force_full_precision(context): if needs_to_force_full_precision(context):
log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}') log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
# Apply force_full_precision now before models are loaded. # Apply force_full_precision now before models are loaded.
context.precision = 'full' context.half_precision = False
log.info(f'Setting {device} as active') log.info(f'Setting {device} as active')
torch.cuda.device(device) torch.cuda.device(device)
@ -115,6 +122,9 @@ def device_init(context, device):
return return
def needs_to_force_full_precision(context): def needs_to_force_full_precision(context):
if 'FORCE_FULL_PRECISION' in os.environ:
return True
device_name = context.device_name.lower() device_name = context.device_name.lower()
return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name) return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)

View File

@ -3,8 +3,7 @@ import logging
import picklescan.scanner import picklescan.scanner
import rich import rich
from sd_internal import app, device_manager from sd_internal import app
from sd_internal import Request
log = logging.getLogger() log = logging.getLogger()
@ -157,19 +156,3 @@ def getModels():
models['options']['stable-diffusion'].append('custom-model') models['options']['stable-diffusion'].append('custom-model')
return models return models
def is_sd_model_reload_necessary(thread_data, req: Request):
needs_model_reload = False
if 'stable-diffusion' not in thread_data.models or \
thread_data.model_paths['stable-diffusion'] != req.use_stable_diffusion_model or \
thread_data.model_paths['vae'] != req.use_vae_model:
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and req.use_full_precision) or \
(thread_data.precision == 'full' and not req.use_full_precision and not device_manager.needs_to_force_full_precision(thread_data)):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
return needs_model_reload

View File

@ -9,13 +9,14 @@ import traceback
import logging import logging
from sd_internal import device_manager, model_manager from sd_internal import device_manager, model_manager
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
from modules import model_loader, image_generator, image_utils, filters as image_filters from modules import model_loader, image_generator, image_utils, filters as image_filters
from modules.types import Context, GenerateImageRequest
log = logging.getLogger() log = logging.getLogger()
thread_data = threading.local() thread_data = Context()
''' '''
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
''' '''
@ -28,14 +29,7 @@ def init(device):
''' '''
thread_data.stop_processing = False thread_data.stop_processing = False
thread_data.temp_images = {} thread_data.temp_images = {}
thread_data.partial_x_samples = None
thread_data.models = {}
thread_data.model_paths = {}
thread_data.device = None
thread_data.device_name = None
thread_data.precision = 'autocast'
thread_data.vram_optimizations = ('TURBO', 'MOVE_MODELS')
device_manager.device_init(thread_data, device) device_manager.device_init(thread_data, device)
@ -51,16 +45,16 @@ def load_default_models():
# load mandatory models # load mandatory models
model_loader.load_model(thread_data, 'stable-diffusion') model_loader.load_model(thread_data, 'stable-diffusion')
def reload_models_if_necessary(req: Request): def reload_models_if_necessary(task_data: TaskData):
model_paths_in_req = ( model_paths_in_req = (
('hypernetwork', req.use_hypernetwork_model), ('hypernetwork', task_data.use_hypernetwork_model),
('gfpgan', req.use_face_correction), ('gfpgan', task_data.use_face_correction),
('realesrgan', req.use_upscale), ('realesrgan', task_data.use_upscale),
) )
if model_manager.is_sd_model_reload_necessary(thread_data, req): if thread_data.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or thread_data.model_paths.get('vae') != task_data.use_vae_model:
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model thread_data.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model
thread_data.model_paths['vae'] = req.use_vae_model thread_data.model_paths['vae'] = task_data.use_vae_model
model_loader.load_model(thread_data, 'stable-diffusion') model_loader.load_model(thread_data, 'stable-diffusion')
@ -73,10 +67,17 @@ def reload_models_if_necessary(req: Request):
else: else:
model_loader.unload_model(thread_data, model_type) model_loader.unload_model(thread_data, model_type)
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
try: try:
log.info(req) # resolve the model paths to use
return _make_images_internal(req, data_queue, task_temp_images, step_callback) resolve_model_paths(task_data)
# convert init image to PIL.Image
req.init_image = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
req.init_image_mask = image_utils.base64_str_to_img(req.init_image_mask) if req.init_image_mask is not None else None
# generate
return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
except Exception as e: except Exception as e:
log.error(traceback.format_exc()) log.error(traceback.format_exc())
@ -86,66 +87,76 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s
})) }))
raise e raise e
def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback): def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
args = req_to_args(req) metadata = req.dict()
del metadata['init_image']
del metadata['init_image_mask']
print(metadata)
images, user_stopped = generate_images(args, data_queue, task_temp_images, step_callback, req.stream_image_progress) images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
images = apply_color_correction(args, images, user_stopped) images = apply_color_correction(req, images, user_stopped)
images = apply_filters(args, images, user_stopped, req.show_only_filtered_image) images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image)
if req.save_to_disk_path is not None: if task_data.save_to_disk_path is not None:
out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id)) out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
save_images(images, out_path, metadata=req.json(), show_only_filtered_image=req.show_only_filtered_image) save_images(images, out_path, metadata=metadata, show_only_filtered_image=task_data.show_only_filtered_image)
res = Response(req, images=construct_response(req, images)) res = Response(req, task_data, images=construct_response(images))
res = res.json() res = res.json()
data_queue.put(json.dumps(res)) data_queue.put(json.dumps(res))
log.info('Task completed') log.info('Task completed')
return res return res
def generate_images(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool): def resolve_model_paths(task_data: TaskData):
task_data.use_stable_diffusion_model = model_manager.resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
task_data.use_vae_model = model_manager.resolve_model_to_use(task_data.use_vae_model, model_type='vae')
task_data.use_hypernetwork_model = model_manager.resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
if task_data.use_face_correction: task_data.use_face_correction = model_manager.resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan')
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
thread_data.temp_images.clear() thread_data.temp_images.clear()
image_generator.on_image_step = make_step_callback(args, data_queue, task_temp_images, step_callback, stream_image_progress) image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
try: try:
images = image_generator.make_images(context=thread_data, args=args) images = image_generator.make_images(context=thread_data, req=req)
user_stopped = False user_stopped = False
except UserInitiatedStop: except UserInitiatedStop:
images = [] images = []
user_stopped = True user_stopped = True
if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None: if thread_data.partial_x_samples is not None:
return images for i in range(req.num_outputs):
for i in range(args['num_outputs']): images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
thread_data.partial_x_samples = None
del thread_data.partial_x_samples
finally: finally:
model_loader.gc(thread_data) model_loader.gc(thread_data)
images = [(image, args['seed'] + i, False) for i, image in enumerate(images)] images = [(image, req.seed + i, False) for i, image in enumerate(images)]
return images, user_stopped return images, user_stopped
def apply_color_correction(args: dict, images: list, user_stopped): def apply_color_correction(req: GenerateImageRequest, images: list, user_stopped):
if user_stopped or args['init_image'] is None or not args['apply_color_correction']: if user_stopped or req.init_image is None or not req.apply_color_correction:
return images return images
for i, img_info in enumerate(images): for i, img_info in enumerate(images):
img, seed, filtered = img_info img, seed, filtered = img_info
img = image_utils.apply_color_correction(orig_image=args['init_image'], image_to_correct=img) img = image_utils.apply_color_correction(orig_image=req.init_image, image_to_correct=img)
images[i] = (img, seed, filtered) images[i] = (img, seed, filtered)
return images return images
def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_image): def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_filtered_image):
if user_stopped or (args['use_face_correction'] is None and args['use_upscale'] is None): if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
return images return images
filters = [] filters = []
if 'gfpgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_gfpgan) if 'gfpgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_gfpgan)
if 'realesrgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_realesrgan) if 'realesrgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_realesrgan)
filtered_images = [] filtered_images = []
for img, seed, _ in images: for img, seed, _ in images:
@ -188,37 +199,29 @@ def save_images(images: list, save_to_disk_path, metadata: dict, show_only_filte
img_path += '.' + metadata['output_format'] img_path += '.' + metadata['output_format']
img.save(img_path, quality=metadata['output_quality']) img.save(img_path, quality=metadata['output_quality'])
def construct_response(req: Request, images: list): def construct_response(task_data: TaskData, images: list):
return [ return [
ResponseImage( ResponseImage(
data=image_utils.img_to_base64_str(img, req.output_format, req.output_quality), data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality),
seed=seed seed=seed
) for img, seed, _ in images ) for img, seed, _ in images
] ]
def req_to_args(req: Request): def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
args = req.json() n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
args['mask'] = image_utils.base64_str_to_img(req.mask) if req.mask is not None else None
return args
def make_step_callback(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
n_steps = args['num_inference_steps'] if args['init_image'] is None else int(args['num_inference_steps'] * args['prompt_strength'])
last_callback_time = -1 last_callback_time = -1
def update_temp_img(x_samples, task_temp_images: list): def update_temp_img(x_samples, task_temp_images: list):
partial_images = [] partial_images = []
for i in range(args['num_outputs']): for i in range(req.num_outputs):
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0)) img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
buf = image_utils.img_to_buffer(img, output_format='JPEG') buf = image_utils.img_to_buffer(img, output_format='JPEG')
del img del img
thread_data.temp_images[f"{args['request_id']}/{i}"] = buf thread_data.temp_images[f"{task_data.request_id}/{i}"] = buf
task_temp_images[i] = buf task_temp_images[i] = buf
partial_images.append({'path': f"/image/tmp/{args['request_id']}/{i}"}) partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
return partial_images return partial_images
def on_image_step(x_samples, i): def on_image_step(x_samples, i):

View File

@ -14,8 +14,8 @@ import torch
import queue, threading, time, weakref import queue, threading, time, weakref
from typing import Any, Hashable from typing import Any, Hashable
from pydantic import BaseModel from sd_internal import TaskData, device_manager
from sd_internal import Request, device_manager from modules.types import GenerateImageRequest
log = logging.getLogger() log = logging.getLogger()
@ -39,9 +39,10 @@ class ServerStates:
class Unavailable(Symbol): pass class Unavailable(Symbol): pass
class RenderTask(): # Task with output queue and completion lock. class RenderTask(): # Task with output queue and completion lock.
def __init__(self, req: Request): def __init__(self, req: GenerateImageRequest, task_data: TaskData):
req.request_id = id(self) req.request_id = id(self)
self.request: Request = req # Initial Request self.render_request: GenerateImageRequest = req # Initial Request
self.task_data: TaskData = task_data
self.response: Any = None # Copy of the last reponse self.response: Any = None # Copy of the last reponse
self.render_device = None # Select the task affinity. (Not used to change active devices). self.render_device = None # Select the task affinity. (Not used to change active devices).
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2) self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
@ -72,55 +73,6 @@ class RenderTask(): # Task with output queue and completion lock.
def is_pending(self): def is_pending(self):
return bool(not self.response and not self.error) return bool(not self.response and not self.error)
# defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel):
session_id: str = "session"
prompt: str = ""
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
apply_color_correction: bool = False
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = 42
prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False ##TODO Remove after UI and plugins transition.
render_device: str = None # Select the task affinity. (Not used to change active devices).
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
use_hypernetwork_model: str = None
hypernetwork_strength: float = None
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
stream_progress_updates: bool = False
stream_image_progress: bool = False
class FilterRequest(BaseModel):
session_id: str = "session"
model: str = None
name: str = ""
init_image: str = None # base64
width: int = 512
height: int = 512
save_to_disk_path: str = None
turbo: bool = True
render_device: str = None
use_full_precision: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
# Temporary cache to allow to query tasks results for a short time after they are completed. # Temporary cache to allow to query tasks results for a short time after they are completed.
class DataCache(): class DataCache():
def __init__(self): def __init__(self):
@ -311,7 +263,7 @@ def thread_render(device):
task.response = {"status": 'failed', "detail": str(task.error)} task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response)) task.buffer_queue.put(json.dumps(task.response))
continue continue
log.info(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}') log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.') if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try: try:
def step_callback(): def step_callback():
@ -322,16 +274,16 @@ def thread_render(device):
if isinstance(current_state_error, StopAsyncIteration): if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error task.error = current_state_error
current_state_error = None current_state_error = None
log.info(f'Session {task.request.session_id} sent cancel signal for task {id(task)}') log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}')
current_state = ServerStates.LoadingModel current_state = ServerStates.LoadingModel
runtime2.reload_models_if_necessary(task.request) runtime2.reload_models_if_necessary(task.task_data)
current_state = ServerStates.Rendering current_state = ServerStates.Rendering
task.response = runtime2.make_images(task.request, task.buffer_queue, task.temp_images, step_callback) task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
# Before looping back to the generator, mark cache as still alive. # Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL) task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL)
except Exception as e: except Exception as e:
task.error = e task.error = e
log.error(traceback.format_exc()) log.error(traceback.format_exc())
@ -340,13 +292,13 @@ def thread_render(device):
# Task completed # Task completed
task.lock.release() task.lock.release()
task_cache.keep(id(task), TASK_TTL) task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL) session_cache.keep(task.task_data.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration): if isinstance(task.error, StopAsyncIteration):
log.info(f'Session {task.request.session_id} task {id(task)} cancelled!') log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!')
elif task.error is not None: elif task.error is not None:
log.info(f'Session {task.request.session_id} task {id(task)} failed!') log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
else: else:
log.info(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.') log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
current_state = ServerStates.Online current_state = ServerStates.Online
def get_cached_task(task_id:str, update_ttl:bool=False): def get_cached_task(task_id:str, update_ttl:bool=False):
@ -509,53 +461,18 @@ def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error global current_state_error
current_state_error = SystemExit('Application shutting down.') current_state_error = SystemExit('Application shutting down.')
def render(req : ImageRequest): def render(render_req: GenerateImageRequest, task_data: TaskData):
current_thread_count = is_alive() current_thread_count = is_alive()
if current_thread_count <= 0: # Render thread is dead if current_thread_count <= 0: # Render thread is dead
raise ChildProcessError('Rendering thread has died.') raise ChildProcessError('Rendering thread has died.')
# Alive, check if task in cache # Alive, check if task in cache
session = get_cached_session(req.session_id, update_ttl=True) session = get_cached_session(task_data.session_id, update_ttl=True)
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks)) pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
if current_thread_count < len(pending_tasks): if current_thread_count < len(pending_tasks):
raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.') raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
r = Request() new_task = RenderTask(render_req, task_data)
r.session_id = req.session_id
r.prompt = req.prompt
r.negative_prompt = req.negative_prompt
r.init_image = req.init_image
r.mask = req.mask
r.apply_color_correction = req.apply_color_correction
r.num_outputs = req.num_outputs
r.num_inference_steps = req.num_inference_steps
r.guidance_scale = req.guidance_scale
r.width = req.width
r.height = req.height
r.seed = req.seed
r.prompt_strength = req.prompt_strength
r.sampler = req.sampler
# r.allow_nsfw = req.allow_nsfw
r.turbo = req.turbo
r.use_full_precision = req.use_full_precision
r.save_to_disk_path = req.save_to_disk_path
r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction
r.use_stable_diffusion_model = req.use_stable_diffusion_model
r.use_vae_model = req.use_vae_model
r.use_hypernetwork_model = req.use_hypernetwork_model
r.hypernetwork_strength = req.hypernetwork_strength
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.output_quality = req.output_quality
r.stream_progress_updates = True # the underlying implementation only supports streaming
r.stream_image_progress = req.stream_image_progress
if not req.stream_progress_updates:
r.stream_image_progress = False
new_task = RenderTask(r)
if session.put(new_task, TASK_TTL): if session.put(new_task, TASK_TTL):
# Use twice the normal timeout for adding user requests. # Use twice the normal timeout for adding user requests.
# Tries to force session.put to fail before tasks_queue.put would. # Tries to force session.put to fail before tasks_queue.put would.

View File

@ -14,6 +14,8 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from sd_internal import app, model_manager, task_manager from sd_internal import app, model_manager, task_manager
from sd_internal import TaskData
from modules.types import GenerateImageRequest
log = logging.getLogger() log = logging.getLogger()
@ -22,8 +24,6 @@ log.info(f'started at {datetime.datetime.now():%x %X}')
server_api = FastAPI() server_api = FastAPI()
# don't show access log entries for URLs that start with the given prefix
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"} NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
class NoCacheStaticFiles(StaticFiles): class NoCacheStaticFiles(StaticFiles):
@ -43,17 +43,6 @@ class SetAppConfigRequest(BaseModel):
listen_port: int = None listen_port: int = None
test_sd2: bool = None test_sd2: bool = None
class LogSuppressFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
path = record.getMessage()
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
if path.find(prefix) != -1:
return False
return True
# don't log certain requests
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media") server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media")
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES: for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
@ -137,20 +126,18 @@ def ping(session_id:str=None):
return JSONResponse(response, headers=NOCACHE_HEADERS) return JSONResponse(response, headers=NOCACHE_HEADERS)
@server_api.post('/render') @server_api.post('/render')
def render(req : task_manager.ImageRequest): def render(req: dict):
try: try:
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model) # separate out the request data into rendering and task-specific data
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
task_data: TaskData = TaskData.parse_obj(req)
# resolve the model paths to use render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision
req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion')
req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae')
req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork')
if req.use_face_correction: req.use_face_correction = model_manager.resolve_model_to_use(req.use_face_correction, 'gfpgan') app.save_model_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model)
if req.use_upscale: req.use_upscale = model_manager.resolve_model_to_use(req.use_upscale, 'gfpgan')
# enqueue the task # enqueue the task
new_task = task_manager.render(req) new_task = task_manager.render(render_req, task_data)
response = { response = {
'status': str(task_manager.current_state), 'status': str(task_manager.current_state),
'queue': len(task_manager.tasks_queue), 'queue': len(task_manager.tasks_queue),