rpc: chunk JSON parts of communication + refactoring

JSONDecoder was buffering more of connection data than just the JSON.
=> Unchunker didn't bother and just started unchunking.

While chaining JSONDecoder.Buffered() and the connection using
ChainedReader works, it's still not a clean architecture.

=> Every JSON message is now wrapped in a chunked stream
   (chunked and unchunked)
   => no special-cases
=> Keep ChainedReader, might be useful later on...
This commit is contained in:
Christian Schwarz 2017-05-12 20:39:11 +02:00
parent feabf1abcd
commit 74719ad846
4 changed files with 181 additions and 67 deletions

View File

@ -1,6 +1,7 @@
package rpc package rpc
import ( import (
"bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
@ -33,19 +34,19 @@ type Logger interface {
const ByteStreamRPCProtocolVersion = 1 const ByteStreamRPCProtocolVersion = 1
type ByteStream interface {
io.ReadWriteCloser
}
type ByteStreamRPC struct { type ByteStreamRPC struct {
conn io.ReadWriteCloser conn ByteStream
encoder *json.Encoder
decoder *json.Decoder
log Logger log Logger
} }
func ConnectByteStreamRPC(conn io.ReadWriteCloser) (RPCRequester, error) { func ConnectByteStreamRPC(conn ByteStream) (RPCRequester, error) {
// TODO do ssh connection to transport, establish TCP-like communication channel
rpc := ByteStreamRPC{ rpc := ByteStreamRPC{
conn: conn, conn: conn,
encoder: json.NewEncoder(conn),
decoder: json.NewDecoder(conn),
} }
// Assert protocol versions are equal // Assert protocol versions are equal
@ -57,9 +58,18 @@ func ConnectByteStreamRPC(conn io.ReadWriteCloser) (RPCRequester, error) {
return rpc, nil return rpc, nil
} }
func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger) error { type ByteStreamRPCDecodeJSONError struct {
Type reflect.Type
DecoderErr error
}
// A request consists of two subsequent JSON objects func (e ByteStreamRPCDecodeJSONError) Error() string {
return fmt.Sprintf("cannot decode %s: %s", e.Type, e.DecoderErr)
}
func ListenByteStreamRPC(conn ByteStream, handler RPCHandler, log Logger) error {
// A request consists of two subsequent chunked JSON objects
// Object 1: RequestHeader => contains type of Request Body // Object 1: RequestHeader => contains type of Request Body
// Object 2: RequestBody, e.g. IncrementalTransferRequest // Object 2: RequestBody, e.g. IncrementalTransferRequest
// A response is always a ResponseHeader followed by bytes to be interpreted // A response is always a ResponseHeader followed by bytes to be interpreted
@ -69,27 +79,47 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
defer conn.Close() defer conn.Close()
decoder := json.NewDecoder(conn) send := func(r interface{}) {
encoder := json.NewEncoder(conn) if err := writeChunkedJSON(conn, r); err != nil {
panic(err)
}
}
sendError := func(id ErrorId, msg string) {
r := ResponseHeader{
ErrorId: id,
ResponseType: RNONE,
Message: msg,
}
log.Printf("sending error response: %#v", r)
if err := writeChunkedJSON(conn, r); err != nil {
log.Printf("error sending error response: %#v", err)
panic(err)
}
}
recv := func(r interface{}) (err error) {
return readChunkedJSON(conn, r)
}
for { for {
var header RequestHeader = RequestHeader{} var header RequestHeader = RequestHeader{}
if err := decoder.Decode(&header); err != nil { if err := recv(&header); err != nil {
respondWithError(encoder, EDecodeHeader, err) sendError(EDecodeHeader, err.Error())
return conn.Close() return conn.Close()
} }
switch header.Type { switch header.Type {
case RTProtocolVersionRequest: case RTProtocolVersionRequest:
var rq ByteStreamRPCProtocolVersionRequest var rq ByteStreamRPCProtocolVersionRequest
if err := decoder.Decode(&rq); err != nil { if err := recv(&rq); err != nil {
respondWithError(encoder, EDecodeRequestBody, nil) sendError(EDecodeRequestBody, err.Error())
return conn.Close() return conn.Close()
} }
if rq.ClientVersion != ByteStreamRPCProtocolVersion { if rq.ClientVersion != ByteStreamRPCProtocolVersion {
respondWithError(encoder, EProtocolVersionMismatch, nil) sendError(EProtocolVersionMismatch, "")
return conn.Close() return conn.Close()
} }
@ -97,70 +127,61 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
RequestId: header.Id, RequestId: header.Id,
ResponseType: ROK, ResponseType: ROK,
} }
if err := encoder.Encode(&r); err != nil { send(&r)
panic(err)
}
case RTFilesystemRequest: case RTFilesystemRequest:
var rq FilesystemRequest var rq FilesystemRequest
if err := decoder.Decode(&rq); err != nil { if err := recv(&rq); err != nil {
respondWithError(encoder, EDecodeRequestBody, nil) sendError(EDecodeRequestBody, "")
return conn.Close() return conn.Close()
} }
roots, err := handler.HandleFilesystemRequest(rq) roots, err := handler.HandleFilesystemRequest(rq)
if err != nil { if err != nil {
respondWithError(encoder, EHandler, err) sendError(EHandler, err.Error())
return conn.Close() return conn.Close()
} else { } else {
r := ResponseHeader{ r := ResponseHeader{
RequestId: header.Id, RequestId: header.Id,
ResponseType: RFilesystems, ResponseType: RFilesystems,
} }
if err := encoder.Encode(&r); err != nil { send(&r)
panic(err) send(&roots)
}
if err := encoder.Encode(&roots); err != nil {
panic(err)
}
} }
case RTFilesystemVersionsRequest: case RTFilesystemVersionsRequest:
var rq FilesystemVersionsRequest var rq FilesystemVersionsRequest
if err := decoder.Decode(&rq); err != nil { if err := recv(&rq); err != nil {
respondWithError(encoder, EDecodeRequestBody, err) sendError(EDecodeRequestBody, err.Error())
return err return err
} }
diff, err := handler.HandleFilesystemVersionsRequest(rq) diff, err := handler.HandleFilesystemVersionsRequest(rq)
if err != nil { if err != nil {
respondWithError(encoder, EHandler, err) sendError(EHandler, err.Error())
return err return err
} else { } else {
r := ResponseHeader{ r := ResponseHeader{
RequestId: header.Id, RequestId: header.Id,
ResponseType: RFilesystemDiff, ResponseType: RFilesystemDiff,
} }
if err := encoder.Encode(&r); err != nil { send(&r)
panic(err) send(&diff)
}
if err := encoder.Encode(&diff); err != nil {
panic(err)
}
} }
case RTInitialTransferRequest: case RTInitialTransferRequest:
var rq InitialTransferRequest var rq InitialTransferRequest
if err := decoder.Decode(&rq); err != nil { if err := recv(&rq); err != nil {
respondWithError(encoder, EDecodeRequestBody, nil) sendError(EDecodeRequestBody, "")
return conn.Close() return conn.Close()
} }
log.Printf("initial transfer request: %#v", rq)
snapReader, err := handler.HandleInitialTransferRequest(rq) snapReader, err := handler.HandleInitialTransferRequest(rq)
if err != nil { if err != nil {
respondWithError(encoder, EHandler, err) sendError(EHandler, err.Error())
return conn.Close() return conn.Close()
} else { } else {
@ -168,9 +189,7 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
RequestId: header.Id, RequestId: header.Id,
ResponseType: RChunkedStream, ResponseType: RChunkedStream,
} }
if err := encoder.Encode(&r); err != nil { send(&r)
panic(err)
}
chunker := NewChunker(snapReader) chunker := NewChunker(snapReader)
_, err := io.Copy(conn, &chunker) _, err := io.Copy(conn, &chunker)
@ -182,23 +201,21 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
case RTIncrementalTransferRequest: case RTIncrementalTransferRequest:
var rq IncrementalTransferRequest var rq IncrementalTransferRequest
if err := decoder.Decode(&rq); err != nil { if err := recv(&rq); err != nil {
respondWithError(encoder, EDecodeRequestBody, nil) sendError(EDecodeRequestBody, "")
return conn.Close() return conn.Close()
} }
snapReader, err := handler.HandleIncrementalTransferRequest(rq) snapReader, err := handler.HandleIncrementalTransferRequest(rq)
if err != nil { if err != nil {
respondWithError(encoder, EHandler, err) sendError(EHandler, err.Error())
} else { } else {
r := ResponseHeader{ r := ResponseHeader{
RequestId: header.Id, RequestId: header.Id,
ResponseType: RChunkedStream, ResponseType: RChunkedStream,
} }
if err := encoder.Encode(&r); err != nil { send(&r)
panic(err)
}
chunker := NewChunker(snapReader) chunker := NewChunker(snapReader)
_, err := io.Copy(conn, &chunker) _, err := io.Copy(conn, &chunker)
@ -208,7 +225,7 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
} }
default: default:
respondWithError(encoder, EUnknownRequestType, nil) sendError(EUnknownRequestType, "")
return conn.Close() return conn.Close()
} }
} }
@ -216,17 +233,30 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
return nil return nil
} }
func respondWithError(encoder *json.Encoder, id ErrorId, err error) { func writeChunkedJSON(conn io.Writer, r interface{}) (err error) {
var buf bytes.Buffer
encoder := json.NewEncoder(&buf)
encoder.Encode(r)
ch := NewChunker(&buf)
_, err = io.Copy(conn, &ch)
return
}
r := ResponseHeader{ func readChunkedJSON(conn io.ReadWriter, r interface{}) (err error) {
ErrorId: id, unch := NewUnchunker(conn)
ResponseType: RNONE, dec := json.NewDecoder(unch)
Message: err.Error(), err = dec.Decode(r)
if err != nil {
err = ByteStreamRPCDecodeJSONError{
Type: reflect.TypeOf(r),
DecoderErr: err,
} }
if err := encoder.Encode(&r); err != nil {
panic(err)
} }
closeErr := unch.Close()
if err == nil && closeErr != nil {
err = closeErr
}
return
} }
func inferRequestType(v interface{}) (RequestType, error) { func inferRequestType(v interface{}) (RequestType, error) {
@ -264,10 +294,10 @@ func (c ByteStreamRPC) sendRequest(v interface{}) (err error) {
Id: genUUID(), Id: genUUID(),
} }
if err = c.encoder.Encode(h); err != nil { if err = writeChunkedJSON(c.conn, h); err != nil {
return return
} }
if err = c.encoder.Encode(v); err != nil { if err = writeChunkedJSON(c.conn, v); err != nil {
return return
} }
@ -275,8 +305,9 @@ func (c ByteStreamRPC) sendRequest(v interface{}) (err error) {
} }
func (c ByteStreamRPC) expectResponseType(rt ResponseType) (err error) { func (c ByteStreamRPC) expectResponseType(rt ResponseType) (err error) {
var h ResponseHeader var h ResponseHeader
if err = c.decoder.Decode(&h); err != nil { if err = readChunkedJSON(c.conn, &h); err != nil {
return return
} }
@ -317,7 +348,7 @@ func (c ByteStreamRPC) FilesystemRequest(r FilesystemRequest) (roots []zfs.Datas
roots = make([]zfs.DatasetPath, 0) roots = make([]zfs.DatasetPath, 0)
if err = c.decoder.Decode(&roots); err != nil { if err = readChunkedJSON(c.conn, &roots); err != nil {
return return
} }
@ -330,7 +361,7 @@ func (c ByteStreamRPC) FilesystemVersionsRequest(r FilesystemVersionsRequest) (v
return return
} }
err = c.decoder.Decode(&versions) err = readChunkedJSON(c.conn, &versions)
return return
} }
@ -339,6 +370,7 @@ func (c ByteStreamRPC) InitialTransferRequest(r InitialTransferRequest) (unchunk
if err = c.sendRequestReceiveHeader(r, RChunkedStream); err != nil { if err = c.sendRequestReceiveHeader(r, RChunkedStream); err != nil {
return return
} }
unchunker = NewUnchunker(c.conn) unchunker = NewUnchunker(c.conn)
return return
} }

33
rpc/rpc_test.go Normal file
View File

@ -0,0 +1,33 @@
package rpc
import (
"bytes"
"github.com/stretchr/testify/assert"
"github.com/zrepl/zrepl/util"
"io"
"strings"
"testing"
)
func TestByteStreamRPCDecodeJSONError(t *testing.T) {
r := strings.NewReader("{'a':'aber'}")
var chunked bytes.Buffer
ch := util.NewChunker(r)
io.Copy(&chunked, &ch)
type SampleType struct {
A uint
}
var s SampleType
err := readChunkedJSON(&chunked, &s)
assert.NotNil(t, err)
_, ok := err.(ByteStreamRPCDecodeJSONError)
if !ok {
t.Errorf("expected ByteStreamRPCDecodeJSONError, got %t\n", err)
t.Errorf("%s\n", err)
}
}

View File

@ -26,7 +26,7 @@ func NewUnchunker(conn io.Reader) *Unchunker {
func (c *Unchunker) Read(b []byte) (n int, err error) { func (c *Unchunker) Read(b []byte) (n int, err error) {
if c.finishErr != nil { if c.finishErr != nil {
return 0, err return 0, c.finishErr
} }
if c.remainingChunkBytes == 0 { if c.remainingChunkBytes == 0 {
@ -68,6 +68,19 @@ func (c *Unchunker) Read(b []byte) (n int, err error) {
} }
func (c *Unchunker) Close() (err error) {
buf := make([]byte, 4096)
for err == nil {
_, err = c.Read(buf)
if err == io.EOF {
err = nil
break
}
}
return
}
func min(a, b int) int { func min(a, b int) int {
if a < b { if a < b {
return a return a

36
util/io.go Normal file
View File

@ -0,0 +1,36 @@
package util
import (
"io"
)
type ChainedReader struct {
Readers []io.Reader
curReader int
}
func NewChainedReader(reader ...io.Reader) *ChainedReader {
return &ChainedReader{
Readers: reader,
curReader: 0,
}
}
func (c *ChainedReader) Read(buf []byte) (n int, err error) {
n = 0
for c.curReader < len(c.Readers) {
n, err = c.Readers[c.curReader].Read(buf)
if err == io.EOF {
c.curReader++
continue
}
break
}
if c.curReader == len(c.Readers) {
err = io.EOF // actually, there was no gap
}
return
}