mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-22 00:13:52 +01:00
chunking: rewrite to handle EOF events correctly
bonus: some tests asserting the chunking protocol is adhered to
This commit is contained in:
parent
9b871fb7c0
commit
61c263b91d
122
util/chunking.go
122
util/chunking.go
@ -13,6 +13,7 @@ type Unchunker struct {
|
||||
ChunkCount int
|
||||
in io.Reader
|
||||
remainingChunkBytes uint32
|
||||
finishErr error
|
||||
}
|
||||
|
||||
func NewUnchunker(conn io.Reader) *Unchunker {
|
||||
@ -24,17 +25,23 @@ func NewUnchunker(conn io.Reader) *Unchunker {
|
||||
|
||||
func (c *Unchunker) Read(b []byte) (n int, err error) {
|
||||
|
||||
if c.finishErr != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if c.remainingChunkBytes == 0 {
|
||||
|
||||
var nextChunkLen uint32
|
||||
err = binary.Read(c.in, ChunkHeaderByteOrder, &nextChunkLen)
|
||||
if err != nil {
|
||||
c.finishErr = err // can't handle this
|
||||
return
|
||||
}
|
||||
|
||||
// A chunk of len 0 indicates end of stream
|
||||
if nextChunkLen == 0 {
|
||||
return 0, io.EOF
|
||||
c.finishErr = io.EOF
|
||||
return 0, c.finishErr
|
||||
}
|
||||
|
||||
c.remainingChunkBytes = nextChunkLen
|
||||
@ -42,17 +49,18 @@ func (c *Unchunker) Read(b []byte) (n int, err error) {
|
||||
|
||||
}
|
||||
|
||||
if c.remainingChunkBytes <= 0 {
|
||||
panic("internal inconsistency: c.remainingChunkBytes must be > 0")
|
||||
}
|
||||
if len(b) <= 0 {
|
||||
panic("cannot read into buffer of length 0")
|
||||
}
|
||||
|
||||
maxRead := min(int(c.remainingChunkBytes), len(b))
|
||||
if maxRead < 0 {
|
||||
panic("Cannot read negative amount of bytes")
|
||||
}
|
||||
if maxRead == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
n, err = c.in.Read(b[0:maxRead])
|
||||
if err != nil {
|
||||
return n, err
|
||||
return
|
||||
}
|
||||
c.remainingChunkBytes -= uint32(n)
|
||||
|
||||
@ -71,9 +79,12 @@ func min(a, b int) int {
|
||||
type Chunker struct {
|
||||
ChunkCount int
|
||||
in io.Reader
|
||||
inEOF bool
|
||||
remainingChunkBytes int
|
||||
payloadBufLen int
|
||||
payloadBuf []byte
|
||||
headerBuf *bytes.Buffer
|
||||
finalHeaderBuffered bool
|
||||
}
|
||||
|
||||
func NewChunker(conn io.Reader) Chunker {
|
||||
@ -87,6 +98,7 @@ func NewChunkerSized(conn io.Reader, chunkSize uint32) Chunker {
|
||||
return Chunker{
|
||||
in: conn,
|
||||
remainingChunkBytes: 0,
|
||||
payloadBufLen: 0,
|
||||
payloadBuf: buf,
|
||||
headerBuf: &bytes.Buffer{},
|
||||
}
|
||||
@ -95,37 +107,85 @@ func NewChunkerSized(conn io.Reader, chunkSize uint32) Chunker {
|
||||
|
||||
func (c *Chunker) Read(b []byte) (n int, err error) {
|
||||
|
||||
//fmt.Printf("chunker: c.remainingChunkBytes: %d len(b): %d\n", c.remainingChunkBytes, len(b))
|
||||
if len(b) == 0 {
|
||||
panic("unexpected empty output buffer")
|
||||
}
|
||||
|
||||
if c.inEOF && c.finalHeaderBuffered && c.headerBuf.Len() == 0 { // all bufs empty and no more bytes to expect
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n = 0
|
||||
if c.remainingChunkBytes == 0 {
|
||||
|
||||
if c.headerBuf.Len() > 0 { // first drain the header buf
|
||||
nh, err := c.headerBuf.Read(b[n:])
|
||||
if nh > 0 {
|
||||
n += nh
|
||||
}
|
||||
if nh == 0 || (err != nil && err != io.EOF) {
|
||||
panic("unexpected behavior: in-memory buffer should not throw errors")
|
||||
}
|
||||
if c.headerBuf.Len() != 0 {
|
||||
return n, nil // finish writing the header before we can proceed with payload
|
||||
}
|
||||
if c.finalHeaderBuffered {
|
||||
// we just wrote the final header
|
||||
return n, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
if c.remainingChunkBytes > 0 { // then drain the payload buf
|
||||
|
||||
npl := copy(b[n:], c.payloadBuf[c.payloadBufLen-c.remainingChunkBytes:c.payloadBufLen])
|
||||
c.remainingChunkBytes -= npl
|
||||
if c.remainingChunkBytes < 0 {
|
||||
panic("unexpected behavior, copy() should not copy more than max(cap(), len())")
|
||||
}
|
||||
n += npl
|
||||
}
|
||||
|
||||
if c.remainingChunkBytes == 0 && !c.inEOF { // fillup bufs
|
||||
|
||||
newPayloadLen, err := c.in.Read(c.payloadBuf)
|
||||
|
||||
if newPayloadLen == 0 {
|
||||
return 0, io.EOF
|
||||
} else if err != nil {
|
||||
return newPayloadLen, err
|
||||
if newPayloadLen > 0 {
|
||||
c.payloadBufLen = newPayloadLen
|
||||
c.remainingChunkBytes = newPayloadLen
|
||||
}
|
||||
|
||||
c.remainingChunkBytes = newPayloadLen
|
||||
|
||||
// Write chunk header
|
||||
c.headerBuf.Reset()
|
||||
nextChunkLen := uint32(newPayloadLen)
|
||||
headerLen := binary.Size(nextChunkLen)
|
||||
err = binary.Write(c.headerBuf, ChunkHeaderByteOrder, nextChunkLen)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
c.inEOF = true
|
||||
} else if err != nil {
|
||||
return n, err
|
||||
}
|
||||
copy(b[0:headerLen], c.headerBuf.Bytes())
|
||||
n += headerLen
|
||||
c.ChunkCount++
|
||||
if newPayloadLen == 0 { // apparently, this happens with some Readers
|
||||
c.finalHeaderBuffered = true
|
||||
}
|
||||
|
||||
// Fill header buf
|
||||
{
|
||||
c.headerBuf.Reset()
|
||||
nextChunkLen := uint32(newPayloadLen)
|
||||
err := binary.Write(c.headerBuf, ChunkHeaderByteOrder, nextChunkLen)
|
||||
if err != nil {
|
||||
panic("unexpected error, write to in-memory buffer should not throw error")
|
||||
}
|
||||
}
|
||||
|
||||
if c.headerBuf.Len() == 0 {
|
||||
panic("unexpected empty header buf")
|
||||
}
|
||||
|
||||
} else if c.remainingChunkBytes == 0 && c.inEOF && !c.finalHeaderBuffered {
|
||||
|
||||
c.headerBuf.Reset()
|
||||
err := binary.Write(c.headerBuf, ChunkHeaderByteOrder, uint32(0))
|
||||
if err != nil {
|
||||
panic("unexpected error, write to in-memory buffer should not throw error [2]")
|
||||
}
|
||||
c.finalHeaderBuffered = true
|
||||
|
||||
}
|
||||
|
||||
remainingBuf := b[n:]
|
||||
n2 := copy(remainingBuf, c.payloadBuf[:c.remainingChunkBytes])
|
||||
//fmt.Printf("chunker: written: %d\n", n+int(n2))
|
||||
c.remainingChunkBytes -= n2
|
||||
return n + int(n2), err
|
||||
return
|
||||
|
||||
}
|
||||
|
73
util/chunking_test.go
Normal file
73
util/chunking_test.go
Normal file
@ -0,0 +1,73 @@
|
||||
package chunking
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
)
|
||||
|
||||
func TestUnchunker(t *testing.T) {
|
||||
|
||||
buf := bytes.Buffer{}
|
||||
binary.Write(&buf, ChunkHeaderByteOrder, uint32(2))
|
||||
buf.WriteByte(0xca)
|
||||
buf.WriteByte(0xfe)
|
||||
binary.Write(&buf, ChunkHeaderByteOrder, uint32(0))
|
||||
buf.WriteByte(0xff) // sentinel, should not be read
|
||||
|
||||
un := NewUnchunker(&buf)
|
||||
|
||||
recv := bytes.Buffer{}
|
||||
n, err := io.Copy(&recv, un)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(2), n)
|
||||
assert.Equal(t, []byte{0xca, 0xfe}, recv.Bytes())
|
||||
|
||||
}
|
||||
|
||||
func TestChunker(t *testing.T) {
|
||||
|
||||
buf := bytes.Buffer{}
|
||||
buf.WriteByte(0xca)
|
||||
buf.WriteByte(0xfe)
|
||||
|
||||
ch := NewChunker(&buf)
|
||||
|
||||
chunked := bytes.Buffer{}
|
||||
n, err := io.Copy(&chunked, &ch)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, int64(4+2+4), n)
|
||||
assert.Equal(t, []byte{0x2, 0x0, 0x0, 0x0, 0xca, 0xfe, 0x0, 0x0, 0x0, 0x0}, chunked.Bytes())
|
||||
|
||||
}
|
||||
|
||||
func TestUnchunkerUnchunksChunker(t *testing.T) {
|
||||
|
||||
f := func(b []byte) bool {
|
||||
|
||||
buf := bytes.NewBuffer(b)
|
||||
ch := NewChunker(buf)
|
||||
unch := NewUnchunker(&ch)
|
||||
var tx bytes.Buffer
|
||||
_, err := io.Copy(&tx, unch)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return reflect.DeepEqual(b, tx.Bytes())
|
||||
}
|
||||
|
||||
cfg := quick.Config{
|
||||
MaxCount: 3 * int(ChunkBufSize),
|
||||
MaxCountScale: 2.0,
|
||||
}
|
||||
|
||||
if err := quick.Check(f, &cfg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue
Block a user