rpc: fix data race in timeoutconn

- `timeoutconn` handles state, yet calls to Read/Write make a copy of
  that state (non-pointer receiver) so any outbound calls will not have
  the state updated

- Even without the copy issue, the renew methods can in edge cases set a
  new deadline _after_ DisableTimeouts have been called, consider the
  following racy behavior:
    1. `renewReadDeadline` is called, checks `renewDeadlinesDisabled`
       (not disabled)
    2. `DisableTimeouts` is called, sets `renewDeadlinesDisabled`
    3. `DisableTimeouts` invokes `c.SetDeadline`
    4. `renewReadDeadline` invokes `c.SetReadDeadline`

To fix the above, the `Conn` receiver was made to be a pointer
everywhere and access to renewDeadlinesDisabled is now guarded
by an RWMutex instead of using atomics.

closes #415
This commit is contained in:
Mathias Fredriksson 2021-01-24 15:53:59 +02:00 committed by Christian Schwarz
parent 161ab1fee6
commit 4539ccf79b
4 changed files with 41 additions and 21 deletions

View File

@ -47,7 +47,7 @@ func (f *FrameHeader) Unmarshal(buf []byte) {
type Conn struct {
readMtx, writeMtx sync.Mutex
nc timeoutconn.Conn
nc *timeoutconn.Conn
readNextValid bool
readNext FrameHeader
nextReadErr error
@ -55,7 +55,7 @@ type Conn struct {
shutdown shutdownFSM
}
func Wrap(nc timeoutconn.Conn) *Conn {
func Wrap(nc *timeoutconn.Conn) *Conn {
return &Conn{
nc: nc,
// ncBuf: bufio.NewReadWriter(bufio.NewReaderSize(nc, 1<<23), bufio.NewWriterSize(nc, 1<<23)),

View File

@ -11,7 +11,7 @@ import (
"errors"
"io"
"net"
"sync/atomic"
"sync"
"syscall"
"time"
)
@ -54,39 +54,60 @@ type Wire interface {
}
type Conn struct {
// immutable state
Wire
renewDeadlinesDisabled int32
idleTimeout time.Duration
idleTimeout time.Duration
// mutable state (protected by mtx)
mtx sync.RWMutex
renewDeadlinesDisabled bool
}
func Wrap(conn Wire, idleTimeout time.Duration) Conn {
return Conn{Wire: conn, idleTimeout: idleTimeout}
func Wrap(conn Wire, idleTimeout time.Duration) *Conn {
return &Conn{
Wire: conn,
idleTimeout: idleTimeout,
}
}
// DisableTimeouts disables the idle timeout behavior provided by this package.
// Existing deadlines are cleared iff the call is the first call to this method.
// Existing deadlines are cleared iff the call is the first call to this method
// or if the previous call produced an error.
func (c *Conn) DisableTimeouts() error {
if atomic.CompareAndSwapInt32(&c.renewDeadlinesDisabled, 0, 1) {
return c.SetDeadline(time.Time{})
c.mtx.Lock()
defer c.mtx.Unlock()
if c.renewDeadlinesDisabled {
return nil
}
err := c.SetDeadline(time.Time{})
if err != nil {
return err
}
c.renewDeadlinesDisabled = true
return nil
}
func (c *Conn) renewReadDeadline() error {
if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 {
c.mtx.RLock()
defer c.mtx.RUnlock()
if c.renewDeadlinesDisabled {
return nil
}
return c.SetReadDeadline(time.Now().Add(c.idleTimeout))
}
func (c *Conn) RenewWriteDeadline() error {
if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 {
c.mtx.RLock()
defer c.mtx.RUnlock()
if c.renewDeadlinesDisabled {
return nil
}
return c.SetWriteDeadline(time.Now().Add(c.idleTimeout))
}
func (c Conn) Read(p []byte) (n int, err error) {
func (c *Conn) Read(p []byte) (n int, err error) {
n = 0
err = nil
restart:
@ -103,7 +124,7 @@ restart:
return n, err
}
func (c Conn) Write(p []byte) (n int, err error) {
func (c *Conn) Write(p []byte) (n int, err error) {
n = 0
restart:
if err := c.RenewWriteDeadline(); err != nil {
@ -123,7 +144,7 @@ restart:
// but is guaranteed to use the writev system call if the wrapped Wire
// support it.
// Note the Conn does not support writev through io.Copy(aConn, aNetBuffers).
func (c Conn) WritevFull(bufs net.Buffers) (n int64, err error) {
func (c *Conn) WritevFull(bufs net.Buffers) (n int64, err error) {
n = 0
restart:
if err := c.RenewWriteDeadline(); err != nil {
@ -163,12 +184,12 @@ var _ SyscallConner = (*net.TCPConn)(nil)
// If the connection returned io.EOF, the number of bytes written until
// then + io.EOF is returned. This behavior is different to io.ReadFull
// which returns io.ErrUnexpectedEOF.
func (c Conn) ReadvFull(buffers net.Buffers) (n int64, err error) {
func (c *Conn) ReadvFull(buffers net.Buffers) (n int64, err error) {
return c.readv(buffers)
}
// invoked by c.readv if readv system call cannot be used
func (c Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) {
func (c *Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) {
buffers := [][]byte(nbuffers)
for i := range buffers {
curBuf := buffers[i]

View File

@ -4,7 +4,7 @@ package timeoutconn
import "net"
func (c Conn) readv(buffers net.Buffers) (n int64, err error) {
func (c *Conn) readv(buffers net.Buffers) (n int64, err error) {
// Go does not expose the SYS_READV symbol for Solaris / Illumos - do they have it?
// Anyhow, use the fallback
return c.readvFallback(buffers)

View File

@ -29,8 +29,7 @@ func buildIovecs(buffers net.Buffers) (totalLen int64, vecs []syscall.Iovec) {
return totalLen, vecs
}
func (c Conn) readv(buffers net.Buffers) (n int64, err error) {
func (c *Conn) readv(buffers net.Buffers) (n int64, err error) {
scc, ok := c.Wire.(SyscallConner)
if !ok {
return c.readvFallback(buffers)
@ -62,7 +61,7 @@ func (c Conn) readv(buffers net.Buffers) (n int64, err error) {
return n, nil
}
func (c Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n int64, err error) {
func (c *Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n int64, err error) {
rawReadErr := rawConn.Read(func(fd uintptr) (done bool) {
// iovecs, n and err must not be shadowed!