replication: context support and propert closing of stale readers

This commit is contained in:
Christian Schwarz 2018-07-08 23:31:46 +02:00
parent 8cca0a8547
commit 1a8d2c5ebe
4 changed files with 54 additions and 44 deletions

View File

@ -10,6 +10,7 @@ import (
"github.com/golang/protobuf/proto"
"bytes"
"os"
"context"
)
type InitialReplPolicy string
@ -31,7 +32,7 @@ func NewSenderEndpoint(fsf zfs.DatasetFilter, fsvf zfs.FilesystemVersionFilter)
return &SenderEndpoint{fsf, fsvf}
}
func (p *SenderEndpoint) ListFilesystems() ([]*replication.Filesystem, error) {
func (p *SenderEndpoint) ListFilesystems(ctx context.Context) ([]*replication.Filesystem, error) {
fss, err := zfs.ZFSListMapping(p.FSFilter)
if err != nil {
return nil, err
@ -46,7 +47,7 @@ func (p *SenderEndpoint) ListFilesystems() ([]*replication.Filesystem, error) {
return rfss, nil
}
func (p *SenderEndpoint) ListFilesystemVersions(fs string) ([]*replication.FilesystemVersion, error) {
func (p *SenderEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ([]*replication.FilesystemVersion, error) {
dp, err := zfs.NewDatasetPath(fs)
if err != nil {
return nil, err
@ -69,7 +70,7 @@ func (p *SenderEndpoint) ListFilesystemVersions(fs string) ([]*replication.Files
return rfsvs, nil
}
func (p *SenderEndpoint) Send(r *replication.SendReq) (*replication.SendRes, io.Reader, error) {
func (p *SenderEndpoint) Send(ctx context.Context, r *replication.SendReq) (*replication.SendRes, io.ReadCloser, error) {
os.Stderr.WriteString("sending " + r.String() + "\n")
dp, err := zfs.NewDatasetPath(r.Filesystem)
if err != nil {
@ -89,7 +90,7 @@ func (p *SenderEndpoint) Send(r *replication.SendReq) (*replication.SendRes, io.
return &replication.SendRes{}, stream, nil
}
func (p *SenderEndpoint) Receive(r *replication.ReceiveReq, sendStream io.Reader) (error) {
func (p *SenderEndpoint) Receive(ctx context.Context, r *replication.ReceiveReq, sendStream io.ReadCloser) (error) {
return fmt.Errorf("sender endpoint does not receive")
}
@ -109,7 +110,7 @@ func NewReceiverEndpoint(fsmap *DatasetMapFilter, fsvf zfs.FilesystemVersionFilt
return &ReceiverEndpoint{fsmapInv, fsmap, fsvf}, nil
}
func (e *ReceiverEndpoint) ListFilesystems() ([]*replication.Filesystem, error) {
func (e *ReceiverEndpoint) ListFilesystems(ctx context.Context) ([]*replication.Filesystem, error) {
filtered, err := zfs.ZFSListMapping(e.fsmapInv.AsFilter())
if err != nil {
return nil, errors.Wrap(err, "error checking client permission")
@ -125,7 +126,7 @@ func (e *ReceiverEndpoint) ListFilesystems() ([]*replication.Filesystem, error)
return fss, nil
}
func (e *ReceiverEndpoint) ListFilesystemVersions(fs string) ([]*replication.FilesystemVersion, error) {
func (e *ReceiverEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ([]*replication.FilesystemVersion, error) {
p, err := zfs.NewDatasetPath(fs)
if err != nil {
return nil, err
@ -151,11 +152,13 @@ func (e *ReceiverEndpoint) ListFilesystemVersions(fs string) ([]*replication.Fil
return rfsvs, nil
}
func (e *ReceiverEndpoint) Send(req *replication.SendReq) (*replication.SendRes, io.Reader, error) {
func (e *ReceiverEndpoint) Send(ctx context.Context, req *replication.SendReq) (*replication.SendRes, io.ReadCloser, error) {
return nil, nil, errors.New("receiver endpoint does not send")
}
func (e *ReceiverEndpoint) Receive(req *replication.ReceiveReq, sendStream io.Reader) error {
func (e *ReceiverEndpoint) Receive(ctx context.Context, req *replication.ReceiveReq, sendStream io.ReadCloser) error {
defer sendStream.Close()
p, err := zfs.NewDatasetPath(req.Filesystem)
if err != nil {
return err
@ -210,7 +213,6 @@ func (e *ReceiverEndpoint) Receive(req *replication.ReceiveReq, sendStream io.Re
os.Stderr.WriteString("receiving...\n")
if err := zfs.ZFSRecv(lp.ToString(), sendStream, args...); err != nil {
// FIXME sendStream is on the wire and contains data, if we don't consume it, wire must be closed
return err
}
return nil
@ -232,19 +234,18 @@ type RemoteEndpoint struct {
*streamrpc.Client
}
func (s RemoteEndpoint) ListFilesystems() ([]*replication.Filesystem, error) {
func (s RemoteEndpoint) ListFilesystems(ctx context.Context) ([]*replication.Filesystem, error) {
req := replication.ListFilesystemReq{}
b, err := proto.Marshal(&req)
if err != nil {
return nil, err
}
rb, rs, err := s.RequestReply(RPCListFilesystems, bytes.NewBuffer(b), nil)
rb, rs, err := s.RequestReply(ctx, RPCListFilesystems, bytes.NewBuffer(b), nil)
if err != nil {
return nil, err
}
if rs != nil {
os.Stderr.WriteString(fmt.Sprintf("%#v\n", rs))
s.Close() // FIXME
rs.Close()
return nil, errors.New("response contains unexpected stream")
}
var res replication.ListFilesystemRes
@ -254,7 +255,7 @@ func (s RemoteEndpoint) ListFilesystems() ([]*replication.Filesystem, error) {
return res.Filesystems, nil
}
func (s RemoteEndpoint) ListFilesystemVersions(fs string) ([]*replication.FilesystemVersion, error) {
func (s RemoteEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ([]*replication.FilesystemVersion, error) {
req := replication.ListFilesystemVersionsReq{
Filesystem: fs,
}
@ -262,12 +263,12 @@ func (s RemoteEndpoint) ListFilesystemVersions(fs string) ([]*replication.Filesy
if err != nil {
return nil, err
}
rb, rs, err := s.RequestReply(RPCListFilesystemVersions, bytes.NewBuffer(b), nil)
rb, rs, err := s.RequestReply(ctx, RPCListFilesystemVersions, bytes.NewBuffer(b), nil)
if err != nil {
return nil, err
}
if rs != nil {
s.Close() // FIXME
rs.Close()
return nil, errors.New("response contains unexpected stream")
}
var res replication.ListFilesystemVersionsRes
@ -277,12 +278,12 @@ func (s RemoteEndpoint) ListFilesystemVersions(fs string) ([]*replication.Filesy
return res.Versions, nil
}
func (s RemoteEndpoint) Send(r *replication.SendReq) (*replication.SendRes, io.Reader, error) {
func (s RemoteEndpoint) Send(ctx context.Context, r *replication.SendReq) (*replication.SendRes, io.ReadCloser, error) {
b, err := proto.Marshal(r)
if err != nil {
return nil, nil, err
}
rb, rs, err := s.RequestReply(RPCSend, bytes.NewBuffer(b), nil)
rb, rs, err := s.RequestReply(ctx, RPCSend, bytes.NewBuffer(b), nil)
if err != nil {
return nil, nil, err
}
@ -291,24 +292,25 @@ func (s RemoteEndpoint) Send(r *replication.SendReq) (*replication.SendRes, io.R
}
var res replication.SendRes
if err := proto.Unmarshal(rb.Bytes(), &res); err != nil {
s.Close() // FIXME
rs.Close()
return nil, nil, err
}
// FIXME make sure the consumer will read the reader until the end...
return &res, rs, nil
}
func (s RemoteEndpoint) Receive(r *replication.ReceiveReq, sendStream io.Reader) (error) {
func (s RemoteEndpoint) Receive(ctx context.Context, r *replication.ReceiveReq, sendStream io.ReadCloser) (error) {
defer sendStream.Close()
b, err := proto.Marshal(r)
if err != nil {
return err
}
rb, rs, err := s.RequestReply(RPCReceive, bytes.NewBuffer(b), sendStream)
rb, rs, err := s.RequestReply(ctx, RPCReceive, bytes.NewBuffer(b), sendStream)
if err != nil {
s.Close() // FIXME
return err
}
if rs != nil {
rs.Close()
return errors.New("response contains unexpected stream")
}
var res replication.ReceiveRes
@ -320,9 +322,16 @@ func (s RemoteEndpoint) Receive(r *replication.ReceiveReq, sendStream io.Reader)
type HandlerAdaptor struct {
ep replication.ReplicationEndpoint
log Logger
}
func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, reqStream io.Reader) (resStructured *bytes.Buffer, resStream io.Reader, err error) {
func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, reqStream io.ReadCloser) (resStructured *bytes.Buffer, resStream io.ReadCloser, err error) {
ctx := context.Background()
if a.log != nil {
// FIXME validate type conversion here?
ctx = context.WithValue(ctx, streamrpc.ContextKeyLogger, a.log)
}
switch endpoint {
case RPCListFilesystems:
@ -330,7 +339,7 @@ func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, re
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err
}
fsses, err := a.ep.ListFilesystems()
fsses, err := a.ep.ListFilesystems(ctx)
if err != nil {
return nil, nil, err
}
@ -349,7 +358,7 @@ func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, re
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err
}
fsvs, err := a.ep.ListFilesystemVersions(req.Filesystem)
fsvs, err := a.ep.ListFilesystemVersions(ctx, req.Filesystem)
if err != nil {
return nil, nil, err
}
@ -368,7 +377,7 @@ func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, re
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err
}
res, sendStream, err := a.ep.Send(&req)
res, sendStream, err := a.ep.Send(ctx, &req)
if err != nil {
return nil, nil, err
}
@ -384,7 +393,7 @@ func (a *HandlerAdaptor) Handle(endpoint string, reqStructured *bytes.Buffer, re
if err := proto.Unmarshal(reqStructured.Bytes(), &req); err != nil {
return nil, nil, err
}
err := a.ep.Receive(&req, reqStream)
err := a.ep.Receive(ctx, &req, reqStream)
if err != nil {
return nil, nil, err
}

View File

@ -7,8 +7,8 @@ import (
type ReplicationEndpoint interface {
// Does not include placeholder filesystems
ListFilesystems() ([]*Filesystem, error)
ListFilesystemVersions(fs string) ([]*FilesystemVersion, error) // fix depS
ListFilesystems(ctx context.Context) ([]*Filesystem, error)
ListFilesystemVersions(ctx context.Context, fs string) ([]*FilesystemVersion, error) // fix depS
Sender
Receiver
}
@ -81,21 +81,22 @@ func Replicate(ctx context.Context, ep EndpointPair, ipr IncrementalPathReplicat
log := ctx.Value(ContextKeyLog).(Logger)
sfss, err := ep.Sender().ListFilesystems()
sfss, err := ep.Sender().ListFilesystems(ctx)
if err != nil {
log.Printf("error listing sender filesystems: %s", err)
return
}
rfss, err := ep.Receiver().ListFilesystems()
rfss, err := ep.Receiver().ListFilesystems(ctx)
if err != nil {
log.Printf("error listing receiver filesystems: %s", err)
return
}
for _, fs := range sfss {
log.Printf("replication fs %s", fs.Path)
sfsvs, err := ep.Sender().ListFilesystemVersions(fs.Path)
log.Printf("replicating %s", fs.Path)
sfsvs, err := ep.Sender().ListFilesystemVersions(ctx, fs.Path)
if err != nil {
log.Printf("sender error %s", err)
continue
@ -115,7 +116,7 @@ func Replicate(ctx context.Context, ep EndpointPair, ipr IncrementalPathReplicat
var rfsvs []*FilesystemVersion
if receiverFSExists {
rfsvs, err = ep.Receiver().ListFilesystemVersions(fs.Path)
rfsvs, err = ep.Receiver().ListFilesystemVersions(ctx, fs.Path)
if err != nil {
log.Printf("receiver error %s", err)
if _, ok := err.(FilteredError); ok {
@ -162,11 +163,11 @@ func Replicate(ctx context.Context, ep EndpointPair, ipr IncrementalPathReplicat
}
type Sender interface {
Send(r *SendReq) (*SendRes, io.Reader, error)
Send(ctx context.Context, r *SendReq) (*SendRes, io.ReadCloser, error)
}
type Receiver interface {
Receive(r *ReceiveReq, sendStream io.Reader) (error)
Receive(ctx context.Context, r *ReceiveReq, sendStream io.ReadCloser) (error)
}
type Copier interface {
@ -211,7 +212,7 @@ func (incrementalPathReplicator) Replicate(ctx context.Context, sender Sender, r
From: path[0].RelName(),
ResumeToken: fs.ResumeToken,
}
sres, sstream, err := sender.Send(sr)
sres, sstream, err := sender.Send(ctx, sr)
if err != nil {
log.Printf("send request failed: %s", err)
// FIXME must close connection...
@ -222,7 +223,7 @@ func (incrementalPathReplicator) Replicate(ctx context.Context, sender Sender, r
Filesystem: fs.Path,
ClearResumeToken: fs.ResumeToken != "" && !sres.UsedResumeToken,
}
err = receiver.Receive(rr, sstream)
err = receiver.Receive(ctx, rr, sstream)
if err != nil {
// FIXME this failure could be due to an unexpected exit of ZFS on the sending side
// FIXME which is transported through the streamrpc protocol, and known to the sendStream.(*streamrpc.streamReader),
@ -250,7 +251,7 @@ incrementalLoop:
To: path[j+1].RelName(),
ResumeToken: rt,
}
sres, sstream, err := sender.Send(sr)
sres, sstream, err := sender.Send(ctx, sr)
if err != nil {
log.Printf("send request failed: %s", err)
// handle and ignore
@ -262,7 +263,7 @@ incrementalLoop:
Filesystem: fs.Path,
ClearResumeToken: rt != "" && !sres.UsedResumeToken,
}
err = receiver.Receive(rr, sstream)
err = receiver.Receive(ctx, rr, sstream)
if err != nil {
log.Printf("receive request failed: %s", err)
// handle and ignore

View File

@ -11,7 +11,7 @@ import (
type IncrementalPathSequenceStep struct {
SendRequest *replication.SendReq
SendResponse *replication.SendRes
SendReader io.Reader
SendReader io.ReadCloser
SendError error
ReceiveRequest *replication.ReceiveReq
ReceiveError error
@ -23,7 +23,7 @@ type MockIncrementalPathRecorder struct {
Pos int
}
func (m *MockIncrementalPathRecorder) Receive(r *replication.ReceiveReq, rs io.Reader) (error) {
func (m *MockIncrementalPathRecorder) Receive(ctx context.Context, r *replication.ReceiveReq, rs io.ReadCloser) (error) {
if m.Pos >= len(m.Sequence) {
m.T.Fatal("unexpected Receive")
}
@ -35,7 +35,7 @@ func (m *MockIncrementalPathRecorder) Receive(r *replication.ReceiveReq, rs io.R
return i.ReceiveError
}
func (m *MockIncrementalPathRecorder) Send(r *replication.SendReq) (*replication.SendRes, io.Reader, error) {
func (m *MockIncrementalPathRecorder) Send(ctx context.Context, r *replication.SendReq) (*replication.SendRes, io.ReadCloser, error) {
if m.Pos >= len(m.Sequence) {
m.T.Fatal("unexpected Send")
}

View File

@ -278,7 +278,7 @@ func absVersion(fs, v string) (full string, err error) {
return fmt.Sprintf("%s%s", fs, v), nil
}
func ZFSSend(fs string, from, to string) (stream io.Reader, err error) {
func ZFSSend(fs string, from, to string) (stream io.ReadCloser, err error) {
fromV, err := absVersion(fs, from)
if err != nil {