mirror of
https://github.com/zrepl/zrepl.git
synced 2025-01-24 15:19:36 +01:00
254a292362
refs #184
317 lines
9.7 KiB
Go
317 lines
9.7 KiB
Go
// package timeoutconn wraps a Wire to provide idle timeouts
|
|
// based on Set{Read,Write}Deadline.
|
|
// Additionally, it exports abstractions for vectored I/O.
|
|
package timeoutconn
|
|
|
|
// NOTE
|
|
// Readv and Writev are not split-off into a separate package
|
|
// because we use raw syscalls, bypassing Conn's Read / Write methods.
|
|
|
|
import (
|
|
"errors"
|
|
"io"
|
|
"net"
|
|
"sync/atomic"
|
|
"syscall"
|
|
"time"
|
|
"unsafe"
|
|
)
|
|
|
|
type Wire interface {
|
|
net.Conn
|
|
// A call to CloseWrite indicates that no further Write calls will be made to Wire.
|
|
// The implementation must return an error in case of Write calls after CloseWrite.
|
|
// On the peer's side, after it read all data written to Wire prior to the call to
|
|
// CloseWrite on our side, the peer's Read calls must return io.EOF.
|
|
// CloseWrite must not affect the read-direction of Wire: specifically, the
|
|
// peer must continue to be able to send, and our side must continue be
|
|
// able to receive data over Wire.
|
|
//
|
|
// Note that CloseWrite may (and most likely will) return sooner than the
|
|
// peer having received all data written to Wire prior to CloseWrite.
|
|
// Note further that buffering happening in the network stacks on either side
|
|
// mandates an explicit acknowledgement from the peer that the connection may
|
|
// be fully shut down: If we call Close without such acknowledgement, any data
|
|
// from peer to us that was already in flight may cause connection resets to
|
|
// be sent from us to the peer via the specific transport protocol. Those
|
|
// resets (e.g. RST frames) may erase all connection context on the peer,
|
|
// including data in its receive buffers. Thus, those resets are in race with
|
|
// a) transmission of data written prior to CloseWrite and
|
|
// b) the peer application reading from those buffers.
|
|
//
|
|
// The WaitForPeerClose method can be used to wait for connection termination,
|
|
// iff the implementation supports it. If it does not, the only reliable way
|
|
// to wait for a peer to have read all data from Wire (until io.EOF), is to
|
|
// expect it to close the wire at that point as well, and to drain Wire until
|
|
// we also read io.EOF.
|
|
CloseWrite() error
|
|
|
|
// Wait for the peer to close the connection.
|
|
// No data that could otherwise be Read is lost as a consequence of this call.
|
|
// The use case for this API is abortive connection shutdown.
|
|
// To provide any value over draining Wire using io.Read, an implementation
|
|
// will likely use out-of-bounds messaging mechanisms.
|
|
// TODO WaitForPeerClose() (supported bool, err error)
|
|
}
|
|
|
|
type Conn struct {
|
|
Wire
|
|
renewDeadlinesDisabled int32
|
|
idleTimeout time.Duration
|
|
}
|
|
|
|
func Wrap(conn Wire, idleTimeout time.Duration) Conn {
|
|
return Conn{Wire: conn, idleTimeout: idleTimeout}
|
|
}
|
|
|
|
// DisableTimeouts disables the idle timeout behavior provided by this package.
|
|
// Existing deadlines are cleared iff the call is the first call to this method.
|
|
func (c *Conn) DisableTimeouts() error {
|
|
if atomic.CompareAndSwapInt32(&c.renewDeadlinesDisabled, 0, 1) {
|
|
return c.SetDeadline(time.Time{})
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *Conn) renewReadDeadline() error {
|
|
if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 {
|
|
return nil
|
|
}
|
|
return c.SetReadDeadline(time.Now().Add(c.idleTimeout))
|
|
}
|
|
|
|
func (c *Conn) RenewWriteDeadline() error {
|
|
if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 {
|
|
return nil
|
|
}
|
|
return c.SetWriteDeadline(time.Now().Add(c.idleTimeout))
|
|
}
|
|
|
|
func (c Conn) Read(p []byte) (n int, err error) {
|
|
n = 0
|
|
err = nil
|
|
restart:
|
|
if err := c.renewReadDeadline(); err != nil {
|
|
return n, err
|
|
}
|
|
var nCurRead int
|
|
nCurRead, err = c.Wire.Read(p[n:])
|
|
n += nCurRead
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurRead > 0 {
|
|
err = nil
|
|
goto restart
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
func (c Conn) Write(p []byte) (n int, err error) {
|
|
n = 0
|
|
restart:
|
|
if err := c.RenewWriteDeadline(); err != nil {
|
|
return n, err
|
|
}
|
|
var nCurWrite int
|
|
nCurWrite, err = c.Wire.Write(p[n:])
|
|
n += nCurWrite
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurWrite > 0 {
|
|
err = nil
|
|
goto restart
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
// Writes the given buffers to Conn, following the sematincs of io.Copy,
|
|
// but is guaranteed to use the writev system call if the wrapped Wire
|
|
// support it.
|
|
// Note the Conn does not support writev through io.Copy(aConn, aNetBuffers).
|
|
func (c Conn) WritevFull(bufs net.Buffers) (n int64, err error) {
|
|
n = 0
|
|
restart:
|
|
if err := c.RenewWriteDeadline(); err != nil {
|
|
return n, err
|
|
}
|
|
var nCurWrite int64
|
|
nCurWrite, err = io.Copy(c.Wire, &bufs)
|
|
n += nCurWrite
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurWrite > 0 {
|
|
err = nil
|
|
goto restart
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
var SyscallConnNotSupported = errors.New("SyscallConn not supported")
|
|
|
|
// The interface that must be implemented for vectored I/O support.
|
|
// If the wrapped Wire does not implement it, a less efficient
|
|
// fallback implementation is used.
|
|
// Rest assured that Go's *net.TCPConn implements this interface.
|
|
type SyscallConner interface {
|
|
// The sentinel error value SyscallConnNotSupported can be returned
|
|
// if the support for SyscallConn depends on runtime conditions and
|
|
// that runtime condition is not met.
|
|
SyscallConn() (syscall.RawConn, error)
|
|
}
|
|
|
|
var _ SyscallConner = (*net.TCPConn)(nil)
|
|
|
|
func buildIovecs(buffers net.Buffers) (totalLen int64, vecs []syscall.Iovec) {
|
|
vecs = make([]syscall.Iovec, 0, len(buffers))
|
|
for i := range buffers {
|
|
totalLen += int64(len(buffers[i]))
|
|
if len(buffers[i]) == 0 {
|
|
continue
|
|
}
|
|
|
|
v := syscall.Iovec{
|
|
Base: &buffers[i][0],
|
|
}
|
|
// syscall.Iovec.Len has platform-dependent size, thus use SetLen
|
|
v.SetLen(len(buffers[i]))
|
|
|
|
vecs = append(vecs, v)
|
|
}
|
|
return totalLen, vecs
|
|
}
|
|
|
|
// Reads the given buffers full:
|
|
// Think of io.ReadvFull, but for net.Buffers + using the readv syscall.
|
|
//
|
|
// If the underlying Wire is not a SyscallConner, a fallback
|
|
// ipmlementation based on repeated Conn.Read invocations is used.
|
|
//
|
|
// If the connection returned io.EOF, the number of bytes up ritten until
|
|
// then + io.EOF is returned. This behavior is different to io.ReadFull
|
|
// which returns io.ErrUnexpectedEOF.
|
|
func (c Conn) ReadvFull(buffers net.Buffers) (n int64, err error) {
|
|
totalLen, iovecs := buildIovecs(buffers)
|
|
if debugReadvNoShortReadsAssertEnable {
|
|
defer debugReadvNoShortReadsAssert(totalLen, n, err)
|
|
}
|
|
scc, ok := c.Wire.(SyscallConner)
|
|
if !ok {
|
|
return c.readvFallback(buffers)
|
|
}
|
|
raw, err := scc.SyscallConn()
|
|
if err == SyscallConnNotSupported {
|
|
return c.readvFallback(buffers)
|
|
}
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
n, err = c.readv(raw, iovecs)
|
|
return
|
|
}
|
|
|
|
func (c Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) {
|
|
buffers := [][]byte(nbuffers)
|
|
for i := range buffers {
|
|
curBuf := buffers[i]
|
|
inner:
|
|
for len(curBuf) > 0 {
|
|
if err := c.renewReadDeadline(); err != nil {
|
|
return n, err
|
|
}
|
|
var oneN int
|
|
oneN, err = c.Read(curBuf[:]) // WE WANT NO SHADOWING
|
|
curBuf = curBuf[oneN:]
|
|
n += int64(oneN)
|
|
if err != nil {
|
|
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && oneN > 0 {
|
|
continue inner
|
|
}
|
|
return n, err
|
|
}
|
|
}
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
func (c Conn) readv(rawConn syscall.RawConn, iovecs []syscall.Iovec) (n int64, err error) {
|
|
for len(iovecs) > 0 {
|
|
if err := c.renewReadDeadline(); err != nil {
|
|
return n, err
|
|
}
|
|
oneN, oneErr := c.doOneReadv(rawConn, &iovecs)
|
|
n += oneN
|
|
if netErr, ok := oneErr.(net.Error); ok && netErr.Timeout() && oneN > 0 { // TODO likely not working
|
|
continue
|
|
} else if oneErr == nil && oneN > 0 {
|
|
continue
|
|
} else {
|
|
return n, oneErr
|
|
}
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
func (c Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n int64, err error) {
|
|
rawReadErr := rawConn.Read(func(fd uintptr) (done bool) {
|
|
// iovecs, n and err must not be shadowed!
|
|
|
|
// NOTE: unsafe.Pointer safety rules
|
|
// https://tip.golang.org/pkg/unsafe/#Pointer
|
|
//
|
|
// (4) Conversion of a Pointer to a uintptr when calling syscall.Syscall.
|
|
// ...
|
|
// uintptr() conversions must appear within the syscall.Syscall argument list.
|
|
// (even though we are not the escape analysis Likely not )
|
|
thisReadN, _, errno := syscall.Syscall(
|
|
syscall.SYS_READV,
|
|
fd,
|
|
uintptr(unsafe.Pointer(&(*iovecs)[0])),
|
|
uintptr(len(*iovecs)),
|
|
)
|
|
if thisReadN == ^uintptr(0) {
|
|
if errno == syscall.EAGAIN {
|
|
return false
|
|
}
|
|
err = syscall.Errno(errno)
|
|
return true
|
|
}
|
|
if int(thisReadN) < 0 {
|
|
panic("unexpected return value")
|
|
}
|
|
n += int64(thisReadN) // TODO check overflow
|
|
|
|
// shift iovecs forward
|
|
for left := int(thisReadN); left > 0; {
|
|
// conversion to uint does not change value, see TestIovecLenFieldIsMachineUint, and left > 0
|
|
thisIovecConsumedCompletely := uint((*iovecs)[0].Len) <= uint(left)
|
|
if thisIovecConsumedCompletely {
|
|
// Update left, cannot go below 0 due to
|
|
// a) definition of thisIovecConsumedCompletely
|
|
// b) left > 0 due to loop invariant
|
|
// Convertion .Len to int64 is thus also safe now, because it is < left < INT_MAX
|
|
left -= int((*iovecs)[0].Len)
|
|
*iovecs = (*iovecs)[1:]
|
|
} else {
|
|
// trim this iovec to remaining length
|
|
|
|
// NOTE: unsafe.Pointer safety rules
|
|
// https://tip.golang.org/pkg/unsafe/#Pointer
|
|
// (3) Conversion of a Pointer to a uintptr and back, with arithmetic.
|
|
// ...
|
|
// Note that both conversions must appear in the same expression,
|
|
// with only the intervening arithmetic between them:
|
|
(*iovecs)[0].Base = (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer((*iovecs)[0].Base)) + uintptr(left)))
|
|
curVecNewLength := uint((*iovecs)[0].Len) - uint(left) // casts to uint do not change value
|
|
(*iovecs)[0].SetLen(int(curVecNewLength)) // int and uint have the same size, no change of value
|
|
|
|
break // inner
|
|
}
|
|
}
|
|
if thisReadN == 0 {
|
|
err = io.EOF
|
|
return true
|
|
}
|
|
return true
|
|
})
|
|
|
|
if rawReadErr != nil {
|
|
err = rawReadErr
|
|
}
|
|
|
|
return n, err
|
|
}
|