chunking: rewrite to handle EOF events correctly

bonus: some tests asserting the chunking protocol is adhered to
This commit is contained in:
Christian Schwarz 2017-05-06 23:41:51 +02:00
parent 9b871fb7c0
commit 61c263b91d
2 changed files with 164 additions and 31 deletions

View File

@ -13,6 +13,7 @@ type Unchunker struct {
ChunkCount int
in io.Reader
remainingChunkBytes uint32
finishErr error
}
func NewUnchunker(conn io.Reader) *Unchunker {
@ -24,17 +25,23 @@ func NewUnchunker(conn io.Reader) *Unchunker {
func (c *Unchunker) Read(b []byte) (n int, err error) {
if c.finishErr != nil {
return 0, err
}
if c.remainingChunkBytes == 0 {
var nextChunkLen uint32
err = binary.Read(c.in, ChunkHeaderByteOrder, &nextChunkLen)
if err != nil {
c.finishErr = err // can't handle this
return
}
// A chunk of len 0 indicates end of stream
if nextChunkLen == 0 {
return 0, io.EOF
c.finishErr = io.EOF
return 0, c.finishErr
}
c.remainingChunkBytes = nextChunkLen
@ -42,17 +49,18 @@ func (c *Unchunker) Read(b []byte) (n int, err error) {
}
if c.remainingChunkBytes <= 0 {
panic("internal inconsistency: c.remainingChunkBytes must be > 0")
}
if len(b) <= 0 {
panic("cannot read into buffer of length 0")
}
maxRead := min(int(c.remainingChunkBytes), len(b))
if maxRead < 0 {
panic("Cannot read negative amount of bytes")
}
if maxRead == 0 {
return 0, nil
}
n, err = c.in.Read(b[0:maxRead])
if err != nil {
return n, err
return
}
c.remainingChunkBytes -= uint32(n)
@ -71,9 +79,12 @@ func min(a, b int) int {
type Chunker struct {
ChunkCount int
in io.Reader
inEOF bool
remainingChunkBytes int
payloadBufLen int
payloadBuf []byte
headerBuf *bytes.Buffer
finalHeaderBuffered bool
}
func NewChunker(conn io.Reader) Chunker {
@ -87,6 +98,7 @@ func NewChunkerSized(conn io.Reader, chunkSize uint32) Chunker {
return Chunker{
in: conn,
remainingChunkBytes: 0,
payloadBufLen: 0,
payloadBuf: buf,
headerBuf: &bytes.Buffer{},
}
@ -95,37 +107,85 @@ func NewChunkerSized(conn io.Reader, chunkSize uint32) Chunker {
func (c *Chunker) Read(b []byte) (n int, err error) {
//fmt.Printf("chunker: c.remainingChunkBytes: %d len(b): %d\n", c.remainingChunkBytes, len(b))
if len(b) == 0 {
panic("unexpected empty output buffer")
}
if c.inEOF && c.finalHeaderBuffered && c.headerBuf.Len() == 0 { // all bufs empty and no more bytes to expect
return 0, io.EOF
}
n = 0
if c.remainingChunkBytes == 0 {
if c.headerBuf.Len() > 0 { // first drain the header buf
nh, err := c.headerBuf.Read(b[n:])
if nh > 0 {
n += nh
}
if nh == 0 || (err != nil && err != io.EOF) {
panic("unexpected behavior: in-memory buffer should not throw errors")
}
if c.headerBuf.Len() != 0 {
return n, nil // finish writing the header before we can proceed with payload
}
if c.finalHeaderBuffered {
// we just wrote the final header
return n, io.EOF
}
}
if c.remainingChunkBytes > 0 { // then drain the payload buf
npl := copy(b[n:], c.payloadBuf[c.payloadBufLen-c.remainingChunkBytes:c.payloadBufLen])
c.remainingChunkBytes -= npl
if c.remainingChunkBytes < 0 {
panic("unexpected behavior, copy() should not copy more than max(cap(), len())")
}
n += npl
}
if c.remainingChunkBytes == 0 && !c.inEOF { // fillup bufs
newPayloadLen, err := c.in.Read(c.payloadBuf)
if newPayloadLen == 0 {
return 0, io.EOF
} else if err != nil {
return newPayloadLen, err
if newPayloadLen > 0 {
c.payloadBufLen = newPayloadLen
c.remainingChunkBytes = newPayloadLen
}
c.remainingChunkBytes = newPayloadLen
// Write chunk header
c.headerBuf.Reset()
nextChunkLen := uint32(newPayloadLen)
headerLen := binary.Size(nextChunkLen)
err = binary.Write(c.headerBuf, ChunkHeaderByteOrder, nextChunkLen)
if err != nil {
if err == io.EOF {
c.inEOF = true
} else if err != nil {
return n, err
}
copy(b[0:headerLen], c.headerBuf.Bytes())
n += headerLen
c.ChunkCount++
if newPayloadLen == 0 { // apparently, this happens with some Readers
c.finalHeaderBuffered = true
}
// Fill header buf
{
c.headerBuf.Reset()
nextChunkLen := uint32(newPayloadLen)
err := binary.Write(c.headerBuf, ChunkHeaderByteOrder, nextChunkLen)
if err != nil {
panic("unexpected error, write to in-memory buffer should not throw error")
}
}
if c.headerBuf.Len() == 0 {
panic("unexpected empty header buf")
}
} else if c.remainingChunkBytes == 0 && c.inEOF && !c.finalHeaderBuffered {
c.headerBuf.Reset()
err := binary.Write(c.headerBuf, ChunkHeaderByteOrder, uint32(0))
if err != nil {
panic("unexpected error, write to in-memory buffer should not throw error [2]")
}
c.finalHeaderBuffered = true
}
remainingBuf := b[n:]
n2 := copy(remainingBuf, c.payloadBuf[:c.remainingChunkBytes])
//fmt.Printf("chunker: written: %d\n", n+int(n2))
c.remainingChunkBytes -= n2
return n + int(n2), err
return
}

73
util/chunking_test.go Normal file
View File

@ -0,0 +1,73 @@
package chunking
import (
"bytes"
"encoding/binary"
"github.com/stretchr/testify/assert"
"io"
"reflect"
"testing"
"testing/quick"
)
func TestUnchunker(t *testing.T) {
buf := bytes.Buffer{}
binary.Write(&buf, ChunkHeaderByteOrder, uint32(2))
buf.WriteByte(0xca)
buf.WriteByte(0xfe)
binary.Write(&buf, ChunkHeaderByteOrder, uint32(0))
buf.WriteByte(0xff) // sentinel, should not be read
un := NewUnchunker(&buf)
recv := bytes.Buffer{}
n, err := io.Copy(&recv, un)
assert.Nil(t, err)
assert.Equal(t, int64(2), n)
assert.Equal(t, []byte{0xca, 0xfe}, recv.Bytes())
}
func TestChunker(t *testing.T) {
buf := bytes.Buffer{}
buf.WriteByte(0xca)
buf.WriteByte(0xfe)
ch := NewChunker(&buf)
chunked := bytes.Buffer{}
n, err := io.Copy(&chunked, &ch)
assert.Nil(t, err)
assert.Equal(t, int64(4+2+4), n)
assert.Equal(t, []byte{0x2, 0x0, 0x0, 0x0, 0xca, 0xfe, 0x0, 0x0, 0x0, 0x0}, chunked.Bytes())
}
func TestUnchunkerUnchunksChunker(t *testing.T) {
f := func(b []byte) bool {
buf := bytes.NewBuffer(b)
ch := NewChunker(buf)
unch := NewUnchunker(&ch)
var tx bytes.Buffer
_, err := io.Copy(&tx, unch)
if err != nil {
return false
}
return reflect.DeepEqual(b, tx.Bytes())
}
cfg := quick.Config{
MaxCount: 3 * int(ChunkBufSize),
MaxCountScale: 2.0,
}
if err := quick.Check(f, &cfg); err != nil {
t.Error(err)
}
}