reimplement io.ReadWriteCloser based RPC mechanism

The existing ByteStreamRPC requires writing RPC stub + server code
for each RPC endpoint. Does not scale well.

Goal: adding a new RPC call should

- not require writing an RPC stub / handler
- not require modifications to the RPC lib

The wire format is inspired by HTTP2, the API by net/rpc.

Frames are used for framing messages, i.e. a message is made of multiple
frames which are glued together using a frame-bridging reader / writer.
This roughly corresponds to HTTP2 streams, although we're happy with
just one stream at any time and the resulting non-need for flow control,
etc.

Frames are typed using a header. The two most important types are
'Header' and 'Data'.

The RPC protocol is built on top of this:

- Client sends a header         => multiple frames of type 'header'
- Client sends request body     => mulitiple frames of type 'data'
- Server reads a header         => multiple frames of type 'header'
- Server reads request body     => mulitiple frames of type 'data'
- Server sends response header  => ...
- Server sends response body    => ...

An RPC header is serialized JSON and always the same structure.
The body is of the type specified in the header.

The RPC server and client use some semi-fancy reflection tequniques to
automatically infer the data type of the request/response body based on
the method signature of the server handler; or the client parameters,
respectively.
This boils down to a special-case for io.Reader, which are just dumped
into a series of data frames as efficiently as possible.
All other types are (de)serialized using encoding/json.

The RPC layer and Frame Layer log some arbitrary messages that proved
useful during debugging. By default, they log to a non-logger, which
should not have a big impact on performance.

pprof analysis shows the implementation spends its CPU time
        60% waiting for syscalls
        30% in memmove
        10% ...

On a Intel(R) Core(TM) i7-6600U CPU @ 2.60GHz CPU, Linux 4.12, the
implementation achieved ~3.6GiB/s.

Future optimization may include spice(2) / vmspice(2) on Linux, although
this doesn't fit so well with the heavy use of io.Reader / io.Writer
throughout the codebase.

The existing hackaround for local calls was re-implemented to fit the
new interface of PRCServer and RPCClient.
The 'R'PC method invocation is a bit slower because reflection is
involved inbetween, but otherwise performance should be no different.

The RPC code currently does not support multipart requests and thus does
not support the equivalent of a POST.

Thus, the switch to the new rpc code had the following fallout:

- Move request objects + constants from rpc package to main app code
- Sacrifice the hacky 'push = pull me' way of doing push
-> need to further extend RPC to support multipart requests or
     something to implement this properly with additional interfaces
-> should be done after replication is abstracted better than separate
     algorithms for doPull() and doPush()
This commit is contained in:
Christian Schwarz 2017-08-19 22:37:14 +02:00
parent e5b713ce5b
commit 6ab05ee1fa
12 changed files with 938 additions and 815 deletions

View File

@ -32,10 +32,9 @@ type Remote struct {
} }
type Transport interface { type Transport interface {
Connect(rpcLog Logger) (rpc.RPCRequester, error) Connect(rpcLog Logger) (rpc.RPCClient, error)
} }
type LocalTransport struct { type LocalTransport struct {
Handler rpc.RPCHandler
} }
type SSHTransport struct { type SSHTransport struct {
Host string Host string
@ -53,14 +52,14 @@ type Push struct {
JobName string // for use with jobrun package JobName string // for use with jobrun package
To *Remote To *Remote
Filter zfs.DatasetFilter Filter zfs.DatasetFilter
InitialReplPolicy rpc.InitialReplPolicy InitialReplPolicy InitialReplPolicy
RepeatStrategy jobrun.RepeatStrategy RepeatStrategy jobrun.RepeatStrategy
} }
type Pull struct { type Pull struct {
JobName string // for use with jobrun package JobName string // for use with jobrun package
From *Remote From *Remote
Mapping DatasetMapFilter Mapping DatasetMapFilter
InitialReplPolicy rpc.InitialReplPolicy InitialReplPolicy InitialReplPolicy
RepeatStrategy jobrun.RepeatStrategy RepeatStrategy jobrun.RepeatStrategy
} }
@ -161,8 +160,8 @@ func parseRemotes(v interface{}) (remotes map[string]*Remote, err error) {
remotes = make(map[string]*Remote, len(asMap)) remotes = make(map[string]*Remote, len(asMap))
for name, p := range asMap { for name, p := range asMap {
if name == rpc.LOCAL_TRANSPORT_IDENTITY { if name == LOCAL_TRANSPORT_IDENTITY {
err = errors.New(fmt.Sprintf("remote name '%s' reserved for local pulls", rpc.LOCAL_TRANSPORT_IDENTITY)) err = errors.New(fmt.Sprintf("remote name '%s' reserved for local pulls", LOCAL_TRANSPORT_IDENTITY))
return return
} }
@ -238,7 +237,7 @@ func parsePushs(v interface{}, rl remoteLookup) (p map[string]*Push, err error)
return return
} }
if push.InitialReplPolicy, err = parseInitialReplPolicy(e.InitialReplPolicy, rpc.DEFAULT_INITIAL_REPL_POLICY); err != nil { if push.InitialReplPolicy, err = parseInitialReplPolicy(e.InitialReplPolicy, DEFAULT_INITIAL_REPL_POLICY); err != nil {
return return
} }
@ -276,9 +275,9 @@ func parsePulls(v interface{}, rl remoteLookup) (p map[string]*Pull, err error)
var fromRemote *Remote var fromRemote *Remote
if e.From == rpc.LOCAL_TRANSPORT_IDENTITY { if e.From == LOCAL_TRANSPORT_IDENTITY {
fromRemote = &Remote{ fromRemote = &Remote{
Name: rpc.LOCAL_TRANSPORT_IDENTITY, Name: LOCAL_TRANSPORT_IDENTITY,
Transport: LocalTransport{}, Transport: LocalTransport{},
} }
} else { } else {
@ -296,7 +295,7 @@ func parsePulls(v interface{}, rl remoteLookup) (p map[string]*Pull, err error)
if pull.Mapping, err = parseDatasetMapFilter(e.Mapping, false); err != nil { if pull.Mapping, err = parseDatasetMapFilter(e.Mapping, false); err != nil {
return return
} }
if pull.InitialReplPolicy, err = parseInitialReplPolicy(e.InitialReplPolicy, rpc.DEFAULT_INITIAL_REPL_POLICY); err != nil { if pull.InitialReplPolicy, err = parseInitialReplPolicy(e.InitialReplPolicy, DEFAULT_INITIAL_REPL_POLICY); err != nil {
return return
} }
if pull.RepeatStrategy, err = parseRepeatStrategy(e.Repeat); err != nil { if pull.RepeatStrategy, err = parseRepeatStrategy(e.Repeat); err != nil {
@ -309,7 +308,7 @@ func parsePulls(v interface{}, rl remoteLookup) (p map[string]*Pull, err error)
return return
} }
func parseInitialReplPolicy(v interface{}, defaultPolicy rpc.InitialReplPolicy) (p rpc.InitialReplPolicy, err error) { func parseInitialReplPolicy(v interface{}, defaultPolicy InitialReplPolicy) (p InitialReplPolicy, err error) {
s, ok := v.(string) s, ok := v.(string)
if !ok { if !ok {
goto err goto err
@ -319,9 +318,9 @@ func parseInitialReplPolicy(v interface{}, defaultPolicy rpc.InitialReplPolicy)
case s == "": case s == "":
p = defaultPolicy p = defaultPolicy
case s == "most_recent": case s == "most_recent":
p = rpc.InitialReplPolicyMostRecent p = InitialReplPolicyMostRecent
case s == "all": case s == "all":
p = rpc.InitialReplPolicyAll p = InitialReplPolicyAll
default: default:
goto err goto err
} }
@ -434,7 +433,7 @@ func parseDatasetMapFilter(mi interface{}, filterOnly bool) (f DatasetMapFilter,
return return
} }
func (t SSHTransport) Connect(rpcLog Logger) (r rpc.RPCRequester, err error) { func (t SSHTransport) Connect(rpcLog Logger) (r rpc.RPCClient, err error) {
var stream io.ReadWriteCloser var stream io.ReadWriteCloser
var rpcTransport sshbytestream.SSHTransport var rpcTransport sshbytestream.SSHTransport
if err = copier.Copy(&rpcTransport, t); err != nil { if err = copier.Copy(&rpcTransport, t); err != nil {
@ -447,18 +446,23 @@ func (t SSHTransport) Connect(rpcLog Logger) (r rpc.RPCRequester, err error) {
if err != nil { if err != nil {
return return
} }
return rpc.ConnectByteStreamRPC(stream, rpcLog) client := rpc.NewClient(stream)
return client, nil
} }
func (t LocalTransport) Connect(rpcLog Logger) (r rpc.RPCRequester, err error) { func (t LocalTransport) Connect(rpcLog Logger) (r rpc.RPCClient, err error) {
if t.Handler == nil { local := rpc.NewLocalRPC()
panic("local transport with uninitialized handler") handler := Handler{
Logger: log,
// Allow access to any dataset since we control what mapping
// is passed to the pull routine.
// All local datasets will be passed to its Map() function,
// but only those for which a mapping exists will actually be pulled.
// We can pay this small performance penalty for now.
PullACL: localPullACL{},
} }
return rpc.ConnectLocalRPC(t.Handler), nil registerEndpoints(local, handler)
} return local, nil
func (t *LocalTransport) SetHandler(handler rpc.RPCHandler) {
t.Handler = handler
} }
func parsePrunes(m interface{}) (rets map[string]*Prune, err error) { func parsePrunes(m interface{}) (rets map[string]*Prune, err error) {

View File

@ -2,38 +2,79 @@ package cmd
import ( import (
"fmt" "fmt"
"io"
"github.com/zrepl/zrepl/rpc" "github.com/zrepl/zrepl/rpc"
"github.com/zrepl/zrepl/zfs" "github.com/zrepl/zrepl/zfs"
"io"
) )
type DatasetMapping interface { type DatasetMapping interface {
Map(source *zfs.DatasetPath) (target *zfs.DatasetPath, err error) Map(source *zfs.DatasetPath) (target *zfs.DatasetPath, err error)
} }
type FilesystemRequest struct {
Roots []string // may be nil, indicating interest in all filesystems
}
type FilesystemVersionsRequest struct {
Filesystem *zfs.DatasetPath
}
type InitialTransferRequest struct {
Filesystem *zfs.DatasetPath
FilesystemVersion zfs.FilesystemVersion
}
type IncrementalTransferRequest struct {
Filesystem *zfs.DatasetPath
From zfs.FilesystemVersion
To zfs.FilesystemVersion
}
type Handler struct { type Handler struct {
Logger Logger Logger Logger
PullACL zfs.DatasetFilter PullACL zfs.DatasetFilter
SinkMappingFunc func(clientIdentity string) (mapping DatasetMapping, err error) SinkMappingFunc func(clientIdentity string) (mapping DatasetMapping, err error)
} }
func (h Handler) HandleFilesystemRequest(r rpc.FilesystemRequest) (roots []*zfs.DatasetPath, err error) { func registerEndpoints(server rpc.RPCServer, handler Handler) (err error) {
err = server.RegisterEndpoint("FilesystemRequest", handler.HandleFilesystemRequest)
if err != nil {
panic(err)
}
err = server.RegisterEndpoint("FilesystemVersionsRequest", handler.HandleFilesystemVersionsRequest)
if err != nil {
panic(err)
}
err = server.RegisterEndpoint("InitialTransferRequest", handler.HandleInitialTransferRequest)
if err != nil {
panic(err)
}
err = server.RegisterEndpoint("IncrementalTransferRequest", handler.HandleIncrementalTransferRequest)
if err != nil {
panic(err)
}
return nil
}
func (h Handler) HandleFilesystemRequest(r *FilesystemRequest, roots *[]*zfs.DatasetPath) (err error) {
h.Logger.Printf("handling fsr: %#v", r) h.Logger.Printf("handling fsr: %#v", r)
h.Logger.Printf("using PullACL: %#v", h.PullACL) h.Logger.Printf("using PullACL: %#v", h.PullACL)
if roots, err = zfs.ZFSListMapping(h.PullACL); err != nil { allowed, err := zfs.ZFSListMapping(h.PullACL)
if err != nil {
h.Logger.Printf("handle fsr err: %v\n", err) h.Logger.Printf("handle fsr err: %v\n", err)
return return
} }
h.Logger.Printf("returning: %#v", roots) h.Logger.Printf("returning: %#v", allowed)
*roots = allowed
return return
} }
func (h Handler) HandleFilesystemVersionsRequest(r rpc.FilesystemVersionsRequest) (versions []zfs.FilesystemVersion, err error) { func (h Handler) HandleFilesystemVersionsRequest(r *FilesystemVersionsRequest, versions *[]zfs.FilesystemVersion) (err error) {
h.Logger.Printf("handling filesystem versions request: %#v", r) h.Logger.Printf("handling filesystem versions request: %#v", r)
@ -43,17 +84,20 @@ func (h Handler) HandleFilesystemVersionsRequest(r rpc.FilesystemVersionsRequest
} }
// find our versions // find our versions
if versions, err = zfs.ZFSListFilesystemVersions(r.Filesystem, nil); err != nil { vs, err := zfs.ZFSListFilesystemVersions(r.Filesystem, nil)
if err != nil {
h.Logger.Printf("our versions error: %#v\n", err) h.Logger.Printf("our versions error: %#v\n", err)
return return
} }
h.Logger.Printf("our versions: %#v\n", versions) h.Logger.Printf("our versions: %#v\n", vs)
*versions = vs
return return
} }
func (h Handler) HandleInitialTransferRequest(r rpc.InitialTransferRequest) (stream io.Reader, err error) { func (h Handler) HandleInitialTransferRequest(r *InitialTransferRequest, stream *io.Reader) (err error) {
h.Logger.Printf("handling initial transfer request: %#v", r) h.Logger.Printf("handling initial transfer request: %#v", r)
if err = h.pullACLCheck(r.Filesystem); err != nil { if err = h.pullACLCheck(r.Filesystem); err != nil {
@ -62,15 +106,17 @@ func (h Handler) HandleInitialTransferRequest(r rpc.InitialTransferRequest) (str
h.Logger.Printf("invoking zfs send") h.Logger.Printf("invoking zfs send")
if stream, err = zfs.ZFSSend(r.Filesystem, &r.FilesystemVersion, nil); err != nil { s, err := zfs.ZFSSend(r.Filesystem, &r.FilesystemVersion, nil)
if err != nil {
h.Logger.Printf("error sending filesystem: %#v", err) h.Logger.Printf("error sending filesystem: %#v", err)
} }
*stream = s
return return
} }
func (h Handler) HandleIncrementalTransferRequest(r rpc.IncrementalTransferRequest) (stream io.Reader, err error) { func (h Handler) HandleIncrementalTransferRequest(r *IncrementalTransferRequest, stream *io.Reader) (err error) {
h.Logger.Printf("handling incremental transfer request: %#v", r) h.Logger.Printf("handling incremental transfer request: %#v", r)
if err = h.pullACLCheck(r.Filesystem); err != nil { if err = h.pullACLCheck(r.Filesystem); err != nil {
@ -79,47 +125,16 @@ func (h Handler) HandleIncrementalTransferRequest(r rpc.IncrementalTransferReque
h.Logger.Printf("invoking zfs send") h.Logger.Printf("invoking zfs send")
if stream, err = zfs.ZFSSend(r.Filesystem, &r.From, &r.To); err != nil { s, err := zfs.ZFSSend(r.Filesystem, &r.From, &r.To)
if err != nil {
h.Logger.Printf("error sending filesystem: %#v", err) h.Logger.Printf("error sending filesystem: %#v", err)
} }
*stream = s
return return
} }
func (h Handler) HandlePullMeRequest(r rpc.PullMeRequest, clientIdentity string, client rpc.RPCRequester) (err error) {
// Check if we have a sink for this request
// Use that mapping to do what happens in doPull
h.Logger.Printf("handling PullMeRequest: %#v", r)
var sinkMapping DatasetMapping
sinkMapping, err = h.SinkMappingFunc(clientIdentity)
if err != nil {
h.Logger.Printf("no sink mapping for client identity '%s', denying PullMeRequest", clientIdentity)
err = fmt.Errorf("no sink for client identity '%s'", clientIdentity)
return
}
h.Logger.Printf("doing pull...")
err = doPull(PullContext{
Remote: client,
Log: h.Logger,
Mapping: sinkMapping,
InitialReplPolicy: r.InitialReplPolicy,
})
if err != nil {
h.Logger.Printf("PullMeRequest failed with error: %s", err)
return
}
h.Logger.Printf("finished handling PullMeRequest: %#v", r)
return
}
func (h Handler) pullACLCheck(p *zfs.DatasetPath) (err error) { func (h Handler) pullACLCheck(p *zfs.DatasetPath) (err error) {
var allowed bool var allowed bool
allowed, err = h.PullACL.Filter(p) allowed, err = h.PullACL.Filter(p)

View File

@ -148,24 +148,20 @@ func (a localPullACL) Filter(p *zfs.DatasetPath) (pass bool, err error) {
return true, nil return true, nil
} }
const LOCAL_TRANSPORT_IDENTITY string = "local"
const DEFAULT_INITIAL_REPL_POLICY = InitialReplPolicyMostRecent
type InitialReplPolicy string
const (
InitialReplPolicyMostRecent InitialReplPolicy = "most_recent"
InitialReplPolicyAll InitialReplPolicy = "all"
)
func jobPull(pull *Pull, log jobrun.Logger) (err error) { func jobPull(pull *Pull, log jobrun.Logger) (err error) {
if lt, ok := pull.From.Transport.(LocalTransport); ok { var remote rpc.RPCClient
lt.SetHandler(Handler{
Logger: log,
// Allow access to any dataset since we control what mapping
// is passed to the pull routine.
// All local datasets will be passed to its Map() function,
// but only those for which a mapping exists will actually be pulled.
// We can pay this small performance penalty for now.
PullACL: localPullACL{},
})
pull.From.Transport = lt
log.Printf("fixing up local transport: %#v", pull.From.Transport)
}
var remote rpc.RPCRequester
if remote, err = pull.From.Transport.Connect(log); err != nil { if remote, err = pull.From.Transport.Connect(log); err != nil {
return return
@ -182,7 +178,7 @@ func jobPush(push *Push, log jobrun.Logger) (err error) {
panic("no support for local pushs") panic("no support for local pushs")
} }
var remote rpc.RPCRequester var remote rpc.RPCClient
if remote, err = push.To.Transport.Connect(log); err != nil { if remote, err = push.To.Transport.Connect(log); err != nil {
return err return err
} }
@ -197,27 +193,19 @@ func jobPush(push *Push, log jobrun.Logger) (err error) {
} }
log.Printf("handler: %#v", handler) log.Printf("handler: %#v", handler)
r := rpc.PullMeRequest{ panic("no support for push atm")
InitialReplPolicy: push.InitialReplPolicy,
}
log.Printf("doing PullMeRequest: %#v", r)
if err = remote.PullMeRequest(r, handler); err != nil {
log.Printf("PullMeRequest failed: %s", err)
return
}
log.Printf("push job finished") log.Printf("push job finished")
return return
} }
func closeRPCWithTimeout(log Logger, remote rpc.RPCRequester, timeout time.Duration, goodbye string) { func closeRPCWithTimeout(log Logger, remote rpc.RPCClient, timeout time.Duration, goodbye string) {
log.Printf("closing rpc connection") log.Printf("closing rpc connection")
ch := make(chan error) ch := make(chan error)
go func() { go func() {
ch <- remote.CloseRequest(rpc.CloseRequest{goodbye}) ch <- remote.Close()
close(ch) close(ch)
}() }()
@ -231,19 +219,15 @@ func closeRPCWithTimeout(log Logger, remote rpc.RPCRequester, timeout time.Durat
if err != nil { if err != nil {
log.Printf("error closing connection: %s", err) log.Printf("error closing connection: %s", err)
err = remote.ForceClose()
if err != nil {
log.Printf("error force-closing connection: %s", err)
}
} }
return return
} }
type PullContext struct { type PullContext struct {
Remote rpc.RPCRequester Remote rpc.RPCClient
Log Logger Log Logger
Mapping DatasetMapping Mapping DatasetMapping
InitialReplPolicy rpc.InitialReplPolicy InitialReplPolicy InitialReplPolicy
} }
func doPull(pull PullContext) (err error) { func doPull(pull PullContext) (err error) {
@ -252,9 +236,9 @@ func doPull(pull PullContext) (err error) {
log := pull.Log log := pull.Log
log.Printf("requesting remote filesystem list") log.Printf("requesting remote filesystem list")
fsr := rpc.FilesystemRequest{} fsr := FilesystemRequest{}
var remoteFilesystems []*zfs.DatasetPath var remoteFilesystems []*zfs.DatasetPath
if remoteFilesystems, err = remote.FilesystemRequest(fsr); err != nil { if err = remote.Call("FilesystemRequest", &fsr, &remoteFilesystems); err != nil {
return return
} }
@ -335,11 +319,11 @@ func doPull(pull PullContext) (err error) {
} }
log("requesting remote filesystem versions") log("requesting remote filesystem versions")
var theirVersions []zfs.FilesystemVersion r := FilesystemVersionsRequest{
theirVersions, err = remote.FilesystemVersionsRequest(rpc.FilesystemVersionsRequest{
Filesystem: m.Remote, Filesystem: m.Remote,
}) }
if err != nil { var theirVersions []zfs.FilesystemVersion
if err = remote.Call("FilesystemVersionsRequest", &r, &theirVersions); err != nil {
log("error requesting remote filesystem versions: %s", err) log("error requesting remote filesystem versions: %s", err)
log("stopping replication for all filesystems mapped as children of %s", m.Local.ToString()) log("stopping replication for all filesystems mapped as children of %s", m.Local.ToString())
return false return false
@ -358,7 +342,7 @@ func doPull(pull PullContext) (err error) {
log("performing initial sync, following policy: '%s'", pull.InitialReplPolicy) log("performing initial sync, following policy: '%s'", pull.InitialReplPolicy)
if pull.InitialReplPolicy != rpc.InitialReplPolicyMostRecent { if pull.InitialReplPolicy != InitialReplPolicyMostRecent {
panic(fmt.Sprintf("policy '%s' not implemented", pull.InitialReplPolicy)) panic(fmt.Sprintf("policy '%s' not implemented", pull.InitialReplPolicy))
} }
@ -374,7 +358,7 @@ func doPull(pull PullContext) (err error) {
return false return false
} }
r := rpc.InitialTransferRequest{ r := InitialTransferRequest{
Filesystem: m.Remote, Filesystem: m.Remote,
FilesystemVersion: snapsOnly[len(snapsOnly)-1], FilesystemVersion: snapsOnly[len(snapsOnly)-1],
} }
@ -382,7 +366,8 @@ func doPull(pull PullContext) (err error) {
log("requesting snapshot stream for %s", r.FilesystemVersion) log("requesting snapshot stream for %s", r.FilesystemVersion)
var stream io.Reader var stream io.Reader
if stream, err = remote.InitialTransferRequest(r); err != nil {
if err = remote.Call("InitialTransferRequest", &r, &stream); err != nil {
log("error requesting initial transfer: %s", err) log("error requesting initial transfer: %s", err)
return false return false
} }
@ -434,13 +419,13 @@ func doPull(pull PullContext) (err error) {
} }
log("requesting incremental snapshot stream") log("requesting incremental snapshot stream")
r := rpc.IncrementalTransferRequest{ r := IncrementalTransferRequest{
Filesystem: m.Remote, Filesystem: m.Remote,
From: from, From: from,
To: to, To: to,
} }
var stream io.Reader var stream io.Reader
if stream, err = remote.IncrementalTransferRequest(r); err != nil { if err = remote.Call("IncrementalTransferRequest", &r, &stream); err != nil {
log("error requesting incremental snapshot stream: %s", err) log("error requesting incremental snapshot stream: %s", err)
return false return false
} }

View File

@ -2,12 +2,13 @@ package cmd
import ( import (
"fmt" "fmt"
"github.com/spf13/cobra"
"github.com/zrepl/zrepl/rpc"
"github.com/zrepl/zrepl/sshbytestream"
"io" "io"
golog "log" golog "log"
"os" "os"
"github.com/spf13/cobra"
"github.com/zrepl/zrepl/rpc"
"github.com/zrepl/zrepl/sshbytestream"
) )
var StdinserverCmd = &cobra.Command{ var StdinserverCmd = &cobra.Command{
@ -62,8 +63,10 @@ func cmdStdinServer(cmd *cobra.Command, args []string) {
PullACL: pullACL, PullACL: pullACL,
} }
if err = rpc.ListenByteStreamRPC(sshByteStream, identity, handler, sinkLogger); err != nil { server := rpc.NewServer(sshByteStream)
log.Printf("listenbytestreamerror: %#v\n", err) registerEndpoints(server, handler)
if err = server.Serve(); err != nil {
log.Printf("error serving connection: %s", err)
os.Exit(1) os.Exit(1)
} }

111
rpc/client.go Normal file
View File

@ -0,0 +1,111 @@
package rpc
import (
"bytes"
"encoding/json"
"io"
"reflect"
"github.com/pkg/errors"
)
type Client struct {
ml *MessageLayer
logger Logger
}
func NewClient(rwc io.ReadWriteCloser) *Client {
return &Client{NewMessageLayer(rwc), noLogger{}}
}
func (c *Client) SetLogger(logger Logger, logMessageLayer bool) {
c.logger = logger
if logMessageLayer {
c.ml.logger = logger
} else {
c.ml.logger = noLogger{}
}
}
func (c *Client) Close() (err error) {
err = c.ml.HangUp()
if err == RST {
return nil
}
return err
}
func (c *Client) recvResponse() (h *Header, err error) {
h, err = c.ml.ReadHeader()
if err != nil {
return nil, errors.Wrap(err, "cannot read header")
}
// TODO validate
return
}
func (c *Client) writeRequest(h *Header) (err error) {
// TODO validate
err = c.ml.WriteHeader(h)
if err != nil {
return errors.Wrap(err, "cannot write header")
}
return
}
func (c *Client) Call(endpoint string, in, out interface{}) (err error) {
var accept DataType
{
outType := reflect.TypeOf(out)
if typeIsIOReaderPtr(outType) {
accept = DataTypeOctets
} else {
accept = DataTypeMarshaledJSON
}
}
h := Header{
Endpoint: endpoint,
DataType: DataTypeMarshaledJSON,
Accept: accept,
}
if err = c.writeRequest(&h); err != nil {
return err
}
var buf bytes.Buffer
if err = json.NewEncoder(&buf).Encode(in); err != nil {
panic("cannot encode 'in' parameter")
}
if err = c.ml.WriteData(&buf); err != nil {
return err
}
rh, err := c.recvResponse()
if err != nil {
return err
}
if rh.Error != StatusOK {
return &RPCError{rh}
}
rd := c.ml.ReadData()
switch accept {
case DataTypeOctets:
c.logger.Printf("setting out to ML data reader")
outPtr := out.(*io.Reader) // we checked that above
*outPtr = rd
case DataTypeMarshaledJSON:
c.logger.Printf("decoding marshaled json")
if err = json.NewDecoder(c.ml.ReadData()).Decode(out); err != nil {
return errors.Wrap(err, "cannot decode marshaled reply")
}
default:
panic("implementation error") // accept is controlled by us
}
return
}

301
rpc/frame_layer.go Normal file
View File

@ -0,0 +1,301 @@
package rpc
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"github.com/pkg/errors"
)
type Frame struct {
Type FrameType
NoMoreFrames bool
PayloadLength uint32
}
//go:generate stringer -type=FrameType
type FrameType uint8
const (
FrameTypeHeader FrameType = 0x01
FrameTypeData FrameType = 0x02
FrameTypeTrailer FrameType = 0x03
FrameTypeRST FrameType = 0xff
)
type Status uint64
const (
StatusOK Status = 1 + iota
StatusRequestError
StatusServerError
// Returned when an error occurred but the side at fault cannot be determined
StatusError
)
type Header struct {
// Request-only
Endpoint string
// Data type of body (request & reply)
DataType DataType
// Request-only
Accept DataType
// Reply-only
Error Status
// Reply-only
ErrorMessage string
}
func NewErrorHeader(status Status, format string, args ...interface{}) (h *Header) {
h = &Header{}
h.Error = status
h.ErrorMessage = fmt.Sprintf(format, args...)
return
}
type DataType uint8
const (
DataTypeNone DataType = 1 + iota
DataTypeMarshaledJSON
DataTypeOctets
)
const (
MAX_PAYLOAD_LENGTH = 4 * 1024 * 1024
MAX_HEADER_LENGTH = 4 * 1024
)
type frameBridgingReader struct {
l *MessageLayer
frameType FrameType
// < 0 means no limit
bytesLeftToLimit int
f Frame
}
func NewFrameBridgingReader(l *MessageLayer, frameType FrameType, totalLimit int) *frameBridgingReader {
return &frameBridgingReader{l, frameType, totalLimit, Frame{}}
}
func (r *frameBridgingReader) Read(b []byte) (n int, err error) {
if r.bytesLeftToLimit == 0 {
return 0, io.EOF
}
log := r.l.logger
if r.f.PayloadLength == 0 {
if r.f.NoMoreFrames {
err = io.EOF
return
}
log.Printf("reading frame")
r.f, err = r.l.readFrame()
if err != nil {
return 0, err
}
log.Printf("read frame: %#v", r.f)
if r.f.Type != r.frameType {
err = errors.Wrapf(err, "expected frame of type %s", r.frameType)
return 0, err
}
}
maxread := len(b)
if maxread > int(r.f.PayloadLength) {
maxread = int(r.f.PayloadLength)
}
if r.bytesLeftToLimit > 0 && maxread > r.bytesLeftToLimit {
maxread = r.bytesLeftToLimit
}
nb, err := r.l.rwc.Read(b[:maxread])
log.Printf("read %v from rwc\n", nb)
if nb < 0 {
panic("should not return negative number of bytes")
}
r.f.PayloadLength -= uint32(nb)
r.bytesLeftToLimit -= nb
return nb, err // TODO io.EOF for maxread = r.f.PayloadLength ?
}
type frameBridgingWriter struct {
l *MessageLayer
frameType FrameType
// < 0 means no limit
bytesLeftToLimit int
payloadLength int
buffer *bytes.Buffer
}
func NewFrameBridgingWriter(l *MessageLayer, frameType FrameType, totalLimit int) *frameBridgingWriter {
return &frameBridgingWriter{l, frameType, totalLimit, MAX_PAYLOAD_LENGTH, bytes.NewBuffer(make([]byte, 0, MAX_PAYLOAD_LENGTH))}
}
func (w *frameBridgingWriter) Write(b []byte) (n int, err error) {
for n = 0; n < len(b); {
i, err := w.writeUntilFrameFull(b[n:])
n += i
if err != nil {
return n, errors.WithStack(err)
}
}
return
}
func (w *frameBridgingWriter) writeUntilFrameFull(b []byte) (n int, err error) {
if len(b) <= 0 {
return
}
if w.bytesLeftToLimit == 0 {
err = errors.Errorf("exceeded limit of total %v bytes for this message")
return
}
maxwrite := len(b)
remainingInFrame := w.payloadLength - w.buffer.Len()
if maxwrite > remainingInFrame {
maxwrite = remainingInFrame
}
if w.bytesLeftToLimit > 0 && maxwrite > w.bytesLeftToLimit {
maxwrite = w.bytesLeftToLimit
}
w.buffer.Write(b[:maxwrite])
w.bytesLeftToLimit -= maxwrite
n = maxwrite
if w.bytesLeftToLimit == 0 {
err = w.flush(true)
} else if w.buffer.Len() == w.payloadLength {
err = w.flush(false)
}
return
}
func (w *frameBridgingWriter) flush(nomore bool) (err error) {
f := Frame{w.frameType, nomore, uint32(w.buffer.Len())}
err = w.l.writeFrame(f)
if err != nil {
errors.WithStack(err)
}
_, err = w.buffer.WriteTo(w.l.rwc)
return
}
func (w *frameBridgingWriter) Close() (err error) {
return w.flush(true)
}
type MessageLayer struct {
rwc io.ReadWriteCloser
logger Logger
}
func NewMessageLayer(rwc io.ReadWriteCloser) *MessageLayer {
return &MessageLayer{rwc, noLogger{}}
}
// Always returns an error, RST error if no error occurred while sending RST frame
func (l *MessageLayer) HangUp() (err error) {
l.logger.Printf("hanging up")
f := Frame{
Type: FrameTypeRST,
NoMoreFrames: true,
}
rstFrameError := l.writeFrame(f)
closeErr := l.rwc.Close()
if rstFrameError != nil {
return errors.WithStack(rstFrameError)
} else if closeErr != nil {
return errors.WithStack(closeErr)
} else {
return RST
}
}
var RST error = fmt.Errorf("reset frame observed on connection")
func (l *MessageLayer) readFrame() (f Frame, err error) {
err = binary.Read(l.rwc, binary.LittleEndian, &f.Type)
if err != nil {
err = errors.WithStack(err)
return
}
err = binary.Read(l.rwc, binary.LittleEndian, &f.NoMoreFrames)
if err != nil {
err = errors.WithStack(err)
return
}
err = binary.Read(l.rwc, binary.LittleEndian, &f.PayloadLength)
if err != nil {
err = errors.WithStack(err)
return
}
if f.Type == FrameTypeRST {
err = RST
return
}
if f.PayloadLength > MAX_PAYLOAD_LENGTH {
err = errors.Errorf("frame exceeds max payload length")
return
}
return
}
func (l *MessageLayer) writeFrame(f Frame) (err error) {
err = binary.Write(l.rwc, binary.LittleEndian, &f.Type)
if err != nil {
return errors.WithStack(err)
}
err = binary.Write(l.rwc, binary.LittleEndian, &f.NoMoreFrames)
if err != nil {
return errors.WithStack(err)
}
err = binary.Write(l.rwc, binary.LittleEndian, &f.PayloadLength)
if err != nil {
return errors.WithStack(err)
}
if f.PayloadLength > MAX_PAYLOAD_LENGTH {
err = errors.Errorf("frame exceeds max payload length")
return
}
return
}
func (l *MessageLayer) ReadHeader() (h *Header, err error) {
r := NewFrameBridgingReader(l, FrameTypeHeader, MAX_HEADER_LENGTH)
h = &Header{}
if err = json.NewDecoder(r).Decode(&h); err != nil {
l.logger.Printf("cannot decode marshaled header: %s", err)
return nil, err
}
return h, nil
}
func (l *MessageLayer) WriteHeader(h *Header) (err error) {
w := NewFrameBridgingWriter(l, FrameTypeHeader, MAX_HEADER_LENGTH)
err = json.NewEncoder(w).Encode(h)
if err != nil {
return errors.Wrap(err, "cannot encode header, probably fatal")
}
w.Close()
return
}
func (l *MessageLayer) ReadData() (reader io.Reader) {
r := NewFrameBridgingReader(l, FrameTypeData, -1)
return r
}
func (l *MessageLayer) WriteData(source io.Reader) (err error) {
w := NewFrameBridgingWriter(l, FrameTypeData, -1)
_, err = io.Copy(w, source)
if err != nil {
return errors.WithStack(err)
}
err = w.Close()
return
}

64
rpc/local.go Normal file
View File

@ -0,0 +1,64 @@
package rpc
import (
"github.com/pkg/errors"
"reflect"
)
type LocalRPC struct {
endpoints map[string]reflect.Value
}
func NewLocalRPC() *LocalRPC {
return &LocalRPC{make(map[string]reflect.Value, 0)}
}
func (s *LocalRPC) RegisterEndpoint(name string, handler interface{}) (err error) {
_, ok := s.endpoints[name]
if ok {
return errors.Errorf("already set up an endpoint for '%s'", name)
}
ep, err := makeEndpointDescr(handler)
if err != nil {
return err
}
s.endpoints[name] = ep.handler
return nil
}
func (s *LocalRPC) Serve() (err error) {
panic("local cannot serve")
return nil
}
func (c *LocalRPC) Call(endpoint string, in, out interface{}) (err error) {
ep, ok := c.endpoints[endpoint]
if !ok {
panic("implementation error: implementation should not call local RPC without knowing which endpoints exist")
}
args := []reflect.Value{reflect.ValueOf(in), reflect.ValueOf(out)}
if err = checkRPCParamTypes(args[0].Type(), args[1].Type()); err != nil {
return
}
rets := ep.Call(args)
if len(rets) != 1 {
panic("implementation error: endpoints must have one error ")
}
if err = checkRPCReturnType(rets[0].Type()); err != nil {
panic(err)
}
err = nil
if !rets[0].IsNil() {
err = rets[0].Interface().(error) // we checked that above
}
return
}
func (c *LocalRPC) Close() (err error) {
return nil
}

View File

@ -1,545 +0,0 @@
package rpc
import (
"bytes"
"encoding/json"
"errors"
"fmt"
. "github.com/zrepl/zrepl/util"
"github.com/zrepl/zrepl/zfs"
"io"
"reflect"
"time"
)
type RPCRequester interface {
FilesystemRequest(r FilesystemRequest) (roots []*zfs.DatasetPath, err error)
FilesystemVersionsRequest(r FilesystemVersionsRequest) (versions []zfs.FilesystemVersion, err error)
InitialTransferRequest(r InitialTransferRequest) (io.Reader, error)
IncrementalTransferRequest(r IncrementalTransferRequest) (io.Reader, error)
PullMeRequest(r PullMeRequest, handler RPCHandler) (err error)
CloseRequest(r CloseRequest) (err error)
ForceClose() (err error)
}
type RPCHandler interface {
HandleFilesystemRequest(r FilesystemRequest) (roots []*zfs.DatasetPath, err error)
// returned versions ordered by birthtime, oldest first
HandleFilesystemVersionsRequest(r FilesystemVersionsRequest) (versions []zfs.FilesystemVersion, err error)
HandleInitialTransferRequest(r InitialTransferRequest) (io.Reader, error)
HandleIncrementalTransferRequest(r IncrementalTransferRequest) (io.Reader, error)
// invert roles, i.e. handler becomes server and performs the requested pull using the client connection
HandlePullMeRequest(r PullMeRequest, clientIdentity string, client RPCRequester) (err error)
}
type Logger interface {
Printf(format string, args ...interface{})
}
const ByteStreamRPCProtocolVersion = 1
type ByteStream interface {
io.ReadWriteCloser
}
type ByteStreamRPC struct {
conn ByteStream
log Logger
clientIdentity string
}
func ConnectByteStreamRPC(conn ByteStream, log Logger) (RPCRequester, error) {
rpc := ByteStreamRPC{
conn: conn,
log: log,
}
// Assert protocol versions are equal
err := rpc.ProtocolVersionRequest()
if err != nil {
return nil, err
}
return rpc, nil
}
type ByteStreamRPCDecodeJSONError struct {
Type reflect.Type
DecoderErr error
}
func (e ByteStreamRPCDecodeJSONError) Error() string {
return fmt.Sprintf("cannot decode %s: %s", e.Type, e.DecoderErr)
}
func ListenByteStreamRPC(conn ByteStream, clientIdentity string, handler RPCHandler, log Logger) error {
c := ByteStreamRPC{
conn: conn,
log: log,
clientIdentity: clientIdentity,
}
return c.serverLoop(handler)
}
func (c ByteStreamRPC) serverLoop(handler RPCHandler) error {
// A request consists of two subsequent chunked JSON objects
// Object 1: RequestHeader => contains type of Request Body
// Object 2: RequestBody, e.g. IncrementalTransferRequest
// A response is always a ResponseHeader followed by bytes to be interpreted
// as indicated by the ResponseHeader.ResponseType, e.g.
// a) a chunked response
// b) or another JSON object
conn := c.conn
log := c.log
defer func() {
panicObj := recover()
// if we just exited, we don't want to close the connection (PullMeRequest depends on this)
log.Printf("exiting server loop, panic object %#v", panicObj)
if panicObj != nil {
conn.Close()
}
}()
send := func(r interface{}) {
if err := writeChunkedJSON(conn, r); err != nil {
panic(err)
}
}
sendError := func(id ErrorId, msg string) {
r := ResponseHeader{
ErrorId: id,
ResponseType: RNONE,
Message: msg,
}
log.Printf("sending error response: %#v", r)
if err := writeChunkedJSON(conn, r); err != nil {
log.Printf("error sending error response: %#v", err)
panic(err)
}
}
recv := func(r interface{}) (err error) {
return readChunkedJSON(conn, r)
}
for {
var header RequestHeader = RequestHeader{}
if err := recv(&header); err != nil {
sendError(EDecodeHeader, err.Error())
return conn.Close()
}
switch header.Type {
case RTProtocolVersionRequest:
var rq ByteStreamRPCProtocolVersionRequest
if err := recv(&rq); err != nil {
sendError(EDecodeRequestBody, err.Error())
return conn.Close()
}
if rq.ClientVersion != ByteStreamRPCProtocolVersion {
sendError(EProtocolVersionMismatch, "")
return conn.Close()
}
r := ResponseHeader{
RequestId: header.Id,
ResponseType: ROK,
}
send(&r)
case RTCloseRequest:
var rq CloseRequest
if err := recv(&rq); err != nil {
sendError(EDecodeRequestBody, err.Error())
return conn.Close()
}
log.Printf("close request with goodbye: %s", rq.Goodbye)
send(&ResponseHeader{
RequestId: header.Id,
ResponseType: ROK,
})
return conn.Close()
case RTFilesystemRequest:
var rq FilesystemRequest
if err := recv(&rq); err != nil {
sendError(EDecodeRequestBody, "")
return conn.Close()
}
roots, err := handler.HandleFilesystemRequest(rq)
if err != nil {
sendError(EHandler, err.Error())
return conn.Close()
} else {
r := ResponseHeader{
RequestId: header.Id,
ResponseType: RFilesystems,
}
send(&r)
send(&roots)
}
case RTFilesystemVersionsRequest:
var rq FilesystemVersionsRequest
if err := recv(&rq); err != nil {
sendError(EDecodeRequestBody, err.Error())
return err
}
diff, err := handler.HandleFilesystemVersionsRequest(rq)
if err != nil {
sendError(EHandler, err.Error())
return err
} else {
r := ResponseHeader{
RequestId: header.Id,
ResponseType: RFilesystemDiff,
}
send(&r)
send(&diff)
}
log.Printf("finished FilesystemVersionReqeust")
case RTInitialTransferRequest:
var rq InitialTransferRequest
if err := recv(&rq); err != nil {
sendError(EDecodeRequestBody, "")
return conn.Close()
}
log.Printf("initial transfer request: %#v", rq)
snapReader, err := handler.HandleInitialTransferRequest(rq)
if err != nil {
sendError(EHandler, err.Error())
return conn.Close()
} else {
r := ResponseHeader{
RequestId: header.Id,
ResponseType: RChunkedStream,
}
send(&r)
chunker := NewChunker(snapReader)
watcher := IOProgressWatcher{Reader: &chunker}
watcher.KickOff(1*time.Second, func(p IOProgress) {
log.Printf("progress sending initial snapshot stream: %v bytes sent", p.TotalRX)
})
_, err := io.Copy(conn, &watcher)
if err != nil {
log.Printf("error sending initial snapshot stream: %s", err)
panic(err)
}
log.Printf("finished sending initial snapshot stream: total %v bytes sent", watcher.Progress().TotalRX)
}
case RTIncrementalTransferRequest:
var rq IncrementalTransferRequest
if err := recv(&rq); err != nil {
sendError(EDecodeRequestBody, "")
return conn.Close()
}
snapReader, err := handler.HandleIncrementalTransferRequest(rq)
if err != nil {
sendError(EHandler, err.Error())
} else {
r := ResponseHeader{
RequestId: header.Id,
ResponseType: RChunkedStream,
}
send(&r)
chunker := NewChunker(snapReader)
watcher := IOProgressWatcher{Reader: &chunker}
watcher.KickOff(1*time.Second, func(p IOProgress) {
log.Printf("progress sending incremental snapshot stream: %v bytes sent", p.TotalRX)
})
_, err := io.Copy(conn, &watcher)
if err != nil {
panic(err)
}
log.Printf("finished sending incremental snapshot stream: total %v bytes sent", watcher.Progress().TotalRX)
}
case RTPullMeRequest:
var rq PullMeRequest
if err := recv(&rq); err != nil {
sendError(EDecodeRequestBody, err.Error())
return conn.Close()
}
if rq.Finished {
// we are the client that sent a PullMeRequest with Finished = false
// and then entered this server loop
log.Printf("PullMeRequest.Finished == true, exiting server loop")
send(ResponseHeader{
RequestId: header.Id,
ResponseType: ROK,
})
return nil
}
// We are a server receiving a PullMeRequest from a client
log.Printf("confirming PullMeRequest")
send(ResponseHeader{
RequestId: header.Id,
ResponseType: ROK,
})
log.Printf("pulling from client '%s', expecting client is in server loop", c.clientIdentity)
if c.clientIdentity == "" || c.clientIdentity == LOCAL_TRANSPORT_IDENTITY {
err := fmt.Errorf("client has bad name: '%s'", c.clientIdentity)
log.Printf(err.Error())
panic(err)
}
pullErr := handler.HandlePullMeRequest(rq, c.clientIdentity, c)
if pullErr != nil {
log.Printf("pulling failed with error: %s", pullErr)
panic(pullErr)
}
log.Printf("finished handling PullMeRequest, sending Finished = true")
req := PullMeRequest{Finished: true}
c.sendRequestReceiveHeader(req, ROK)
default:
sendError(EUnknownRequestType, "")
return conn.Close()
}
}
return nil
}
func writeChunkedJSON(conn io.Writer, r interface{}) (err error) {
var buf bytes.Buffer
encoder := json.NewEncoder(&buf)
encoder.Encode(r)
ch := NewChunker(&buf)
_, err = io.Copy(conn, &ch)
return
}
func readChunkedJSON(conn io.ReadWriter, r interface{}) (err error) {
unch := NewUnchunker(conn)
dec := json.NewDecoder(unch)
err = dec.Decode(r)
if err != nil {
err = ByteStreamRPCDecodeJSONError{
Type: reflect.TypeOf(r),
DecoderErr: err,
}
}
closeErr := unch.Close()
if err == nil && closeErr != nil {
err = closeErr
}
return
}
func inferRequestType(v interface{}) (RequestType, error) {
switch v.(type) {
case ByteStreamRPCProtocolVersionRequest:
return RTProtocolVersionRequest, nil
case FilesystemRequest:
return RTFilesystemRequest, nil
case FilesystemVersionsRequest:
return RTFilesystemVersionsRequest, nil
case InitialTransferRequest:
return RTInitialTransferRequest, nil
case IncrementalTransferRequest:
return RTIncrementalTransferRequest, nil
case PullMeRequest:
return RTPullMeRequest, nil
case CloseRequest:
return RTCloseRequest, nil
default:
return 0, errors.New(fmt.Sprintf("cannot infer request type for type '%v'",
reflect.TypeOf(v)))
}
}
func genUUID() [16]byte {
return [16]byte{} // TODO
}
func (c ByteStreamRPC) sendRequest(v interface{}) (err error) {
var rt RequestType
if rt, err = inferRequestType(v); err != nil {
return
}
h := RequestHeader{
Type: rt,
Id: genUUID(),
}
if err = writeChunkedJSON(c.conn, h); err != nil {
return
}
if err = writeChunkedJSON(c.conn, v); err != nil {
return
}
return
}
func (c ByteStreamRPC) expectResponseType(rt ResponseType) (err error) {
var h ResponseHeader
if err = readChunkedJSON(c.conn, &h); err != nil {
return
}
if h.ResponseType != rt {
return errors.New(fmt.Sprintf("unexpected response type in response header: got %#v, expected %#v. response header: %#v",
h.ResponseType, rt, h))
}
return
}
func (c ByteStreamRPC) sendRequestReceiveHeader(request interface{}, rt ResponseType) (err error) {
if err = c.sendRequest(request); err != nil {
return err
}
if err = c.expectResponseType(rt); err != nil {
return err
}
return nil
}
func (c ByteStreamRPC) ProtocolVersionRequest() (err error) {
b := ByteStreamRPCProtocolVersionRequest{
ClientVersion: ByteStreamRPCProtocolVersion,
}
// OK response means the remote side can cope with our protocol version
return c.sendRequestReceiveHeader(b, ROK)
}
func (c ByteStreamRPC) FilesystemRequest(r FilesystemRequest) (roots []*zfs.DatasetPath, err error) {
if err = c.sendRequestReceiveHeader(r, RFilesystems); err != nil {
return
}
roots = make([]*zfs.DatasetPath, 0)
if err = readChunkedJSON(c.conn, &roots); err != nil {
return
}
return
}
func (c ByteStreamRPC) FilesystemVersionsRequest(r FilesystemVersionsRequest) (versions []zfs.FilesystemVersion, err error) {
if err = c.sendRequestReceiveHeader(r, RFilesystemDiff); err != nil {
return
}
err = readChunkedJSON(c.conn, &versions)
return
}
func (c ByteStreamRPC) InitialTransferRequest(r InitialTransferRequest) (unchunker io.Reader, err error) {
if err = c.sendRequestReceiveHeader(r, RChunkedStream); err != nil {
return
}
unchunker = NewUnchunker(c.conn)
return
}
func (c ByteStreamRPC) IncrementalTransferRequest(r IncrementalTransferRequest) (unchunker io.Reader, err error) {
if err = c.sendRequestReceiveHeader(r, RChunkedStream); err != nil {
return
}
unchunker = NewUnchunker(c.conn)
return
}
func (c ByteStreamRPC) PullMeRequest(r PullMeRequest, handler RPCHandler) (err error) {
err = c.sendRequestReceiveHeader(r, ROK)
return c.serverLoop(handler)
}
func (c ByteStreamRPC) CloseRequest(r CloseRequest) (err error) {
if err = c.sendRequestReceiveHeader(r, ROK); err != nil {
return
}
err = c.conn.Close()
return
}
func (c ByteStreamRPC) ForceClose() (err error) {
return c.conn.Close()
}
type LocalRPC struct {
handler RPCHandler
}
func ConnectLocalRPC(handler RPCHandler) RPCRequester {
return LocalRPC{handler}
}
func (c LocalRPC) FilesystemRequest(r FilesystemRequest) (roots []*zfs.DatasetPath, err error) {
return c.handler.HandleFilesystemRequest(r)
}
func (c LocalRPC) FilesystemVersionsRequest(r FilesystemVersionsRequest) (versions []zfs.FilesystemVersion, err error) {
return c.handler.HandleFilesystemVersionsRequest(r)
}
func (c LocalRPC) InitialTransferRequest(r InitialTransferRequest) (io.Reader, error) {
return c.handler.HandleInitialTransferRequest(r)
}
func (c LocalRPC) IncrementalTransferRequest(r IncrementalTransferRequest) (reader io.Reader, err error) {
reader, err = c.handler.HandleIncrementalTransferRequest(r)
return
}
func (c LocalRPC) PullMeRequest(r PullMeRequest, handler RPCHandler) (err error) {
// The config syntactically only allows local Pulls, hence this is never called
// In theory, the following line should work:
// return handler.HandlePullMeRequest(r, LOCAL_TRANSPORT_IDENTITY, c)
panic("internal inconsistency: local pull me request unsupported")
}
func (c LocalRPC) CloseRequest(r CloseRequest) error { return nil }
func (c LocalRPC) ForceClose() error { return nil }

View File

@ -1,33 +0,0 @@
package rpc
import (
"bytes"
"github.com/stretchr/testify/assert"
"github.com/zrepl/zrepl/util"
"io"
"strings"
"testing"
)
func TestByteStreamRPCDecodeJSONError(t *testing.T) {
r := strings.NewReader("{'a':'aber'}")
var chunked bytes.Buffer
ch := util.NewChunker(r)
io.Copy(&chunked, &ch)
type SampleType struct {
A uint
}
var s SampleType
err := readChunkedJSON(&chunked, &s)
assert.NotNil(t, err)
_, ok := err.(ByteStreamRPCDecodeJSONError)
if !ok {
t.Errorf("expected ByteStreamRPCDecodeJSONError, got %t\n", err)
t.Errorf("%s\n", err)
}
}

226
rpc/server.go Normal file
View File

@ -0,0 +1,226 @@
package rpc
import (
"bytes"
"encoding/json"
"io"
"reflect"
"github.com/pkg/errors"
)
type Server struct {
ml *MessageLayer
logger Logger
endpoints map[string]endpointDescr
}
type typeMap struct {
local reflect.Type
proto DataType
}
type endpointDescr struct {
inType typeMap
outType typeMap
handler reflect.Value
}
type MarshaledJSONEndpoint func(bodyJSON interface{})
func NewServer(rwc io.ReadWriteCloser) *Server {
ml := NewMessageLayer(rwc)
return &Server{
ml, noLogger{}, make(map[string]endpointDescr),
}
}
func (s *Server) SetLogger(logger Logger, logMessageLayer bool) {
s.logger = logger
if logMessageLayer {
s.ml.logger = logger
} else {
s.ml.logger = noLogger{}
}
}
func (s *Server) RegisterEndpoint(name string, handler interface{}) (err error) {
_, ok := s.endpoints[name]
if ok {
return errors.Errorf("already set up an endpoint for '%s'", name)
}
s.endpoints[name], err = makeEndpointDescr(handler)
return
}
func checkResponseHeader(h *Header) (err error) {
var statusNotSet Status
if h.Error == statusNotSet {
return errors.Errorf("status has zero-value")
}
return nil
}
func (s *Server) writeResponse(h *Header) (err error) {
// TODO validate
return s.ml.WriteHeader(h)
}
func (s *Server) recvRequest() (h *Header, err error) {
h, err = s.ml.ReadHeader()
if err != nil {
s.logger.Printf("error reading header: %s", err)
return nil, err
}
s.logger.Printf("validating request")
err = nil // TODO validate
if err == nil {
return h, nil
}
s.logger.Printf("request validation error: %s", err)
r := NewErrorHeader(StatusRequestError, "%s", err)
return nil, s.writeResponse(r)
}
var doneServeNext error = errors.New("this should not cause a HangUp() in the server")
var ProtocolError error = errors.New("protocol error, server should hang up")
// Serve the connection until failure or the client hangs up
func (s *Server) Serve() (err error) {
for {
err = s.ServeRequest()
if err == nil {
continue
}
if err == doneServeNext {
s.logger.Printf("subroutine returned pseudo-error indicating early-exit")
continue
}
s.logger.Printf("hanging up after ServeRequest returned error: %s", err)
s.ml.HangUp()
return err
}
}
// Serve a single request
// * wait for request to come in
// * call handler
// * reply
//
// The connection is left open, the next bytes on the conn should be
// the next request header.
//
// Returns an err != nil if the error is bad enough to hang up on the client.
// Examples: protocol version mismatches, protocol errors in general, ...
// Non-Examples: a handler error
func (s *Server) ServeRequest() (err error) {
ml := s.ml
s.logger.Printf("reading header")
h, err := s.recvRequest()
if err != nil {
return err
}
ep, ok := s.endpoints[h.Endpoint]
if !ok {
r := NewErrorHeader(StatusRequestError, "unregistered endpoint %s", h.Endpoint)
return s.writeResponse(r)
}
if ep.inType.proto != h.DataType {
r := NewErrorHeader(StatusRequestError, "wrong DataType for endpoint %s (has %s, you provided %s)", h.Endpoint, ep.inType.proto, h.DataType)
return s.writeResponse(r)
}
if ep.outType.proto != h.Accept {
r := NewErrorHeader(StatusRequestError, "wrong Accept for endpoint %s (has %s, you provided %s)", h.Endpoint, ep.outType.proto, h.Accept)
return s.writeResponse(r)
}
dr := ml.ReadData()
// Determine inval
var inval reflect.Value
switch ep.inType.proto {
case DataTypeMarshaledJSON:
// Unmarshal input
inval = reflect.New(ep.inType.local.Elem())
invalIface := inval.Interface()
err = json.NewDecoder(dr).Decode(invalIface)
if err != nil {
r := NewErrorHeader(StatusRequestError, "cannot decode marshaled JSON: %s", err)
return s.writeResponse(r)
}
case DataTypeOctets:
// Take data as is
inval = reflect.ValueOf(dr)
default:
panic("not implemented")
}
outval := reflect.New(ep.outType.local.Elem()) // outval is a double pointer
s.logger.Printf("before handler, inval=%v outval=%v", inval, outval)
// Call the handler
errs := ep.handler.Call([]reflect.Value{inval, outval})
if !errs[0].IsNil() {
he := errs[0].Interface().(error) // we checked that before...
s.logger.Printf("handler returned error: %s", err)
r := NewErrorHeader(StatusError, "%s", he.Error())
return s.writeResponse(r)
}
switch ep.outType.proto {
case DataTypeMarshaledJSON:
var dataBuf bytes.Buffer
// Marshal output
err = json.NewEncoder(&dataBuf).Encode(outval.Interface())
if err != nil {
r := NewErrorHeader(StatusServerError, "cannot marshal response: %s", err)
return s.writeResponse(r)
}
replyHeader := Header{
Error: StatusOK,
DataType: ep.outType.proto,
}
if err = s.writeResponse(&replyHeader); err != nil {
return err
}
if err = ml.WriteData(&dataBuf); err != nil {
return
}
case DataTypeOctets:
h := Header{
Error: StatusOK,
DataType: DataTypeOctets,
}
if err = s.writeResponse(&h); err != nil {
return
}
reader := outval.Interface().(*io.Reader) // we checked that when adding the endpoint
err = ml.WriteData(*reader)
if err != nil {
return err
}
}
return nil
}

111
rpc/shared.go Normal file
View File

@ -0,0 +1,111 @@
package rpc
import (
"fmt"
"github.com/pkg/errors"
"io"
"reflect"
)
type RPCServer interface {
Serve() (err error)
RegisterEndpoint(name string, handler interface{}) (err error)
}
type RPCClient interface {
Call(endpoint string, in, out interface{}) (err error)
Close() (err error)
}
type Logger interface {
Printf(format string, args ...interface{})
}
type noLogger struct{}
func (l noLogger) Printf(format string, args ...interface{}) {}
func typeIsIOReader(t reflect.Type) bool {
return t == reflect.TypeOf((*io.Reader)(nil)).Elem()
}
func typeIsIOReaderPtr(t reflect.Type) bool {
return t == reflect.TypeOf((*io.Reader)(nil))
}
// An error returned by the Client if the response indicated a status code other than StatusOK
type RPCError struct {
ResponseHeader *Header
}
func (e *RPCError) Error() string {
return fmt.Sprintf("%s: %s", e.ResponseHeader.Error, e.ResponseHeader.ErrorMessage)
}
type RPCProtoError struct {
Message string
UnderlyingError error
}
func (e *RPCProtoError) Error() string {
return e.Message
}
func checkRPCParamTypes(in, out reflect.Type) (err error) {
if !(in.Kind() == reflect.Ptr || typeIsIOReader(in)) {
err = errors.Errorf("input parameter must be a pointer or an io.Reader, is of kind %s, type %s", in.Kind(), in)
return
}
if !(out.Kind() == reflect.Ptr) {
err = errors.Errorf("second input parameter (the non-error output parameter) must be a pointer or an *io.Reader")
return
}
return nil
}
func checkRPCReturnType(rt reflect.Type) (err error) {
errInterfaceType := reflect.TypeOf((*error)(nil)).Elem()
if !rt.Implements(errInterfaceType) {
err = errors.Errorf("handler must return an error")
return
}
return nil
}
func makeEndpointDescr(handler interface{}) (descr endpointDescr, err error) {
ht := reflect.TypeOf(handler)
if ht.Kind() != reflect.Func {
err = errors.Errorf("handler must be of kind reflect.Func")
return
}
if ht.NumIn() != 2 || ht.NumOut() != 1 {
err = errors.Errorf("handler must have exactly two input parameters and one output parameter")
return
}
if err = checkRPCParamTypes(ht.In(0), ht.In(1)); err != nil {
return
}
if err = checkRPCReturnType(ht.Out(0)); err != nil {
return
}
descr.handler = reflect.ValueOf(handler)
descr.inType.local = ht.In(0)
descr.outType.local = ht.In(1)
if typeIsIOReader(ht.In(0)) {
descr.inType.proto = DataTypeOctets
} else {
descr.inType.proto = DataTypeMarshaledJSON
}
if typeIsIOReaderPtr(ht.In(1)) {
descr.outType.proto = DataTypeOctets
} else {
descr.outType.proto = DataTypeMarshaledJSON
}
return
}

View File

@ -1,119 +0,0 @@
package rpc
import (
"encoding/json"
"io"
"github.com/zrepl/zrepl/zfs"
)
var _ json.Marshaler = &zfs.DatasetPath{}
var _ json.Unmarshaler = &zfs.DatasetPath{}
type RequestId [16]byte
type RequestType uint8
const (
RTProtocolVersionRequest RequestType = 0x01
RTFilesystemRequest = 0x10
RTFilesystemVersionsRequest = 0x11
RTInitialTransferRequest = 0x12
RTIncrementalTransferRequest = 0x13
RTPullMeRequest = 0x20
RTCloseRequest = 0xf0
)
type RequestHeader struct {
Type RequestType
Id [16]byte // UUID
}
type FilesystemRequest struct {
Roots []string // may be nil, indicating interest in all filesystems
}
type FilesystemVersionsRequest struct {
Filesystem *zfs.DatasetPath
}
type InitialTransferRequest struct {
Filesystem *zfs.DatasetPath
FilesystemVersion zfs.FilesystemVersion
}
func (r InitialTransferRequest) Respond(snapshotReader io.Reader) {
}
type IncrementalTransferRequest struct {
Filesystem *zfs.DatasetPath
From zfs.FilesystemVersion
To zfs.FilesystemVersion
}
func (r IncrementalTransferRequest) Respond(snapshotReader io.Reader) {
}
type ByteStreamRPCProtocolVersionRequest struct {
ClientVersion uint8
}
const LOCAL_TRANSPORT_IDENTITY string = "local"
const DEFAULT_INITIAL_REPL_POLICY = InitialReplPolicyMostRecent
type InitialReplPolicy string
const (
InitialReplPolicyMostRecent InitialReplPolicy = "most_recent"
InitialReplPolicyAll InitialReplPolicy = "all"
)
type PullMeRequest struct {
// if true, the other fields are undefined
Finished bool
InitialReplPolicy InitialReplPolicy
}
type CloseRequest struct {
Goodbye string
}
type ErrorId uint8
const (
ENoError ErrorId = 0
EDecodeHeader = 1
EUnknownRequestType = 2
EDecodeRequestBody = 3
EProtocolVersionMismatch = 4
EHandler = 5
)
type ResponseType uint8
const (
RNONE ResponseType = 0x0
ROK = 0x1
RFilesystems = 0x10
RFilesystemDiff = 0x11
RChunkedStream = 0x20
)
type ResponseHeader struct {
RequestId RequestId
ErrorId ErrorId
Message string
ResponseType ResponseType
}
func NewByteStreamRPCProtocolVersionRequest() ByteStreamRPCProtocolVersionRequest {
return ByteStreamRPCProtocolVersionRequest{
ClientVersion: ByteStreamRPCProtocolVersion,
}
}
func newUUID() [16]byte {
return [16]byte{}
}