mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-01-14 18:28:18 +01:00
Fixes for TensorRT
This commit is contained in:
parent
8538a684e7
commit
d39e1da183
@ -63,7 +63,7 @@ class RenderTask(Task):
|
||||
if (
|
||||
runtime.set_vram_optimizations(context)
|
||||
or self.has_param_changed(context, "clip_skip")
|
||||
or self.has_param_changed(context, "convert_to_tensorrt")
|
||||
or self.trt_needs_reload(context)
|
||||
):
|
||||
models_to_force_reload.append("stable-diffusion")
|
||||
|
||||
@ -92,6 +92,17 @@ class RenderTask(Task):
|
||||
new_val = self.models_data.model_params.get("stable-diffusion", {}).get(param_name, False)
|
||||
return model["params"].get(param_name) != new_val
|
||||
|
||||
def trt_needs_reload(self, context):
|
||||
if not self.has_param_changed(context, "convert_to_tensorrt"):
|
||||
return False
|
||||
|
||||
model = context.models["stable-diffusion"]
|
||||
pipe = model["default"]
|
||||
if hasattr(pipe.unet, "_allocate_trt_buffers"): # TRT already loaded
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def make_images(
|
||||
context,
|
||||
@ -148,6 +159,7 @@ def make_images_internal(
|
||||
context,
|
||||
req,
|
||||
task_data,
|
||||
models_data,
|
||||
data_queue,
|
||||
task_temp_images,
|
||||
step_callback,
|
||||
@ -174,6 +186,7 @@ def generate_images_internal(
|
||||
context,
|
||||
req: GenerateImageRequest,
|
||||
task_data: TaskData,
|
||||
models_data: ModelsData,
|
||||
data_queue: queue.Queue,
|
||||
task_temp_images: list,
|
||||
step_callback,
|
||||
@ -197,6 +210,15 @@ def generate_images_internal(
|
||||
if req.init_image is not None and not context.test_diffusers:
|
||||
req.sampler_name = "ddim"
|
||||
|
||||
if context.test_diffusers:
|
||||
pipe = context.models["stable-diffusion"]["default"]
|
||||
if hasattr(pipe.unet, "_allocate_trt_buffers"):
|
||||
convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False)
|
||||
pipe.unet.forward = pipe.unet._trt_forward if convert_to_trt else pipe.unet._non_trt_forward
|
||||
# pipe.vae.decoder.forward = (
|
||||
# pipe.vae.decoder._trt_forward if convert_to_trt else pipe.vae.decoder._non_trt_forward
|
||||
# )
|
||||
|
||||
images = generate_images(context, callback=callback, **req.dict())
|
||||
user_stopped = False
|
||||
except UserInitiatedStop:
|
||||
|
@ -148,7 +148,7 @@
|
||||
<a href="https://github.com/easydiffusion/easydiffusion/wiki/Custom-Models" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about custom models</span></i></a>
|
||||
</td></tr>
|
||||
<tr class="pl-5 displayNone" id="enable_trt_config">
|
||||
<td><label for="convert_to_tensorrt">Convert to TensorRT:</label></td>
|
||||
<td><label for="convert_to_tensorrt">Enable TensorRT:</label></td>
|
||||
<td class="diffusers-restart-needed">
|
||||
<input id="convert_to_tensorrt" name="convert_to_tensorrt" type="checkbox">
|
||||
<a href="https://github.com/easydiffusion/easydiffusion/wiki/TensorRT" target="_blank"><i class="fa-solid fa-circle-question help-btn"><span class="simple-tooltip top-left">Click to learn more about TensorRT</span></i></a>
|
||||
|
@ -1346,9 +1346,19 @@ function createTask(task) {
|
||||
|
||||
function getCurrentUserRequest() {
|
||||
const numOutputsTotal = parseInt(numOutputsTotalField.value)
|
||||
const numOutputsParallel = parseInt(numOutputsParallelField.value)
|
||||
let numOutputsParallel = parseInt(numOutputsParallelField.value)
|
||||
const seed = randomSeedField.checked ? Math.floor(Math.random() * (2 ** 32 - 1)) : parseInt(seedField.value)
|
||||
|
||||
if (
|
||||
testDiffusers.checked &&
|
||||
document.getElementById("toggle-tensorrt-install").innerHTML == "Uninstall" &&
|
||||
document.querySelector("#convert_to_tensorrt").checked
|
||||
) {
|
||||
// TRT enabled
|
||||
|
||||
numOutputsParallel = 1 // force 1 parallel
|
||||
}
|
||||
|
||||
const newTask = {
|
||||
batchesDone: 0,
|
||||
numOutputsTotal: numOutputsTotal,
|
||||
|
Loading…
Reference in New Issue
Block a user