zrepl/rpc/rpc.go

306 lines
7.1 KiB
Go
Raw Normal View History

2017-04-14 19:26:32 +02:00
package rpc
import (
"encoding/json"
2017-04-16 21:38:31 +02:00
"errors"
"fmt"
2017-04-26 20:25:53 +02:00
. "github.com/zrepl/zrepl/model"
. "github.com/zrepl/zrepl/util"
"io"
2017-04-16 21:38:31 +02:00
"reflect"
)
2017-04-14 19:26:32 +02:00
type RPCRequester interface {
2017-04-16 21:38:31 +02:00
FilesystemRequest(r FilesystemRequest) (roots []Filesystem, err error)
InitialTransferRequest(r InitialTransferRequest) (io.Reader, error)
IncrementalTransferRequest(r IncrementalTransferRequest) (io.Reader, error)
2017-04-14 19:26:32 +02:00
}
type RPCHandler interface {
2017-04-16 21:38:31 +02:00
HandleFilesystemRequest(r FilesystemRequest) (roots []Filesystem, err error)
HandleInitialTransferRequest(r InitialTransferRequest) (io.Reader, error)
HandleIncrementalTransferRequest(r IncrementalTransferRequest) (io.Reader, error)
2017-04-14 19:26:32 +02:00
}
const ByteStreamRPCProtocolVersion = 1
2017-04-14 19:26:32 +02:00
type ByteStreamRPC struct {
conn io.ReadWriteCloser
encoder *json.Encoder
decoder *json.Decoder
2017-04-14 19:26:32 +02:00
}
2017-04-16 21:38:31 +02:00
func ConnectByteStreamRPC(conn io.ReadWriteCloser) (RPCRequester, error) {
2017-04-14 19:26:32 +02:00
// TODO do ssh connection to transport, establish TCP-like communication channel
rpc := ByteStreamRPC{
conn: conn,
encoder: json.NewEncoder(conn),
decoder: json.NewDecoder(conn),
2017-04-14 19:26:32 +02:00
}
// Assert protocol versions are equal
2017-04-16 21:38:31 +02:00
err := rpc.ProtocolVersionRequest()
if err != nil {
return nil, err
}
return rpc, nil
2017-04-14 19:26:32 +02:00
}
func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler) error {
// A request consists of two subsequent JSON objects
// Object 1: RequestHeader => contains type of Request Body
// Object 2: RequestBody, e.g. IncrementalTransferRequest
// A response is always a ResponseHeader followed by bytes to be interpreted
// as indicated by the ResponseHeader.ResponseType, e.g.
// a) a chunked response
// b) or another JSON object
defer conn.Close()
decoder := json.NewDecoder(conn)
encoder := json.NewEncoder(conn)
for {
var header RequestHeader = RequestHeader{}
if err := decoder.Decode(&header); err != nil {
2017-04-30 17:58:39 +02:00
respondWithError(encoder, EDecodeHeader, err)
return conn.Close()
}
switch header.Type {
case RTProtocolVersionRequest:
var rq ByteStreamRPCProtocolVersionRequest
if err := decoder.Decode(&rq); err != nil {
2017-04-30 17:58:39 +02:00
respondWithError(encoder, EDecodeRequestBody, nil)
return conn.Close()
}
if rq.ClientVersion != ByteStreamRPCProtocolVersion {
2017-04-30 17:58:39 +02:00
respondWithError(encoder, EProtocolVersionMismatch, nil)
return conn.Close()
}
r := ResponseHeader{
RequestId: header.Id,
ResponseType: ROK,
}
if err := encoder.Encode(&r); err != nil {
panic(err)
}
case RTFilesystemRequest:
var rq FilesystemRequest
if err := decoder.Decode(&rq); err != nil {
2017-04-30 17:58:39 +02:00
respondWithError(encoder, EDecodeRequestBody, nil)
return conn.Close()
}
roots, err := handler.HandleFilesystemRequest(rq)
if err != nil {
2017-04-30 17:58:39 +02:00
respondWithError(encoder, EHandler, err)
return conn.Close()
} else {
r := ResponseHeader{
RequestId: header.Id,
ResponseType: RFilesystems,
}
if err := encoder.Encode(&r); err != nil {
panic(err)
}
if err := encoder.Encode(&roots); err != nil {
panic(err)
}
}
case RTInitialTransferRequest:
var rq InitialTransferRequest
if err := decoder.Decode(&rq); err != nil {
2017-04-30 17:58:39 +02:00
respondWithError(encoder, EDecodeRequestBody, nil)
return conn.Close()
}
snapReader, err := handler.HandleInitialTransferRequest(rq)
if err != nil {
2017-04-30 17:58:39 +02:00
respondWithError(encoder, EHandler, err)
return conn.Close()
} else {
chunker := NewChunker(snapReader)
_, err := io.Copy(conn, &chunker)
if err != nil {
panic(err)
}
}
2017-04-16 21:38:31 +02:00
case RTIncrementalTransferRequest:
var rq IncrementalTransferRequest
if err := decoder.Decode(&rq); err != nil {
2017-04-30 17:58:39 +02:00
respondWithError(encoder, EDecodeRequestBody, nil)
return conn.Close()
2017-04-16 21:38:31 +02:00
}
snapReader, err := handler.HandleIncrementalTransferRequest(rq)
if err != nil {
2017-04-30 17:58:39 +02:00
respondWithError(encoder, EHandler, err)
2017-04-16 21:38:31 +02:00
} else {
chunker := NewChunker(snapReader)
_, err := io.Copy(conn, &chunker)
if err != nil {
panic(err)
2017-04-16 21:38:31 +02:00
}
}
default:
2017-04-30 17:58:39 +02:00
respondWithError(encoder, EUnknownRequestType, nil)
return conn.Close()
}
}
2017-04-14 19:26:32 +02:00
return nil
}
2017-04-30 17:58:39 +02:00
func respondWithError(encoder *json.Encoder, id ErrorId, err error) {
r := ResponseHeader{
ErrorId: id,
ResponseType: RNONE,
Message: err.Error(),
}
if err := encoder.Encode(&r); err != nil {
panic(err)
}
2017-04-14 19:26:32 +02:00
}
2017-04-16 21:38:31 +02:00
func inferRequestType(v interface{}) (RequestType, error) {
switch v.(type) {
case ByteStreamRPCProtocolVersionRequest:
2017-04-26 20:25:53 +02:00
return RTProtocolVersionRequest, nil
2017-04-16 21:38:31 +02:00
case FilesystemRequest:
2017-04-26 20:25:53 +02:00
return RTFilesystemRequest, nil
2017-04-16 21:38:31 +02:00
case InitialTransferRequest:
2017-04-26 20:25:53 +02:00
return RTInitialTransferRequest, nil
2017-04-16 21:38:31 +02:00
default:
return 0, errors.New(fmt.Sprintf("cannot infer request type for type '%v'",
2017-04-26 20:25:53 +02:00
reflect.TypeOf(v)))
2017-04-16 21:38:31 +02:00
}
}
func genUUID() [16]byte {
return [16]byte{} // TODO
}
func (c ByteStreamRPC) sendRequest(v interface{}) (err error) {
var rt RequestType
if rt, err = inferRequestType(v); err != nil {
return
}
h := RequestHeader{
Type: rt,
2017-04-26 20:25:53 +02:00
Id: genUUID(),
2017-04-16 21:38:31 +02:00
}
if err = c.encoder.Encode(h); err != nil {
return
}
if err = c.encoder.Encode(v); err != nil {
return
}
return
}
func (c ByteStreamRPC) expectResponseType(rt ResponseType) (err error) {
var h ResponseHeader
if err = c.decoder.Decode(&h); err != nil {
return
}
if h.ResponseType != rt {
return errors.New("unexpected response type in response header")
}
return
}
func (c ByteStreamRPC) sendRequestReceiveHeader(request interface{}, rt ResponseType) (err error) {
if err = c.sendRequest(request); err != nil {
return err
}
if err = c.expectResponseType(rt); err != nil {
return err
}
return nil
}
func (c ByteStreamRPC) ProtocolVersionRequest() (err error) {
b := ByteStreamRPCProtocolVersionRequest{
ClientVersion: ByteStreamRPCProtocolVersion,
}
// OK response means the remote side can cope with our protocol version
return c.sendRequestReceiveHeader(b, ROK)
}
func (c ByteStreamRPC) FilesystemRequest(r FilesystemRequest) (roots []Filesystem, err error) {
2017-04-16 21:38:31 +02:00
if err = c.sendRequestReceiveHeader(r, RFilesystems); err != nil {
return
}
roots = make([]Filesystem, 0)
if err = c.decoder.Decode(&roots); err != nil {
2017-04-16 21:38:31 +02:00
return
}
return
2017-04-14 19:26:32 +02:00
}
2017-04-16 21:38:31 +02:00
func (c ByteStreamRPC) InitialTransferRequest(r InitialTransferRequest) (unchunker io.Reader, err error) {
if err = c.sendRequestReceiveHeader(r, RChunkedStream); err != nil {
return
}
unchunker = NewUnchunker(c.conn)
return
2017-04-14 19:26:32 +02:00
}
2017-04-16 21:38:31 +02:00
func (c ByteStreamRPC) IncrementalTransferRequest(r IncrementalTransferRequest) (unchunker io.Reader, err error) {
if err = c.sendRequestReceiveHeader(r, RChunkedStream); err != nil {
return
}
unchunker = NewUnchunker(c.conn)
return
}
2017-04-14 19:26:32 +02:00
type LocalRPC struct {
handler RPCHandler
}
2017-04-16 21:38:31 +02:00
func ConnectLocalRPC(handler RPCHandler) RPCRequester {
2017-04-14 19:26:32 +02:00
return LocalRPC{handler}
}
2017-04-16 21:38:31 +02:00
func (c LocalRPC) FilesystemRequest(r FilesystemRequest) (roots []Filesystem, err error) {
2017-04-14 19:26:32 +02:00
return c.handler.HandleFilesystemRequest(r)
}
func (c LocalRPC) InitialTransferRequest(r InitialTransferRequest) (io.Reader, error) {
2017-04-14 19:26:32 +02:00
return c.handler.HandleInitialTransferRequest(r)
}
func (c LocalRPC) IncrementalTransferRequest(r IncrementalTransferRequest) (reader io.Reader, err error) {
reader, err = c.handler.HandleIncrementalTransferRequest(r)
return
2017-04-14 19:26:32 +02:00
}