mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2024-12-27 09:28:56 +01:00
Bug fixes for TRT
This commit is contained in:
parent
801a3dd598
commit
76b7e32125
@ -93,15 +93,27 @@ class RenderTask(Task):
|
||||
return model["params"].get(param_name) != new_val
|
||||
|
||||
def trt_needs_reload(self, context):
|
||||
if not self.has_param_changed(context, "convert_to_tensorrt"):
|
||||
if not context.test_diffusers:
|
||||
return False
|
||||
if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]:
|
||||
return True
|
||||
|
||||
model = context.models["stable-diffusion"]
|
||||
pipe = model["default"]
|
||||
if hasattr(pipe.unet, "_allocate_trt_buffers"): # TRT already loaded
|
||||
return False
|
||||
|
||||
return True
|
||||
# curr_convert_to_trt = model["params"].get("convert_to_tensorrt")
|
||||
new_convert_to_trt = self.models_data.model_params.get("stable-diffusion", {}).get("convert_to_tensorrt", False)
|
||||
|
||||
pipe = model["default"]
|
||||
is_trt_loaded = hasattr(pipe.unet, "_allocate_trt_buffers") or hasattr(
|
||||
pipe.unet, "_allocate_trt_buffers_backup"
|
||||
)
|
||||
if new_convert_to_trt and not is_trt_loaded:
|
||||
return True
|
||||
|
||||
curr_build_config = model["params"].get("trt_build_config")
|
||||
new_build_config = self.models_data.model_params.get("stable-diffusion", {}).get("trt_build_config", {})
|
||||
|
||||
return new_convert_to_trt and curr_build_config != new_build_config
|
||||
|
||||
|
||||
def make_images(
|
||||
@ -215,12 +227,20 @@ def generate_images_internal(
|
||||
|
||||
if context.test_diffusers:
|
||||
pipe = context.models["stable-diffusion"]["default"]
|
||||
if hasattr(pipe.unet, "_allocate_trt_buffers_backup"):
|
||||
setattr(pipe.unet, "_allocate_trt_buffers", pipe.unet._allocate_trt_buffers_backup)
|
||||
delattr(pipe.unet, "_allocate_trt_buffers_backup")
|
||||
|
||||
if hasattr(pipe.unet, "_allocate_trt_buffers"):
|
||||
convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False)
|
||||
pipe.unet.forward = pipe.unet._trt_forward if convert_to_trt else pipe.unet._non_trt_forward
|
||||
# pipe.vae.decoder.forward = (
|
||||
# pipe.vae.decoder._trt_forward if convert_to_trt else pipe.vae.decoder._non_trt_forward
|
||||
# )
|
||||
if convert_to_trt:
|
||||
pipe.unet.forward = pipe.unet._trt_forward
|
||||
# pipe.vae.decoder.forward = pipe.vae.decoder._trt_forward
|
||||
else:
|
||||
pipe.unet.forward = pipe.unet._non_trt_forward
|
||||
# pipe.vae.decoder.forward = pipe.vae.decoder._non_trt_forward
|
||||
setattr(pipe.unet, "_allocate_trt_buffers_backup", pipe.unet._allocate_trt_buffers)
|
||||
delattr(pipe.unet, "_allocate_trt_buffers")
|
||||
|
||||
images = generate_images(context, callback=callback, **req.dict())
|
||||
user_stopped = False
|
||||
|
Loading…
Reference in New Issue
Block a user