Get rid of the ugly copying around (and maintaining) of multiple request-related fields. Split into two objects: task-related fields, and render-related fields. Also remove the ability for request-defined full-precision. Full-precision can now be forced by using a USE_FULL_PRECISION environment variable

This commit is contained in:
cmdr2 2022-12-11 18:16:29 +05:30
parent d03eed3859
commit 6ce6dc3ff6
11 changed files with 115 additions and 305 deletions

View File

@ -37,7 +37,6 @@ const SETTINGS_IDS_LIST = [
"diskPath",
"sound_toggle",
"turbo",
"use_full_precision",
"confirm_dangerous_actions",
"auto_save_settings",
"apply_color_correction"
@ -278,7 +277,6 @@ function tryLoadOldSettings() {
"soundEnabled": "sound_toggle",
"saveToDisk": "save_to_disk",
"useCPU": "use_cpu",
"useFullPrecision": "use_full_precision",
"useTurboMode": "turbo",
"diskPath": "diskPath",
"useFaceCorrection": "use_face_correction",

View File

@ -217,13 +217,6 @@ const TASK_MAPPING = {
readUI: () => turboField.checked,
parse: (val) => Boolean(val)
},
use_full_precision: { name: 'Use Full Precision',
setUI: (use_full_precision) => {
useFullPrecisionField.checked = use_full_precision
},
readUI: () => useFullPrecisionField.checked,
parse: (val) => Boolean(val)
},
stream_image_progress: { name: 'Stream Image Progress',
setUI: (stream_image_progress) => {
@ -453,7 +446,6 @@ document.addEventListener("dragover", dragOverHandler)
const TASK_REQ_NO_EXPORT = [
"use_cpu",
"turbo",
"use_full_precision",
"save_to_disk_path"
]
const resetSettings = document.getElementById('reset-image-settings')

View File

@ -728,7 +728,6 @@
"stream_image_progress": 'boolean',
"show_only_filtered_image": 'boolean',
"turbo": 'boolean',
"use_full_precision": 'boolean',
"output_format": 'string',
"output_quality": 'number',
}
@ -744,7 +743,6 @@
"stream_image_progress": true,
"show_only_filtered_image": true,
"turbo": false,
"use_full_precision": false,
"output_format": "png",
"output_quality": 75,
}

View File

@ -850,7 +850,6 @@ function getCurrentUserRequest() {
// allow_nsfw: allowNSFWField.checked,
turbo: turboField.checked,
//render_device: undefined, // Set device affinity. Prefer this device, but wont activate.
use_full_precision: useFullPrecisionField.checked,
use_stable_diffusion_model: stableDiffusionModelField.value,
use_vae_model: vaeModelField.value,
stream_progress_updates: true,

View File

@ -105,14 +105,6 @@ var PARAMETERS = [
note: "to process in parallel",
default: false,
},
{
id: "use_full_precision",
type: ParameterType.checkbox,
label: "Use Full Precision",
note: "for GPU-only. warning: this will consume more VRAM",
icon: "fa-crosshairs",
default: false,
},
{
id: "auto_save_settings",
type: ParameterType.checkbox,
@ -214,7 +206,6 @@ let turboField = document.querySelector('#turbo')
let useCPUField = document.querySelector('#use_cpu')
let autoPickGPUsField = document.querySelector('#auto_pick_gpus')
let useGPUsField = document.querySelector('#use_gpus')
let useFullPrecisionField = document.querySelector('#use_full_precision')
let saveToDiskField = document.querySelector('#save_to_disk')
let diskPathField = document.querySelector('#diskPath')
let listenToNetworkField = document.querySelector("#listen_to_network")

View File

@ -1,93 +1,22 @@
import json
from pydantic import BaseModel
class Request:
from modules.types import GenerateImageRequest
class TaskData(BaseModel):
request_id: str = None
session_id: str = "session"
prompt: str = ""
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
apply_color_correction = False
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = 42
prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False
precision: str = "autocast" # or "full"
save_to_disk_path: str = None
turbo: bool = True
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
use_hypernetwork_model: str = None
hypernetwork_strength: float = 1
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
stream_progress_updates: bool = False
stream_image_progress: bool = False
def json(self):
return {
"request_id": self.request_id,
"session_id": self.session_id,
"prompt": self.prompt,
"negative_prompt": self.negative_prompt,
"num_outputs": self.num_outputs,
"num_inference_steps": self.num_inference_steps,
"guidance_scale": self.guidance_scale,
"width": self.width,
"height": self.height,
"seed": self.seed,
"prompt_strength": self.prompt_strength,
"sampler": self.sampler,
"apply_color_correction": self.apply_color_correction,
"use_face_correction": self.use_face_correction,
"use_upscale": self.use_upscale,
"use_stable_diffusion_model": self.use_stable_diffusion_model,
"use_vae_model": self.use_vae_model,
"use_hypernetwork_model": self.use_hypernetwork_model,
"hypernetwork_strength": self.hypernetwork_strength,
"output_format": self.output_format,
"output_quality": self.output_quality,
}
def __str__(self):
return f'''
session_id: {self.session_id}
prompt: {self.prompt}
negative_prompt: {self.negative_prompt}
seed: {self.seed}
num_inference_steps: {self.num_inference_steps}
sampler: {self.sampler}
guidance_scale: {self.guidance_scale}
w: {self.width}
h: {self.height}
precision: {self.precision}
save_to_disk_path: {self.save_to_disk_path}
turbo: {self.turbo}
use_full_precision: {self.use_full_precision}
apply_color_correction: {self.apply_color_correction}
use_face_correction: {self.use_face_correction}
use_upscale: {self.use_upscale}
use_stable_diffusion_model: {self.use_stable_diffusion_model}
use_vae_model: {self.use_vae_model}
use_hypernetwork_model: {self.use_hypernetwork_model}
hypernetwork_strength: {self.hypernetwork_strength}
show_only_filtered_image: {self.show_only_filtered_image}
output_format: {self.output_format}
output_quality: {self.output_quality}
stream_progress_updates: {self.stream_progress_updates}
stream_image_progress: {self.stream_image_progress}'''
class Image:
data: str # base64
seed: int
@ -106,17 +35,20 @@ class Image:
}
class Response:
request: Request
render_request: GenerateImageRequest
task_data: TaskData
images: list
def __init__(self, request: Request, images: list):
self.request = request
def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list):
self.render_request = render_request
self.task_data = task_data
self.images = images
def json(self):
res = {
"status": 'succeeded',
"request": self.request.json(),
"render_request": self.render_request.dict(),
"task_data": self.task_data.dict(),
"output": [],
}

View File

@ -6,6 +6,13 @@ import logging
log = logging.getLogger()
'''
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
Otherwise the models will load at half-precision (i.e. float16).
Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs).
'''
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
mem_free_threshold = 0
@ -96,7 +103,7 @@ def device_init(context, device):
if device == 'cpu':
context.device = 'cpu'
context.device_name = get_processor_name()
context.precision = 'full'
context.half_precision = False
log.debug(f'Render device CPU available as {context.device_name}')
return
@ -107,7 +114,7 @@ def device_init(context, device):
if needs_to_force_full_precision(context):
log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
# Apply force_full_precision now before models are loaded.
context.precision = 'full'
context.half_precision = False
log.info(f'Setting {device} as active')
torch.cuda.device(device)
@ -115,6 +122,9 @@ def device_init(context, device):
return
def needs_to_force_full_precision(context):
if 'FORCE_FULL_PRECISION' in os.environ:
return True
device_name = context.device_name.lower()
return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)

View File

@ -3,8 +3,7 @@ import logging
import picklescan.scanner
import rich
from sd_internal import app, device_manager
from sd_internal import Request
from sd_internal import app
log = logging.getLogger()
@ -157,19 +156,3 @@ def getModels():
models['options']['stable-diffusion'].append('custom-model')
return models
def is_sd_model_reload_necessary(thread_data, req: Request):
needs_model_reload = False
if 'stable-diffusion' not in thread_data.models or \
thread_data.model_paths['stable-diffusion'] != req.use_stable_diffusion_model or \
thread_data.model_paths['vae'] != req.use_vae_model:
needs_model_reload = True
if thread_data.device != 'cpu':
if (thread_data.precision == 'autocast' and req.use_full_precision) or \
(thread_data.precision == 'full' and not req.use_full_precision and not device_manager.needs_to_force_full_precision(thread_data)):
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
needs_model_reload = True
return needs_model_reload

View File

@ -9,13 +9,14 @@ import traceback
import logging
from sd_internal import device_manager, model_manager
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
from modules import model_loader, image_generator, image_utils, filters as image_filters
from modules.types import Context, GenerateImageRequest
log = logging.getLogger()
thread_data = threading.local()
thread_data = Context()
'''
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
'''
@ -28,14 +29,7 @@ def init(device):
'''
thread_data.stop_processing = False
thread_data.temp_images = {}
thread_data.models = {}
thread_data.model_paths = {}
thread_data.device = None
thread_data.device_name = None
thread_data.precision = 'autocast'
thread_data.vram_optimizations = ('TURBO', 'MOVE_MODELS')
thread_data.partial_x_samples = None
device_manager.device_init(thread_data, device)
@ -51,16 +45,16 @@ def load_default_models():
# load mandatory models
model_loader.load_model(thread_data, 'stable-diffusion')
def reload_models_if_necessary(req: Request):
def reload_models_if_necessary(task_data: TaskData):
model_paths_in_req = (
('hypernetwork', req.use_hypernetwork_model),
('gfpgan', req.use_face_correction),
('realesrgan', req.use_upscale),
('hypernetwork', task_data.use_hypernetwork_model),
('gfpgan', task_data.use_face_correction),
('realesrgan', task_data.use_upscale),
)
if model_manager.is_sd_model_reload_necessary(thread_data, req):
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
thread_data.model_paths['vae'] = req.use_vae_model
if thread_data.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or thread_data.model_paths.get('vae') != task_data.use_vae_model:
thread_data.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model
thread_data.model_paths['vae'] = task_data.use_vae_model
model_loader.load_model(thread_data, 'stable-diffusion')
@ -73,10 +67,17 @@ def reload_models_if_necessary(req: Request):
else:
model_loader.unload_model(thread_data, model_type)
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
try:
log.info(req)
return _make_images_internal(req, data_queue, task_temp_images, step_callback)
# resolve the model paths to use
resolve_model_paths(task_data)
# convert init image to PIL.Image
req.init_image = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
req.init_image_mask = image_utils.base64_str_to_img(req.init_image_mask) if req.init_image_mask is not None else None
# generate
return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
except Exception as e:
log.error(traceback.format_exc())
@ -86,66 +87,76 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s
}))
raise e
def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
args = req_to_args(req)
def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
metadata = req.dict()
del metadata['init_image']
del metadata['init_image_mask']
print(metadata)
images, user_stopped = generate_images(args, data_queue, task_temp_images, step_callback, req.stream_image_progress)
images = apply_color_correction(args, images, user_stopped)
images = apply_filters(args, images, user_stopped, req.show_only_filtered_image)
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
images = apply_color_correction(req, images, user_stopped)
images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image)
if req.save_to_disk_path is not None:
out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id))
save_images(images, out_path, metadata=req.json(), show_only_filtered_image=req.show_only_filtered_image)
if task_data.save_to_disk_path is not None:
out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
save_images(images, out_path, metadata=metadata, show_only_filtered_image=task_data.show_only_filtered_image)
res = Response(req, images=construct_response(req, images))
res = Response(req, task_data, images=construct_response(images))
res = res.json()
data_queue.put(json.dumps(res))
log.info('Task completed')
return res
def generate_images(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
def resolve_model_paths(task_data: TaskData):
task_data.use_stable_diffusion_model = model_manager.resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
task_data.use_vae_model = model_manager.resolve_model_to_use(task_data.use_vae_model, model_type='vae')
task_data.use_hypernetwork_model = model_manager.resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
if task_data.use_face_correction: task_data.use_face_correction = model_manager.resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan')
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
thread_data.temp_images.clear()
image_generator.on_image_step = make_step_callback(args, data_queue, task_temp_images, step_callback, stream_image_progress)
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
try:
images = image_generator.make_images(context=thread_data, args=args)
images = image_generator.make_images(context=thread_data, req=req)
user_stopped = False
except UserInitiatedStop:
images = []
user_stopped = True
if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None:
return images
for i in range(args['num_outputs']):
if thread_data.partial_x_samples is not None:
for i in range(req.num_outputs):
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
del thread_data.partial_x_samples
thread_data.partial_x_samples = None
finally:
model_loader.gc(thread_data)
images = [(image, args['seed'] + i, False) for i, image in enumerate(images)]
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
return images, user_stopped
def apply_color_correction(args: dict, images: list, user_stopped):
if user_stopped or args['init_image'] is None or not args['apply_color_correction']:
def apply_color_correction(req: GenerateImageRequest, images: list, user_stopped):
if user_stopped or req.init_image is None or not req.apply_color_correction:
return images
for i, img_info in enumerate(images):
img, seed, filtered = img_info
img = image_utils.apply_color_correction(orig_image=args['init_image'], image_to_correct=img)
img = image_utils.apply_color_correction(orig_image=req.init_image, image_to_correct=img)
images[i] = (img, seed, filtered)
return images
def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_image):
if user_stopped or (args['use_face_correction'] is None and args['use_upscale'] is None):
def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_filtered_image):
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
return images
filters = []
if 'gfpgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_gfpgan)
if 'realesrgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_realesrgan)
if 'gfpgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_gfpgan)
if 'realesrgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_realesrgan)
filtered_images = []
for img, seed, _ in images:
@ -188,37 +199,29 @@ def save_images(images: list, save_to_disk_path, metadata: dict, show_only_filte
img_path += '.' + metadata['output_format']
img.save(img_path, quality=metadata['output_quality'])
def construct_response(req: Request, images: list):
def construct_response(task_data: TaskData, images: list):
return [
ResponseImage(
data=image_utils.img_to_base64_str(img, req.output_format, req.output_quality),
data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality),
seed=seed
) for img, seed, _ in images
]
def req_to_args(req: Request):
args = req.json()
args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
args['mask'] = image_utils.base64_str_to_img(req.mask) if req.mask is not None else None
return args
def make_step_callback(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
n_steps = args['num_inference_steps'] if args['init_image'] is None else int(args['num_inference_steps'] * args['prompt_strength'])
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
last_callback_time = -1
def update_temp_img(x_samples, task_temp_images: list):
partial_images = []
for i in range(args['num_outputs']):
for i in range(req.num_outputs):
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
buf = image_utils.img_to_buffer(img, output_format='JPEG')
del img
thread_data.temp_images[f"{args['request_id']}/{i}"] = buf
thread_data.temp_images[f"{task_data.request_id}/{i}"] = buf
task_temp_images[i] = buf
partial_images.append({'path': f"/image/tmp/{args['request_id']}/{i}"})
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
return partial_images
def on_image_step(x_samples, i):

View File

@ -14,8 +14,8 @@ import torch
import queue, threading, time, weakref
from typing import Any, Hashable
from pydantic import BaseModel
from sd_internal import Request, device_manager
from sd_internal import TaskData, device_manager
from modules.types import GenerateImageRequest
log = logging.getLogger()
@ -39,9 +39,10 @@ class ServerStates:
class Unavailable(Symbol): pass
class RenderTask(): # Task with output queue and completion lock.
def __init__(self, req: Request):
def __init__(self, req: GenerateImageRequest, task_data: TaskData):
req.request_id = id(self)
self.request: Request = req # Initial Request
self.render_request: GenerateImageRequest = req # Initial Request
self.task_data: TaskData = task_data
self.response: Any = None # Copy of the last reponse
self.render_device = None # Select the task affinity. (Not used to change active devices).
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
@ -72,55 +73,6 @@ class RenderTask(): # Task with output queue and completion lock.
def is_pending(self):
return bool(not self.response and not self.error)
# defaults from https://huggingface.co/blog/stable_diffusion
class ImageRequest(BaseModel):
session_id: str = "session"
prompt: str = ""
negative_prompt: str = ""
init_image: str = None # base64
mask: str = None # base64
apply_color_correction: bool = False
num_outputs: int = 1
num_inference_steps: int = 50
guidance_scale: float = 7.5
width: int = 512
height: int = 512
seed: int = 42
prompt_strength: float = 0.8
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
# allow_nsfw: bool = False
save_to_disk_path: str = None
turbo: bool = True
use_cpu: bool = False ##TODO Remove after UI and plugins transition.
render_device: str = None # Select the task affinity. (Not used to change active devices).
use_full_precision: bool = False
use_face_correction: str = None # or "GFPGANv1.3"
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
use_stable_diffusion_model: str = "sd-v1-4"
use_vae_model: str = None
use_hypernetwork_model: str = None
hypernetwork_strength: float = None
show_only_filtered_image: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
stream_progress_updates: bool = False
stream_image_progress: bool = False
class FilterRequest(BaseModel):
session_id: str = "session"
model: str = None
name: str = ""
init_image: str = None # base64
width: int = 512
height: int = 512
save_to_disk_path: str = None
turbo: bool = True
render_device: str = None
use_full_precision: bool = False
output_format: str = "jpeg" # or "png"
output_quality: int = 75
# Temporary cache to allow to query tasks results for a short time after they are completed.
class DataCache():
def __init__(self):
@ -311,7 +263,7 @@ def thread_render(device):
task.response = {"status": 'failed', "detail": str(task.error)}
task.buffer_queue.put(json.dumps(task.response))
continue
log.info(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
try:
def step_callback():
@ -322,16 +274,16 @@ def thread_render(device):
if isinstance(current_state_error, StopAsyncIteration):
task.error = current_state_error
current_state_error = None
log.info(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}')
current_state = ServerStates.LoadingModel
runtime2.reload_models_if_necessary(task.request)
runtime2.reload_models_if_necessary(task.task_data)
current_state = ServerStates.Rendering
task.response = runtime2.make_images(task.request, task.buffer_queue, task.temp_images, step_callback)
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
# Before looping back to the generator, mark cache as still alive.
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL)
except Exception as e:
task.error = e
log.error(traceback.format_exc())
@ -340,13 +292,13 @@ def thread_render(device):
# Task completed
task.lock.release()
task_cache.keep(id(task), TASK_TTL)
session_cache.keep(task.request.session_id, TASK_TTL)
session_cache.keep(task.task_data.session_id, TASK_TTL)
if isinstance(task.error, StopAsyncIteration):
log.info(f'Session {task.request.session_id} task {id(task)} cancelled!')
log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!')
elif task.error is not None:
log.info(f'Session {task.request.session_id} task {id(task)} failed!')
log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
else:
log.info(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
current_state = ServerStates.Online
def get_cached_task(task_id:str, update_ttl:bool=False):
@ -509,53 +461,18 @@ def shutdown_event(): # Signal render thread to close on shutdown
global current_state_error
current_state_error = SystemExit('Application shutting down.')
def render(req : ImageRequest):
def render(render_req: GenerateImageRequest, task_data: TaskData):
current_thread_count = is_alive()
if current_thread_count <= 0: # Render thread is dead
raise ChildProcessError('Rendering thread has died.')
# Alive, check if task in cache
session = get_cached_session(req.session_id, update_ttl=True)
session = get_cached_session(task_data.session_id, update_ttl=True)
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
if current_thread_count < len(pending_tasks):
raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
r = Request()
r.session_id = req.session_id
r.prompt = req.prompt
r.negative_prompt = req.negative_prompt
r.init_image = req.init_image
r.mask = req.mask
r.apply_color_correction = req.apply_color_correction
r.num_outputs = req.num_outputs
r.num_inference_steps = req.num_inference_steps
r.guidance_scale = req.guidance_scale
r.width = req.width
r.height = req.height
r.seed = req.seed
r.prompt_strength = req.prompt_strength
r.sampler = req.sampler
# r.allow_nsfw = req.allow_nsfw
r.turbo = req.turbo
r.use_full_precision = req.use_full_precision
r.save_to_disk_path = req.save_to_disk_path
r.use_upscale: str = req.use_upscale
r.use_face_correction = req.use_face_correction
r.use_stable_diffusion_model = req.use_stable_diffusion_model
r.use_vae_model = req.use_vae_model
r.use_hypernetwork_model = req.use_hypernetwork_model
r.hypernetwork_strength = req.hypernetwork_strength
r.show_only_filtered_image = req.show_only_filtered_image
r.output_format = req.output_format
r.output_quality = req.output_quality
r.stream_progress_updates = True # the underlying implementation only supports streaming
r.stream_image_progress = req.stream_image_progress
if not req.stream_progress_updates:
r.stream_image_progress = False
new_task = RenderTask(r)
new_task = RenderTask(render_req, task_data)
if session.put(new_task, TASK_TTL):
# Use twice the normal timeout for adding user requests.
# Tries to force session.put to fail before tasks_queue.put would.

View File

@ -14,6 +14,8 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse
from pydantic import BaseModel
from sd_internal import app, model_manager, task_manager
from sd_internal import TaskData
from modules.types import GenerateImageRequest
log = logging.getLogger()
@ -22,8 +24,6 @@ log.info(f'started at {datetime.datetime.now():%x %X}')
server_api = FastAPI()
# don't show access log entries for URLs that start with the given prefix
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
class NoCacheStaticFiles(StaticFiles):
@ -43,17 +43,6 @@ class SetAppConfigRequest(BaseModel):
listen_port: int = None
test_sd2: bool = None
class LogSuppressFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
path = record.getMessage()
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
if path.find(prefix) != -1:
return False
return True
# don't log certain requests
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media")
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
@ -137,20 +126,18 @@ def ping(session_id:str=None):
return JSONResponse(response, headers=NOCACHE_HEADERS)
@server_api.post('/render')
def render(req : task_manager.ImageRequest):
def render(req: dict):
try:
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
# separate out the request data into rendering and task-specific data
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
task_data: TaskData = TaskData.parse_obj(req)
# resolve the model paths to use
req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion')
req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae')
req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork')
render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision
if req.use_face_correction: req.use_face_correction = model_manager.resolve_model_to_use(req.use_face_correction, 'gfpgan')
if req.use_upscale: req.use_upscale = model_manager.resolve_model_to_use(req.use_upscale, 'gfpgan')
app.save_model_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model)
# enqueue the task
new_task = task_manager.render(req)
new_task = task_manager.render(render_req, task_data)
response = {
'status': str(task_manager.current_state),
'queue': len(task_manager.tasks_queue),