diff --git a/cmd/config_job_local.go b/cmd/config_job_local.go index 4cb7091..c168c18 100644 --- a/cmd/config_job_local.go +++ b/cmd/config_job_local.go @@ -104,9 +104,9 @@ func (j *LocalJob) JobStart(ctx context.Context) { // We can pay this small performance penalty for now. wildcardMapFilter := NewDatasetMapFilter(1, false) wildcardMapFilter.Add("<", "<") - sender := &endpoint.SenderEndpoint{wildcardMapFilter, NewPrefixFilter(j.SnapshotPrefix)} + sender := &endpoint.Sender{wildcardMapFilter, NewPrefixFilter(j.SnapshotPrefix)} - receiver, err := endpoint.NewReceiverEndpoint(j.Mapping, NewPrefixFilter(j.SnapshotPrefix)) + receiver, err := endpoint.NewReceiver(j.Mapping, NewPrefixFilter(j.SnapshotPrefix)) if err != nil { rootLog.WithError(err).Error("unexpected error setting up local handler") } diff --git a/cmd/config_job_pull.go b/cmd/config_job_pull.go index 6fda29e..255e47c 100644 --- a/cmd/config_job_pull.go +++ b/cmd/config_job_pull.go @@ -8,7 +8,6 @@ import ( "time" "context" - "fmt" "github.com/mitchellh/mapstructure" "github.com/pkg/errors" @@ -176,9 +175,9 @@ func (j *PullJob) doRun(ctx context.Context) { j.task.Enter("pull") - sender := endpoint.RemoteEndpoint{client} + sender := endpoint.NewRemote(client) - puller, err := endpoint.NewReceiverEndpoint( + puller, err := endpoint.NewReceiver( j.Mapping, NewPrefixFilter(j.SnapshotPrefix), ) @@ -188,10 +187,9 @@ func (j *PullJob) doRun(ctx context.Context) { return } - ctx = replication.WithLogger(ctx, replicationLogAdaptor{j.task.Log().WithField("subsystem", "replication")}) ctx = streamrpc.ContextWithLogger(ctx, streamrpcLogAdaptor{j.task.Log().WithField("subsystem", "rpc.protocol")}) - ctx = context.WithValue(ctx, contextKeyLog, j.task.Log().WithField("subsystem", "rpc.endpoint")) + ctx = endpoint.WithLogger(ctx, j.task.Log().WithField("subsystem", "rpc.endpoint")) j.rep = replication.NewReplication() j.rep.Drive(ctx, sender, puller) @@ -229,28 +227,3 @@ func (j *PullJob) Pruner(task *Task, side PrunePolicySide, dryRun bool) (p Prune } return } - -func closeRPCWithTimeout(task *Task, remote endpoint.RemoteEndpoint, timeout time.Duration, goodbye string) { - - task.Log().Info("closing rpc connection") - - ch := make(chan error) - go func() { - remote.Close() - ch <- nil - close(ch) - }() - - var err error - select { - case <-time.After(timeout): - err = fmt.Errorf("timeout exceeded (%s)", timeout) - case closeRequestErr := <-ch: - err = closeRequestErr - } - - if err != nil { - task.Log().WithError(err).Error("error closing connection") - } - return -} diff --git a/cmd/config_job_source.go b/cmd/config_job_source.go index 5a4405a..b91179d 100644 --- a/cmd/config_job_source.go +++ b/cmd/config_job_source.go @@ -212,12 +212,12 @@ func (j *SourceJob) handleConnection(conn net.Conn, task *Task) { task.Log().Info("handling client connection") - senderEP := endpoint.NewSenderEndpoint(j.Filesystems, NewPrefixFilter(j.SnapshotPrefix)) + senderEP := endpoint.NewSender(j.Filesystems, NewPrefixFilter(j.SnapshotPrefix)) ctx := context.Background() ctx = context.WithValue(ctx, contextKeyLog, task.Log().WithField("subsystem", "rpc.endpoint")) ctx = streamrpc.ContextWithLogger(ctx, streamrpcLogAdaptor{task.Log().WithField("subsystem", "rpc.protocol")}) - handler := endpoint.NewHandlerAdaptor(senderEP) + handler := endpoint.NewHandler(senderEP) if err := streamrpc.ServeConn(ctx, conn, STREAMRPC_CONFIG, handler.Handle); err != nil { task.Log().WithError(err).Error("error serving connection") } else { diff --git a/cmd/endpoint/endpoint.go b/cmd/endpoint/endpoint.go index 9ecbf36..9f31412 100644 --- a/cmd/endpoint/endpoint.go +++ b/cmd/endpoint/endpoint.go @@ -1,7 +1,7 @@ +// Package endpoint implements replication endpoints for use with package replication. package endpoint import ( - "fmt" "github.com/zrepl/zrepl/replication/pdu" "github.com/problame/go-streamrpc" "github.com/zrepl/zrepl/zfs" @@ -24,17 +24,17 @@ const ( // FIXME: remove this const DEFAULT_INITIAL_REPL_POLICY = InitialReplPolicyMostRecent -// SenderEndpoint implements replication.ReplicationEndpoint for a sending side -type SenderEndpoint struct { +// Sender implements replication.ReplicationEndpoint for a sending side +type Sender struct { FSFilter zfs.DatasetFilter FilesystemVersionFilter zfs.FilesystemVersionFilter } -func NewSenderEndpoint(fsf zfs.DatasetFilter, fsvf zfs.FilesystemVersionFilter) *SenderEndpoint { - return &SenderEndpoint{fsf, fsvf} +func NewSender(fsf zfs.DatasetFilter, fsvf zfs.FilesystemVersionFilter) *Sender { + return &Sender{fsf, fsvf} } -func (p *SenderEndpoint) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { +func (p *Sender) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { fss, err := zfs.ZFSListMapping(p.FSFilter) if err != nil { return nil, err @@ -49,7 +49,7 @@ func (p *SenderEndpoint) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem return rfss, nil } -func (p *SenderEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { +func (p *Sender) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { dp, err := zfs.NewDatasetPath(fs) if err != nil { return nil, err @@ -72,7 +72,7 @@ func (p *SenderEndpoint) ListFilesystemVersions(ctx context.Context, fs string) return rfsvs, nil } -func (p *SenderEndpoint) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { +func (p *Sender) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { dp, err := zfs.NewDatasetPath(r.Filesystem) if err != nil { return nil, nil, err @@ -91,10 +91,6 @@ func (p *SenderEndpoint) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes return &pdu.SendRes{}, stream, nil } -func (p *SenderEndpoint) Receive(ctx context.Context, r *pdu.ReceiveReq, sendStream io.ReadCloser) (error) { - return fmt.Errorf("sender endpoint does not receive") -} - type FSFilter interface { Filter(path *zfs.DatasetPath) (pass bool, err error) } @@ -107,22 +103,22 @@ type FSMap interface { AsFilter() (FSFilter) } -// ReceiverEndpoint implements replication.ReplicationEndpoint for a receiving side -type ReceiverEndpoint struct { +// Receiver implements replication.ReplicationEndpoint for a receiving side +type Receiver struct { fsmapInv FSMap fsmap FSMap fsvf zfs.FilesystemVersionFilter } -func NewReceiverEndpoint(fsmap FSMap, fsvf zfs.FilesystemVersionFilter) (*ReceiverEndpoint, error) { +func NewReceiver(fsmap FSMap, fsvf zfs.FilesystemVersionFilter) (*Receiver, error) { fsmapInv, err := fsmap.Invert() if err != nil { return nil, err } - return &ReceiverEndpoint{fsmapInv, fsmap, fsvf}, nil + return &Receiver{fsmapInv, fsmap, fsvf}, nil } -func (e *ReceiverEndpoint) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { +func (e *Receiver) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { filtered, err := zfs.ZFSListMapping(e.fsmapInv.AsFilter()) if err != nil { return nil, errors.Wrap(err, "error checking client permission") @@ -138,7 +134,7 @@ func (e *ReceiverEndpoint) ListFilesystems(ctx context.Context) ([]*pdu.Filesyst return fss, nil } -func (e *ReceiverEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { +func (e *Receiver) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { p, err := zfs.NewDatasetPath(fs) if err != nil { return nil, err @@ -164,11 +160,7 @@ func (e *ReceiverEndpoint) ListFilesystemVersions(ctx context.Context, fs string return rfsvs, nil } -func (e *ReceiverEndpoint) Send(ctx context.Context, req *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { - return nil, nil, errors.New("receiver endpoint does not send") -} - -func (e *ReceiverEndpoint) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream io.ReadCloser) error { +func (e *Receiver) Receive(ctx context.Context, req *pdu.ReceiveReq, sendStream io.ReadCloser) error { defer sendStream.Close() p, err := zfs.NewDatasetPath(req.Filesystem) @@ -246,17 +238,22 @@ const ( RPCSend = "Send" ) -type RemoteEndpoint struct { - *streamrpc.Client +// Remote implements an endpoint stub that uses streamrpc as a transport. +type Remote struct { + c *streamrpc.Client } -func (s RemoteEndpoint) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { +func NewRemote(c *streamrpc.Client) Remote { + return Remote{c} +} + +func (s Remote) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, error) { req := pdu.ListFilesystemReq{} b, err := proto.Marshal(&req) if err != nil { return nil, err } - rb, rs, err := s.RequestReply(ctx, RPCListFilesystems, bytes.NewBuffer(b), nil) + rb, rs, err := s.c.RequestReply(ctx, RPCListFilesystems, bytes.NewBuffer(b), nil) if err != nil { return nil, err } @@ -271,7 +268,7 @@ func (s RemoteEndpoint) ListFilesystems(ctx context.Context) ([]*pdu.Filesystem, return res.Filesystems, nil } -func (s RemoteEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { +func (s Remote) ListFilesystemVersions(ctx context.Context, fs string) ([]*pdu.FilesystemVersion, error) { req := pdu.ListFilesystemVersionsReq{ Filesystem: fs, } @@ -279,7 +276,7 @@ func (s RemoteEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ( if err != nil { return nil, err } - rb, rs, err := s.RequestReply(ctx, RPCListFilesystemVersions, bytes.NewBuffer(b), nil) + rb, rs, err := s.c.RequestReply(ctx, RPCListFilesystemVersions, bytes.NewBuffer(b), nil) if err != nil { return nil, err } @@ -294,12 +291,12 @@ func (s RemoteEndpoint) ListFilesystemVersions(ctx context.Context, fs string) ( return res.Versions, nil } -func (s RemoteEndpoint) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { +func (s Remote) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, io.ReadCloser, error) { b, err := proto.Marshal(r) if err != nil { return nil, nil, err } - rb, rs, err := s.RequestReply(ctx, RPCSend, bytes.NewBuffer(b), nil) + rb, rs, err := s.c.RequestReply(ctx, RPCSend, bytes.NewBuffer(b), nil) if err != nil { return nil, nil, err } @@ -315,13 +312,13 @@ func (s RemoteEndpoint) Send(ctx context.Context, r *pdu.SendReq) (*pdu.SendRes, return &res, rs, nil } -func (s RemoteEndpoint) Receive(ctx context.Context, r *pdu.ReceiveReq, sendStream io.ReadCloser) (error) { +func (s Remote) Receive(ctx context.Context, r *pdu.ReceiveReq, sendStream io.ReadCloser) (error) { defer sendStream.Close() b, err := proto.Marshal(r) if err != nil { return err } - rb, rs, err := s.RequestReply(ctx, RPCReceive, bytes.NewBuffer(b), sendStream) + rb, rs, err := s.c.RequestReply(ctx, RPCReceive, bytes.NewBuffer(b), sendStream) if err != nil { return err } @@ -336,15 +333,16 @@ func (s RemoteEndpoint) Receive(ctx context.Context, r *pdu.ReceiveReq, sendStre return nil } -type HandlerAdaptor struct { +// Handler implements the server-side streamrpc.HandlerFunc for a Remote endpoint stub. +type Handler struct { ep replication.Endpoint } -func NewHandlerAdaptor(ep replication.Endpoint) HandlerAdaptor { - return HandlerAdaptor{ep} +func NewHandler(ep replication.Endpoint) Handler { + return Handler{ep} } -func (a *HandlerAdaptor) Handle(ctx context.Context, endpoint string, reqStructured *bytes.Buffer, reqStream io.ReadCloser) (resStructured *bytes.Buffer, resStream io.ReadCloser, err error) { +func (a *Handler) Handle(ctx context.Context, endpoint string, reqStructured *bytes.Buffer, reqStream io.ReadCloser) (resStructured *bytes.Buffer, resStream io.ReadCloser, err error) { switch endpoint { case RPCListFilesystems: