replication: context support and propert closing of stale readers

This commit is contained in:
Christian Schwarz 2018-07-08 23:31:46 +02:00
parent 8cca0a8547
commit 1a8d2c5ebe
4 changed files with 54 additions and 44 deletions

View File

@ -10,6 +10,7 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"bytes" "bytes"
"os" "os"
"context"
) )
type InitialReplPolicy string type InitialReplPolicy string
@ -31,7 +32,7 @@ func NewSenderEndpoint(fsf zfs.DatasetFilter, fsvf zfs.FilesystemVersionFilter)
return &SenderEndpoint{fsf, fsvf} return &SenderEndpoint{fsf, fsvf}
} }
func (p *SenderEndpoint) ListFilesystems() ([]*replication.Filesystem, error) { func (p *SenderEndpoint) ListFilesystems(ctx context.Context) ([]*replication.Filesystem, error) {
fss, err := zfs.ZFSListMapping(p.FSFilter) fss, err := zfs.ZFSListMapping(p.FSFilter)
if err != nil { if err != nil {
return nil, err return nil, err
@ -46,7 +47,7 @@ func (p *SenderEndpoint) ListFilesystems() ([]*replication.Filesystem, error) {
return rfss, nil return rfss, nil
} }
func (p *SenderEndpoint) ListFilesystemVersions(fs string) ([]*replication.FilesystemVersion, error) { func (p *SenderEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ([]*replication.FilesystemVersion, error) {
dp, err := zfs.NewDatasetPath(fs) dp, err := zfs.NewDatasetPath(fs)
if err != nil { if err != nil {
return nil, err return nil, err
@ -69,7 +70,7 @@ func (p *SenderEndpoint) ListFilesystemVersions(fs string) ([]*replication.Files
return rfsvs, nil return rfsvs, nil
} }
func (p *SenderEndpoint) Send(r *replication.SendReq) (*replication.SendRes, io.Reader, error) { func (p *SenderEndpoint) Send(ctx context.Context, r *replication.SendReq) (*replication.SendRes, io.ReadCloser, error) {
os.Stderr.WriteString("sending " + r.String() + "\n") os.Stderr.WriteString("sending " + r.String() + "\n")
dp, err := zfs.NewDatasetPath(r.Filesystem) dp, err := zfs.NewDatasetPath(r.Filesystem)
if err != nil { if err != nil {
@ -89,7 +90,7 @@ func (p *SenderEndpoint) Send(r *replication.SendReq) (*replication.SendRes, io.
return &replication.SendRes{}, stream, nil return &replication.SendRes{}, stream, nil
} }
func (p *SenderEndpoint) Receive(r *replication.ReceiveReq, sendStream io.Reader) (error) { func (p *SenderEndpoint) Receive(ctx context.Context, r *replication.ReceiveReq, sendStream io.ReadCloser) (error) {
return fmt.Errorf("sender endpoint does not receive") return fmt.Errorf("sender endpoint does not receive")
} }
@ -109,7 +110,7 @@ func NewReceiverEndpoint(fsmap *DatasetMapFilter, fsvf zfs.FilesystemVersionFilt
return &ReceiverEndpoint{fsmapInv, fsmap, fsvf}, nil return &ReceiverEndpoint{fsmapInv, fsmap, fsvf}, nil
} }
func (e *ReceiverEndpoint) ListFilesystems() ([]*replication.Filesystem, error) { func (e *ReceiverEndpoint) ListFilesystems(ctx context.Context) ([]*replication.Filesystem, error) {
filtered, err := zfs.ZFSListMapping(e.fsmapInv.AsFilter()) filtered, err := zfs.ZFSListMapping(e.fsmapInv.AsFilter())
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error checking client permission") return nil, errors.Wrap(err, "error checking client permission")
@ -125,7 +126,7 @@ func (e *ReceiverEndpoint) ListFilesystems() ([]*replication.Filesystem, error)
return fss, nil return fss, nil
} }
func (e *ReceiverEndpoint) ListFilesystemVersions(fs string) ([]*replication.FilesystemVersion, error) { func (e *ReceiverEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ([]*replication.FilesystemVersion, error) {
p, err := zfs.NewDatasetPath(fs) p, err := zfs.NewDatasetPath(fs)
if err != nil { if err != nil {
return nil, err return nil, err
@ -151,11 +152,13 @@ func (e *ReceiverEndpoint) ListFilesystemVersions(fs string) ([]*replication.Fil
return rfsvs, nil return rfsvs, nil
} }
func (e *ReceiverEndpoint) Send(req *replication.SendReq) (*replication.SendRes, io.Reader, error) { func (e *ReceiverEndpoint) Send(ctx context.Context, req *replication.SendReq) (*replication.SendRes, io.ReadCloser, error) {
return nil, nil, errors.New("receiver endpoint does not send") return nil, nil, errors.New("receiver endpoint does not send")
} }
func (e *ReceiverEndpoint) Receive(req *replication.ReceiveReq, sendStream io.Reader) error { func (e *ReceiverEndpoint) Receive(ctx context.Context, req *replication.ReceiveReq, sendStream io.ReadCloser) error {
defer sendStream.Close()
p, err := zfs.NewDatasetPath(req.Filesystem) p, err := zfs.NewDatasetPath(req.Filesystem)
if err != nil { if err != nil {
return err return err
@ -210,7 +213,6 @@ func (e *ReceiverEndpoint) Receive(req *replication.ReceiveReq, sendStream io.Re
os.Stderr.WriteString("receiving...\n") os.Stderr.WriteString("receiving...\n")
if err := zfs.ZFSRecv(lp.ToString(), sendStream, args...); err != nil { if err := zfs.ZFSRecv(lp.ToString(), sendStream, args...); err != nil {
// FIXME sendStream is on the wire and contains data, if we don't consume it, wire must be closed
return err return err
} }
return nil return nil
@ -232,19 +234,18 @@ type RemoteEndpoint struct {
*streamrpc.Client *streamrpc.Client
} }
func (s RemoteEndpoint) ListFilesystems() ([]*replication.Filesystem, error) { func (s RemoteEndpoint) ListFilesystems(ctx context.Context) ([]*replication.Filesystem, error) {
req := replication.ListFilesystemReq{} req := replication.ListFilesystemReq{}
b, err := proto.Marshal(&req) b, err := proto.Marshal(&req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rb, rs, err := s.RequestReply(RPCListFilesystems, bytes.NewBuffer(b), nil) rb, rs, err := s.RequestReply(ctx, RPCListFilesystems, bytes.NewBuffer(b), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if rs != nil { if rs != nil {
os.Stderr.WriteString(fmt.Sprintf("%#v\n", rs)) rs.Close()
s.Close() // FIXME
return nil, errors.New("response contains unexpected stream") return nil, errors.New("response contains unexpected stream")
} }
var res replication.ListFilesystemRes var res replication.ListFilesystemRes
@ -254,7 +255,7 @@ func (s RemoteEndpoint) ListFilesystems() ([]*replication.Filesystem, error) {
return res.Filesystems, nil return res.Filesystems, nil
} }
func (s RemoteEndpoint) ListFilesystemVersions(fs string) ([]*replication.FilesystemVersion, error) { func (s RemoteEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ([]*replication.FilesystemVersion, error) {
req := replication.ListFilesystemVersionsReq{ req := replication.ListFilesystemVersionsReq{
Filesystem: fs, Filesystem: fs,
} }
@ -262,12 +263,12 @@ func (s RemoteEndpoint) ListFilesystemVersions(fs string) ([]*replication.Filesy
if err != nil { if err != nil {
return nil, err return nil, err
} }
rb, rs, err := s.RequestReply(RPCListFilesystemVersions, bytes.NewBuffer(b), nil) rb, rs, err := s.RequestReply(ctx, RPCListFilesystemVersions, bytes.NewBuffer(b), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if rs != nil { if rs != nil {
s.Close() // FIXME rs.Close()
return nil, errors.New("response contains unexpected stream") return nil, errors.New("response contains unexpected stream")
} }
var res replication.ListFilesystemVersionsRes var res replication.ListFilesystemVersionsRes
@ -277,12 +278,12 @@ func (s RemoteEndpoint) ListFilesystemVersions(fs string) ([]*replication.Filesy
return res.Versions, nil return res.Versions, nil
} }
func (s RemoteEndpoint) Send(r *replication.SendReq) (*replication.SendRes, io.Reader, error) { func (s RemoteEndpoint) Send(ctx context.Context, r *replication.SendReq) (*replication.SendRes, io.ReadCloser, error) {
b, err := proto.Marshal(r) b, err := proto.Marshal(r)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
rb, rs, err := s.RequestReply(RPCSend, bytes.NewBuffer(b), nil) rb, rs, err := s.RequestReply(ctx, RPCSend, bytes.NewBuffer(b), nil)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -291,24 +292,25 @@ func (s RemoteEndpoint) Send(r *replication.SendReq) (*replication.SendRes, io.R
} }
var res replication.SendRes var res replication.SendRes
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil { if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
s.Close() // FIXME rs.Close()
return nil, nil, err return nil, nil, err
} }
// FIXME make sure the consumer will read the reader until the end... // FIXME make sure the consumer will read the reader until the end...
return &res, rs, nil return &res, rs, nil
} }
func (s RemoteEndpoint) Receive(r *replication.ReceiveReq, sendStream io.Reader) (error) { func (s RemoteEndpoint) Receive(ctx context.Context, r *replication.ReceiveReq, sendStream io.ReadCloser) (error) {
defer sendStream.Close()
b, err := proto.Marshal(r) b, err := proto.Marshal(r)
if err != nil { if err != nil {
return err return err
} }
rb, rs, err := s.RequestReply(RPCReceive, bytes.NewBuffer(b), sendStream) rb, rs, err := s.RequestReply(ctx, RPCReceive, bytes.NewBuffer(b), sendStream)
if err != nil { if err != nil {
s.Close() // FIXME
return err return err
} }
if rs != nil { if rs != nil {
rs.Close()
return errors.New("response contains unexpected stream") return errors.New("response contains unexpected stream")
} }
var res replication.ReceiveRes var res replication.ReceiveRes
@ -320,9 +322,16 @@ func (s RemoteEndpoint) Receive(r *replication.ReceiveReq, sendStream io.Reader)
type HandlerAdaptor struct { type HandlerAdaptor struct {
ep replication.ReplicationEndpoint ep replication.ReplicationEndpoint
log Logger
} }
func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, reqStream io.Reader) (resStructured *bytes.Buffer, resStream io.Reader, err error) { func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, reqStream io.ReadCloser) (resStructured *bytes.Buffer, resStream io.ReadCloser, err error) {
ctx := context.Background()
if a.log != nil {
// FIXME validate type conversion here?
ctx = context.WithValue(ctx, streamrpc.ContextKeyLogger, a.log)
}
switch endpoint { switch endpoint {
case RPCListFilesystems: case RPCListFilesystems:
@ -330,7 +339,7 @@ func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, re
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err return nil, nil, err
} }
fsses, err := a.ep.ListFilesystems() fsses, err := a.ep.ListFilesystems(ctx)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -349,7 +358,7 @@ func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, re
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err return nil, nil, err
} }
fsvs, err := a.ep.ListFilesystemVersions(req.Filesystem) fsvs, err := a.ep.ListFilesystemVersions(ctx, req.Filesystem)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -368,7 +377,7 @@ func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, re
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err return nil, nil, err
} }
res, sendStream, err := a.ep.Send(&req) res, sendStream, err := a.ep.Send(ctx, &req)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -384,7 +393,7 @@ func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, re
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil { if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err return nil, nil, err
} }
err := a.ep.Receive(&req, reqStream) err := a.ep.Receive(ctx, &req, reqStream)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View File

@ -7,8 +7,8 @@ import (
type ReplicationEndpoint interface { type ReplicationEndpoint interface {
// Does not include placeholder filesystems // Does not include placeholder filesystems
ListFilesystems() ([]*Filesystem, error) ListFilesystems(ctx context.Context) ([]*Filesystem, error)
ListFilesystemVersions(fs string) ([]*FilesystemVersion, error) // fix depS ListFilesystemVersions(ctx context.Context, fs string) ([]*FilesystemVersion, error) // fix depS
Sender Sender
Receiver Receiver
} }
@ -81,21 +81,22 @@ func Replicate(ctx context.Context, ep EndpointPair, ipr IncrementalPathReplicat
log := ctx.Value(ContextKeyLog).(Logger) log := ctx.Value(ContextKeyLog).(Logger)
sfss, err := ep.Sender().ListFilesystems() sfss, err := ep.Sender().ListFilesystems(ctx)
if err != nil { if err != nil {
log.Printf("error listing sender filesystems: %s", err) log.Printf("error listing sender filesystems: %s", err)
return return
} }
rfss, err := ep.Receiver().ListFilesystems() rfss, err := ep.Receiver().ListFilesystems(ctx)
if err != nil { if err != nil {
log.Printf("error listing receiver filesystems: %s", err) log.Printf("error listing receiver filesystems: %s", err)
return return
} }
for _, fs := range sfss { for _, fs := range sfss {
log.Printf("replication fs %s", fs.Path) log.Printf("replicating %s", fs.Path)
sfsvs, err := ep.Sender().ListFilesystemVersions(fs.Path)
sfsvs, err := ep.Sender().ListFilesystemVersions(ctx, fs.Path)
if err != nil { if err != nil {
log.Printf("sender error %s", err) log.Printf("sender error %s", err)
continue continue
@ -115,7 +116,7 @@ func Replicate(ctx context.Context, ep EndpointPair, ipr IncrementalPathReplicat
var rfsvs []*FilesystemVersion var rfsvs []*FilesystemVersion
if receiverFSExists { if receiverFSExists {
rfsvs, err = ep.Receiver().ListFilesystemVersions(fs.Path) rfsvs, err = ep.Receiver().ListFilesystemVersions(ctx, fs.Path)
if err != nil { if err != nil {
log.Printf("receiver error %s", err) log.Printf("receiver error %s", err)
if _, ok := err.(FilteredError); ok { if _, ok := err.(FilteredError); ok {
@ -162,11 +163,11 @@ func Replicate(ctx context.Context, ep EndpointPair, ipr IncrementalPathReplicat
} }
type Sender interface { type Sender interface {
Send(r *SendReq) (*SendRes, io.Reader, error) Send(ctx context.Context, r *SendReq) (*SendRes, io.ReadCloser, error)
} }
type Receiver interface { type Receiver interface {
Receive(r *ReceiveReq, sendStream io.Reader) (error) Receive(ctx context.Context, r *ReceiveReq, sendStream io.ReadCloser) (error)
} }
type Copier interface { type Copier interface {
@ -211,7 +212,7 @@ func (incrementalPathReplicator) Replicate(ctx context.Context, sender Sender, r
From: path[0].RelName(), From: path[0].RelName(),
ResumeToken: fs.ResumeToken, ResumeToken: fs.ResumeToken,
} }
sres, sstream, err := sender.Send(sr) sres, sstream, err := sender.Send(ctx, sr)
if err != nil { if err != nil {
log.Printf("send request failed: %s", err) log.Printf("send request failed: %s", err)
// FIXME must close connection... // FIXME must close connection...
@ -222,7 +223,7 @@ func (incrementalPathReplicator) Replicate(ctx context.Context, sender Sender, r
Filesystem: fs.Path, Filesystem: fs.Path,
ClearResumeToken: fs.ResumeToken != "" && !sres.UsedResumeToken, ClearResumeToken: fs.ResumeToken != "" && !sres.UsedResumeToken,
} }
err = receiver.Receive(rr, sstream) err = receiver.Receive(ctx, rr, sstream)
if err != nil { if err != nil {
// FIXME this failure could be due to an unexpected exit of ZFS on the sending side // FIXME this failure could be due to an unexpected exit of ZFS on the sending side
// FIXME which is transported through the streamrpc protocol, and known to the sendStream.(*streamrpc.streamReader), // FIXME which is transported through the streamrpc protocol, and known to the sendStream.(*streamrpc.streamReader),
@ -250,7 +251,7 @@ incrementalLoop:
To: path[j+1].RelName(), To: path[j+1].RelName(),
ResumeToken: rt, ResumeToken: rt,
} }
sres, sstream, err := sender.Send(sr) sres, sstream, err := sender.Send(ctx, sr)
if err != nil { if err != nil {
log.Printf("send request failed: %s", err) log.Printf("send request failed: %s", err)
// handle and ignore // handle and ignore
@ -262,7 +263,7 @@ incrementalLoop:
Filesystem: fs.Path, Filesystem: fs.Path,
ClearResumeToken: rt != "" && !sres.UsedResumeToken, ClearResumeToken: rt != "" && !sres.UsedResumeToken,
} }
err = receiver.Receive(rr, sstream) err = receiver.Receive(ctx, rr, sstream)
if err != nil { if err != nil {
log.Printf("receive request failed: %s", err) log.Printf("receive request failed: %s", err)
// handle and ignore // handle and ignore

View File

@ -11,7 +11,7 @@ import (
type IncrementalPathSequenceStep struct { type IncrementalPathSequenceStep struct {
SendRequest *replication.SendReq SendRequest *replication.SendReq
SendResponse *replication.SendRes SendResponse *replication.SendRes
SendReader io.Reader SendReader io.ReadCloser
SendError error SendError error
ReceiveRequest *replication.ReceiveReq ReceiveRequest *replication.ReceiveReq
ReceiveError error ReceiveError error
@ -23,7 +23,7 @@ type MockIncrementalPathRecorder struct {
Pos int Pos int
} }
func (m *MockIncrementalPathRecorder) Receive(r *replication.ReceiveReq, rs io.Reader) (error) { func (m *MockIncrementalPathRecorder) Receive(ctx context.Context, r *replication.ReceiveReq, rs io.ReadCloser) (error) {
if m.Pos >= len(m.Sequence) { if m.Pos >= len(m.Sequence) {
m.T.Fatal("unexpected Receive") m.T.Fatal("unexpected Receive")
} }
@ -35,7 +35,7 @@ func (m *MockIncrementalPathRecorder) Receive(r *replication.ReceiveReq, rs io.R
return i.ReceiveError return i.ReceiveError
} }
func (m *MockIncrementalPathRecorder) Send(r *replication.SendReq) (*replication.SendRes, io.Reader, error) { func (m *MockIncrementalPathRecorder) Send(ctx context.Context, r *replication.SendReq) (*replication.SendRes, io.ReadCloser, error) {
if m.Pos >= len(m.Sequence) { if m.Pos >= len(m.Sequence) {
m.T.Fatal("unexpected Send") m.T.Fatal("unexpected Send")
} }

View File

@ -278,7 +278,7 @@ func absVersion(fs, v string) (full string, err error) {
return fmt.Sprintf("%s%s", fs, v), nil return fmt.Sprintf("%s%s", fs, v), nil
} }
func ZFSSend(fs string, from, to string) (stream io.Reader, err error) { func ZFSSend(fs string, from, to string) (stream io.ReadCloser, err error) {
fromV, err := absVersion(fs, from) fromV, err := absVersion(fs, from)
if err != nil { if err != nil {