Fixes for TensorRT

This commit is contained in:
cmdr2 2023-08-01 11:49:30 +05:30
parent 8538a684e7
commit d39e1da183
3 changed files with 35 additions and 3 deletions

View File

@ -63,7 +63,7 @@ class RenderTask(Task):
if ( if (
runtime.set_vram_optimizations(context) runtime.set_vram_optimizations(context)
or self.has_param_changed(context, "clip_skip") or self.has_param_changed(context, "clip_skip")
or self.has_param_changed(context, "convert_to_tensorrt") or self.trt_needs_reload(context)
): ):
models_to_force_reload.append("stable-diffusion") models_to_force_reload.append("stable-diffusion")
@ -92,6 +92,17 @@ class RenderTask(Task):
new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False) new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False)
return model["params"].get(param_name) != new_val return model["params"].get(param_name) != new_val
def trt_needs_reload(self, context):
if not self.has_param_changed(context, "convert_to_tensorrt"):
return False
model = context.models["stable-diffusion"]
pipe = model["default"]
if hasattr(pipe.unet, "_allocate_trt_buffers"): # TRT already loaded
return False
return True
def make_images( def make_images(
context, context,
@ -148,6 +159,7 @@ def make_images_internal(
context, context,
req, req,
task_data, task_data,
models_data,
data_queue, data_queue,
task_temp_images, task_temp_images,
step_callback, step_callback,
@ -174,6 +186,7 @@ def generate_images_internal(
context, context,
req: GenerateImageRequest, req: GenerateImageRequest,
task_data: TaskData, task_data: TaskData,
models_data: ModelsData,
data_queue: queue.Queue, data_queue: queue.Queue,
task_temp_images: list, task_temp_images: list,
step_callback, step_callback,
@ -197,6 +210,15 @@ def generate_images_internal(
if req.init_image is not None and not context.test_diffusers: if req.init_image is not None and not context.test_diffusers:
req.sampler_name = "ddim" req.sampler_name = "ddim"
if context.test_diffusers:
pipe = context.models["stable-diffusion"]["default"]
if hasattr(pipe.unet, "_allocate_trt_buffers"):
convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False)
pipe.unet.forward = pipe.unet._trt_forward if convert_to_trt else pipe.unet._non_trt_forward
# pipe.vae.decoder.forward = (
# pipe.vae.decoder._trt_forward if convert_to_trt else pipe.vae.decoder._non_trt_forward
# )
images = generate_images(context, callback=callback, **req.dict()) images = generate_images(context, callback=callback, **req.dict())
user_stopped = False user_stopped = False
except UserInitiatedStop: except UserInitiatedStop:

View File

@ -148,7 +148,7 @@
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a> <a href="https://github.com/easydiffusion/easydiffusion/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a>
</td></tr> </td></tr>
<tr class="pl-5 displayNone" id="enable_trt_config"> <tr class="pl-5 displayNone" id="enable_trt_config">
<td><label for="convert_to_tensorrt">Convert to TensorRT:</label></td> <td><label for="convert_to_tensorrt">Enable TensorRT:</label></td>
<td class="diffusers-restart-needed"> <td class="diffusers-restart-needed">
<input id="convert_to_tensorrt" name="convert_to_tensorrt" type="checkbox"> <input id="convert_to_tensorrt" name="convert_to_tensorrt" type="checkbox">
<a href="https://github.com/easydiffusion/easydiffusion/wiki/TensorRT" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about TensorRT</span></i></a> <a href="https://github.com/easydiffusion/easydiffusion/wiki/TensorRT" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about TensorRT</span></i></a>

View File

@ -1346,9 +1346,19 @@ function createTask(task) {
function getCurrentUserRequest() { function getCurrentUserRequest() {
const numOutputsTotal = parseInt(numOutputsTotalField.value) const numOutputsTotal = parseInt(numOutputsTotalField.value)
const numOutputsParallel = parseInt(numOutputsParallelField.value) let numOutputsParallel = parseInt(numOutputsParallelField.value)
const seed = randomSeedField.checked ? Math.floor(Math.random() * (2 ** 32 - 1)) : parseInt(seedField.value) const seed = randomSeedField.checked ? Math.floor(Math.random() * (2 ** 32 - 1)) : parseInt(seedField.value)
if (
testDiffusers.checked &&
document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall" &&
document.querySelector("#convert_to_tensorrt").checked
) {
// TRT enabled
numOutputsParallel = 1 // force 1 parallel
}
const newTask = { const newTask = {
batchesDone: 0, batchesDone: 0,
numOutputsTotal: numOutputsTotal, numOutputsTotal: numOutputsTotal,