forked from extern/easydiffusion
Get rid of the ugly copying around (and maintaining) of multiple request-related fields. Split into two objects: task-related fields, and render-related fields. Also remove the ability for request-defined full-precision. Full-precision can now be forced by using a USE_FULL_PRECISION environment variable
This commit is contained in:
parent
d03eed3859
commit
6ce6dc3ff6
@ -37,7 +37,6 @@ const SETTINGS_IDS_LIST = [
|
||||
"diskPath",
|
||||
"sound_toggle",
|
||||
"turbo",
|
||||
"use_full_precision",
|
||||
"confirm_dangerous_actions",
|
||||
"auto_save_settings",
|
||||
"apply_color_correction"
|
||||
@ -278,7 +277,6 @@ function tryLoadOldSettings() {
|
||||
"soundEnabled": "sound_toggle",
|
||||
"saveToDisk": "save_to_disk",
|
||||
"useCPU": "use_cpu",
|
||||
"useFullPrecision": "use_full_precision",
|
||||
"useTurboMode": "turbo",
|
||||
"diskPath": "diskPath",
|
||||
"useFaceCorrection": "use_face_correction",
|
||||
|
@ -217,13 +217,6 @@ const TASK_MAPPING = {
|
||||
readUI: () => turboField.checked,
|
||||
parse: (val) => Boolean(val)
|
||||
},
|
||||
use_full_precision: { name: 'Use Full Precision',
|
||||
setUI: (use_full_precision) => {
|
||||
useFullPrecisionField.checked = use_full_precision
|
||||
},
|
||||
readUI: () => useFullPrecisionField.checked,
|
||||
parse: (val) => Boolean(val)
|
||||
},
|
||||
|
||||
stream_image_progress: { name: 'Stream Image Progress',
|
||||
setUI: (stream_image_progress) => {
|
||||
@ -453,7 +446,6 @@ document.addEventListener("dragover", dragOverHandler)
|
||||
const TASK_REQ_NO_EXPORT = [
|
||||
"use_cpu",
|
||||
"turbo",
|
||||
"use_full_precision",
|
||||
"save_to_disk_path"
|
||||
]
|
||||
const resetSettings = document.getElementById('reset-image-settings')
|
||||
|
@ -728,7 +728,6 @@
|
||||
"stream_image_progress": 'boolean',
|
||||
"show_only_filtered_image": 'boolean',
|
||||
"turbo": 'boolean',
|
||||
"use_full_precision": 'boolean',
|
||||
"output_format": 'string',
|
||||
"output_quality": 'number',
|
||||
}
|
||||
@ -744,7 +743,6 @@
|
||||
"stream_image_progress": true,
|
||||
"show_only_filtered_image": true,
|
||||
"turbo": false,
|
||||
"use_full_precision": false,
|
||||
"output_format": "png",
|
||||
"output_quality": 75,
|
||||
}
|
||||
|
@ -850,7 +850,6 @@ function getCurrentUserRequest() {
|
||||
// allow_nsfw: allowNSFWField.checked,
|
||||
turbo: turboField.checked,
|
||||
//render_device: undefined, // Set device affinity. Prefer this device, but wont activate.
|
||||
use_full_precision: useFullPrecisionField.checked,
|
||||
use_stable_diffusion_model: stableDiffusionModelField.value,
|
||||
use_vae_model: vaeModelField.value,
|
||||
stream_progress_updates: true,
|
||||
|
@ -105,14 +105,6 @@ var PARAMETERS = [
|
||||
note: "to process in parallel",
|
||||
default: false,
|
||||
},
|
||||
{
|
||||
id: "use_full_precision",
|
||||
type: ParameterType.checkbox,
|
||||
label: "Use Full Precision",
|
||||
note: "for GPU-only. warning: this will consume more VRAM",
|
||||
icon: "fa-crosshairs",
|
||||
default: false,
|
||||
},
|
||||
{
|
||||
id: "auto_save_settings",
|
||||
type: ParameterType.checkbox,
|
||||
@ -214,7 +206,6 @@ let turboField = document.querySelector('#turbo')
|
||||
let useCPUField = document.querySelector('#use_cpu')
|
||||
let autoPickGPUsField = document.querySelector('#auto_pick_gpus')
|
||||
let useGPUsField = document.querySelector('#use_gpus')
|
||||
let useFullPrecisionField = document.querySelector('#use_full_precision')
|
||||
let saveToDiskField = document.querySelector('#save_to_disk')
|
||||
let diskPathField = document.querySelector('#diskPath')
|
||||
let listenToNetworkField = document.querySelector("#listen_to_network")
|
||||
|
@ -1,93 +1,22 @@
|
||||
import json
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Request:
|
||||
from modules.types import GenerateImageRequest
|
||||
|
||||
class TaskData(BaseModel):
|
||||
request_id: str = None
|
||||
session_id: str = "session"
|
||||
prompt: str = ""
|
||||
negative_prompt: str = ""
|
||||
init_image: str = None # base64
|
||||
mask: str = None # base64
|
||||
apply_color_correction = False
|
||||
num_outputs: int = 1
|
||||
num_inference_steps: int = 50
|
||||
guidance_scale: float = 7.5
|
||||
width: int = 512
|
||||
height: int = 512
|
||||
seed: int = 42
|
||||
prompt_strength: float = 0.8
|
||||
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
||||
# allow_nsfw: bool = False
|
||||
precision: str = "autocast" # or "full"
|
||||
save_to_disk_path: str = None
|
||||
turbo: bool = True
|
||||
use_full_precision: bool = False
|
||||
use_face_correction: str = None # or "GFPGANv1.3"
|
||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||
use_stable_diffusion_model: str = "sd-v1-4"
|
||||
use_vae_model: str = None
|
||||
use_hypernetwork_model: str = None
|
||||
hypernetwork_strength: float = 1
|
||||
show_only_filtered_image: bool = False
|
||||
output_format: str = "jpeg" # or "png"
|
||||
output_quality: int = 75
|
||||
|
||||
stream_progress_updates: bool = False
|
||||
stream_image_progress: bool = False
|
||||
|
||||
def json(self):
|
||||
return {
|
||||
"request_id": self.request_id,
|
||||
"session_id": self.session_id,
|
||||
"prompt": self.prompt,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
"num_outputs": self.num_outputs,
|
||||
"num_inference_steps": self.num_inference_steps,
|
||||
"guidance_scale": self.guidance_scale,
|
||||
"width": self.width,
|
||||
"height": self.height,
|
||||
"seed": self.seed,
|
||||
"prompt_strength": self.prompt_strength,
|
||||
"sampler": self.sampler,
|
||||
"apply_color_correction": self.apply_color_correction,
|
||||
"use_face_correction": self.use_face_correction,
|
||||
"use_upscale": self.use_upscale,
|
||||
"use_stable_diffusion_model": self.use_stable_diffusion_model,
|
||||
"use_vae_model": self.use_vae_model,
|
||||
"use_hypernetwork_model": self.use_hypernetwork_model,
|
||||
"hypernetwork_strength": self.hypernetwork_strength,
|
||||
"output_format": self.output_format,
|
||||
"output_quality": self.output_quality,
|
||||
}
|
||||
|
||||
def __str__(self):
|
||||
return f'''
|
||||
session_id: {self.session_id}
|
||||
prompt: {self.prompt}
|
||||
negative_prompt: {self.negative_prompt}
|
||||
seed: {self.seed}
|
||||
num_inference_steps: {self.num_inference_steps}
|
||||
sampler: {self.sampler}
|
||||
guidance_scale: {self.guidance_scale}
|
||||
w: {self.width}
|
||||
h: {self.height}
|
||||
precision: {self.precision}
|
||||
save_to_disk_path: {self.save_to_disk_path}
|
||||
turbo: {self.turbo}
|
||||
use_full_precision: {self.use_full_precision}
|
||||
apply_color_correction: {self.apply_color_correction}
|
||||
use_face_correction: {self.use_face_correction}
|
||||
use_upscale: {self.use_upscale}
|
||||
use_stable_diffusion_model: {self.use_stable_diffusion_model}
|
||||
use_vae_model: {self.use_vae_model}
|
||||
use_hypernetwork_model: {self.use_hypernetwork_model}
|
||||
hypernetwork_strength: {self.hypernetwork_strength}
|
||||
show_only_filtered_image: {self.show_only_filtered_image}
|
||||
output_format: {self.output_format}
|
||||
output_quality: {self.output_quality}
|
||||
|
||||
stream_progress_updates: {self.stream_progress_updates}
|
||||
stream_image_progress: {self.stream_image_progress}'''
|
||||
|
||||
class Image:
|
||||
data: str # base64
|
||||
seed: int
|
||||
@ -106,17 +35,20 @@ class Image:
|
||||
}
|
||||
|
||||
class Response:
|
||||
request: Request
|
||||
render_request: GenerateImageRequest
|
||||
task_data: TaskData
|
||||
images: list
|
||||
|
||||
def __init__(self, request: Request, images: list):
|
||||
self.request = request
|
||||
def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list):
|
||||
self.render_request = render_request
|
||||
self.task_data = task_data
|
||||
self.images = images
|
||||
|
||||
def json(self):
|
||||
res = {
|
||||
"status": 'succeeded',
|
||||
"request": self.request.json(),
|
||||
"render_request": self.render_request.dict(),
|
||||
"task_data": self.task_data.dict(),
|
||||
"output": [],
|
||||
}
|
||||
|
||||
|
@ -6,6 +6,13 @@ import logging
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
'''
|
||||
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
|
||||
Otherwise the models will load at half-precision (i.e. float16).
|
||||
|
||||
Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs).
|
||||
'''
|
||||
|
||||
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
|
||||
|
||||
mem_free_threshold = 0
|
||||
@ -96,7 +103,7 @@ def device_init(context, device):
|
||||
if device == 'cpu':
|
||||
context.device = 'cpu'
|
||||
context.device_name = get_processor_name()
|
||||
context.precision = 'full'
|
||||
context.half_precision = False
|
||||
log.debug(f'Render device CPU available as {context.device_name}')
|
||||
return
|
||||
|
||||
@ -107,7 +114,7 @@ def device_init(context, device):
|
||||
if needs_to_force_full_precision(context):
|
||||
log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
|
||||
# Apply force_full_precision now before models are loaded.
|
||||
context.precision = 'full'
|
||||
context.half_precision = False
|
||||
|
||||
log.info(f'Setting {device} as active')
|
||||
torch.cuda.device(device)
|
||||
@ -115,6 +122,9 @@ def device_init(context, device):
|
||||
return
|
||||
|
||||
def needs_to_force_full_precision(context):
|
||||
if 'FORCE_FULL_PRECISION' in os.environ:
|
||||
return True
|
||||
|
||||
device_name = context.device_name.lower()
|
||||
return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
|
||||
|
||||
|
@ -3,8 +3,7 @@ import logging
|
||||
import picklescan.scanner
|
||||
import rich
|
||||
|
||||
from sd_internal import app, device_manager
|
||||
from sd_internal import Request
|
||||
from sd_internal import app
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
@ -157,19 +156,3 @@ def getModels():
|
||||
models['options']['stable-diffusion'].append('custom-model')
|
||||
|
||||
return models
|
||||
|
||||
def is_sd_model_reload_necessary(thread_data, req: Request):
|
||||
needs_model_reload = False
|
||||
if 'stable-diffusion' not in thread_data.models or \
|
||||
thread_data.model_paths['stable-diffusion'] != req.use_stable_diffusion_model or \
|
||||
thread_data.model_paths['vae'] != req.use_vae_model:
|
||||
|
||||
needs_model_reload = True
|
||||
|
||||
if thread_data.device != 'cpu':
|
||||
if (thread_data.precision == 'autocast' and req.use_full_precision) or \
|
||||
(thread_data.precision == 'full' and not req.use_full_precision and not device_manager.needs_to_force_full_precision(thread_data)):
|
||||
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
||||
needs_model_reload = True
|
||||
|
||||
return needs_model_reload
|
||||
|
@ -9,13 +9,14 @@ import traceback
|
||||
import logging
|
||||
|
||||
from sd_internal import device_manager, model_manager
|
||||
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
|
||||
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
||||
|
||||
from modules import model_loader, image_generator, image_utils, filters as image_filters
|
||||
from modules.types import Context, GenerateImageRequest
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
thread_data = threading.local()
|
||||
thread_data = Context()
|
||||
'''
|
||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||
'''
|
||||
@ -28,14 +29,7 @@ def init(device):
|
||||
'''
|
||||
thread_data.stop_processing = False
|
||||
thread_data.temp_images = {}
|
||||
|
||||
thread_data.models = {}
|
||||
thread_data.model_paths = {}
|
||||
|
||||
thread_data.device = None
|
||||
thread_data.device_name = None
|
||||
thread_data.precision = 'autocast'
|
||||
thread_data.vram_optimizations = ('TURBO', 'MOVE_MODELS')
|
||||
thread_data.partial_x_samples = None
|
||||
|
||||
device_manager.device_init(thread_data, device)
|
||||
|
||||
@ -51,16 +45,16 @@ def load_default_models():
|
||||
# load mandatory models
|
||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||
|
||||
def reload_models_if_necessary(req: Request):
|
||||
def reload_models_if_necessary(task_data: TaskData):
|
||||
model_paths_in_req = (
|
||||
('hypernetwork', req.use_hypernetwork_model),
|
||||
('gfpgan', req.use_face_correction),
|
||||
('realesrgan', req.use_upscale),
|
||||
('hypernetwork', task_data.use_hypernetwork_model),
|
||||
('gfpgan', task_data.use_face_correction),
|
||||
('realesrgan', task_data.use_upscale),
|
||||
)
|
||||
|
||||
if model_manager.is_sd_model_reload_necessary(thread_data, req):
|
||||
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
|
||||
thread_data.model_paths['vae'] = req.use_vae_model
|
||||
if thread_data.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or thread_data.model_paths.get('vae') != task_data.use_vae_model:
|
||||
thread_data.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model
|
||||
thread_data.model_paths['vae'] = task_data.use_vae_model
|
||||
|
||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||
|
||||
@ -73,10 +67,17 @@ def reload_models_if_necessary(req: Request):
|
||||
else:
|
||||
model_loader.unload_model(thread_data, model_type)
|
||||
|
||||
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
try:
|
||||
log.info(req)
|
||||
return _make_images_internal(req, data_queue, task_temp_images, step_callback)
|
||||
# resolve the model paths to use
|
||||
resolve_model_paths(task_data)
|
||||
|
||||
# convert init image to PIL.Image
|
||||
req.init_image = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
|
||||
req.init_image_mask = image_utils.base64_str_to_img(req.init_image_mask) if req.init_image_mask is not None else None
|
||||
|
||||
# generate
|
||||
return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||
except Exception as e:
|
||||
log.error(traceback.format_exc())
|
||||
|
||||
@ -86,66 +87,76 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s
|
||||
}))
|
||||
raise e
|
||||
|
||||
def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
args = req_to_args(req)
|
||||
def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||
metadata = req.dict()
|
||||
del metadata['init_image']
|
||||
del metadata['init_image_mask']
|
||||
print(metadata)
|
||||
|
||||
images, user_stopped = generate_images(args, data_queue, task_temp_images, step_callback, req.stream_image_progress)
|
||||
images = apply_color_correction(args, images, user_stopped)
|
||||
images = apply_filters(args, images, user_stopped, req.show_only_filtered_image)
|
||||
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
||||
images = apply_color_correction(req, images, user_stopped)
|
||||
images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image)
|
||||
|
||||
if req.save_to_disk_path is not None:
|
||||
out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id))
|
||||
save_images(images, out_path, metadata=req.json(), show_only_filtered_image=req.show_only_filtered_image)
|
||||
if task_data.save_to_disk_path is not None:
|
||||
out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
||||
save_images(images, out_path, metadata=metadata, show_only_filtered_image=task_data.show_only_filtered_image)
|
||||
|
||||
res = Response(req, images=construct_response(req, images))
|
||||
res = Response(req, task_data, images=construct_response(images))
|
||||
res = res.json()
|
||||
data_queue.put(json.dumps(res))
|
||||
log.info('Task completed')
|
||||
|
||||
return res
|
||||
|
||||
def generate_images(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
def resolve_model_paths(task_data: TaskData):
|
||||
task_data.use_stable_diffusion_model = model_manager.resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
|
||||
task_data.use_vae_model = model_manager.resolve_model_to_use(task_data.use_vae_model, model_type='vae')
|
||||
task_data.use_hypernetwork_model = model_manager.resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
|
||||
|
||||
if task_data.use_face_correction: task_data.use_face_correction = model_manager.resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
|
||||
if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan')
|
||||
|
||||
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
thread_data.temp_images.clear()
|
||||
|
||||
image_generator.on_image_step = make_step_callback(args, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||
|
||||
try:
|
||||
images = image_generator.make_images(context=thread_data, args=args)
|
||||
images = image_generator.make_images(context=thread_data, req=req)
|
||||
user_stopped = False
|
||||
except UserInitiatedStop:
|
||||
images = []
|
||||
user_stopped = True
|
||||
if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None:
|
||||
return images
|
||||
for i in range(args['num_outputs']):
|
||||
if thread_data.partial_x_samples is not None:
|
||||
for i in range(req.num_outputs):
|
||||
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
|
||||
|
||||
del thread_data.partial_x_samples
|
||||
thread_data.partial_x_samples = None
|
||||
finally:
|
||||
model_loader.gc(thread_data)
|
||||
|
||||
images = [(image, args['seed'] + i, False) for i, image in enumerate(images)]
|
||||
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
|
||||
|
||||
return images, user_stopped
|
||||
|
||||
def apply_color_correction(args: dict, images: list, user_stopped):
|
||||
if user_stopped or args['init_image'] is None or not args['apply_color_correction']:
|
||||
def apply_color_correction(req: GenerateImageRequest, images: list, user_stopped):
|
||||
if user_stopped or req.init_image is None or not req.apply_color_correction:
|
||||
return images
|
||||
|
||||
for i, img_info in enumerate(images):
|
||||
img, seed, filtered = img_info
|
||||
img = image_utils.apply_color_correction(orig_image=args['init_image'], image_to_correct=img)
|
||||
img = image_utils.apply_color_correction(orig_image=req.init_image, image_to_correct=img)
|
||||
images[i] = (img, seed, filtered)
|
||||
|
||||
return images
|
||||
|
||||
def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_image):
|
||||
if user_stopped or (args['use_face_correction'] is None and args['use_upscale'] is None):
|
||||
def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_filtered_image):
|
||||
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
|
||||
return images
|
||||
|
||||
filters = []
|
||||
if 'gfpgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_gfpgan)
|
||||
if 'realesrgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_realesrgan)
|
||||
if 'gfpgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_gfpgan)
|
||||
if 'realesrgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_realesrgan)
|
||||
|
||||
filtered_images = []
|
||||
for img, seed, _ in images:
|
||||
@ -188,37 +199,29 @@ def save_images(images: list, save_to_disk_path, metadata: dict, show_only_filte
|
||||
img_path += '.' + metadata['output_format']
|
||||
img.save(img_path, quality=metadata['output_quality'])
|
||||
|
||||
def construct_response(req: Request, images: list):
|
||||
def construct_response(task_data: TaskData, images: list):
|
||||
return [
|
||||
ResponseImage(
|
||||
data=image_utils.img_to_base64_str(img, req.output_format, req.output_quality),
|
||||
data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality),
|
||||
seed=seed
|
||||
) for img, seed, _ in images
|
||||
]
|
||||
|
||||
def req_to_args(req: Request):
|
||||
args = req.json()
|
||||
|
||||
args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
|
||||
args['mask'] = image_utils.base64_str_to_img(req.mask) if req.mask is not None else None
|
||||
|
||||
return args
|
||||
|
||||
def make_step_callback(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
n_steps = args['num_inference_steps'] if args['init_image'] is None else int(args['num_inference_steps'] * args['prompt_strength'])
|
||||
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
||||
last_callback_time = -1
|
||||
|
||||
def update_temp_img(x_samples, task_temp_images: list):
|
||||
partial_images = []
|
||||
for i in range(args['num_outputs']):
|
||||
for i in range(req.num_outputs):
|
||||
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
|
||||
buf = image_utils.img_to_buffer(img, output_format='JPEG')
|
||||
|
||||
del img
|
||||
|
||||
thread_data.temp_images[f"{args['request_id']}/{i}"] = buf
|
||||
thread_data.temp_images[f"{task_data.request_id}/{i}"] = buf
|
||||
task_temp_images[i] = buf
|
||||
partial_images.append({'path': f"/image/tmp/{args['request_id']}/{i}"})
|
||||
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
|
||||
return partial_images
|
||||
|
||||
def on_image_step(x_samples, i):
|
||||
|
@ -14,8 +14,8 @@ import torch
|
||||
import queue, threading, time, weakref
|
||||
from typing import Any, Hashable
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sd_internal import Request, device_manager
|
||||
from sd_internal import TaskData, device_manager
|
||||
from modules.types import GenerateImageRequest
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
@ -39,9 +39,10 @@ class ServerStates:
|
||||
class Unavailable(Symbol): pass
|
||||
|
||||
class RenderTask(): # Task with output queue and completion lock.
|
||||
def __init__(self, req: Request):
|
||||
def __init__(self, req: GenerateImageRequest, task_data: TaskData):
|
||||
req.request_id = id(self)
|
||||
self.request: Request = req # Initial Request
|
||||
self.render_request: GenerateImageRequest = req # Initial Request
|
||||
self.task_data: TaskData = task_data
|
||||
self.response: Any = None # Copy of the last reponse
|
||||
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
||||
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
|
||||
@ -72,55 +73,6 @@ class RenderTask(): # Task with output queue and completion lock.
|
||||
def is_pending(self):
|
||||
return bool(not self.response and not self.error)
|
||||
|
||||
# defaults from https://huggingface.co/blog/stable_diffusion
|
||||
class ImageRequest(BaseModel):
|
||||
session_id: str = "session"
|
||||
prompt: str = ""
|
||||
negative_prompt: str = ""
|
||||
init_image: str = None # base64
|
||||
mask: str = None # base64
|
||||
apply_color_correction: bool = False
|
||||
num_outputs: int = 1
|
||||
num_inference_steps: int = 50
|
||||
guidance_scale: float = 7.5
|
||||
width: int = 512
|
||||
height: int = 512
|
||||
seed: int = 42
|
||||
prompt_strength: float = 0.8
|
||||
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
||||
# allow_nsfw: bool = False
|
||||
save_to_disk_path: str = None
|
||||
turbo: bool = True
|
||||
use_cpu: bool = False ##TODO Remove after UI and plugins transition.
|
||||
render_device: str = None # Select the task affinity. (Not used to change active devices).
|
||||
use_full_precision: bool = False
|
||||
use_face_correction: str = None # or "GFPGANv1.3"
|
||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||
use_stable_diffusion_model: str = "sd-v1-4"
|
||||
use_vae_model: str = None
|
||||
use_hypernetwork_model: str = None
|
||||
hypernetwork_strength: float = None
|
||||
show_only_filtered_image: bool = False
|
||||
output_format: str = "jpeg" # or "png"
|
||||
output_quality: int = 75
|
||||
|
||||
stream_progress_updates: bool = False
|
||||
stream_image_progress: bool = False
|
||||
|
||||
class FilterRequest(BaseModel):
|
||||
session_id: str = "session"
|
||||
model: str = None
|
||||
name: str = ""
|
||||
init_image: str = None # base64
|
||||
width: int = 512
|
||||
height: int = 512
|
||||
save_to_disk_path: str = None
|
||||
turbo: bool = True
|
||||
render_device: str = None
|
||||
use_full_precision: bool = False
|
||||
output_format: str = "jpeg" # or "png"
|
||||
output_quality: int = 75
|
||||
|
||||
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||
class DataCache():
|
||||
def __init__(self):
|
||||
@ -311,7 +263,7 @@ def thread_render(device):
|
||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||
task.buffer_queue.put(json.dumps(task.response))
|
||||
continue
|
||||
log.info(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
|
||||
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
|
||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||
try:
|
||||
def step_callback():
|
||||
@ -322,16 +274,16 @@ def thread_render(device):
|
||||
if isinstance(current_state_error, StopAsyncIteration):
|
||||
task.error = current_state_error
|
||||
current_state_error = None
|
||||
log.info(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
||||
log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}')
|
||||
|
||||
current_state = ServerStates.LoadingModel
|
||||
runtime2.reload_models_if_necessary(task.request)
|
||||
runtime2.reload_models_if_necessary(task.task_data)
|
||||
|
||||
current_state = ServerStates.Rendering
|
||||
task.response = runtime2.make_images(task.request, task.buffer_queue, task.temp_images, step_callback)
|
||||
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
||||
# Before looping back to the generator, mark cache as still alive.
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.request.session_id, TASK_TTL)
|
||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||
except Exception as e:
|
||||
task.error = e
|
||||
log.error(traceback.format_exc())
|
||||
@ -340,13 +292,13 @@ def thread_render(device):
|
||||
# Task completed
|
||||
task.lock.release()
|
||||
task_cache.keep(id(task), TASK_TTL)
|
||||
session_cache.keep(task.request.session_id, TASK_TTL)
|
||||
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||
if isinstance(task.error, StopAsyncIteration):
|
||||
log.info(f'Session {task.request.session_id} task {id(task)} cancelled!')
|
||||
log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!')
|
||||
elif task.error is not None:
|
||||
log.info(f'Session {task.request.session_id} task {id(task)} failed!')
|
||||
log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
|
||||
else:
|
||||
log.info(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
|
||||
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
|
||||
current_state = ServerStates.Online
|
||||
|
||||
def get_cached_task(task_id:str, update_ttl:bool=False):
|
||||
@ -509,53 +461,18 @@ def shutdown_event(): # Signal render thread to close on shutdown
|
||||
global current_state_error
|
||||
current_state_error = SystemExit('Application shutting down.')
|
||||
|
||||
def render(req : ImageRequest):
|
||||
def render(render_req: GenerateImageRequest, task_data: TaskData):
|
||||
current_thread_count = is_alive()
|
||||
if current_thread_count <= 0: # Render thread is dead
|
||||
raise ChildProcessError('Rendering thread has died.')
|
||||
|
||||
# Alive, check if task in cache
|
||||
session = get_cached_session(req.session_id, update_ttl=True)
|
||||
session = get_cached_session(task_data.session_id, update_ttl=True)
|
||||
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
|
||||
if current_thread_count < len(pending_tasks):
|
||||
raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
|
||||
raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
|
||||
|
||||
r = Request()
|
||||
r.session_id = req.session_id
|
||||
r.prompt = req.prompt
|
||||
r.negative_prompt = req.negative_prompt
|
||||
r.init_image = req.init_image
|
||||
r.mask = req.mask
|
||||
r.apply_color_correction = req.apply_color_correction
|
||||
r.num_outputs = req.num_outputs
|
||||
r.num_inference_steps = req.num_inference_steps
|
||||
r.guidance_scale = req.guidance_scale
|
||||
r.width = req.width
|
||||
r.height = req.height
|
||||
r.seed = req.seed
|
||||
r.prompt_strength = req.prompt_strength
|
||||
r.sampler = req.sampler
|
||||
# r.allow_nsfw = req.allow_nsfw
|
||||
r.turbo = req.turbo
|
||||
r.use_full_precision = req.use_full_precision
|
||||
r.save_to_disk_path = req.save_to_disk_path
|
||||
r.use_upscale: str = req.use_upscale
|
||||
r.use_face_correction = req.use_face_correction
|
||||
r.use_stable_diffusion_model = req.use_stable_diffusion_model
|
||||
r.use_vae_model = req.use_vae_model
|
||||
r.use_hypernetwork_model = req.use_hypernetwork_model
|
||||
r.hypernetwork_strength = req.hypernetwork_strength
|
||||
r.show_only_filtered_image = req.show_only_filtered_image
|
||||
r.output_format = req.output_format
|
||||
r.output_quality = req.output_quality
|
||||
|
||||
r.stream_progress_updates = True # the underlying implementation only supports streaming
|
||||
r.stream_image_progress = req.stream_image_progress
|
||||
|
||||
if not req.stream_progress_updates:
|
||||
r.stream_image_progress = False
|
||||
|
||||
new_task = RenderTask(r)
|
||||
new_task = RenderTask(render_req, task_data)
|
||||
if session.put(new_task, TASK_TTL):
|
||||
# Use twice the normal timeout for adding user requests.
|
||||
# Tries to force session.put to fail before tasks_queue.put would.
|
||||
|
31
ui/server.py
31
ui/server.py
@ -14,6 +14,8 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from sd_internal import app, model_manager, task_manager
|
||||
from sd_internal import TaskData
|
||||
from modules.types import GenerateImageRequest
|
||||
|
||||
log = logging.getLogger()
|
||||
|
||||
@ -22,8 +24,6 @@ log.info(f'started at {datetime.datetime.now():%x %X}')
|
||||
|
||||
server_api = FastAPI()
|
||||
|
||||
# don't show access log entries for URLs that start with the given prefix
|
||||
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
|
||||
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||
|
||||
class NoCacheStaticFiles(StaticFiles):
|
||||
@ -43,17 +43,6 @@ class SetAppConfigRequest(BaseModel):
|
||||
listen_port: int = None
|
||||
test_sd2: bool = None
|
||||
|
||||
class LogSuppressFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
path = record.getMessage()
|
||||
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
|
||||
if path.find(prefix) != -1:
|
||||
return False
|
||||
return True
|
||||
|
||||
# don't log certain requests
|
||||
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
|
||||
|
||||
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media")
|
||||
|
||||
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
|
||||
@ -137,20 +126,18 @@ def ping(session_id:str=None):
|
||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||
|
||||
@server_api.post('/render')
|
||||
def render(req : task_manager.ImageRequest):
|
||||
def render(req: dict):
|
||||
try:
|
||||
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
|
||||
# separate out the request data into rendering and task-specific data
|
||||
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
|
||||
task_data: TaskData = TaskData.parse_obj(req)
|
||||
|
||||
# resolve the model paths to use
|
||||
req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion')
|
||||
req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae')
|
||||
req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork')
|
||||
render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision
|
||||
|
||||
if req.use_face_correction: req.use_face_correction = model_manager.resolve_model_to_use(req.use_face_correction, 'gfpgan')
|
||||
if req.use_upscale: req.use_upscale = model_manager.resolve_model_to_use(req.use_upscale, 'gfpgan')
|
||||
app.save_model_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model)
|
||||
|
||||
# enqueue the task
|
||||
new_task = task_manager.render(req)
|
||||
new_task = task_manager.render(render_req, task_data)
|
||||
response = {
|
||||
'status': str(task_manager.current_state),
|
||||
'queue': len(task_manager.tasks_queue),
|
||||
|
Loading…
Reference in New Issue
Block a user