zrepl/rpc/frame_layer.go
2018-02-18 13:33:53 +01:00

303 lines
6.7 KiB
Go

package rpc
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"github.com/pkg/errors"
)
type Frame struct {
Type FrameType
NoMoreFrames bool
PayloadLength uint32
}
//go:generate stringer -type=FrameType
type FrameType uint8
const (
FrameTypeHeader FrameType = 0x01
FrameTypeData FrameType = 0x02
FrameTypeTrailer FrameType = 0x03
FrameTypeRST FrameType = 0xff
)
//go:generate stringer -type=Status
type Status uint64
const (
StatusOK Status = 1 + iota
StatusRequestError
StatusServerError
// Returned when an error occurred but the side at fault cannot be determined
StatusError
)
type Header struct {
// Request-only
Endpoint string
// Data type of body (request & reply)
DataType DataType
// Request-only
Accept DataType
// Reply-only
Error Status
// Reply-only
ErrorMessage string
}
func NewErrorHeader(status Status, format string, args ...interface{}) (h *Header) {
h = &Header{}
h.Error = status
h.ErrorMessage = fmt.Sprintf(format, args...)
return
}
//go:generate stringer -type=DataType
type DataType uint8
const (
DataTypeNone DataType = 1 + iota
DataTypeControl
DataTypeMarshaledJSON
DataTypeOctets
)
const (
MAX_PAYLOAD_LENGTH = 4 * 1024 * 1024
MAX_HEADER_LENGTH = 4 * 1024
)
type frameBridgingReader struct {
l *MessageLayer
frameType FrameType
// < 0 means no limit
bytesLeftToLimit int
f Frame
}
func NewFrameBridgingReader(l *MessageLayer, frameType FrameType, totalLimit int) *frameBridgingReader {
return &frameBridgingReader{l, frameType, totalLimit, Frame{}}
}
func (r *frameBridgingReader) Read(b []byte) (n int, err error) {
if r.bytesLeftToLimit == 0 {
r.l.logger.Printf("limit reached, returning EOF")
return 0, io.EOF
}
log := r.l.logger
if r.f.PayloadLength == 0 {
if r.f.NoMoreFrames {
r.l.logger.Printf("no more frames flag set, returning EOF")
err = io.EOF
return
}
log.Printf("reading frame")
r.f, err = r.l.readFrame()
if err != nil {
log.Printf("error reading frame: %+v", err)
return 0, err
}
log.Printf("read frame: %#v", r.f)
if r.f.Type != r.frameType {
err = errors.Wrapf(err, "expected frame of type %s", r.frameType)
return 0, err
}
}
maxread := len(b)
if maxread > int(r.f.PayloadLength) {
maxread = int(r.f.PayloadLength)
}
if r.bytesLeftToLimit > 0 && maxread > r.bytesLeftToLimit {
maxread = r.bytesLeftToLimit
}
nb, err := r.l.rwc.Read(b[:maxread])
log.Printf("read %v from rwc\n", nb)
if nb < 0 {
panic("should not return negative number of bytes")
}
r.f.PayloadLength -= uint32(nb)
r.bytesLeftToLimit -= nb
return nb, err // TODO io.EOF for maxread = r.f.PayloadLength ?
}
type frameBridgingWriter struct {
l *MessageLayer
frameType FrameType
// < 0 means no limit
bytesLeftToLimit int
payloadLength int
buffer *bytes.Buffer
}
func NewFrameBridgingWriter(l *MessageLayer, frameType FrameType, totalLimit int) *frameBridgingWriter {
return &frameBridgingWriter{l, frameType, totalLimit, MAX_PAYLOAD_LENGTH, bytes.NewBuffer(make([]byte, 0, MAX_PAYLOAD_LENGTH))}
}
func (w *frameBridgingWriter) Write(b []byte) (n int, err error) {
for n = 0; n < len(b); {
i, err := w.writeUntilFrameFull(b[n:])
n += i
if err != nil {
return n, errors.WithStack(err)
}
}
return
}
func (w *frameBridgingWriter) writeUntilFrameFull(b []byte) (n int, err error) {
if len(b) <= 0 {
return
}
if w.bytesLeftToLimit == 0 {
err = errors.Errorf("message exceeds max number of allowed bytes")
return
}
maxwrite := len(b)
remainingInFrame := w.payloadLength - w.buffer.Len()
if maxwrite > remainingInFrame {
maxwrite = remainingInFrame
}
if w.bytesLeftToLimit > 0 && maxwrite > w.bytesLeftToLimit {
maxwrite = w.bytesLeftToLimit
}
w.buffer.Write(b[:maxwrite])
w.bytesLeftToLimit -= maxwrite
n = maxwrite
if w.bytesLeftToLimit == 0 {
err = w.flush(true)
} else if w.buffer.Len() == w.payloadLength {
err = w.flush(false)
}
return
}
func (w *frameBridgingWriter) flush(nomore bool) (err error) {
f := Frame{w.frameType, nomore, uint32(w.buffer.Len())}
err = w.l.writeFrame(f)
if err != nil {
errors.WithStack(err)
}
_, err = w.buffer.WriteTo(w.l.rwc)
return
}
func (w *frameBridgingWriter) Close() (err error) {
return w.flush(true)
}
type MessageLayer struct {
rwc io.ReadWriteCloser
logger Logger
}
func NewMessageLayer(rwc io.ReadWriteCloser) *MessageLayer {
return &MessageLayer{rwc, noLogger{}}
}
func (l *MessageLayer) Close() (err error) {
f := Frame{
Type: FrameTypeRST,
NoMoreFrames: true,
}
if err = l.writeFrame(f); err != nil {
l.logger.Printf("error sending RST frame: %s", err)
return errors.WithStack(err)
}
return nil
}
var RST error = fmt.Errorf("reset frame observed on connection")
func (l *MessageLayer) readFrame() (f Frame, err error) {
err = binary.Read(l.rwc, binary.LittleEndian, &f.Type)
if err != nil {
err = errors.WithStack(err)
return
}
err = binary.Read(l.rwc, binary.LittleEndian, &f.NoMoreFrames)
if err != nil {
err = errors.WithStack(err)
return
}
err = binary.Read(l.rwc, binary.LittleEndian, &f.PayloadLength)
if err != nil {
err = errors.WithStack(err)
return
}
if f.Type == FrameTypeRST {
l.logger.Printf("read RST frame")
err = RST
return
}
if f.PayloadLength > MAX_PAYLOAD_LENGTH {
err = errors.Errorf("frame exceeds max payload length")
return
}
return
}
func (l *MessageLayer) writeFrame(f Frame) (err error) {
err = binary.Write(l.rwc, binary.LittleEndian, &f.Type)
if err != nil {
return errors.WithStack(err)
}
err = binary.Write(l.rwc, binary.LittleEndian, &f.NoMoreFrames)
if err != nil {
return errors.WithStack(err)
}
err = binary.Write(l.rwc, binary.LittleEndian, &f.PayloadLength)
if err != nil {
return errors.WithStack(err)
}
if f.PayloadLength > MAX_PAYLOAD_LENGTH {
err = errors.Errorf("frame exceeds max payload length")
return
}
return
}
func (l *MessageLayer) ReadHeader() (h *Header, err error) {
r := NewFrameBridgingReader(l, FrameTypeHeader, MAX_HEADER_LENGTH)
h = &Header{}
if err = json.NewDecoder(r).Decode(&h); err != nil {
l.logger.Printf("cannot decode marshaled header: %s", err)
return nil, err
}
return h, nil
}
func (l *MessageLayer) WriteHeader(h *Header) (err error) {
w := NewFrameBridgingWriter(l, FrameTypeHeader, MAX_HEADER_LENGTH)
err = json.NewEncoder(w).Encode(h)
if err != nil {
return errors.Wrap(err, "cannot encode header, probably fatal")
}
w.Close()
return
}
func (l *MessageLayer) ReadData() (reader io.Reader) {
r := NewFrameBridgingReader(l, FrameTypeData, -1)
return r
}
func (l *MessageLayer) WriteData(source io.Reader) (err error) {
w := NewFrameBridgingWriter(l, FrameTypeData, -1)
_, err = io.Copy(w, source)
if err != nil {
return errors.WithStack(err)
}
err = w.Close()
return
}