diff --git a/ui/easydiffusion/tasks/render_images.py b/ui/easydiffusion/tasks/render_images.py index 8df208b6..ab0917d6 100644 --- a/ui/easydiffusion/tasks/render_images.py +++ b/ui/easydiffusion/tasks/render_images.py @@ -93,15 +93,27 @@ class RenderTask(Task): return model["params"].get(param_name) != new_val def trt_needs_reload(self, context): - if not self.has_param_changed(context, "convert_to_tensorrt"): + if not context.test_diffusers: return False + if "stable-diffusion" not in context.models or "params" not in context.models["stable-diffusion"]: + return True model = context.models["stable-diffusion"] - pipe = model["default"] - if hasattr(pipe.unet, "_allocate_trt_buffers"): # TRT already loaded - return False - return True + # curr_convert_to_trt = model["params"].get("convert_to_tensorrt") + new_convert_to_trt = self.models_data.model_params.get("stable-diffusion", {}).get("convert_to_tensorrt", False) + + pipe = model["default"] + is_trt_loaded = hasattr(pipe.unet, "_allocate_trt_buffers") or hasattr( + pipe.unet, "_allocate_trt_buffers_backup" + ) + if new_convert_to_trt and not is_trt_loaded: + return True + + curr_build_config = model["params"].get("trt_build_config") + new_build_config = self.models_data.model_params.get("stable-diffusion", {}).get("trt_build_config", {}) + + return new_convert_to_trt and curr_build_config != new_build_config def make_images( @@ -215,12 +227,20 @@ def generate_images_internal( if context.test_diffusers: pipe = context.models["stable-diffusion"]["default"] + if hasattr(pipe.unet, "_allocate_trt_buffers_backup"): + setattr(pipe.unet, "_allocate_trt_buffers", pipe.unet._allocate_trt_buffers_backup) + delattr(pipe.unet, "_allocate_trt_buffers_backup") + if hasattr(pipe.unet, "_allocate_trt_buffers"): convert_to_trt = models_data.model_params["stable-diffusion"].get("convert_to_tensorrt", False) - pipe.unet.forward = pipe.unet._trt_forward if convert_to_trt else pipe.unet._non_trt_forward - # pipe.vae.decoder.forward = ( - # pipe.vae.decoder._trt_forward if convert_to_trt else pipe.vae.decoder._non_trt_forward - # ) + if convert_to_trt: + pipe.unet.forward = pipe.unet._trt_forward + # pipe.vae.decoder.forward = pipe.vae.decoder._trt_forward + else: + pipe.unet.forward = pipe.unet._non_trt_forward + # pipe.vae.decoder.forward = pipe.vae.decoder._non_trt_forward + setattr(pipe.unet, "_allocate_trt_buffers_backup", pipe.unet._allocate_trt_buffers) + delattr(pipe.unet, "_allocate_trt_buffers") images = generate_images(context, callback=callback, **req.dict()) user_stopped = False