From 1a8d2c5ebe9acef8a12b89032446d445b030eed4 Mon Sep 17 00:00:00 2001 From: Christian Schwarz Date: Sun, 8 Jul 2018 23:31:46 +0200 Subject: [PATCH] replication: context support and propert closing of stale readers --- cmd/replication.go | 63 ++++++++++++++++------------- cmd/replication/replication.go | 27 +++++++------ cmd/replication/replication_test.go | 6 +-- zfs/zfs.go | 2 +- 4 files changed, 54 insertions(+), 44 deletions(-) diff --git a/cmd/replication.go b/cmd/replication.go index ee0e47a..d333056 100644 --- a/cmd/replication.go +++ b/cmd/replication.go @@ -10,6 +10,7 @@ import ( "github.com/golang/protobuf/proto" "bytes" "os" + "context" ) type InitialReplPolicy string @@ -31,7 +32,7 @@ func NewSenderEndpoint(fsf zfs.DatasetFilter, fsvf zfs.FilesystemVersionFilter) return &SenderEndpoint{fsf, fsvf} } -func (p *SenderEndpoint) ListFilesystems() ([]*replication.Filesystem, error) { +func (p *SenderEndpoint) ListFilesystems(ctx context.Context) ([]*replication.Filesystem, error) { fss, err := zfs.ZFSListMapping(p.FSFilter) if err != nil { return nil, err @@ -46,7 +47,7 @@ func (p *SenderEndpoint) ListFilesystems() ([]*replication.Filesystem, error) { return rfss, nil } -func (p *SenderEndpoint) ListFilesystemVersions(fs string) ([]*replication.FilesystemVersion, error) { +func (p *SenderEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ([]*replication.FilesystemVersion, error) { dp, err := zfs.NewDatasetPath(fs) if err != nil { return nil, err @@ -69,7 +70,7 @@ func (p *SenderEndpoint) ListFilesystemVersions(fs string) ([]*replication.Files return rfsvs, nil } -func (p *SenderEndpoint) Send(r *replication.SendReq) (*replication.SendRes, io.Reader, error) { +func (p *SenderEndpoint) Send(ctx context.Context, r *replication.SendReq) (*replication.SendRes, io.ReadCloser, error) { os.Stderr.WriteString("sending " + r.String() + "\n") dp, err := zfs.NewDatasetPath(r.Filesystem) if err != nil { @@ -89,7 +90,7 @@ func (p *SenderEndpoint) Send(r *replication.SendReq) (*replication.SendRes, io. return &replication.SendRes{}, stream, nil } -func (p *SenderEndpoint) Receive(r *replication.ReceiveReq, sendStream io.Reader) (error) { +func (p *SenderEndpoint) Receive(ctx context.Context, r *replication.ReceiveReq, sendStream io.ReadCloser) (error) { return fmt.Errorf("sender endpoint does not receive") } @@ -109,7 +110,7 @@ func NewReceiverEndpoint(fsmap *DatasetMapFilter, fsvf zfs.FilesystemVersionFilt return &ReceiverEndpoint{fsmapInv, fsmap, fsvf}, nil } -func (e *ReceiverEndpoint) ListFilesystems() ([]*replication.Filesystem, error) { +func (e *ReceiverEndpoint) ListFilesystems(ctx context.Context) ([]*replication.Filesystem, error) { filtered, err := zfs.ZFSListMapping(e.fsmapInv.AsFilter()) if err != nil { return nil, errors.Wrap(err, "error checking client permission") @@ -125,7 +126,7 @@ func (e *ReceiverEndpoint) ListFilesystems() ([]*replication.Filesystem, error) return fss, nil } -func (e *ReceiverEndpoint) ListFilesystemVersions(fs string) ([]*replication.FilesystemVersion, error) { +func (e *ReceiverEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ([]*replication.FilesystemVersion, error) { p, err := zfs.NewDatasetPath(fs) if err != nil { return nil, err @@ -151,11 +152,13 @@ func (e *ReceiverEndpoint) ListFilesystemVersions(fs string) ([]*replication.Fil return rfsvs, nil } -func (e *ReceiverEndpoint) Send(req *replication.SendReq) (*replication.SendRes, io.Reader, error) { +func (e *ReceiverEndpoint) Send(ctx context.Context, req *replication.SendReq) (*replication.SendRes, io.ReadCloser, error) { return nil, nil, errors.New("receiver endpoint does not send") } -func (e *ReceiverEndpoint) Receive(req *replication.ReceiveReq, sendStream io.Reader) error { +func (e *ReceiverEndpoint) Receive(ctx context.Context, req *replication.ReceiveReq, sendStream io.ReadCloser) error { + defer sendStream.Close() + p, err := zfs.NewDatasetPath(req.Filesystem) if err != nil { return err @@ -210,7 +213,6 @@ func (e *ReceiverEndpoint) Receive(req *replication.ReceiveReq, sendStream io.Re os.Stderr.WriteString("receiving...\n") if err := zfs.ZFSRecv(lp.ToString(), sendStream, args...); err != nil { - // FIXME sendStream is on the wire and contains data, if we don't consume it, wire must be closed return err } return nil @@ -232,19 +234,18 @@ type RemoteEndpoint struct { *streamrpc.Client } -func (s RemoteEndpoint) ListFilesystems() ([]*replication.Filesystem, error) { +func (s RemoteEndpoint) ListFilesystems(ctx context.Context) ([]*replication.Filesystem, error) { req := replication.ListFilesystemReq{} b, err := proto.Marshal(&req) if err != nil { return nil, err } - rb, rs, err := s.RequestReply(RPCListFilesystems, bytes.NewBuffer(b), nil) + rb, rs, err := s.RequestReply(ctx, RPCListFilesystems, bytes.NewBuffer(b), nil) if err != nil { return nil, err } if rs != nil { - os.Stderr.WriteString(fmt.Sprintf("%#v\n", rs)) - s.Close() // FIXME + rs.Close() return nil, errors.New("response contains unexpected stream") } var res replication.ListFilesystemRes @@ -254,7 +255,7 @@ func (s RemoteEndpoint) ListFilesystems() ([]*replication.Filesystem, error) { return res.Filesystems, nil } -func (s RemoteEndpoint) ListFilesystemVersions(fs string) ([]*replication.FilesystemVersion, error) { +func (s RemoteEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ([]*replication.FilesystemVersion, error) { req := replication.ListFilesystemVersionsReq{ Filesystem: fs, } @@ -262,12 +263,12 @@ func (s RemoteEndpoint) ListFilesystemVersions(fs string) ([]*replication.Filesy if err != nil { return nil, err } - rb, rs, err := s.RequestReply(RPCListFilesystemVersions, bytes.NewBuffer(b), nil) + rb, rs, err := s.RequestReply(ctx, RPCListFilesystemVersions, bytes.NewBuffer(b), nil) if err != nil { return nil, err } if rs != nil { - s.Close() // FIXME + rs.Close() return nil, errors.New("response contains unexpected stream") } var res replication.ListFilesystemVersionsRes @@ -277,12 +278,12 @@ func (s RemoteEndpoint) ListFilesystemVersions(fs string) ([]*replication.Filesy return res.Versions, nil } -func (s RemoteEndpoint) Send(r *replication.SendReq) (*replication.SendRes, io.Reader, error) { +func (s RemoteEndpoint) Send(ctx context.Context, r *replication.SendReq) (*replication.SendRes, io.ReadCloser, error) { b, err := proto.Marshal(r) if err != nil { return nil, nil, err } - rb, rs, err := s.RequestReply(RPCSend, bytes.NewBuffer(b), nil) + rb, rs, err := s.RequestReply(ctx, RPCSend, bytes.NewBuffer(b), nil) if err != nil { return nil, nil, err } @@ -291,24 +292,25 @@ func (s RemoteEndpoint) Send(r *replication.SendReq) (*replication.SendRes, io.R } var res replication.SendRes if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { - s.Close() // FIXME + rs.Close() return nil, nil, err } // FIXME make sure the consumer will read the reader until the end... return &res, rs, nil } -func (s RemoteEndpoint) Receive(r *replication.ReceiveReq, sendStream io.Reader) (error) { +func (s RemoteEndpoint) Receive(ctx context.Context, r *replication.ReceiveReq, sendStream io.ReadCloser) (error) { + defer sendStream.Close() b, err := proto.Marshal(r) if err != nil { return err } - rb, rs, err := s.RequestReply(RPCReceive, bytes.NewBuffer(b), sendStream) + rb, rs, err := s.RequestReply(ctx, RPCReceive, bytes.NewBuffer(b), sendStream) if err != nil { - s.Close() // FIXME return err } if rs != nil { + rs.Close() return errors.New("response contains unexpected stream") } var res replication.ReceiveRes @@ -320,9 +322,16 @@ func (s RemoteEndpoint) Receive(r *replication.ReceiveReq, sendStream io.Reader) type HandlerAdaptor struct { ep replication.ReplicationEndpoint + log Logger } -func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, reqStream io.Reader) (resStructured *bytes.Buffer, resStream io.Reader, err error) { +func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, reqStream io.ReadCloser) (resStructured *bytes.Buffer, resStream io.ReadCloser, err error) { + + ctx := context.Background() + if a.log != nil { + // FIXME validate type conversion here? + ctx = context.WithValue(ctx, streamrpc.ContextKeyLogger, a.log) + } switch endpoint { case RPCListFilesystems: @@ -330,7 +339,7 @@ func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, re if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { return nil, nil, err } - fsses, err := a.ep.ListFilesystems() + fsses, err := a.ep.ListFilesystems(ctx) if err != nil { return nil, nil, err } @@ -349,7 +358,7 @@ func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, re if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { return nil, nil, err } - fsvs, err := a.ep.ListFilesystemVersions(req.Filesystem) + fsvs, err := a.ep.ListFilesystemVersions(ctx, req.Filesystem) if err != nil { return nil, nil, err } @@ -368,7 +377,7 @@ func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, re if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { return nil, nil, err } - res, sendStream, err := a.ep.Send(&req) + res, sendStream, err := a.ep.Send(ctx, &req) if err != nil { return nil, nil, err } @@ -384,7 +393,7 @@ func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, re if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { return nil, nil, err } - err := a.ep.Receive(&req, reqStream) + err := a.ep.Receive(ctx, &req, reqStream) if err != nil { return nil, nil, err } diff --git a/cmd/replication/replication.go b/cmd/replication/replication.go index 6b3f57b..fb3b154 100644 --- a/cmd/replication/replication.go +++ b/cmd/replication/replication.go @@ -7,8 +7,8 @@ import ( type ReplicationEndpoint interface { // Does not include placeholder filesystems - ListFilesystems() ([]*Filesystem, error) - ListFilesystemVersions(fs string) ([]*FilesystemVersion, error) // fix depS + ListFilesystems(ctx context.Context) ([]*Filesystem, error) + ListFilesystemVersions(ctx context.Context, fs string) ([]*FilesystemVersion, error) // fix depS Sender Receiver } @@ -81,21 +81,22 @@ func Replicate(ctx context.Context, ep EndpointPair, ipr IncrementalPathReplicat log := ctx.Value(ContextKeyLog).(Logger) - sfss, err := ep.Sender().ListFilesystems() + sfss, err := ep.Sender().ListFilesystems(ctx) if err != nil { log.Printf("error listing sender filesystems: %s", err) return } - rfss, err := ep.Receiver().ListFilesystems() + rfss, err := ep.Receiver().ListFilesystems(ctx) if err != nil { log.Printf("error listing receiver filesystems: %s", err) return } for _, fs := range sfss { - log.Printf("replication fs %s", fs.Path) - sfsvs, err := ep.Sender().ListFilesystemVersions(fs.Path) + log.Printf("replicating %s", fs.Path) + + sfsvs, err := ep.Sender().ListFilesystemVersions(ctx, fs.Path) if err != nil { log.Printf("sender error %s", err) continue @@ -115,7 +116,7 @@ func Replicate(ctx context.Context, ep EndpointPair, ipr IncrementalPathReplicat var rfsvs []*FilesystemVersion if receiverFSExists { - rfsvs, err = ep.Receiver().ListFilesystemVersions(fs.Path) + rfsvs, err = ep.Receiver().ListFilesystemVersions(ctx, fs.Path) if err != nil { log.Printf("receiver error %s", err) if _, ok := err.(FilteredError); ok { @@ -162,11 +163,11 @@ func Replicate(ctx context.Context, ep EndpointPair, ipr IncrementalPathReplicat } type Sender interface { - Send(r *SendReq) (*SendRes, io.Reader, error) + Send(ctx context.Context, r *SendReq) (*SendRes, io.ReadCloser, error) } type Receiver interface { - Receive(r *ReceiveReq, sendStream io.Reader) (error) + Receive(ctx context.Context, r *ReceiveReq, sendStream io.ReadCloser) (error) } type Copier interface { @@ -211,7 +212,7 @@ func (incrementalPathReplicator) Replicate(ctx context.Context, sender Sender, r From: path[0].RelName(), ResumeToken: fs.ResumeToken, } - sres, sstream, err := sender.Send(sr) + sres, sstream, err := sender.Send(ctx, sr) if err != nil { log.Printf("send request failed: %s", err) // FIXME must close connection... @@ -222,7 +223,7 @@ func (incrementalPathReplicator) Replicate(ctx context.Context, sender Sender, r Filesystem: fs.Path, ClearResumeToken: fs.ResumeToken != "" && !sres.UsedResumeToken, } - err = receiver.Receive(rr, sstream) + err = receiver.Receive(ctx, rr, sstream) if err != nil { // FIXME this failure could be due to an unexpected exit of ZFS on the sending side // FIXME which is transported through the streamrpc protocol, and known to the sendStream.(*streamrpc.streamReader), @@ -250,7 +251,7 @@ incrementalLoop: To: path[j+1].RelName(), ResumeToken: rt, } - sres, sstream, err := sender.Send(sr) + sres, sstream, err := sender.Send(ctx, sr) if err != nil { log.Printf("send request failed: %s", err) // handle and ignore @@ -262,7 +263,7 @@ incrementalLoop: Filesystem: fs.Path, ClearResumeToken: rt != "" && !sres.UsedResumeToken, } - err = receiver.Receive(rr, sstream) + err = receiver.Receive(ctx, rr, sstream) if err != nil { log.Printf("receive request failed: %s", err) // handle and ignore diff --git a/cmd/replication/replication_test.go b/cmd/replication/replication_test.go index 0334d02..2e6072b 100644 --- a/cmd/replication/replication_test.go +++ b/cmd/replication/replication_test.go @@ -11,7 +11,7 @@ import ( type IncrementalPathSequenceStep struct { SendRequest *replication.SendReq SendResponse *replication.SendRes - SendReader io.Reader + SendReader io.ReadCloser SendError error ReceiveRequest *replication.ReceiveReq ReceiveError error @@ -23,7 +23,7 @@ type MockIncrementalPathRecorder struct { Pos int } -func (m *MockIncrementalPathRecorder) Receive(r *replication.ReceiveReq, rs io.Reader) (error) { +func (m *MockIncrementalPathRecorder) Receive(ctx context.Context, r *replication.ReceiveReq, rs io.ReadCloser) (error) { if m.Pos >= len(m.Sequence) { m.T.Fatal("unexpected Receive") } @@ -35,7 +35,7 @@ func (m *MockIncrementalPathRecorder) Receive(r *replication.ReceiveReq, rs io.R return i.ReceiveError } -func (m *MockIncrementalPathRecorder) Send(r *replication.SendReq) (*replication.SendRes, io.Reader, error) { +func (m *MockIncrementalPathRecorder) Send(ctx context.Context, r *replication.SendReq) (*replication.SendRes, io.ReadCloser, error) { if m.Pos >= len(m.Sequence) { m.T.Fatal("unexpected Send") } diff --git a/zfs/zfs.go b/zfs/zfs.go index b9557fe..2b11da3 100644 --- a/zfs/zfs.go +++ b/zfs/zfs.go @@ -278,7 +278,7 @@ func absVersion(fs, v string) (full string, err error) { return fmt.Sprintf("%s%s", fs, v), nil } -func ZFSSend(fs string, from, to string) (stream io.Reader, err error) { +func ZFSSend(fs string, from, to string) (stream io.ReadCloser, err error) { fromV, err := absVersion(fs, from) if err != nil {