mirror of
https://github.com/zrepl/zrepl.git
synced 2025-06-19 17:27:46 +02:00
chunking: rewrite to handle EOF events correctly
bonus: some tests asserting the chunking protocol is adhered to
This commit is contained in:
parent
9b871fb7c0
commit
61c263b91d
122
util/chunking.go
122
util/chunking.go
@ -13,6 +13,7 @@ type Unchunker struct {
|
|||||||
ChunkCount int
|
ChunkCount int
|
||||||
in io.Reader
|
in io.Reader
|
||||||
remainingChunkBytes uint32
|
remainingChunkBytes uint32
|
||||||
|
finishErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUnchunker(conn io.Reader) *Unchunker {
|
func NewUnchunker(conn io.Reader) *Unchunker {
|
||||||
@ -24,17 +25,23 @@ func NewUnchunker(conn io.Reader) *Unchunker {
|
|||||||
|
|
||||||
func (c *Unchunker) Read(b []byte) (n int, err error) {
|
func (c *Unchunker) Read(b []byte) (n int, err error) {
|
||||||
|
|
||||||
|
if c.finishErr != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
if c.remainingChunkBytes == 0 {
|
if c.remainingChunkBytes == 0 {
|
||||||
|
|
||||||
var nextChunkLen uint32
|
var nextChunkLen uint32
|
||||||
err = binary.Read(c.in, ChunkHeaderByteOrder, &nextChunkLen)
|
err = binary.Read(c.in, ChunkHeaderByteOrder, &nextChunkLen)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
c.finishErr = err // can't handle this
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// A chunk of len 0 indicates end of stream
|
// A chunk of len 0 indicates end of stream
|
||||||
if nextChunkLen == 0 {
|
if nextChunkLen == 0 {
|
||||||
return 0, io.EOF
|
c.finishErr = io.EOF
|
||||||
|
return 0, c.finishErr
|
||||||
}
|
}
|
||||||
|
|
||||||
c.remainingChunkBytes = nextChunkLen
|
c.remainingChunkBytes = nextChunkLen
|
||||||
@ -42,17 +49,18 @@ func (c *Unchunker) Read(b []byte) (n int, err error) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.remainingChunkBytes <= 0 {
|
||||||
|
panic("internal inconsistency: c.remainingChunkBytes must be > 0")
|
||||||
|
}
|
||||||
|
if len(b) <= 0 {
|
||||||
|
panic("cannot read into buffer of length 0")
|
||||||
|
}
|
||||||
|
|
||||||
maxRead := min(int(c.remainingChunkBytes), len(b))
|
maxRead := min(int(c.remainingChunkBytes), len(b))
|
||||||
if maxRead < 0 {
|
|
||||||
panic("Cannot read negative amount of bytes")
|
|
||||||
}
|
|
||||||
if maxRead == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err = c.in.Read(b[0:maxRead])
|
n, err = c.in.Read(b[0:maxRead])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, err
|
return
|
||||||
}
|
}
|
||||||
c.remainingChunkBytes -= uint32(n)
|
c.remainingChunkBytes -= uint32(n)
|
||||||
|
|
||||||
@ -71,9 +79,12 @@ func min(a, b int) int {
|
|||||||
type Chunker struct {
|
type Chunker struct {
|
||||||
ChunkCount int
|
ChunkCount int
|
||||||
in io.Reader
|
in io.Reader
|
||||||
|
inEOF bool
|
||||||
remainingChunkBytes int
|
remainingChunkBytes int
|
||||||
|
payloadBufLen int
|
||||||
payloadBuf []byte
|
payloadBuf []byte
|
||||||
headerBuf *bytes.Buffer
|
headerBuf *bytes.Buffer
|
||||||
|
finalHeaderBuffered bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewChunker(conn io.Reader) Chunker {
|
func NewChunker(conn io.Reader) Chunker {
|
||||||
@ -87,6 +98,7 @@ func NewChunkerSized(conn io.Reader, chunkSize uint32) Chunker {
|
|||||||
return Chunker{
|
return Chunker{
|
||||||
in: conn,
|
in: conn,
|
||||||
remainingChunkBytes: 0,
|
remainingChunkBytes: 0,
|
||||||
|
payloadBufLen: 0,
|
||||||
payloadBuf: buf,
|
payloadBuf: buf,
|
||||||
headerBuf: &bytes.Buffer{},
|
headerBuf: &bytes.Buffer{},
|
||||||
}
|
}
|
||||||
@ -95,37 +107,85 @@ func NewChunkerSized(conn io.Reader, chunkSize uint32) Chunker {
|
|||||||
|
|
||||||
func (c *Chunker) Read(b []byte) (n int, err error) {
|
func (c *Chunker) Read(b []byte) (n int, err error) {
|
||||||
|
|
||||||
//fmt.Printf("chunker: c.remainingChunkBytes: %d len(b): %d\n", c.remainingChunkBytes, len(b))
|
if len(b) == 0 {
|
||||||
|
panic("unexpected empty output buffer")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.inEOF && c.finalHeaderBuffered && c.headerBuf.Len() == 0 { // all bufs empty and no more bytes to expect
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
n = 0
|
n = 0
|
||||||
if c.remainingChunkBytes == 0 {
|
|
||||||
|
if c.headerBuf.Len() > 0 { // first drain the header buf
|
||||||
|
nh, err := c.headerBuf.Read(b[n:])
|
||||||
|
if nh > 0 {
|
||||||
|
n += nh
|
||||||
|
}
|
||||||
|
if nh == 0 || (err != nil && err != io.EOF) {
|
||||||
|
panic("unexpected behavior: in-memory buffer should not throw errors")
|
||||||
|
}
|
||||||
|
if c.headerBuf.Len() != 0 {
|
||||||
|
return n, nil // finish writing the header before we can proceed with payload
|
||||||
|
}
|
||||||
|
if c.finalHeaderBuffered {
|
||||||
|
// we just wrote the final header
|
||||||
|
return n, io.EOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.remainingChunkBytes > 0 { // then drain the payload buf
|
||||||
|
|
||||||
|
npl := copy(b[n:], c.payloadBuf[c.payloadBufLen-c.remainingChunkBytes:c.payloadBufLen])
|
||||||
|
c.remainingChunkBytes -= npl
|
||||||
|
if c.remainingChunkBytes < 0 {
|
||||||
|
panic("unexpected behavior, copy() should not copy more than max(cap(), len())")
|
||||||
|
}
|
||||||
|
n += npl
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.remainingChunkBytes == 0 && !c.inEOF { // fillup bufs
|
||||||
|
|
||||||
newPayloadLen, err := c.in.Read(c.payloadBuf)
|
newPayloadLen, err := c.in.Read(c.payloadBuf)
|
||||||
|
|
||||||
if newPayloadLen == 0 {
|
if newPayloadLen > 0 {
|
||||||
return 0, io.EOF
|
c.payloadBufLen = newPayloadLen
|
||||||
} else if err != nil {
|
c.remainingChunkBytes = newPayloadLen
|
||||||
return newPayloadLen, err
|
|
||||||
}
|
}
|
||||||
|
if err == io.EOF {
|
||||||
c.remainingChunkBytes = newPayloadLen
|
c.inEOF = true
|
||||||
|
} else if err != nil {
|
||||||
// Write chunk header
|
|
||||||
c.headerBuf.Reset()
|
|
||||||
nextChunkLen := uint32(newPayloadLen)
|
|
||||||
headerLen := binary.Size(nextChunkLen)
|
|
||||||
err = binary.Write(c.headerBuf, ChunkHeaderByteOrder, nextChunkLen)
|
|
||||||
if err != nil {
|
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
copy(b[0:headerLen], c.headerBuf.Bytes())
|
if newPayloadLen == 0 { // apparently, this happens with some Readers
|
||||||
n += headerLen
|
c.finalHeaderBuffered = true
|
||||||
c.ChunkCount++
|
}
|
||||||
|
|
||||||
|
// Fill header buf
|
||||||
|
{
|
||||||
|
c.headerBuf.Reset()
|
||||||
|
nextChunkLen := uint32(newPayloadLen)
|
||||||
|
err := binary.Write(c.headerBuf, ChunkHeaderByteOrder, nextChunkLen)
|
||||||
|
if err != nil {
|
||||||
|
panic("unexpected error, write to in-memory buffer should not throw error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.headerBuf.Len() == 0 {
|
||||||
|
panic("unexpected empty header buf")
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if c.remainingChunkBytes == 0 && c.inEOF && !c.finalHeaderBuffered {
|
||||||
|
|
||||||
|
c.headerBuf.Reset()
|
||||||
|
err := binary.Write(c.headerBuf, ChunkHeaderByteOrder, uint32(0))
|
||||||
|
if err != nil {
|
||||||
|
panic("unexpected error, write to in-memory buffer should not throw error [2]")
|
||||||
|
}
|
||||||
|
c.finalHeaderBuffered = true
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
remainingBuf := b[n:]
|
return
|
||||||
n2 := copy(remainingBuf, c.payloadBuf[:c.remainingChunkBytes])
|
|
||||||
//fmt.Printf("chunker: written: %d\n", n+int(n2))
|
|
||||||
c.remainingChunkBytes -= n2
|
|
||||||
return n + int(n2), err
|
|
||||||
}
|
}
|
||||||
|
73
util/chunking_test.go
Normal file
73
util/chunking_test.go
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
package chunking
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"testing/quick"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUnchunker(t *testing.T) {
|
||||||
|
|
||||||
|
buf := bytes.Buffer{}
|
||||||
|
binary.Write(&buf, ChunkHeaderByteOrder, uint32(2))
|
||||||
|
buf.WriteByte(0xca)
|
||||||
|
buf.WriteByte(0xfe)
|
||||||
|
binary.Write(&buf, ChunkHeaderByteOrder, uint32(0))
|
||||||
|
buf.WriteByte(0xff) // sentinel, should not be read
|
||||||
|
|
||||||
|
un := NewUnchunker(&buf)
|
||||||
|
|
||||||
|
recv := bytes.Buffer{}
|
||||||
|
n, err := io.Copy(&recv, un)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, int64(2), n)
|
||||||
|
assert.Equal(t, []byte{0xca, 0xfe}, recv.Bytes())
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChunker(t *testing.T) {
|
||||||
|
|
||||||
|
buf := bytes.Buffer{}
|
||||||
|
buf.WriteByte(0xca)
|
||||||
|
buf.WriteByte(0xfe)
|
||||||
|
|
||||||
|
ch := NewChunker(&buf)
|
||||||
|
|
||||||
|
chunked := bytes.Buffer{}
|
||||||
|
n, err := io.Copy(&chunked, &ch)
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, int64(4+2+4), n)
|
||||||
|
assert.Equal(t, []byte{0x2, 0x0, 0x0, 0x0, 0xca, 0xfe, 0x0, 0x0, 0x0, 0x0}, chunked.Bytes())
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnchunkerUnchunksChunker(t *testing.T) {
|
||||||
|
|
||||||
|
f := func(b []byte) bool {
|
||||||
|
|
||||||
|
buf := bytes.NewBuffer(b)
|
||||||
|
ch := NewChunker(buf)
|
||||||
|
unch := NewUnchunker(&ch)
|
||||||
|
var tx bytes.Buffer
|
||||||
|
_, err := io.Copy(&tx, unch)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return reflect.DeepEqual(b, tx.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := quick.Config{
|
||||||
|
MaxCount: 3 * int(ChunkBufSize),
|
||||||
|
MaxCountScale: 2.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := quick.Check(f, &cfg); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user