zrepl/internal/rpc/dataconn/timeoutconn/timeoutconn.go
2024-10-18 19:21:17 +02:00

211 lines
6.4 KiB
Go

// package timeoutconn wraps a Wire to provide idle timeouts
// based on Set{Read,Write}Deadline.
// Additionally, it exports abstractions for vectored I/O.
package timeoutconn
// NOTE
// Readv and Writev are not split-off into a separate package
// because we use raw syscalls, bypassing Conn's Read / Write methods.
import (
"errors"
"io"
"net"
"sync"
"syscall"
"time"
)
type Wire interface {
net.Conn
// A call to CloseWrite indicates that no further Write calls will be made to Wire.
// The implementation must return an error in case of Write calls after CloseWrite.
// On the peer's side, after it read all data written to Wire prior to the call to
// CloseWrite on our side, the peer's Read calls must return io.EOF.
// CloseWrite must not affect the read-direction of Wire: specifically, the
// peer must continue to be able to send, and our side must continue be
// able to receive data over Wire.
//
// Note that CloseWrite may (and most likely will) return sooner than the
// peer having received all data written to Wire prior to CloseWrite.
// Note further that buffering happening in the network stacks on either side
// mandates an explicit acknowledgement from the peer that the connection may
// be fully shut down: If we call Close without such acknowledgement, any data
// from peer to us that was already in flight may cause connection resets to
// be sent from us to the peer via the specific transport protocol. Those
// resets (e.g. RST frames) may erase all connection context on the peer,
// including data in its receive buffers. Thus, those resets are in race with
// a) transmission of data written prior to CloseWrite and
// b) the peer application reading from those buffers.
//
// The WaitForPeerClose method can be used to wait for connection termination,
// iff the implementation supports it. If it does not, the only reliable way
// to wait for a peer to have read all data from Wire (until io.EOF), is to
// expect it to close the wire at that point as well, and to drain Wire until
// we also read io.EOF.
CloseWrite() error
// Wait for the peer to close the connection.
// No data that could otherwise be Read is lost as a consequence of this call.
// The use case for this API is abortive connection shutdown.
// To provide any value over draining Wire using io.Read, an implementation
// will likely use out-of-band messaging mechanisms.
// TODO WaitForPeerClose() (supported bool, err error)
}
type Conn struct {
// immutable state
Wire
idleTimeout time.Duration
// mutable state (protected by mtx)
mtx sync.RWMutex
renewDeadlinesDisabled bool
}
func Wrap(conn Wire, idleTimeout time.Duration) *Conn {
return &Conn{
Wire: conn,
idleTimeout: idleTimeout,
}
}
// DisableTimeouts disables the idle timeout behavior provided by this package.
// Existing deadlines are cleared iff the call is the first call to this method
// or if the previous call produced an error.
func (c *Conn) DisableTimeouts() error {
c.mtx.Lock()
defer c.mtx.Unlock()
if c.renewDeadlinesDisabled {
return nil
}
err := c.SetDeadline(time.Time{})
if err != nil {
return err
}
c.renewDeadlinesDisabled = true
return nil
}
func (c *Conn) renewReadDeadline() error {
c.mtx.RLock()
defer c.mtx.RUnlock()
if c.renewDeadlinesDisabled {
return nil
}
return c.SetReadDeadline(time.Now().Add(c.idleTimeout))
}
func (c *Conn) RenewWriteDeadline() error {
c.mtx.RLock()
defer c.mtx.RUnlock()
if c.renewDeadlinesDisabled {
return nil
}
return c.SetWriteDeadline(time.Now().Add(c.idleTimeout))
}
func (c *Conn) Read(p []byte) (n int, _ error) {
n = 0
restart:
if err := c.renewReadDeadline(); err != nil {
return n, err
}
var nCurRead int
nCurRead, err := c.Wire.Read(p[n:])
n += nCurRead
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurRead > 0 {
goto restart
}
return n, err
}
func (c *Conn) Write(p []byte) (n int, _ error) {
n = 0
restart:
if err := c.RenewWriteDeadline(); err != nil {
return n, err
}
var nCurWrite int
nCurWrite, err := c.Wire.Write(p[n:])
n += nCurWrite
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurWrite > 0 {
goto restart
}
return n, err
}
// Writes the given buffers to Conn, following the semantics of io.Copy,
// but is guaranteed to use the writev system call if the wrapped Wire
// support it.
// Note the Conn does not support writev through io.Copy(aConn, aNetBuffers).
func (c *Conn) WritevFull(bufs net.Buffers) (n int64, _ error) {
n = 0
restart:
if err := c.RenewWriteDeadline(); err != nil {
return n, err
}
var nCurWrite int64
nCurWrite, err := io.Copy(c.Wire, &bufs)
n += nCurWrite
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && nCurWrite > 0 {
goto restart
}
return n, err
}
var SyscallConnNotSupported = errors.New("SyscallConn not supported")
// The interface that must be implemented for vectored I/O support.
// If the wrapped Wire does not implement it, a less efficient
// fallback implementation is used.
// Rest assured that Go's *net.TCPConn implements this interface.
type SyscallConner interface {
// The sentinel error value SyscallConnNotSupported can be returned
// if the support for SyscallConn depends on runtime conditions and
// that runtime condition is not met.
SyscallConn() (syscall.RawConn, error)
}
var _ SyscallConner = (*net.TCPConn)(nil)
// Reads the given buffers full:
// Think of io.ReadvFull, but for net.Buffers + using the readv syscall.
//
// If the underlying Wire is not a SyscallConner, a fallback
// implementation based on repeated Conn.Read invocations is used.
//
// If the connection returned io.EOF, the number of bytes written until
// then + io.EOF is returned. This behavior is different to io.ReadFull
// which returns io.ErrUnexpectedEOF.
func (c *Conn) ReadvFull(buffers net.Buffers) (n int64, err error) {
return c.readv(buffers)
}
// invoked by c.readv if readv system call cannot be used
func (c *Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) {
buffers := [][]byte(nbuffers)
for i := range buffers {
curBuf := buffers[i]
inner:
for len(curBuf) > 0 {
if err := c.renewReadDeadline(); err != nil {
return n, err
}
var oneN int
oneN, err = c.Read(curBuf[:]) // WE WANT NO SHADOWING
curBuf = curBuf[oneN:]
n += int64(oneN)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() && oneN > 0 {
continue inner
}
return n, err
}
}
}
return n, nil
}