forked from extern/easydiffusion
Get rid of the ugly copying around (and maintaining) of multiple request-related fields. Split into two objects: task-related fields, and render-related fields. Also remove the ability for request-defined full-precision. Full-precision can now be forced by using a USE_FULL_PRECISION environment variable
This commit is contained in:
parent
d03eed3859
commit
6ce6dc3ff6
@ -37,7 +37,6 @@ const SETTINGS_IDS_LIST = [
|
|||||||
"diskPath",
|
"diskPath",
|
||||||
"sound_toggle",
|
"sound_toggle",
|
||||||
"turbo",
|
"turbo",
|
||||||
"use_full_precision",
|
|
||||||
"confirm_dangerous_actions",
|
"confirm_dangerous_actions",
|
||||||
"auto_save_settings",
|
"auto_save_settings",
|
||||||
"apply_color_correction"
|
"apply_color_correction"
|
||||||
@ -278,7 +277,6 @@ function tryLoadOldSettings() {
|
|||||||
"soundEnabled": "sound_toggle",
|
"soundEnabled": "sound_toggle",
|
||||||
"saveToDisk": "save_to_disk",
|
"saveToDisk": "save_to_disk",
|
||||||
"useCPU": "use_cpu",
|
"useCPU": "use_cpu",
|
||||||
"useFullPrecision": "use_full_precision",
|
|
||||||
"useTurboMode": "turbo",
|
"useTurboMode": "turbo",
|
||||||
"diskPath": "diskPath",
|
"diskPath": "diskPath",
|
||||||
"useFaceCorrection": "use_face_correction",
|
"useFaceCorrection": "use_face_correction",
|
||||||
|
@ -217,13 +217,6 @@ const TASK_MAPPING = {
|
|||||||
readUI: () => turboField.checked,
|
readUI: () => turboField.checked,
|
||||||
parse: (val) => Boolean(val)
|
parse: (val) => Boolean(val)
|
||||||
},
|
},
|
||||||
use_full_precision: { name: 'Use Full Precision',
|
|
||||||
setUI: (use_full_precision) => {
|
|
||||||
useFullPrecisionField.checked = use_full_precision
|
|
||||||
},
|
|
||||||
readUI: () => useFullPrecisionField.checked,
|
|
||||||
parse: (val) => Boolean(val)
|
|
||||||
},
|
|
||||||
|
|
||||||
stream_image_progress: { name: 'Stream Image Progress',
|
stream_image_progress: { name: 'Stream Image Progress',
|
||||||
setUI: (stream_image_progress) => {
|
setUI: (stream_image_progress) => {
|
||||||
@ -453,7 +446,6 @@ document.addEventListener("dragover", dragOverHandler)
|
|||||||
const TASK_REQ_NO_EXPORT = [
|
const TASK_REQ_NO_EXPORT = [
|
||||||
"use_cpu",
|
"use_cpu",
|
||||||
"turbo",
|
"turbo",
|
||||||
"use_full_precision",
|
|
||||||
"save_to_disk_path"
|
"save_to_disk_path"
|
||||||
]
|
]
|
||||||
const resetSettings = document.getElementById('reset-image-settings')
|
const resetSettings = document.getElementById('reset-image-settings')
|
||||||
|
@ -728,7 +728,6 @@
|
|||||||
"stream_image_progress": 'boolean',
|
"stream_image_progress": 'boolean',
|
||||||
"show_only_filtered_image": 'boolean',
|
"show_only_filtered_image": 'boolean',
|
||||||
"turbo": 'boolean',
|
"turbo": 'boolean',
|
||||||
"use_full_precision": 'boolean',
|
|
||||||
"output_format": 'string',
|
"output_format": 'string',
|
||||||
"output_quality": 'number',
|
"output_quality": 'number',
|
||||||
}
|
}
|
||||||
@ -744,7 +743,6 @@
|
|||||||
"stream_image_progress": true,
|
"stream_image_progress": true,
|
||||||
"show_only_filtered_image": true,
|
"show_only_filtered_image": true,
|
||||||
"turbo": false,
|
"turbo": false,
|
||||||
"use_full_precision": false,
|
|
||||||
"output_format": "png",
|
"output_format": "png",
|
||||||
"output_quality": 75,
|
"output_quality": 75,
|
||||||
}
|
}
|
||||||
|
@ -850,7 +850,6 @@ function getCurrentUserRequest() {
|
|||||||
// allow_nsfw: allowNSFWField.checked,
|
// allow_nsfw: allowNSFWField.checked,
|
||||||
turbo: turboField.checked,
|
turbo: turboField.checked,
|
||||||
//render_device: undefined, // Set device affinity. Prefer this device, but wont activate.
|
//render_device: undefined, // Set device affinity. Prefer this device, but wont activate.
|
||||||
use_full_precision: useFullPrecisionField.checked,
|
|
||||||
use_stable_diffusion_model: stableDiffusionModelField.value,
|
use_stable_diffusion_model: stableDiffusionModelField.value,
|
||||||
use_vae_model: vaeModelField.value,
|
use_vae_model: vaeModelField.value,
|
||||||
stream_progress_updates: true,
|
stream_progress_updates: true,
|
||||||
|
@ -105,14 +105,6 @@ var PARAMETERS = [
|
|||||||
note: "to process in parallel",
|
note: "to process in parallel",
|
||||||
default: false,
|
default: false,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
id: "use_full_precision",
|
|
||||||
type: ParameterType.checkbox,
|
|
||||||
label: "Use Full Precision",
|
|
||||||
note: "for GPU-only. warning: this will consume more VRAM",
|
|
||||||
icon: "fa-crosshairs",
|
|
||||||
default: false,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
id: "auto_save_settings",
|
id: "auto_save_settings",
|
||||||
type: ParameterType.checkbox,
|
type: ParameterType.checkbox,
|
||||||
@ -214,7 +206,6 @@ let turboField = document.querySelector('#turbo')
|
|||||||
let useCPUField = document.querySelector('#use_cpu')
|
let useCPUField = document.querySelector('#use_cpu')
|
||||||
let autoPickGPUsField = document.querySelector('#auto_pick_gpus')
|
let autoPickGPUsField = document.querySelector('#auto_pick_gpus')
|
||||||
let useGPUsField = document.querySelector('#use_gpus')
|
let useGPUsField = document.querySelector('#use_gpus')
|
||||||
let useFullPrecisionField = document.querySelector('#use_full_precision')
|
|
||||||
let saveToDiskField = document.querySelector('#save_to_disk')
|
let saveToDiskField = document.querySelector('#save_to_disk')
|
||||||
let diskPathField = document.querySelector('#diskPath')
|
let diskPathField = document.querySelector('#diskPath')
|
||||||
let listenToNetworkField = document.querySelector("#listen_to_network")
|
let listenToNetworkField = document.querySelector("#listen_to_network")
|
||||||
|
@ -1,93 +1,22 @@
|
|||||||
import json
|
from pydantic import BaseModel
|
||||||
|
|
||||||
class Request:
|
from modules.types import GenerateImageRequest
|
||||||
|
|
||||||
|
class TaskData(BaseModel):
|
||||||
request_id: str = None
|
request_id: str = None
|
||||||
session_id: str = "session"
|
session_id: str = "session"
|
||||||
prompt: str = ""
|
|
||||||
negative_prompt: str = ""
|
|
||||||
init_image: str = None # base64
|
|
||||||
mask: str = None # base64
|
|
||||||
apply_color_correction = False
|
|
||||||
num_outputs: int = 1
|
|
||||||
num_inference_steps: int = 50
|
|
||||||
guidance_scale: float = 7.5
|
|
||||||
width: int = 512
|
|
||||||
height: int = 512
|
|
||||||
seed: int = 42
|
|
||||||
prompt_strength: float = 0.8
|
|
||||||
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
|
||||||
# allow_nsfw: bool = False
|
|
||||||
precision: str = "autocast" # or "full"
|
|
||||||
save_to_disk_path: str = None
|
save_to_disk_path: str = None
|
||||||
turbo: bool = True
|
turbo: bool = True
|
||||||
use_full_precision: bool = False
|
|
||||||
use_face_correction: str = None # or "GFPGANv1.3"
|
use_face_correction: str = None # or "GFPGANv1.3"
|
||||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
||||||
use_stable_diffusion_model: str = "sd-v1-4"
|
use_stable_diffusion_model: str = "sd-v1-4"
|
||||||
use_vae_model: str = None
|
use_vae_model: str = None
|
||||||
use_hypernetwork_model: str = None
|
use_hypernetwork_model: str = None
|
||||||
hypernetwork_strength: float = 1
|
|
||||||
show_only_filtered_image: bool = False
|
show_only_filtered_image: bool = False
|
||||||
output_format: str = "jpeg" # or "png"
|
output_format: str = "jpeg" # or "png"
|
||||||
output_quality: int = 75
|
output_quality: int = 75
|
||||||
|
|
||||||
stream_progress_updates: bool = False
|
|
||||||
stream_image_progress: bool = False
|
stream_image_progress: bool = False
|
||||||
|
|
||||||
def json(self):
|
|
||||||
return {
|
|
||||||
"request_id": self.request_id,
|
|
||||||
"session_id": self.session_id,
|
|
||||||
"prompt": self.prompt,
|
|
||||||
"negative_prompt": self.negative_prompt,
|
|
||||||
"num_outputs": self.num_outputs,
|
|
||||||
"num_inference_steps": self.num_inference_steps,
|
|
||||||
"guidance_scale": self.guidance_scale,
|
|
||||||
"width": self.width,
|
|
||||||
"height": self.height,
|
|
||||||
"seed": self.seed,
|
|
||||||
"prompt_strength": self.prompt_strength,
|
|
||||||
"sampler": self.sampler,
|
|
||||||
"apply_color_correction": self.apply_color_correction,
|
|
||||||
"use_face_correction": self.use_face_correction,
|
|
||||||
"use_upscale": self.use_upscale,
|
|
||||||
"use_stable_diffusion_model": self.use_stable_diffusion_model,
|
|
||||||
"use_vae_model": self.use_vae_model,
|
|
||||||
"use_hypernetwork_model": self.use_hypernetwork_model,
|
|
||||||
"hypernetwork_strength": self.hypernetwork_strength,
|
|
||||||
"output_format": self.output_format,
|
|
||||||
"output_quality": self.output_quality,
|
|
||||||
}
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return f'''
|
|
||||||
session_id: {self.session_id}
|
|
||||||
prompt: {self.prompt}
|
|
||||||
negative_prompt: {self.negative_prompt}
|
|
||||||
seed: {self.seed}
|
|
||||||
num_inference_steps: {self.num_inference_steps}
|
|
||||||
sampler: {self.sampler}
|
|
||||||
guidance_scale: {self.guidance_scale}
|
|
||||||
w: {self.width}
|
|
||||||
h: {self.height}
|
|
||||||
precision: {self.precision}
|
|
||||||
save_to_disk_path: {self.save_to_disk_path}
|
|
||||||
turbo: {self.turbo}
|
|
||||||
use_full_precision: {self.use_full_precision}
|
|
||||||
apply_color_correction: {self.apply_color_correction}
|
|
||||||
use_face_correction: {self.use_face_correction}
|
|
||||||
use_upscale: {self.use_upscale}
|
|
||||||
use_stable_diffusion_model: {self.use_stable_diffusion_model}
|
|
||||||
use_vae_model: {self.use_vae_model}
|
|
||||||
use_hypernetwork_model: {self.use_hypernetwork_model}
|
|
||||||
hypernetwork_strength: {self.hypernetwork_strength}
|
|
||||||
show_only_filtered_image: {self.show_only_filtered_image}
|
|
||||||
output_format: {self.output_format}
|
|
||||||
output_quality: {self.output_quality}
|
|
||||||
|
|
||||||
stream_progress_updates: {self.stream_progress_updates}
|
|
||||||
stream_image_progress: {self.stream_image_progress}'''
|
|
||||||
|
|
||||||
class Image:
|
class Image:
|
||||||
data: str # base64
|
data: str # base64
|
||||||
seed: int
|
seed: int
|
||||||
@ -106,17 +35,20 @@ class Image:
|
|||||||
}
|
}
|
||||||
|
|
||||||
class Response:
|
class Response:
|
||||||
request: Request
|
render_request: GenerateImageRequest
|
||||||
|
task_data: TaskData
|
||||||
images: list
|
images: list
|
||||||
|
|
||||||
def __init__(self, request: Request, images: list):
|
def __init__(self, render_request: GenerateImageRequest, task_data: TaskData, images: list):
|
||||||
self.request = request
|
self.render_request = render_request
|
||||||
|
self.task_data = task_data
|
||||||
self.images = images
|
self.images = images
|
||||||
|
|
||||||
def json(self):
|
def json(self):
|
||||||
res = {
|
res = {
|
||||||
"status": 'succeeded',
|
"status": 'succeeded',
|
||||||
"request": self.request.json(),
|
"render_request": self.render_request.dict(),
|
||||||
|
"task_data": self.task_data.dict(),
|
||||||
"output": [],
|
"output": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,13 @@ import logging
|
|||||||
|
|
||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
|
|
||||||
|
'''
|
||||||
|
Set `FORCE_FULL_PRECISION` in the environment variables, or in `config.bat`/`config.sh` to set full precision (i.e. float32).
|
||||||
|
Otherwise the models will load at half-precision (i.e. float16).
|
||||||
|
|
||||||
|
Half-precision is fine most of the time. Full precision is only needed for working around GPU bugs (like NVIDIA 16xx GPUs).
|
||||||
|
'''
|
||||||
|
|
||||||
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
|
COMPARABLE_GPU_PERCENTILE = 0.65 # if a GPU's free_mem is within this % of the GPU with the most free_mem, it will be picked
|
||||||
|
|
||||||
mem_free_threshold = 0
|
mem_free_threshold = 0
|
||||||
@ -96,7 +103,7 @@ def device_init(context, device):
|
|||||||
if device == 'cpu':
|
if device == 'cpu':
|
||||||
context.device = 'cpu'
|
context.device = 'cpu'
|
||||||
context.device_name = get_processor_name()
|
context.device_name = get_processor_name()
|
||||||
context.precision = 'full'
|
context.half_precision = False
|
||||||
log.debug(f'Render device CPU available as {context.device_name}')
|
log.debug(f'Render device CPU available as {context.device_name}')
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -107,7 +114,7 @@ def device_init(context, device):
|
|||||||
if needs_to_force_full_precision(context):
|
if needs_to_force_full_precision(context):
|
||||||
log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
|
log.warn(f'forcing full precision on this GPU, to avoid green images. GPU detected: {context.device_name}')
|
||||||
# Apply force_full_precision now before models are loaded.
|
# Apply force_full_precision now before models are loaded.
|
||||||
context.precision = 'full'
|
context.half_precision = False
|
||||||
|
|
||||||
log.info(f'Setting {device} as active')
|
log.info(f'Setting {device} as active')
|
||||||
torch.cuda.device(device)
|
torch.cuda.device(device)
|
||||||
@ -115,6 +122,9 @@ def device_init(context, device):
|
|||||||
return
|
return
|
||||||
|
|
||||||
def needs_to_force_full_precision(context):
|
def needs_to_force_full_precision(context):
|
||||||
|
if 'FORCE_FULL_PRECISION' in os.environ:
|
||||||
|
return True
|
||||||
|
|
||||||
device_name = context.device_name.lower()
|
device_name = context.device_name.lower()
|
||||||
return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
|
return (('nvidia' in device_name or 'geforce' in device_name) and (' 1660' in device_name or ' 1650' in device_name)) or ('Quadro T2000' in device_name)
|
||||||
|
|
||||||
|
@ -3,8 +3,7 @@ import logging
|
|||||||
import picklescan.scanner
|
import picklescan.scanner
|
||||||
import rich
|
import rich
|
||||||
|
|
||||||
from sd_internal import app, device_manager
|
from sd_internal import app
|
||||||
from sd_internal import Request
|
|
||||||
|
|
||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
|
|
||||||
@ -157,19 +156,3 @@ def getModels():
|
|||||||
models['options']['stable-diffusion'].append('custom-model')
|
models['options']['stable-diffusion'].append('custom-model')
|
||||||
|
|
||||||
return models
|
return models
|
||||||
|
|
||||||
def is_sd_model_reload_necessary(thread_data, req: Request):
|
|
||||||
needs_model_reload = False
|
|
||||||
if 'stable-diffusion' not in thread_data.models or \
|
|
||||||
thread_data.model_paths['stable-diffusion'] != req.use_stable_diffusion_model or \
|
|
||||||
thread_data.model_paths['vae'] != req.use_vae_model:
|
|
||||||
|
|
||||||
needs_model_reload = True
|
|
||||||
|
|
||||||
if thread_data.device != 'cpu':
|
|
||||||
if (thread_data.precision == 'autocast' and req.use_full_precision) or \
|
|
||||||
(thread_data.precision == 'full' and not req.use_full_precision and not device_manager.needs_to_force_full_precision(thread_data)):
|
|
||||||
thread_data.precision = 'full' if req.use_full_precision else 'autocast'
|
|
||||||
needs_model_reload = True
|
|
||||||
|
|
||||||
return needs_model_reload
|
|
||||||
|
@ -9,13 +9,14 @@ import traceback
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from sd_internal import device_manager, model_manager
|
from sd_internal import device_manager, model_manager
|
||||||
from sd_internal import Request, Response, Image as ResponseImage, UserInitiatedStop
|
from sd_internal import TaskData, Response, Image as ResponseImage, UserInitiatedStop
|
||||||
|
|
||||||
from modules import model_loader, image_generator, image_utils, filters as image_filters
|
from modules import model_loader, image_generator, image_utils, filters as image_filters
|
||||||
|
from modules.types import Context, GenerateImageRequest
|
||||||
|
|
||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
|
|
||||||
thread_data = threading.local()
|
thread_data = Context()
|
||||||
'''
|
'''
|
||||||
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
runtime data (bound locally to this thread), for e.g. device, references to loaded models, optimization flags etc
|
||||||
'''
|
'''
|
||||||
@ -28,14 +29,7 @@ def init(device):
|
|||||||
'''
|
'''
|
||||||
thread_data.stop_processing = False
|
thread_data.stop_processing = False
|
||||||
thread_data.temp_images = {}
|
thread_data.temp_images = {}
|
||||||
|
thread_data.partial_x_samples = None
|
||||||
thread_data.models = {}
|
|
||||||
thread_data.model_paths = {}
|
|
||||||
|
|
||||||
thread_data.device = None
|
|
||||||
thread_data.device_name = None
|
|
||||||
thread_data.precision = 'autocast'
|
|
||||||
thread_data.vram_optimizations = ('TURBO', 'MOVE_MODELS')
|
|
||||||
|
|
||||||
device_manager.device_init(thread_data, device)
|
device_manager.device_init(thread_data, device)
|
||||||
|
|
||||||
@ -51,16 +45,16 @@ def load_default_models():
|
|||||||
# load mandatory models
|
# load mandatory models
|
||||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||||
|
|
||||||
def reload_models_if_necessary(req: Request):
|
def reload_models_if_necessary(task_data: TaskData):
|
||||||
model_paths_in_req = (
|
model_paths_in_req = (
|
||||||
('hypernetwork', req.use_hypernetwork_model),
|
('hypernetwork', task_data.use_hypernetwork_model),
|
||||||
('gfpgan', req.use_face_correction),
|
('gfpgan', task_data.use_face_correction),
|
||||||
('realesrgan', req.use_upscale),
|
('realesrgan', task_data.use_upscale),
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_manager.is_sd_model_reload_necessary(thread_data, req):
|
if thread_data.model_paths.get('stable-diffusion') != task_data.use_stable_diffusion_model or thread_data.model_paths.get('vae') != task_data.use_vae_model:
|
||||||
thread_data.model_paths['stable-diffusion'] = req.use_stable_diffusion_model
|
thread_data.model_paths['stable-diffusion'] = task_data.use_stable_diffusion_model
|
||||||
thread_data.model_paths['vae'] = req.use_vae_model
|
thread_data.model_paths['vae'] = task_data.use_vae_model
|
||||||
|
|
||||||
model_loader.load_model(thread_data, 'stable-diffusion')
|
model_loader.load_model(thread_data, 'stable-diffusion')
|
||||||
|
|
||||||
@ -73,10 +67,17 @@ def reload_models_if_necessary(req: Request):
|
|||||||
else:
|
else:
|
||||||
model_loader.unload_model(thread_data, model_type)
|
model_loader.unload_model(thread_data, model_type)
|
||||||
|
|
||||||
def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def make_images(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
try:
|
try:
|
||||||
log.info(req)
|
# resolve the model paths to use
|
||||||
return _make_images_internal(req, data_queue, task_temp_images, step_callback)
|
resolve_model_paths(task_data)
|
||||||
|
|
||||||
|
# convert init image to PIL.Image
|
||||||
|
req.init_image = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
|
||||||
|
req.init_image_mask = image_utils.base64_str_to_img(req.init_image_mask) if req.init_image_mask is not None else None
|
||||||
|
|
||||||
|
# generate
|
||||||
|
return _make_images_internal(req, task_data, data_queue, task_temp_images, step_callback)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
|
|
||||||
@ -86,66 +87,76 @@ def make_images(req: Request, data_queue: queue.Queue, task_temp_images: list, s
|
|||||||
}))
|
}))
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _make_images_internal(req: Request, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
def _make_images_internal(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback):
|
||||||
args = req_to_args(req)
|
metadata = req.dict()
|
||||||
|
del metadata['init_image']
|
||||||
|
del metadata['init_image_mask']
|
||||||
|
print(metadata)
|
||||||
|
|
||||||
images, user_stopped = generate_images(args, data_queue, task_temp_images, step_callback, req.stream_image_progress)
|
images, user_stopped = generate_images(req, data_queue, task_temp_images, step_callback, task_data.stream_image_progress)
|
||||||
images = apply_color_correction(args, images, user_stopped)
|
images = apply_color_correction(req, images, user_stopped)
|
||||||
images = apply_filters(args, images, user_stopped, req.show_only_filtered_image)
|
images = apply_filters(task_data, images, user_stopped, task_data.show_only_filtered_image)
|
||||||
|
|
||||||
if req.save_to_disk_path is not None:
|
if task_data.save_to_disk_path is not None:
|
||||||
out_path = os.path.join(req.save_to_disk_path, filename_regex.sub('_', req.session_id))
|
out_path = os.path.join(task_data.save_to_disk_path, filename_regex.sub('_', task_data.session_id))
|
||||||
save_images(images, out_path, metadata=req.json(), show_only_filtered_image=req.show_only_filtered_image)
|
save_images(images, out_path, metadata=metadata, show_only_filtered_image=task_data.show_only_filtered_image)
|
||||||
|
|
||||||
res = Response(req, images=construct_response(req, images))
|
res = Response(req, task_data, images=construct_response(images))
|
||||||
res = res.json()
|
res = res.json()
|
||||||
data_queue.put(json.dumps(res))
|
data_queue.put(json.dumps(res))
|
||||||
log.info('Task completed')
|
log.info('Task completed')
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def generate_images(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
def resolve_model_paths(task_data: TaskData):
|
||||||
|
task_data.use_stable_diffusion_model = model_manager.resolve_model_to_use(task_data.use_stable_diffusion_model, model_type='stable-diffusion')
|
||||||
|
task_data.use_vae_model = model_manager.resolve_model_to_use(task_data.use_vae_model, model_type='vae')
|
||||||
|
task_data.use_hypernetwork_model = model_manager.resolve_model_to_use(task_data.use_hypernetwork_model, model_type='hypernetwork')
|
||||||
|
|
||||||
|
if task_data.use_face_correction: task_data.use_face_correction = model_manager.resolve_model_to_use(task_data.use_face_correction, 'gfpgan')
|
||||||
|
if task_data.use_upscale: task_data.use_upscale = model_manager.resolve_model_to_use(task_data.use_upscale, 'gfpgan')
|
||||||
|
|
||||||
|
def generate_images(req: GenerateImageRequest, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||||
thread_data.temp_images.clear()
|
thread_data.temp_images.clear()
|
||||||
|
|
||||||
image_generator.on_image_step = make_step_callback(args, data_queue, task_temp_images, step_callback, stream_image_progress)
|
image_generator.on_image_step = make_step_callback(req, data_queue, task_temp_images, step_callback, stream_image_progress)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
images = image_generator.make_images(context=thread_data, args=args)
|
images = image_generator.make_images(context=thread_data, req=req)
|
||||||
user_stopped = False
|
user_stopped = False
|
||||||
except UserInitiatedStop:
|
except UserInitiatedStop:
|
||||||
images = []
|
images = []
|
||||||
user_stopped = True
|
user_stopped = True
|
||||||
if not hasattr(thread_data, 'partial_x_samples') or thread_data.partial_x_samples is None:
|
if thread_data.partial_x_samples is not None:
|
||||||
return images
|
for i in range(req.num_outputs):
|
||||||
for i in range(args['num_outputs']):
|
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
|
||||||
images[i] = image_utils.latent_to_img(thread_data, thread_data.partial_x_samples[i].unsqueeze(0))
|
|
||||||
|
thread_data.partial_x_samples = None
|
||||||
del thread_data.partial_x_samples
|
|
||||||
finally:
|
finally:
|
||||||
model_loader.gc(thread_data)
|
model_loader.gc(thread_data)
|
||||||
|
|
||||||
images = [(image, args['seed'] + i, False) for i, image in enumerate(images)]
|
images = [(image, req.seed + i, False) for i, image in enumerate(images)]
|
||||||
|
|
||||||
return images, user_stopped
|
return images, user_stopped
|
||||||
|
|
||||||
def apply_color_correction(args: dict, images: list, user_stopped):
|
def apply_color_correction(req: GenerateImageRequest, images: list, user_stopped):
|
||||||
if user_stopped or args['init_image'] is None or not args['apply_color_correction']:
|
if user_stopped or req.init_image is None or not req.apply_color_correction:
|
||||||
return images
|
return images
|
||||||
|
|
||||||
for i, img_info in enumerate(images):
|
for i, img_info in enumerate(images):
|
||||||
img, seed, filtered = img_info
|
img, seed, filtered = img_info
|
||||||
img = image_utils.apply_color_correction(orig_image=args['init_image'], image_to_correct=img)
|
img = image_utils.apply_color_correction(orig_image=req.init_image, image_to_correct=img)
|
||||||
images[i] = (img, seed, filtered)
|
images[i] = (img, seed, filtered)
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
def apply_filters(args: dict, images: list, user_stopped, show_only_filtered_image):
|
def apply_filters(task_data: TaskData, images: list, user_stopped, show_only_filtered_image):
|
||||||
if user_stopped or (args['use_face_correction'] is None and args['use_upscale'] is None):
|
if user_stopped or (task_data.use_face_correction is None and task_data.use_upscale is None):
|
||||||
return images
|
return images
|
||||||
|
|
||||||
filters = []
|
filters = []
|
||||||
if 'gfpgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_gfpgan)
|
if 'gfpgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_gfpgan)
|
||||||
if 'realesrgan' in args['use_face_correction'].lower(): filters.append(image_filters.apply_realesrgan)
|
if 'realesrgan' in task_data.use_face_correction.lower(): filters.append(image_filters.apply_realesrgan)
|
||||||
|
|
||||||
filtered_images = []
|
filtered_images = []
|
||||||
for img, seed, _ in images:
|
for img, seed, _ in images:
|
||||||
@ -188,37 +199,29 @@ def save_images(images: list, save_to_disk_path, metadata: dict, show_only_filte
|
|||||||
img_path += '.' + metadata['output_format']
|
img_path += '.' + metadata['output_format']
|
||||||
img.save(img_path, quality=metadata['output_quality'])
|
img.save(img_path, quality=metadata['output_quality'])
|
||||||
|
|
||||||
def construct_response(req: Request, images: list):
|
def construct_response(task_data: TaskData, images: list):
|
||||||
return [
|
return [
|
||||||
ResponseImage(
|
ResponseImage(
|
||||||
data=image_utils.img_to_base64_str(img, req.output_format, req.output_quality),
|
data=image_utils.img_to_base64_str(img, task_data.output_format, task_data.output_quality),
|
||||||
seed=seed
|
seed=seed
|
||||||
) for img, seed, _ in images
|
) for img, seed, _ in images
|
||||||
]
|
]
|
||||||
|
|
||||||
def req_to_args(req: Request):
|
def make_step_callback(req: GenerateImageRequest, task_data: TaskData, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
||||||
args = req.json()
|
n_steps = req.num_inference_steps if req.init_image is None else int(req.num_inference_steps * req.prompt_strength)
|
||||||
|
|
||||||
args['init_image'] = image_utils.base64_str_to_img(req.init_image) if req.init_image is not None else None
|
|
||||||
args['mask'] = image_utils.base64_str_to_img(req.mask) if req.mask is not None else None
|
|
||||||
|
|
||||||
return args
|
|
||||||
|
|
||||||
def make_step_callback(args: dict, data_queue: queue.Queue, task_temp_images: list, step_callback, stream_image_progress: bool):
|
|
||||||
n_steps = args['num_inference_steps'] if args['init_image'] is None else int(args['num_inference_steps'] * args['prompt_strength'])
|
|
||||||
last_callback_time = -1
|
last_callback_time = -1
|
||||||
|
|
||||||
def update_temp_img(x_samples, task_temp_images: list):
|
def update_temp_img(x_samples, task_temp_images: list):
|
||||||
partial_images = []
|
partial_images = []
|
||||||
for i in range(args['num_outputs']):
|
for i in range(req.num_outputs):
|
||||||
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
|
img = image_utils.latent_to_img(thread_data, x_samples[i].unsqueeze(0))
|
||||||
buf = image_utils.img_to_buffer(img, output_format='JPEG')
|
buf = image_utils.img_to_buffer(img, output_format='JPEG')
|
||||||
|
|
||||||
del img
|
del img
|
||||||
|
|
||||||
thread_data.temp_images[f"{args['request_id']}/{i}"] = buf
|
thread_data.temp_images[f"{task_data.request_id}/{i}"] = buf
|
||||||
task_temp_images[i] = buf
|
task_temp_images[i] = buf
|
||||||
partial_images.append({'path': f"/image/tmp/{args['request_id']}/{i}"})
|
partial_images.append({'path': f"/image/tmp/{task_data.request_id}/{i}"})
|
||||||
return partial_images
|
return partial_images
|
||||||
|
|
||||||
def on_image_step(x_samples, i):
|
def on_image_step(x_samples, i):
|
||||||
|
@ -14,8 +14,8 @@ import torch
|
|||||||
import queue, threading, time, weakref
|
import queue, threading, time, weakref
|
||||||
from typing import Any, Hashable
|
from typing import Any, Hashable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from sd_internal import TaskData, device_manager
|
||||||
from sd_internal import Request, device_manager
|
from modules.types import GenerateImageRequest
|
||||||
|
|
||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
|
|
||||||
@ -39,9 +39,10 @@ class ServerStates:
|
|||||||
class Unavailable(Symbol): pass
|
class Unavailable(Symbol): pass
|
||||||
|
|
||||||
class RenderTask(): # Task with output queue and completion lock.
|
class RenderTask(): # Task with output queue and completion lock.
|
||||||
def __init__(self, req: Request):
|
def __init__(self, req: GenerateImageRequest, task_data: TaskData):
|
||||||
req.request_id = id(self)
|
req.request_id = id(self)
|
||||||
self.request: Request = req # Initial Request
|
self.render_request: GenerateImageRequest = req # Initial Request
|
||||||
|
self.task_data: TaskData = task_data
|
||||||
self.response: Any = None # Copy of the last reponse
|
self.response: Any = None # Copy of the last reponse
|
||||||
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
self.render_device = None # Select the task affinity. (Not used to change active devices).
|
||||||
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
|
self.temp_images:list = [None] * req.num_outputs * (1 if req.show_only_filtered_image else 2)
|
||||||
@ -72,55 +73,6 @@ class RenderTask(): # Task with output queue and completion lock.
|
|||||||
def is_pending(self):
|
def is_pending(self):
|
||||||
return bool(not self.response and not self.error)
|
return bool(not self.response and not self.error)
|
||||||
|
|
||||||
# defaults from https://huggingface.co/blog/stable_diffusion
|
|
||||||
class ImageRequest(BaseModel):
|
|
||||||
session_id: str = "session"
|
|
||||||
prompt: str = ""
|
|
||||||
negative_prompt: str = ""
|
|
||||||
init_image: str = None # base64
|
|
||||||
mask: str = None # base64
|
|
||||||
apply_color_correction: bool = False
|
|
||||||
num_outputs: int = 1
|
|
||||||
num_inference_steps: int = 50
|
|
||||||
guidance_scale: float = 7.5
|
|
||||||
width: int = 512
|
|
||||||
height: int = 512
|
|
||||||
seed: int = 42
|
|
||||||
prompt_strength: float = 0.8
|
|
||||||
sampler: str = None # "ddim", "plms", "heun", "euler", "euler_a", "dpm2", "dpm2_a", "lms"
|
|
||||||
# allow_nsfw: bool = False
|
|
||||||
save_to_disk_path: str = None
|
|
||||||
turbo: bool = True
|
|
||||||
use_cpu: bool = False ##TODO Remove after UI and plugins transition.
|
|
||||||
render_device: str = None # Select the task affinity. (Not used to change active devices).
|
|
||||||
use_full_precision: bool = False
|
|
||||||
use_face_correction: str = None # or "GFPGANv1.3"
|
|
||||||
use_upscale: str = None # or "RealESRGAN_x4plus" or "RealESRGAN_x4plus_anime_6B"
|
|
||||||
use_stable_diffusion_model: str = "sd-v1-4"
|
|
||||||
use_vae_model: str = None
|
|
||||||
use_hypernetwork_model: str = None
|
|
||||||
hypernetwork_strength: float = None
|
|
||||||
show_only_filtered_image: bool = False
|
|
||||||
output_format: str = "jpeg" # or "png"
|
|
||||||
output_quality: int = 75
|
|
||||||
|
|
||||||
stream_progress_updates: bool = False
|
|
||||||
stream_image_progress: bool = False
|
|
||||||
|
|
||||||
class FilterRequest(BaseModel):
|
|
||||||
session_id: str = "session"
|
|
||||||
model: str = None
|
|
||||||
name: str = ""
|
|
||||||
init_image: str = None # base64
|
|
||||||
width: int = 512
|
|
||||||
height: int = 512
|
|
||||||
save_to_disk_path: str = None
|
|
||||||
turbo: bool = True
|
|
||||||
render_device: str = None
|
|
||||||
use_full_precision: bool = False
|
|
||||||
output_format: str = "jpeg" # or "png"
|
|
||||||
output_quality: int = 75
|
|
||||||
|
|
||||||
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
# Temporary cache to allow to query tasks results for a short time after they are completed.
|
||||||
class DataCache():
|
class DataCache():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -311,7 +263,7 @@ def thread_render(device):
|
|||||||
task.response = {"status": 'failed', "detail": str(task.error)}
|
task.response = {"status": 'failed', "detail": str(task.error)}
|
||||||
task.buffer_queue.put(json.dumps(task.response))
|
task.buffer_queue.put(json.dumps(task.response))
|
||||||
continue
|
continue
|
||||||
log.info(f'Session {task.request.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
|
log.info(f'Session {task.task_data.session_id} starting task {id(task)} on {runtime2.thread_data.device_name}')
|
||||||
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
if not task.lock.acquire(blocking=False): raise Exception('Got locked task from queue.')
|
||||||
try:
|
try:
|
||||||
def step_callback():
|
def step_callback():
|
||||||
@ -322,16 +274,16 @@ def thread_render(device):
|
|||||||
if isinstance(current_state_error, StopAsyncIteration):
|
if isinstance(current_state_error, StopAsyncIteration):
|
||||||
task.error = current_state_error
|
task.error = current_state_error
|
||||||
current_state_error = None
|
current_state_error = None
|
||||||
log.info(f'Session {task.request.session_id} sent cancel signal for task {id(task)}')
|
log.info(f'Session {task.task_data.session_id} sent cancel signal for task {id(task)}')
|
||||||
|
|
||||||
current_state = ServerStates.LoadingModel
|
current_state = ServerStates.LoadingModel
|
||||||
runtime2.reload_models_if_necessary(task.request)
|
runtime2.reload_models_if_necessary(task.task_data)
|
||||||
|
|
||||||
current_state = ServerStates.Rendering
|
current_state = ServerStates.Rendering
|
||||||
task.response = runtime2.make_images(task.request, task.buffer_queue, task.temp_images, step_callback)
|
task.response = runtime2.make_images(task.render_request, task.task_data, task.buffer_queue, task.temp_images, step_callback)
|
||||||
# Before looping back to the generator, mark cache as still alive.
|
# Before looping back to the generator, mark cache as still alive.
|
||||||
task_cache.keep(id(task), TASK_TTL)
|
task_cache.keep(id(task), TASK_TTL)
|
||||||
session_cache.keep(task.request.session_id, TASK_TTL)
|
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
task.error = e
|
task.error = e
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
@ -340,13 +292,13 @@ def thread_render(device):
|
|||||||
# Task completed
|
# Task completed
|
||||||
task.lock.release()
|
task.lock.release()
|
||||||
task_cache.keep(id(task), TASK_TTL)
|
task_cache.keep(id(task), TASK_TTL)
|
||||||
session_cache.keep(task.request.session_id, TASK_TTL)
|
session_cache.keep(task.task_data.session_id, TASK_TTL)
|
||||||
if isinstance(task.error, StopAsyncIteration):
|
if isinstance(task.error, StopAsyncIteration):
|
||||||
log.info(f'Session {task.request.session_id} task {id(task)} cancelled!')
|
log.info(f'Session {task.task_data.session_id} task {id(task)} cancelled!')
|
||||||
elif task.error is not None:
|
elif task.error is not None:
|
||||||
log.info(f'Session {task.request.session_id} task {id(task)} failed!')
|
log.info(f'Session {task.task_data.session_id} task {id(task)} failed!')
|
||||||
else:
|
else:
|
||||||
log.info(f'Session {task.request.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
|
log.info(f'Session {task.task_data.session_id} task {id(task)} completed by {runtime2.thread_data.device_name}.')
|
||||||
current_state = ServerStates.Online
|
current_state = ServerStates.Online
|
||||||
|
|
||||||
def get_cached_task(task_id:str, update_ttl:bool=False):
|
def get_cached_task(task_id:str, update_ttl:bool=False):
|
||||||
@ -509,53 +461,18 @@ def shutdown_event(): # Signal render thread to close on shutdown
|
|||||||
global current_state_error
|
global current_state_error
|
||||||
current_state_error = SystemExit('Application shutting down.')
|
current_state_error = SystemExit('Application shutting down.')
|
||||||
|
|
||||||
def render(req : ImageRequest):
|
def render(render_req: GenerateImageRequest, task_data: TaskData):
|
||||||
current_thread_count = is_alive()
|
current_thread_count = is_alive()
|
||||||
if current_thread_count <= 0: # Render thread is dead
|
if current_thread_count <= 0: # Render thread is dead
|
||||||
raise ChildProcessError('Rendering thread has died.')
|
raise ChildProcessError('Rendering thread has died.')
|
||||||
|
|
||||||
# Alive, check if task in cache
|
# Alive, check if task in cache
|
||||||
session = get_cached_session(req.session_id, update_ttl=True)
|
session = get_cached_session(task_data.session_id, update_ttl=True)
|
||||||
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
|
pending_tasks = list(filter(lambda t: t.is_pending, session.tasks))
|
||||||
if current_thread_count < len(pending_tasks):
|
if current_thread_count < len(pending_tasks):
|
||||||
raise ConnectionRefusedError(f'Session {req.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
|
raise ConnectionRefusedError(f'Session {task_data.session_id} already has {len(pending_tasks)} pending tasks out of {current_thread_count}.')
|
||||||
|
|
||||||
r = Request()
|
new_task = RenderTask(render_req, task_data)
|
||||||
r.session_id = req.session_id
|
|
||||||
r.prompt = req.prompt
|
|
||||||
r.negative_prompt = req.negative_prompt
|
|
||||||
r.init_image = req.init_image
|
|
||||||
r.mask = req.mask
|
|
||||||
r.apply_color_correction = req.apply_color_correction
|
|
||||||
r.num_outputs = req.num_outputs
|
|
||||||
r.num_inference_steps = req.num_inference_steps
|
|
||||||
r.guidance_scale = req.guidance_scale
|
|
||||||
r.width = req.width
|
|
||||||
r.height = req.height
|
|
||||||
r.seed = req.seed
|
|
||||||
r.prompt_strength = req.prompt_strength
|
|
||||||
r.sampler = req.sampler
|
|
||||||
# r.allow_nsfw = req.allow_nsfw
|
|
||||||
r.turbo = req.turbo
|
|
||||||
r.use_full_precision = req.use_full_precision
|
|
||||||
r.save_to_disk_path = req.save_to_disk_path
|
|
||||||
r.use_upscale: str = req.use_upscale
|
|
||||||
r.use_face_correction = req.use_face_correction
|
|
||||||
r.use_stable_diffusion_model = req.use_stable_diffusion_model
|
|
||||||
r.use_vae_model = req.use_vae_model
|
|
||||||
r.use_hypernetwork_model = req.use_hypernetwork_model
|
|
||||||
r.hypernetwork_strength = req.hypernetwork_strength
|
|
||||||
r.show_only_filtered_image = req.show_only_filtered_image
|
|
||||||
r.output_format = req.output_format
|
|
||||||
r.output_quality = req.output_quality
|
|
||||||
|
|
||||||
r.stream_progress_updates = True # the underlying implementation only supports streaming
|
|
||||||
r.stream_image_progress = req.stream_image_progress
|
|
||||||
|
|
||||||
if not req.stream_progress_updates:
|
|
||||||
r.stream_image_progress = False
|
|
||||||
|
|
||||||
new_task = RenderTask(r)
|
|
||||||
if session.put(new_task, TASK_TTL):
|
if session.put(new_task, TASK_TTL):
|
||||||
# Use twice the normal timeout for adding user requests.
|
# Use twice the normal timeout for adding user requests.
|
||||||
# Tries to force session.put to fail before tasks_queue.put would.
|
# Tries to force session.put to fail before tasks_queue.put would.
|
||||||
|
31
ui/server.py
31
ui/server.py
@ -14,6 +14,8 @@ from starlette.responses import FileResponse, JSONResponse, StreamingResponse
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from sd_internal import app, model_manager, task_manager
|
from sd_internal import app, model_manager, task_manager
|
||||||
|
from sd_internal import TaskData
|
||||||
|
from modules.types import GenerateImageRequest
|
||||||
|
|
||||||
log = logging.getLogger()
|
log = logging.getLogger()
|
||||||
|
|
||||||
@ -22,8 +24,6 @@ log.info(f'started at {datetime.datetime.now():%x %X}')
|
|||||||
|
|
||||||
server_api = FastAPI()
|
server_api = FastAPI()
|
||||||
|
|
||||||
# don't show access log entries for URLs that start with the given prefix
|
|
||||||
ACCESS_LOG_SUPPRESS_PATH_PREFIXES = ['/ping', '/image', '/modifier-thumbnails']
|
|
||||||
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
NOCACHE_HEADERS={"Cache-Control": "no-cache, no-store, must-revalidate", "Pragma": "no-cache", "Expires": "0"}
|
||||||
|
|
||||||
class NoCacheStaticFiles(StaticFiles):
|
class NoCacheStaticFiles(StaticFiles):
|
||||||
@ -43,17 +43,6 @@ class SetAppConfigRequest(BaseModel):
|
|||||||
listen_port: int = None
|
listen_port: int = None
|
||||||
test_sd2: bool = None
|
test_sd2: bool = None
|
||||||
|
|
||||||
class LogSuppressFilter(logging.Filter):
|
|
||||||
def filter(self, record: logging.LogRecord) -> bool:
|
|
||||||
path = record.getMessage()
|
|
||||||
for prefix in ACCESS_LOG_SUPPRESS_PATH_PREFIXES:
|
|
||||||
if path.find(prefix) != -1:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
# don't log certain requests
|
|
||||||
logging.getLogger('uvicorn.access').addFilter(LogSuppressFilter())
|
|
||||||
|
|
||||||
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media")
|
server_api.mount('/media', NoCacheStaticFiles(directory=os.path.join(app.SD_UI_DIR, 'media')), name="media")
|
||||||
|
|
||||||
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
|
for plugins_dir, dir_prefix in app.UI_PLUGINS_SOURCES:
|
||||||
@ -137,20 +126,18 @@ def ping(session_id:str=None):
|
|||||||
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
return JSONResponse(response, headers=NOCACHE_HEADERS)
|
||||||
|
|
||||||
@server_api.post('/render')
|
@server_api.post('/render')
|
||||||
def render(req : task_manager.ImageRequest):
|
def render(req: dict):
|
||||||
try:
|
try:
|
||||||
app.save_model_to_config(req.use_stable_diffusion_model, req.use_vae_model, req.use_hypernetwork_model)
|
# separate out the request data into rendering and task-specific data
|
||||||
|
render_req: GenerateImageRequest = GenerateImageRequest.parse_obj(req)
|
||||||
|
task_data: TaskData = TaskData.parse_obj(req)
|
||||||
|
|
||||||
# resolve the model paths to use
|
render_req.init_image_mask = req.get('mask') # hack: will rename this in the HTTP API in a future revision
|
||||||
req.use_stable_diffusion_model = model_manager.resolve_model_to_use(req.use_stable_diffusion_model, model_type='stable-diffusion')
|
|
||||||
req.use_vae_model = model_manager.resolve_model_to_use(req.use_vae_model, model_type='vae')
|
|
||||||
req.use_hypernetwork_model = model_manager.resolve_model_to_use(req.use_hypernetwork_model, model_type='hypernetwork')
|
|
||||||
|
|
||||||
if req.use_face_correction: req.use_face_correction = model_manager.resolve_model_to_use(req.use_face_correction, 'gfpgan')
|
app.save_model_to_config(task_data.use_stable_diffusion_model, task_data.use_vae_model, task_data.use_hypernetwork_model)
|
||||||
if req.use_upscale: req.use_upscale = model_manager.resolve_model_to_use(req.use_upscale, 'gfpgan')
|
|
||||||
|
|
||||||
# enqueue the task
|
# enqueue the task
|
||||||
new_task = task_manager.render(req)
|
new_task = task_manager.render(render_req, task_data)
|
||||||
response = {
|
response = {
|
||||||
'status': str(task_manager.current_state),
|
'status': str(task_manager.current_state),
|
||||||
'queue': len(task_manager.tasks_queue),
|
'queue': len(task_manager.tasks_queue),
|
||||||
|
Loading…
Reference in New Issue
Block a user