mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-01-14 18:28:18 +01:00
Fixes for TensorRT
This commit is contained in:
parent
8538a684e7
commit
d39e1da183
@ -63,7 +63,7 @@ class RenderTask(Task):
|
|||||||
if (
|
if (
|
||||||
runtime.set_vram_optimizations(context)
|
runtime.set_vram_optimizations(context)
|
||||||
or self.has_param_changed(context, "clip_skip")
|
or self.has_param_changed(context, "clip_skip")
|
||||||
or self.has_param_changed(context, "convert_to_tensorrt")
|
or self.trt_needs_reload(context)
|
||||||
):
|
):
|
||||||
models_to_force_reload.append("stable-diffusion")
|
models_to_force_reload.append("stable-diffusion")
|
||||||
|
|
||||||
@ -92,6 +92,17 @@ class RenderTask(Task):
|
|||||||
new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False)
|
new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False)
|
||||||
return model["params"].get(param_name) != new_val
|
return model["params"].get(param_name) != new_val
|
||||||
|
|
||||||
|
def trt_needs_reload(self, context):
|
||||||
|
if not self.has_param_changed(context, "convert_to_tensorrt"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
model = context.models["stable-diffusion"]
|
||||||
|
pipe = model["default"]
|
||||||
|
if hasattr(pipe.unet, "_allocate_trt_buffers"): # TRT already loaded
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def make_images(
|
def make_images(
|
||||||
context,
|
context,
|
||||||
@ -148,6 +159,7 @@ def make_images_internal(
|
|||||||
context,
|
context,
|
||||||
req,
|
req,
|
||||||
task_data,
|
task_data,
|
||||||
|
models_data,
|
||||||
data_queue,
|
data_queue,
|
||||||
task_temp_images,
|
task_temp_images,
|
||||||
step_callback,
|
step_callback,
|
||||||
@ -174,6 +186,7 @@ def generate_images_internal(
|
|||||||
context,
|
context,
|
||||||
req: GenerateImageRequest,
|
req: GenerateImageRequest,
|
||||||
task_data: TaskData,
|
task_data: TaskData,
|
||||||
|
models_data: ModelsData,
|
||||||
data_queue: queue.Queue,
|
data_queue: queue.Queue,
|
||||||
task_temp_images: list,
|
task_temp_images: list,
|
||||||
step_callback,
|
step_callback,
|
||||||
@ -197,6 +210,15 @@ def generate_images_internal(
|
|||||||
if req.init_image is not None and not context.test_diffusers:
|
if req.init_image is not None and not context.test_diffusers:
|
||||||
req.sampler_name = "ddim"
|
req.sampler_name = "ddim"
|
||||||
|
|
||||||
|
if context.test_diffusers:
|
||||||
|
pipe = context.models["stable-diffusion"]["default"]
|
||||||
|
if hasattr(pipe.unet, "_allocate_trt_buffers"):
|
||||||
|
convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False)
|
||||||
|
pipe.unet.forward = pipe.unet._trt_forward if convert_to_trt else pipe.unet._non_trt_forward
|
||||||
|
# pipe.vae.decoder.forward = (
|
||||||
|
# pipe.vae.decoder._trt_forward if convert_to_trt else pipe.vae.decoder._non_trt_forward
|
||||||
|
# )
|
||||||
|
|
||||||
images = generate_images(context, callback=callback, **req.dict())
|
images = generate_images(context, callback=callback, **req.dict())
|
||||||
user_stopped = False
|
user_stopped = False
|
||||||
except UserInitiatedStop:
|
except UserInitiatedStop:
|
||||||
|
@ -148,7 +148,7 @@
|
|||||||
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a>
|
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a>
|
||||||
</td></tr>
|
</td></tr>
|
||||||
<tr class="pl-5 displayNone" id="enable_trt_config">
|
<tr class="pl-5 displayNone" id="enable_trt_config">
|
||||||
<td><label for="convert_to_tensorrt">Convert to TensorRT:</label></td>
|
<td><label for="convert_to_tensorrt">Enable TensorRT:</label></td>
|
||||||
<td class="diffusers-restart-needed">
|
<td class="diffusers-restart-needed">
|
||||||
<input id="convert_to_tensorrt" name="convert_to_tensorrt" type="checkbox">
|
<input id="convert_to_tensorrt" name="convert_to_tensorrt" type="checkbox">
|
||||||
<a href="https://github.com/easydiffusion/easydiffusion/wiki/TensorRT" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about TensorRT</span></i></a>
|
<a href="https://github.com/easydiffusion/easydiffusion/wiki/TensorRT" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about TensorRT</span></i></a>
|
||||||
|
@ -1346,9 +1346,19 @@ function createTask(task) {
|
|||||||
|
|
||||||
function getCurrentUserRequest() {
|
function getCurrentUserRequest() {
|
||||||
const numOutputsTotal = parseInt(numOutputsTotalField.value)
|
const numOutputsTotal = parseInt(numOutputsTotalField.value)
|
||||||
const numOutputsParallel = parseInt(numOutputsParallelField.value)
|
let numOutputsParallel = parseInt(numOutputsParallelField.value)
|
||||||
const seed = randomSeedField.checked ? Math.floor(Math.random() * (2 ** 32 - 1)) : parseInt(seedField.value)
|
const seed = randomSeedField.checked ? Math.floor(Math.random() * (2 ** 32 - 1)) : parseInt(seedField.value)
|
||||||
|
|
||||||
|
if (
|
||||||
|
testDiffusers.checked &&
|
||||||
|
document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall" &&
|
||||||
|
document.querySelector("#convert_to_tensorrt").checked
|
||||||
|
) {
|
||||||
|
// TRT enabled
|
||||||
|
|
||||||
|
numOutputsParallel = 1 // force 1 parallel
|
||||||
|
}
|
||||||
|
|
||||||
const newTask = {
|
const newTask = {
|
||||||
batchesDone: 0,
|
batchesDone: 0,
|
||||||
numOutputsTotal: numOutputsTotal,
|
numOutputsTotal: numOutputsTotal,
|
||||||
|
Loading…
Reference in New Issue
Block a user