diff --git a/endpoint/endpoint.go b/endpoint/endpoint.go index 50bb409..19355f8 100644 --- a/endpoint/endpoint.go +++ b/endpoint/endpoint.go @@ -198,7 +198,11 @@ func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.St if err != nil { return nil, nil, err } - defer guard.Release() + defer func(guardp **semaphore.AcquireGuard) { + if *guardp != nil { + (*guardp).Release() + } + }(&guard) si, err := zfs.ZFSSendDry(ctx, sendArgs) if err != nil { @@ -255,6 +259,13 @@ func (s *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, zfs.St if err != nil { return nil, nil, errors.Wrap(err, "zfs send failed") } + + // defer releasing guard until streamCopier is closed + streamCopier.SetPostCloseCallback(func(_ error) { + guard.Release() + }) + guard = nil + return res, streamCopier, nil } diff --git a/zfs/zfs.go b/zfs/zfs.go index b101288..2a4a104 100644 --- a/zfs/zfs.go +++ b/zfs/zfs.go @@ -350,7 +350,8 @@ func (a ZFSSendArgs) buildCommonSendArgs() ([]string, error) { } type ReadCloserCopier struct { - recorder readErrRecorder + recorder readErrRecorder + postCloseCallback func(closeErr error) } type readErrRecorder struct { @@ -402,8 +403,17 @@ func (c *ReadCloserCopier) Read(p []byte) (n int, err error) { return c.recorder.Read(p) } +// caller must ensure that this function is not executing concurrently to Close +func (c *ReadCloserCopier) SetPostCloseCallback(callback func(closeErr error)) { + c.postCloseCallback = callback +} + func (c *ReadCloserCopier) Close() error { - return c.recorder.ReadCloser.Close() + err := c.recorder.Close() + if c.postCloseCallback != nil { + c.postCloseCallback(err) + } + return err } func pipeWithCapacityHint(capacity int) (r, w *os.File, err error) {