Bug fixes for TRT

This commit is contained in:
cmdr2 2023-08-02 16:37:05 +05:30
parent 801a3dd598
commit 76b7e32125

View File

@ -93,15 +93,27 @@ class RenderTask(Task):
return model["params"].get(param_name) != new_val
def trt_needs_reload(self, context):
if not self.has_param_changed(context, "convert_to_tensorrt"):
if not context.test_diffusers:
return False
if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]:
return True
model = context.models["stable-diffusion"]
pipe = model["default"]
if hasattr(pipe.unet, "_allocate_trt_buffers"): # TRT already loaded
return False
return True
# curr_convert_to_trt = model["params"].get("convert_to_tensorrt")
new_convert_to_trt = self.models_data.model_params.get("stable-diffusion", {}).get("convert_to_tensorrt", False)
pipe = model["default"]
is_trt_loaded = hasattr(pipe.unet, "_allocate_trt_buffers") or hasattr(
pipe.unet, "_allocate_trt_buffers_backup"
)
if new_convert_to_trt and not is_trt_loaded:
return True
curr_build_config = model["params"].get("trt_build_config")
new_build_config = self.models_data.model_params.get("stable-diffusion", {}).get("trt_build_config", {})
return new_convert_to_trt and curr_build_config != new_build_config
def make_images(
@ -215,12 +227,20 @@ def generate_images_internal(
if context.test_diffusers:
pipe = context.models["stable-diffusion"]["default"]
if hasattr(pipe.unet, "_allocate_trt_buffers_backup"):
setattr(pipe.unet, "_allocate_trt_buffers", pipe.unet._allocate_trt_buffers_backup)
delattr(pipe.unet, "_allocate_trt_buffers_backup")
if hasattr(pipe.unet, "_allocate_trt_buffers"):
convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False)
pipe.unet.forward = pipe.unet._trt_forward if convert_to_trt else pipe.unet._non_trt_forward
# pipe.vae.decoder.forward = (
# pipe.vae.decoder._trt_forward if convert_to_trt else pipe.vae.decoder._non_trt_forward
# )
if convert_to_trt:
pipe.unet.forward = pipe.unet._trt_forward
# pipe.vae.decoder.forward = pipe.vae.decoder._trt_forward
else:
pipe.unet.forward = pipe.unet._non_trt_forward
# pipe.vae.decoder.forward = pipe.vae.decoder._non_trt_forward
setattr(pipe.unet, "_allocate_trt_buffers_backup", pipe.unet._allocate_trt_buffers)
delattr(pipe.unet, "_allocate_trt_buffers")
images = generate_images(context, callback=callback, **req.dict())
user_stopped = False