mirror of
https://github.com/easydiffusion/easydiffusion.git
synced 2025-06-20 09:57:49 +02:00
Bug fixes for TRT
This commit is contained in:
parent
801a3dd598
commit
76b7e32125
@ -93,15 +93,27 @@ class RenderTask(Task):
|
|||||||
return model["params"].get(param_name) != new_val
|
return model["params"].get(param_name) != new_val
|
||||||
|
|
||||||
def trt_needs_reload(self, context):
|
def trt_needs_reload(self, context):
|
||||||
if not self.has_param_changed(context, "convert_to_tensorrt"):
|
if not context.test_diffusers:
|
||||||
return False
|
return False
|
||||||
|
if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]:
|
||||||
|
return True
|
||||||
|
|
||||||
model = context.models["stable-diffusion"]
|
model = context.models["stable-diffusion"]
|
||||||
pipe = model["default"]
|
|
||||||
if hasattr(pipe.unet, "_allocate_trt_buffers"): # TRT already loaded
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
# curr_convert_to_trt = model["params"].get("convert_to_tensorrt")
|
||||||
|
new_convert_to_trt = self.models_data.model_params.get("stable-diffusion", {}).get("convert_to_tensorrt", False)
|
||||||
|
|
||||||
|
pipe = model["default"]
|
||||||
|
is_trt_loaded = hasattr(pipe.unet, "_allocate_trt_buffers") or hasattr(
|
||||||
|
pipe.unet, "_allocate_trt_buffers_backup"
|
||||||
|
)
|
||||||
|
if new_convert_to_trt and not is_trt_loaded:
|
||||||
|
return True
|
||||||
|
|
||||||
|
curr_build_config = model["params"].get("trt_build_config")
|
||||||
|
new_build_config = self.models_data.model_params.get("stable-diffusion", {}).get("trt_build_config", {})
|
||||||
|
|
||||||
|
return new_convert_to_trt and curr_build_config != new_build_config
|
||||||
|
|
||||||
|
|
||||||
def make_images(
|
def make_images(
|
||||||
@ -215,12 +227,20 @@ def generate_images_internal(
|
|||||||
|
|
||||||
if context.test_diffusers:
|
if context.test_diffusers:
|
||||||
pipe = context.models["stable-diffusion"]["default"]
|
pipe = context.models["stable-diffusion"]["default"]
|
||||||
|
if hasattr(pipe.unet, "_allocate_trt_buffers_backup"):
|
||||||
|
setattr(pipe.unet, "_allocate_trt_buffers", pipe.unet._allocate_trt_buffers_backup)
|
||||||
|
delattr(pipe.unet, "_allocate_trt_buffers_backup")
|
||||||
|
|
||||||
if hasattr(pipe.unet, "_allocate_trt_buffers"):
|
if hasattr(pipe.unet, "_allocate_trt_buffers"):
|
||||||
convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False)
|
convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False)
|
||||||
pipe.unet.forward = pipe.unet._trt_forward if convert_to_trt else pipe.unet._non_trt_forward
|
if convert_to_trt:
|
||||||
# pipe.vae.decoder.forward = (
|
pipe.unet.forward = pipe.unet._trt_forward
|
||||||
# pipe.vae.decoder._trt_forward if convert_to_trt else pipe.vae.decoder._non_trt_forward
|
# pipe.vae.decoder.forward = pipe.vae.decoder._trt_forward
|
||||||
# )
|
else:
|
||||||
|
pipe.unet.forward = pipe.unet._non_trt_forward
|
||||||
|
# pipe.vae.decoder.forward = pipe.vae.decoder._non_trt_forward
|
||||||
|
setattr(pipe.unet, "_allocate_trt_buffers_backup", pipe.unet._allocate_trt_buffers)
|
||||||
|
delattr(pipe.unet, "_allocate_trt_buffers")
|
||||||
|
|
||||||
images = generate_images(context, callback=callback, **req.dict())
|
images = generate_images(context, callback=callback, **req.dict())
|
||||||
user_stopped = False
|
user_stopped = False
|
||||||
|
Loading…
x
Reference in New Issue
Block a user