zrepl/rpc/dataconn/timeoutconn/internal/wireevaluator/wireevaluator_closewrite.go

108 lines
2.2 KiB
Go
Raw Permalink Normal View History

package main
import (
"bytes"
"io"
"io/ioutil"
"log"
"github.com/zrepl/zrepl/transport"
)
type CloseWriteMode uint
const (
CloseWriteClientSide CloseWriteMode = 1 + iota
CloseWriteServerSide
)
type CloseWrite struct {
mode CloseWriteMode
}
// sent repeatedly
var closeWriteTestSendData = bytes.Repeat([]byte{0x23, 0x42}, 1<<24)
var closeWriteErrorMsg = []byte{0xb, 0xa, 0xd, 0xf, 0x0, 0x0, 0xd}
func (m CloseWrite) Client(wire transport.Wire) {
switch m.mode {
case CloseWriteClientSide:
m.receiver(wire)
case CloseWriteServerSide:
m.sender(wire)
default:
panic(m.mode)
}
}
func (m CloseWrite) Server(wire transport.Wire) {
switch m.mode {
case CloseWriteClientSide:
m.sender(wire)
case CloseWriteServerSide:
m.receiver(wire)
default:
panic(m.mode)
}
}
func (CloseWrite) sender(wire transport.Wire) {
defer func() {
closeErr := wire.Close()
log.Printf("closeErr=%T %s", closeErr, closeErr)
}()
writeDone := make(chan struct{}, 1)
go func() {
close(writeDone)
for {
_, err := wire.Write(closeWriteTestSendData)
if err != nil {
return
}
}
}()
defer func() {
<-writeDone
}()
var respBuf bytes.Buffer
_, err := io.Copy(&respBuf, wire)
if err != nil {
log.Fatalf("should have received io.EOF, which is masked by io.Copy, got: %s", err)
}
if !bytes.Equal(respBuf.Bytes(), closeWriteErrorMsg) {
log.Fatalf("did not receive error message, got response with len %v:\n%v", respBuf.Len(), respBuf.Bytes())
}
}
func (CloseWrite) receiver(wire transport.Wire) {
// consume half the test data, then detect an error, send it and CloseWrite
2019-03-22 19:41:12 +01:00
r := io.LimitReader(wire, int64(5*len(closeWriteTestSendData)/3))
_, err := io.Copy(ioutil.Discard, r)
noerror(err)
var errBuf bytes.Buffer
errBuf.Write(closeWriteErrorMsg)
_, err = io.Copy(wire, &errBuf)
noerror(err)
err = wire.CloseWrite()
noerror(err)
// drain wire, as documented in transport.Wire, this is the only way we know the client closed the conn
_, err = io.Copy(ioutil.Discard, wire)
if err != nil {
// io.Copy masks io.EOF to nil, and we expect io.EOF from the client's Close() call
log.Panicf("unexpected error returned from reading conn: %s", err)
}
2019-03-22 19:41:12 +01:00
closeErr := wire.Close()
log.Printf("closeErr=%T %s", closeErr, closeErr)
}