rpc: chunk JSON parts of communication + refactoring

JSONDecoder was buffering more of connection data than just the JSON.
=> Unchunker didn't bother and just started unchunking.

While chaining JSONDecoder.Buffered() and the connection using
ChainedReader works, it's still not a clean architecture.

=> Every JSON message is now wrapped in a chunked stream
   (chunked and unchunked)
   => no special-cases
=> Keep ChainedReader, might be useful later on...
This commit is contained in:
Christian Schwarz 2017-05-12 20:39:11 +02:00
parent feabf1abcd
commit 74719ad846
4 changed files with 181 additions and 67 deletions

View File

@ -1,6 +1,7 @@
package rpc
import (
"bytes"
"encoding/json"
"errors"
"fmt"
@ -33,19 +34,19 @@ type Logger interface {
const ByteStreamRPCProtocolVersion = 1
type ByteStream interface {
io.ReadWriteCloser
}
type ByteStreamRPC struct {
conn io.ReadWriteCloser
encoder *json.Encoder
decoder *json.Decoder
conn ByteStream
log Logger
}
func ConnectByteStreamRPC(conn io.ReadWriteCloser) (RPCRequester, error) {
// TODO do ssh connection to transport, establish TCP-like communication channel
func ConnectByteStreamRPC(conn ByteStream) (RPCRequester, error) {
rpc := ByteStreamRPC{
conn: conn,
encoder: json.NewEncoder(conn),
decoder: json.NewDecoder(conn),
}
// Assert protocol versions are equal
@ -57,9 +58,18 @@ func ConnectByteStreamRPC(conn io.ReadWriteCloser) (RPCRequester, error) {
return rpc, nil
}
func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger) error {
type ByteStreamRPCDecodeJSONError struct {
Type reflect.Type
DecoderErr error
}
// A request consists of two subsequent JSON objects
func (e ByteStreamRPCDecodeJSONError) Error() string {
return fmt.Sprintf("cannot decode %s: %s", e.Type, e.DecoderErr)
}
func ListenByteStreamRPC(conn ByteStream, handler RPCHandler, log Logger) error {
// A request consists of two subsequent chunked JSON objects
// Object 1: RequestHeader => contains type of Request Body
// Object 2: RequestBody, e.g. IncrementalTransferRequest
// A response is always a ResponseHeader followed by bytes to be interpreted
@ -69,27 +79,47 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
defer conn.Close()
decoder := json.NewDecoder(conn)
encoder := json.NewEncoder(conn)
send := func(r interface{}) {
if err := writeChunkedJSON(conn, r); err != nil {
panic(err)
}
}
sendError := func(id ErrorId, msg string) {
r := ResponseHeader{
ErrorId: id,
ResponseType: RNONE,
Message: msg,
}
log.Printf("sending error response: %#v", r)
if err := writeChunkedJSON(conn, r); err != nil {
log.Printf("error sending error response: %#v", err)
panic(err)
}
}
recv := func(r interface{}) (err error) {
return readChunkedJSON(conn, r)
}
for {
var header RequestHeader = RequestHeader{}
if err := decoder.Decode(&header); err != nil {
respondWithError(encoder, EDecodeHeader, err)
if err := recv(&header); err != nil {
sendError(EDecodeHeader, err.Error())
return conn.Close()
}
switch header.Type {
case RTProtocolVersionRequest:
var rq ByteStreamRPCProtocolVersionRequest
if err := decoder.Decode(&rq); err != nil {
respondWithError(encoder, EDecodeRequestBody, nil)
if err := recv(&rq); err != nil {
sendError(EDecodeRequestBody, err.Error())
return conn.Close()
}
if rq.ClientVersion != ByteStreamRPCProtocolVersion {
respondWithError(encoder, EProtocolVersionMismatch, nil)
sendError(EProtocolVersionMismatch, "")
return conn.Close()
}
@ -97,70 +127,61 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
RequestId: header.Id,
ResponseType: ROK,
}
if err := encoder.Encode(&r); err != nil {
panic(err)
}
send(&r)
case RTFilesystemRequest:
var rq FilesystemRequest
if err := decoder.Decode(&rq); err != nil {
respondWithError(encoder, EDecodeRequestBody, nil)
if err := recv(&rq); err != nil {
sendError(EDecodeRequestBody, "")
return conn.Close()
}
roots, err := handler.HandleFilesystemRequest(rq)
if err != nil {
respondWithError(encoder, EHandler, err)
sendError(EHandler, err.Error())
return conn.Close()
} else {
r := ResponseHeader{
RequestId: header.Id,
ResponseType: RFilesystems,
}
if err := encoder.Encode(&r); err != nil {
panic(err)
}
if err := encoder.Encode(&roots); err != nil {
panic(err)
}
send(&r)
send(&roots)
}
case RTFilesystemVersionsRequest:
var rq FilesystemVersionsRequest
if err := decoder.Decode(&rq); err != nil {
respondWithError(encoder, EDecodeRequestBody, err)
if err := recv(&rq); err != nil {
sendError(EDecodeRequestBody, err.Error())
return err
}
diff, err := handler.HandleFilesystemVersionsRequest(rq)
if err != nil {
respondWithError(encoder, EHandler, err)
sendError(EHandler, err.Error())
return err
} else {
r := ResponseHeader{
RequestId: header.Id,
ResponseType: RFilesystemDiff,
}
if err := encoder.Encode(&r); err != nil {
panic(err)
}
if err := encoder.Encode(&diff); err != nil {
panic(err)
}
send(&r)
send(&diff)
}
case RTInitialTransferRequest:
var rq InitialTransferRequest
if err := decoder.Decode(&rq); err != nil {
respondWithError(encoder, EDecodeRequestBody, nil)
if err := recv(&rq); err != nil {
sendError(EDecodeRequestBody, "")
return conn.Close()
}
log.Printf("initial transfer request: %#v", rq)
snapReader, err := handler.HandleInitialTransferRequest(rq)
if err != nil {
respondWithError(encoder, EHandler, err)
sendError(EHandler, err.Error())
return conn.Close()
} else {
@ -168,9 +189,7 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
RequestId: header.Id,
ResponseType: RChunkedStream,
}
if err := encoder.Encode(&r); err != nil {
panic(err)
}
send(&r)
chunker := NewChunker(snapReader)
_, err := io.Copy(conn, &chunker)
@ -182,23 +201,21 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
case RTIncrementalTransferRequest:
var rq IncrementalTransferRequest
if err := decoder.Decode(&rq); err != nil {
respondWithError(encoder, EDecodeRequestBody, nil)
if err := recv(&rq); err != nil {
sendError(EDecodeRequestBody, "")
return conn.Close()
}
snapReader, err := handler.HandleIncrementalTransferRequest(rq)
if err != nil {
respondWithError(encoder, EHandler, err)
sendError(EHandler, err.Error())
} else {
r := ResponseHeader{
RequestId: header.Id,
ResponseType: RChunkedStream,
}
if err := encoder.Encode(&r); err != nil {
panic(err)
}
send(&r)
chunker := NewChunker(snapReader)
_, err := io.Copy(conn, &chunker)
@ -208,7 +225,7 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
}
default:
respondWithError(encoder, EUnknownRequestType, nil)
sendError(EUnknownRequestType, "")
return conn.Close()
}
}
@ -216,17 +233,30 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
return nil
}
func respondWithError(encoder *json.Encoder, id ErrorId, err error) {
func writeChunkedJSON(conn io.Writer, r interface{}) (err error) {
var buf bytes.Buffer
encoder := json.NewEncoder(&buf)
encoder.Encode(r)
ch := NewChunker(&buf)
_, err = io.Copy(conn, &ch)
return
}
r := ResponseHeader{
ErrorId: id,
ResponseType: RNONE,
Message: err.Error(),
func readChunkedJSON(conn io.ReadWriter, r interface{}) (err error) {
unch := NewUnchunker(conn)
dec := json.NewDecoder(unch)
err = dec.Decode(r)
if err != nil {
err = ByteStreamRPCDecodeJSONError{
Type: reflect.TypeOf(r),
DecoderErr: err,
}
if err := encoder.Encode(&r); err != nil {
panic(err)
}
closeErr := unch.Close()
if err == nil && closeErr != nil {
err = closeErr
}
return
}
func inferRequestType(v interface{}) (RequestType, error) {
@ -264,10 +294,10 @@ func (c ByteStreamRPC) sendRequest(v interface{}) (err error) {
Id: genUUID(),
}
if err = c.encoder.Encode(h); err != nil {
if err = writeChunkedJSON(c.conn, h); err != nil {
return
}
if err = c.encoder.Encode(v); err != nil {
if err = writeChunkedJSON(c.conn, v); err != nil {
return
}
@ -275,8 +305,9 @@ func (c ByteStreamRPC) sendRequest(v interface{}) (err error) {
}
func (c ByteStreamRPC) expectResponseType(rt ResponseType) (err error) {
var h ResponseHeader
if err = c.decoder.Decode(&h); err != nil {
if err = readChunkedJSON(c.conn, &h); err != nil {
return
}
@ -317,7 +348,7 @@ func (c ByteStreamRPC) FilesystemRequest(r FilesystemRequest) (roots []zfs.Datas
roots = make([]zfs.DatasetPath, 0)
if err = c.decoder.Decode(&roots); err != nil {
if err = readChunkedJSON(c.conn, &roots); err != nil {
return
}
@ -330,7 +361,7 @@ func (c ByteStreamRPC) FilesystemVersionsRequest(r FilesystemVersionsRequest) (v
return
}
err = c.decoder.Decode(&versions)
err = readChunkedJSON(c.conn, &versions)
return
}
@ -339,6 +370,7 @@ func (c ByteStreamRPC) InitialTransferRequest(r InitialTransferRequest) (unchunk
if err = c.sendRequestReceiveHeader(r, RChunkedStream); err != nil {
return
}
unchunker = NewUnchunker(c.conn)
return
}

33
rpc/rpc_test.go Normal file
View File

@ -0,0 +1,33 @@
package rpc
import (
"bytes"
"github.com/stretchr/testify/assert"
"github.com/zrepl/zrepl/util"
"io"
"strings"
"testing"
)
func TestByteStreamRPCDecodeJSONError(t *testing.T) {
r := strings.NewReader("{'a':'aber'}")
var chunked bytes.Buffer
ch := util.NewChunker(r)
io.Copy(&chunked, &ch)
type SampleType struct {
A uint
}
var s SampleType
err := readChunkedJSON(&chunked, &s)
assert.NotNil(t, err)
_, ok := err.(ByteStreamRPCDecodeJSONError)
if !ok {
t.Errorf("expected ByteStreamRPCDecodeJSONError, got %t\n", err)
t.Errorf("%s\n", err)
}
}

View File

@ -26,7 +26,7 @@ func NewUnchunker(conn io.Reader) *Unchunker {
func (c *Unchunker) Read(b []byte) (n int, err error) {
if c.finishErr != nil {
return 0, err
return 0, c.finishErr
}
if c.remainingChunkBytes == 0 {
@ -68,6 +68,19 @@ func (c *Unchunker) Read(b []byte) (n int, err error) {
}
func (c *Unchunker) Close() (err error) {
buf := make([]byte, 4096)
for err == nil {
_, err = c.Read(buf)
if err == io.EOF {
err = nil
break
}
}
return
}
func min(a, b int) int {
if a < b {
return a

36
util/io.go Normal file
View File

@ -0,0 +1,36 @@
package util
import (
"io"
)
type ChainedReader struct {
Readers []io.Reader
curReader int
}
func NewChainedReader(reader ...io.Reader) *ChainedReader {
return &ChainedReader{
Readers: reader,
curReader: 0,
}
}
func (c *ChainedReader) Read(buf []byte) (n int, err error) {
n = 0
for c.curReader < len(c.Readers) {
n, err = c.Readers[c.curReader].Read(buf)
if err == io.EOF {
c.curReader++
continue
}
break
}
if c.curReader == len(c.Readers) {
err = io.EOF // actually, there was no gap
}
return
}