rpc: fix data race in timeoutconn

- `timeoutconn` handles state, yet calls to Read/Write make a copy of
  that state (non-pointer receiver) so any outbound calls will not have
  the state updated

- Even without the copy issue, the renew methods can in edge cases set a
  new deadline _after_ DisableTimeouts have been called, consider the
  following racy behavior:
    1. `renewReadDeadline` is called, checks `renewDeadlinesDisabled`
       (not disabled)
    2. `DisableTimeouts` is called, sets `renewDeadlinesDisabled`
    3. `DisableTimeouts` invokes `c.SetDeadline`
    4. `renewReadDeadline` invokes `c.SetReadDeadline`

To fix the above, the `Conn` receiver was made to be a pointer
everywhere and access to renewDeadlinesDisabled is now guarded
by an RWMutex instead of using atomics.

closes #415
This commit is contained in:
Mathias Fredriksson 2021-01-24 15:53:59 +02:00 committed by Christian Schwarz
parent 48be4032a2
commit d118bcc717
4 changed files with 41 additions and 21 deletions

View File

@ -47,7 +47,7 @@ func (f *FrameHeader) Unmarshal(buf []byte) {
type Conn struct { type Conn struct {
readMtx, writeMtx sync.Mutex readMtx, writeMtx sync.Mutex
nc timeoutconn.Conn nc *timeoutconn.Conn
readNextValid bool readNextValid bool
readNext FrameHeader readNext FrameHeader
nextReadErr error nextReadErr error
@ -55,7 +55,7 @@ type Conn struct {
shutdown shutdownFSM shutdown shutdownFSM
} }
func Wrap(nc timeoutconn.Conn) *Conn { func Wrap(nc *timeoutconn.Conn) *Conn {
return &Conn{ return &Conn{
nc: nc, nc: nc,
// ncBuf: bufio.NewReadWriter(bufio.NewReaderSize(nc, 1<<23), bufio.NewWriterSize(nc, 1<<23)), // ncBuf: bufio.NewReadWriter(bufio.NewReaderSize(nc, 1<<23), bufio.NewWriterSize(nc, 1<<23)),

View File

@ -11,7 +11,7 @@ import (
"errors" "errors"
"io" "io"
"net" "net"
"sync/atomic" "sync"
"syscall" "syscall"
"time" "time"
) )
@ -54,39 +54,60 @@ type Wire interface {
} }
type Conn struct { type Conn struct {
// immutable state
Wire Wire
renewDeadlinesDisabled int32 idleTimeout time.Duration
idleTimeout time.Duration
// mutable state (protected by mtx)
mtx sync.RWMutex
renewDeadlinesDisabled bool
} }
func Wrap(conn Wire, idleTimeout time.Duration) Conn { func Wrap(conn Wire, idleTimeout time.Duration) *Conn {
return Conn{Wire: conn, idleTimeout: idleTimeout} return &Conn{
Wire: conn,
idleTimeout: idleTimeout,
}
} }
// DisableTimeouts disables the idle timeout behavior provided by this package. // DisableTimeouts disables the idle timeout behavior provided by this package.
// Existing deadlines are cleared iff the call is the first call to this method. // Existing deadlines are cleared iff the call is the first call to this method
// or if the previous call produced an error.
func (c *Conn) DisableTimeouts() error { func (c *Conn) DisableTimeouts() error {
if atomic.CompareAndSwapInt32(&c.renewDeadlinesDisabled, 0, 1) { c.mtx.Lock()
return c.SetDeadline(time.Time{}) defer c.mtx.Unlock()
if c.renewDeadlinesDisabled {
return nil
} }
err := c.SetDeadline(time.Time{})
if err != nil {
return err
}
c.renewDeadlinesDisabled = true
return nil return nil
} }
func (c *Conn) renewReadDeadline() error { func (c *Conn) renewReadDeadline() error {
if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 { c.mtx.RLock()
defer c.mtx.RUnlock()
if c.renewDeadlinesDisabled {
return nil return nil
} }
return c.SetReadDeadline(time.Now().Add(c.idleTimeout)) return c.SetReadDeadline(time.Now().Add(c.idleTimeout))
} }
func (c *Conn) RenewWriteDeadline() error { func (c *Conn) RenewWriteDeadline() error {
if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 { c.mtx.RLock()
defer c.mtx.RUnlock()
if c.renewDeadlinesDisabled {
return nil return nil
} }
return c.SetWriteDeadline(time.Now().Add(c.idleTimeout)) return c.SetWriteDeadline(time.Now().Add(c.idleTimeout))
} }
func (c Conn) Read(p []byte) (n int, err error) { func (c *Conn) Read(p []byte) (n int, err error) {
n = 0 n = 0
err = nil err = nil
restart: restart:
@ -103,7 +124,7 @@ restart:
return n, err return n, err
} }
func (c Conn) Write(p []byte) (n int, err error) { func (c *Conn) Write(p []byte) (n int, err error) {
n = 0 n = 0
restart: restart:
if err := c.RenewWriteDeadline(); err != nil { if err := c.RenewWriteDeadline(); err != nil {
@ -123,7 +144,7 @@ restart:
// but is guaranteed to use the writev system call if the wrapped Wire // but is guaranteed to use the writev system call if the wrapped Wire
// support it. // support it.
// Note the Conn does not support writev through io.Copy(aConn, aNetBuffers). // Note the Conn does not support writev through io.Copy(aConn, aNetBuffers).
func (c Conn) WritevFull(bufs net.Buffers) (n int64, err error) { func (c *Conn) WritevFull(bufs net.Buffers) (n int64, err error) {
n = 0 n = 0
restart: restart:
if err := c.RenewWriteDeadline(); err != nil { if err := c.RenewWriteDeadline(); err != nil {
@ -163,12 +184,12 @@ var _ SyscallConner = (*net.TCPConn)(nil)
// If the connection returned io.EOF, the number of bytes written until // If the connection returned io.EOF, the number of bytes written until
// then + io.EOF is returned. This behavior is different to io.ReadFull // then + io.EOF is returned. This behavior is different to io.ReadFull
// which returns io.ErrUnexpectedEOF. // which returns io.ErrUnexpectedEOF.
func (c Conn) ReadvFull(buffers net.Buffers) (n int64, err error) { func (c *Conn) ReadvFull(buffers net.Buffers) (n int64, err error) {
return c.readv(buffers) return c.readv(buffers)
} }
// invoked by c.readv if readv system call cannot be used // invoked by c.readv if readv system call cannot be used
func (c Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) { func (c *Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) {
buffers := [][]byte(nbuffers) buffers := [][]byte(nbuffers)
for i := range buffers { for i := range buffers {
curBuf := buffers[i] curBuf := buffers[i]

View File

@ -4,7 +4,7 @@ package timeoutconn
import "net" import "net"
func (c Conn) readv(buffers net.Buffers) (n int64, err error) { func (c *Conn) readv(buffers net.Buffers) (n int64, err error) {
// Go does not expose the SYS_READV symbol for Solaris / Illumos - do they have it? // Go does not expose the SYS_READV symbol for Solaris / Illumos - do they have it?
// Anyhow, use the fallback // Anyhow, use the fallback
return c.readvFallback(buffers) return c.readvFallback(buffers)

View File

@ -29,8 +29,7 @@ func buildIovecs(buffers net.Buffers) (totalLen int64, vecs []syscall.Iovec) {
return totalLen, vecs return totalLen, vecs
} }
func (c Conn) readv(buffers net.Buffers) (n int64, err error) { func (c *Conn) readv(buffers net.Buffers) (n int64, err error) {
scc, ok := c.Wire.(SyscallConner) scc, ok := c.Wire.(SyscallConner)
if !ok { if !ok {
return c.readvFallback(buffers) return c.readvFallback(buffers)
@ -62,7 +61,7 @@ func (c Conn) readv(buffers net.Buffers) (n int64, err error) {
return n, nil return n, nil
} }
func (c Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n int64, err error) { func (c *Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n int64, err error) {
rawReadErr := rawConn.Read(func(fd uintptr) (done bool) { rawReadErr := rawConn.Read(func(fd uintptr) (done bool) {
// iovecs, n and err must not be shadowed! // iovecs, n and err must not be shadowed!