mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-25 01:44:43 +01:00
rpc: chunk JSON parts of communication + refactoring
JSONDecoder was buffering more of connection data than just the JSON. => Unchunker didn't bother and just started unchunking. While chaining JSONDecoder.Buffered() and the connection using ChainedReader works, it's still not a clean architecture. => Every JSON message is now wrapped in a chunked stream (chunked and unchunked) => no special-cases => Keep ChainedReader, might be useful later on...
This commit is contained in:
parent
feabf1abcd
commit
74719ad846
164
rpc/rpc.go
164
rpc/rpc.go
@ -1,6 +1,7 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -33,19 +34,19 @@ type Logger interface {
|
||||
|
||||
const ByteStreamRPCProtocolVersion = 1
|
||||
|
||||
type ByteStreamRPC struct {
|
||||
conn io.ReadWriteCloser
|
||||
encoder *json.Encoder
|
||||
decoder *json.Decoder
|
||||
log Logger
|
||||
type ByteStream interface {
|
||||
io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func ConnectByteStreamRPC(conn io.ReadWriteCloser) (RPCRequester, error) {
|
||||
// TODO do ssh connection to transport, establish TCP-like communication channel
|
||||
type ByteStreamRPC struct {
|
||||
conn ByteStream
|
||||
log Logger
|
||||
}
|
||||
|
||||
func ConnectByteStreamRPC(conn ByteStream) (RPCRequester, error) {
|
||||
|
||||
rpc := ByteStreamRPC{
|
||||
conn: conn,
|
||||
encoder: json.NewEncoder(conn),
|
||||
decoder: json.NewDecoder(conn),
|
||||
conn: conn,
|
||||
}
|
||||
|
||||
// Assert protocol versions are equal
|
||||
@ -57,9 +58,18 @@ func ConnectByteStreamRPC(conn io.ReadWriteCloser) (RPCRequester, error) {
|
||||
return rpc, nil
|
||||
}
|
||||
|
||||
func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger) error {
|
||||
type ByteStreamRPCDecodeJSONError struct {
|
||||
Type reflect.Type
|
||||
DecoderErr error
|
||||
}
|
||||
|
||||
// A request consists of two subsequent JSON objects
|
||||
func (e ByteStreamRPCDecodeJSONError) Error() string {
|
||||
return fmt.Sprintf("cannot decode %s: %s", e.Type, e.DecoderErr)
|
||||
}
|
||||
|
||||
func ListenByteStreamRPC(conn ByteStream, handler RPCHandler, log Logger) error {
|
||||
|
||||
// A request consists of two subsequent chunked JSON objects
|
||||
// Object 1: RequestHeader => contains type of Request Body
|
||||
// Object 2: RequestBody, e.g. IncrementalTransferRequest
|
||||
// A response is always a ResponseHeader followed by bytes to be interpreted
|
||||
@ -69,27 +79,47 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
decoder := json.NewDecoder(conn)
|
||||
encoder := json.NewEncoder(conn)
|
||||
send := func(r interface{}) {
|
||||
if err := writeChunkedJSON(conn, r); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
sendError := func(id ErrorId, msg string) {
|
||||
r := ResponseHeader{
|
||||
ErrorId: id,
|
||||
ResponseType: RNONE,
|
||||
Message: msg,
|
||||
}
|
||||
log.Printf("sending error response: %#v", r)
|
||||
if err := writeChunkedJSON(conn, r); err != nil {
|
||||
log.Printf("error sending error response: %#v", err)
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
recv := func(r interface{}) (err error) {
|
||||
return readChunkedJSON(conn, r)
|
||||
}
|
||||
|
||||
for {
|
||||
|
||||
var header RequestHeader = RequestHeader{}
|
||||
if err := decoder.Decode(&header); err != nil {
|
||||
respondWithError(encoder, EDecodeHeader, err)
|
||||
if err := recv(&header); err != nil {
|
||||
sendError(EDecodeHeader, err.Error())
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
switch header.Type {
|
||||
case RTProtocolVersionRequest:
|
||||
var rq ByteStreamRPCProtocolVersionRequest
|
||||
if err := decoder.Decode(&rq); err != nil {
|
||||
respondWithError(encoder, EDecodeRequestBody, nil)
|
||||
if err := recv(&rq); err != nil {
|
||||
sendError(EDecodeRequestBody, err.Error())
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
if rq.ClientVersion != ByteStreamRPCProtocolVersion {
|
||||
respondWithError(encoder, EProtocolVersionMismatch, nil)
|
||||
sendError(EProtocolVersionMismatch, "")
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
@ -97,70 +127,61 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
|
||||
RequestId: header.Id,
|
||||
ResponseType: ROK,
|
||||
}
|
||||
if err := encoder.Encode(&r); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
send(&r)
|
||||
|
||||
case RTFilesystemRequest:
|
||||
|
||||
var rq FilesystemRequest
|
||||
if err := decoder.Decode(&rq); err != nil {
|
||||
respondWithError(encoder, EDecodeRequestBody, nil)
|
||||
if err := recv(&rq); err != nil {
|
||||
sendError(EDecodeRequestBody, "")
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
roots, err := handler.HandleFilesystemRequest(rq)
|
||||
if err != nil {
|
||||
respondWithError(encoder, EHandler, err)
|
||||
sendError(EHandler, err.Error())
|
||||
return conn.Close()
|
||||
} else {
|
||||
r := ResponseHeader{
|
||||
RequestId: header.Id,
|
||||
ResponseType: RFilesystems,
|
||||
}
|
||||
if err := encoder.Encode(&r); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := encoder.Encode(&roots); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
send(&r)
|
||||
send(&roots)
|
||||
}
|
||||
|
||||
case RTFilesystemVersionsRequest:
|
||||
|
||||
var rq FilesystemVersionsRequest
|
||||
if err := decoder.Decode(&rq); err != nil {
|
||||
respondWithError(encoder, EDecodeRequestBody, err)
|
||||
if err := recv(&rq); err != nil {
|
||||
sendError(EDecodeRequestBody, err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
diff, err := handler.HandleFilesystemVersionsRequest(rq)
|
||||
if err != nil {
|
||||
respondWithError(encoder, EHandler, err)
|
||||
sendError(EHandler, err.Error())
|
||||
return err
|
||||
} else {
|
||||
r := ResponseHeader{
|
||||
RequestId: header.Id,
|
||||
ResponseType: RFilesystemDiff,
|
||||
}
|
||||
if err := encoder.Encode(&r); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := encoder.Encode(&diff); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
send(&r)
|
||||
send(&diff)
|
||||
}
|
||||
|
||||
case RTInitialTransferRequest:
|
||||
var rq InitialTransferRequest
|
||||
if err := decoder.Decode(&rq); err != nil {
|
||||
respondWithError(encoder, EDecodeRequestBody, nil)
|
||||
if err := recv(&rq); err != nil {
|
||||
sendError(EDecodeRequestBody, "")
|
||||
return conn.Close()
|
||||
}
|
||||
log.Printf("initial transfer request: %#v", rq)
|
||||
|
||||
snapReader, err := handler.HandleInitialTransferRequest(rq)
|
||||
if err != nil {
|
||||
respondWithError(encoder, EHandler, err)
|
||||
sendError(EHandler, err.Error())
|
||||
return conn.Close()
|
||||
} else {
|
||||
|
||||
@ -168,9 +189,7 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
|
||||
RequestId: header.Id,
|
||||
ResponseType: RChunkedStream,
|
||||
}
|
||||
if err := encoder.Encode(&r); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
send(&r)
|
||||
|
||||
chunker := NewChunker(snapReader)
|
||||
_, err := io.Copy(conn, &chunker)
|
||||
@ -182,23 +201,21 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
|
||||
case RTIncrementalTransferRequest:
|
||||
|
||||
var rq IncrementalTransferRequest
|
||||
if err := decoder.Decode(&rq); err != nil {
|
||||
respondWithError(encoder, EDecodeRequestBody, nil)
|
||||
if err := recv(&rq); err != nil {
|
||||
sendError(EDecodeRequestBody, "")
|
||||
return conn.Close()
|
||||
}
|
||||
|
||||
snapReader, err := handler.HandleIncrementalTransferRequest(rq)
|
||||
if err != nil {
|
||||
respondWithError(encoder, EHandler, err)
|
||||
sendError(EHandler, err.Error())
|
||||
} else {
|
||||
|
||||
r := ResponseHeader{
|
||||
RequestId: header.Id,
|
||||
ResponseType: RChunkedStream,
|
||||
}
|
||||
if err := encoder.Encode(&r); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
send(&r)
|
||||
|
||||
chunker := NewChunker(snapReader)
|
||||
_, err := io.Copy(conn, &chunker)
|
||||
@ -208,7 +225,7 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
|
||||
}
|
||||
|
||||
default:
|
||||
respondWithError(encoder, EUnknownRequestType, nil)
|
||||
sendError(EUnknownRequestType, "")
|
||||
return conn.Close()
|
||||
}
|
||||
}
|
||||
@ -216,17 +233,30 @@ func ListenByteStreamRPC(conn io.ReadWriteCloser, handler RPCHandler, log Logger
|
||||
return nil
|
||||
}
|
||||
|
||||
func respondWithError(encoder *json.Encoder, id ErrorId, err error) {
|
||||
func writeChunkedJSON(conn io.Writer, r interface{}) (err error) {
|
||||
var buf bytes.Buffer
|
||||
encoder := json.NewEncoder(&buf)
|
||||
encoder.Encode(r)
|
||||
ch := NewChunker(&buf)
|
||||
_, err = io.Copy(conn, &ch)
|
||||
return
|
||||
}
|
||||
|
||||
r := ResponseHeader{
|
||||
ErrorId: id,
|
||||
ResponseType: RNONE,
|
||||
Message: err.Error(),
|
||||
func readChunkedJSON(conn io.ReadWriter, r interface{}) (err error) {
|
||||
unch := NewUnchunker(conn)
|
||||
dec := json.NewDecoder(unch)
|
||||
err = dec.Decode(r)
|
||||
if err != nil {
|
||||
err = ByteStreamRPCDecodeJSONError{
|
||||
Type: reflect.TypeOf(r),
|
||||
DecoderErr: err,
|
||||
}
|
||||
}
|
||||
if err := encoder.Encode(&r); err != nil {
|
||||
panic(err)
|
||||
closeErr := unch.Close()
|
||||
if err == nil && closeErr != nil {
|
||||
err = closeErr
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func inferRequestType(v interface{}) (RequestType, error) {
|
||||
@ -264,10 +294,10 @@ func (c ByteStreamRPC) sendRequest(v interface{}) (err error) {
|
||||
Id: genUUID(),
|
||||
}
|
||||
|
||||
if err = c.encoder.Encode(h); err != nil {
|
||||
if err = writeChunkedJSON(c.conn, h); err != nil {
|
||||
return
|
||||
}
|
||||
if err = c.encoder.Encode(v); err != nil {
|
||||
if err = writeChunkedJSON(c.conn, v); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -275,8 +305,9 @@ func (c ByteStreamRPC) sendRequest(v interface{}) (err error) {
|
||||
}
|
||||
|
||||
func (c ByteStreamRPC) expectResponseType(rt ResponseType) (err error) {
|
||||
|
||||
var h ResponseHeader
|
||||
if err = c.decoder.Decode(&h); err != nil {
|
||||
if err = readChunkedJSON(c.conn, &h); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -317,7 +348,7 @@ func (c ByteStreamRPC) FilesystemRequest(r FilesystemRequest) (roots []zfs.Datas
|
||||
|
||||
roots = make([]zfs.DatasetPath, 0)
|
||||
|
||||
if err = c.decoder.Decode(&roots); err != nil {
|
||||
if err = readChunkedJSON(c.conn, &roots); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -330,7 +361,7 @@ func (c ByteStreamRPC) FilesystemVersionsRequest(r FilesystemVersionsRequest) (v
|
||||
return
|
||||
}
|
||||
|
||||
err = c.decoder.Decode(&versions)
|
||||
err = readChunkedJSON(c.conn, &versions)
|
||||
return
|
||||
}
|
||||
|
||||
@ -339,6 +370,7 @@ func (c ByteStreamRPC) InitialTransferRequest(r InitialTransferRequest) (unchunk
|
||||
if err = c.sendRequestReceiveHeader(r, RChunkedStream); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
unchunker = NewUnchunker(c.conn)
|
||||
return
|
||||
}
|
||||
|
33
rpc/rpc_test.go
Normal file
33
rpc/rpc_test.go
Normal file
@ -0,0 +1,33 @@
|
||||
package rpc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zrepl/zrepl/util"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestByteStreamRPCDecodeJSONError(t *testing.T) {
|
||||
|
||||
r := strings.NewReader("{'a':'aber'}")
|
||||
|
||||
var chunked bytes.Buffer
|
||||
ch := util.NewChunker(r)
|
||||
io.Copy(&chunked, &ch)
|
||||
|
||||
type SampleType struct {
|
||||
A uint
|
||||
}
|
||||
var s SampleType
|
||||
err := readChunkedJSON(&chunked, &s)
|
||||
assert.NotNil(t, err)
|
||||
|
||||
_, ok := err.(ByteStreamRPCDecodeJSONError)
|
||||
if !ok {
|
||||
t.Errorf("expected ByteStreamRPCDecodeJSONError, got %t\n", err)
|
||||
t.Errorf("%s\n", err)
|
||||
}
|
||||
|
||||
}
|
@ -26,7 +26,7 @@ func NewUnchunker(conn io.Reader) *Unchunker {
|
||||
func (c *Unchunker) Read(b []byte) (n int, err error) {
|
||||
|
||||
if c.finishErr != nil {
|
||||
return 0, err
|
||||
return 0, c.finishErr
|
||||
}
|
||||
|
||||
if c.remainingChunkBytes == 0 {
|
||||
@ -68,6 +68,19 @@ func (c *Unchunker) Read(b []byte) (n int, err error) {
|
||||
|
||||
}
|
||||
|
||||
func (c *Unchunker) Close() (err error) {
|
||||
|
||||
buf := make([]byte, 4096)
|
||||
for err == nil {
|
||||
_, err = c.Read(buf)
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
|
36
util/io.go
Normal file
36
util/io.go
Normal file
@ -0,0 +1,36 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
type ChainedReader struct {
|
||||
Readers []io.Reader
|
||||
curReader int
|
||||
}
|
||||
|
||||
func NewChainedReader(reader ...io.Reader) *ChainedReader {
|
||||
return &ChainedReader{
|
||||
Readers: reader,
|
||||
curReader: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChainedReader) Read(buf []byte) (n int, err error) {
|
||||
|
||||
n = 0
|
||||
|
||||
for c.curReader < len(c.Readers) {
|
||||
n, err = c.Readers[c.curReader].Read(buf)
|
||||
if err == io.EOF {
|
||||
c.curReader++
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
if c.curReader == len(c.Readers) {
|
||||
err = io.EOF // actually, there was no gap
|
||||
}
|
||||
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue
Block a user