From 4539ccf79b4304391520ec44600639dbc558d9b6 Mon Sep 17 00:00:00 2001 From: Mathias Fredriksson Date: Sun, 24 Jan 2021 15:53:59 +0200 Subject: [PATCH] rpc: fix data race in timeoutconn - `timeoutconn` handles state, yet calls to Read/Write make a copy of that state (non-pointer receiver) so any outbound calls will not have the state updated - Even without the copy issue, the renew methods can in edge cases set a new deadline _after_ DisableTimeouts have been called, consider the following racy behavior: 1. `renewReadDeadline` is called, checks `renewDeadlinesDisabled` (not disabled) 2. `DisableTimeouts` is called, sets `renewDeadlinesDisabled` 3. `DisableTimeouts` invokes `c.SetDeadline` 4. `renewReadDeadline` invokes `c.SetReadDeadline` To fix the above, the `Conn` receiver was made to be a pointer everywhere and access to renewDeadlinesDisabled is now guarded by an RWMutex instead of using atomics. closes #415 --- rpc/dataconn/frameconn/frameconn.go | 4 +- rpc/dataconn/timeoutconn/timeoutconn.go | 51 +++++++++++++------ .../timeoutconn_readv_unsupported.go | 2 +- rpc/dataconn/timeoutconn/timoutconn_readv.go | 5 +- 4 files changed, 41 insertions(+), 21 deletions(-) diff --git a/rpc/dataconn/frameconn/frameconn.go b/rpc/dataconn/frameconn/frameconn.go index 281b7d1..ff26eeb 100644 --- a/rpc/dataconn/frameconn/frameconn.go +++ b/rpc/dataconn/frameconn/frameconn.go @@ -47,7 +47,7 @@ func (f *FrameHeader) Unmarshal(buf []byte) { type Conn struct { readMtx, writeMtx sync.Mutex - nc timeoutconn.Conn + nc *timeoutconn.Conn readNextValid bool readNext FrameHeader nextReadErr error @@ -55,7 +55,7 @@ type Conn struct { shutdown shutdownFSM } -func Wrap(nc timeoutconn.Conn) *Conn { +func Wrap(nc *timeoutconn.Conn) *Conn { return &Conn{ nc: nc, // ncBuf: bufio.NewReadWriter(bufio.NewReaderSize(nc, 1<<23), bufio.NewWriterSize(nc, 1<<23)), diff --git a/rpc/dataconn/timeoutconn/timeoutconn.go b/rpc/dataconn/timeoutconn/timeoutconn.go index ca7a979..1de9ede 100644 --- a/rpc/dataconn/timeoutconn/timeoutconn.go +++ b/rpc/dataconn/timeoutconn/timeoutconn.go @@ -11,7 +11,7 @@ import ( "errors" "io" "net" - "sync/atomic" + "sync" "syscall" "time" ) @@ -54,39 +54,60 @@ type Wire interface { } type Conn struct { + // immutable state + Wire - renewDeadlinesDisabled int32 - idleTimeout time.Duration + idleTimeout time.Duration + + // mutable state (protected by mtx) + + mtx sync.RWMutex + renewDeadlinesDisabled bool } -func Wrap(conn Wire, idleTimeout time.Duration) Conn { - return Conn{Wire: conn, idleTimeout: idleTimeout} +func Wrap(conn Wire, idleTimeout time.Duration) *Conn { + return &Conn{ + Wire: conn, + idleTimeout: idleTimeout, + } } // DisableTimeouts disables the idle timeout behavior provided by this package. -// Existing deadlines are cleared iff the call is the first call to this method. +// Existing deadlines are cleared iff the call is the first call to this method +// or if the previous call produced an error. func (c *Conn) DisableTimeouts() error { - if atomic.CompareAndSwapInt32(&c.renewDeadlinesDisabled, 0, 1) { - return c.SetDeadline(time.Time{}) + c.mtx.Lock() + defer c.mtx.Unlock() + if c.renewDeadlinesDisabled { + return nil } + err := c.SetDeadline(time.Time{}) + if err != nil { + return err + } + c.renewDeadlinesDisabled = true return nil } func (c *Conn) renewReadDeadline() error { - if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 { + c.mtx.RLock() + defer c.mtx.RUnlock() + if c.renewDeadlinesDisabled { return nil } return c.SetReadDeadline(time.Now().Add(c.idleTimeout)) } func (c *Conn) RenewWriteDeadline() error { - if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 { + c.mtx.RLock() + defer c.mtx.RUnlock() + if c.renewDeadlinesDisabled { return nil } return c.SetWriteDeadline(time.Now().Add(c.idleTimeout)) } -func (c Conn) Read(p []byte) (n int, err error) { +func (c *Conn) Read(p []byte) (n int, err error) { n = 0 err = nil restart: @@ -103,7 +124,7 @@ restart: return n, err } -func (c Conn) Write(p []byte) (n int, err error) { +func (c *Conn) Write(p []byte) (n int, err error) { n = 0 restart: if err := c.RenewWriteDeadline(); err != nil { @@ -123,7 +144,7 @@ restart: // but is guaranteed to use the writev system call if the wrapped Wire // support it. // Note the Conn does not support writev through io.Copy(aConn, aNetBuffers). -func (c Conn) WritevFull(bufs net.Buffers) (n int64, err error) { +func (c *Conn) WritevFull(bufs net.Buffers) (n int64, err error) { n = 0 restart: if err := c.RenewWriteDeadline(); err != nil { @@ -163,12 +184,12 @@ var _ SyscallConner = (*net.TCPConn)(nil) // If the connection returned io.EOF, the number of bytes written until // then + io.EOF is returned. This behavior is different to io.ReadFull // which returns io.ErrUnexpectedEOF. -func (c Conn) ReadvFull(buffers net.Buffers) (n int64, err error) { +func (c *Conn) ReadvFull(buffers net.Buffers) (n int64, err error) { return c.readv(buffers) } // invoked by c.readv if readv system call cannot be used -func (c Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) { +func (c *Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) { buffers := [][]byte(nbuffers) for i := range buffers { curBuf := buffers[i] diff --git a/rpc/dataconn/timeoutconn/timeoutconn_readv_unsupported.go b/rpc/dataconn/timeoutconn/timeoutconn_readv_unsupported.go index 5cc169b..7b15afd 100644 --- a/rpc/dataconn/timeoutconn/timeoutconn_readv_unsupported.go +++ b/rpc/dataconn/timeoutconn/timeoutconn_readv_unsupported.go @@ -4,7 +4,7 @@ package timeoutconn import "net" -func (c Conn) readv(buffers net.Buffers) (n int64, err error) { +func (c *Conn) readv(buffers net.Buffers) (n int64, err error) { // Go does not expose the SYS_READV symbol for Solaris / Illumos - do they have it? // Anyhow, use the fallback return c.readvFallback(buffers) diff --git a/rpc/dataconn/timeoutconn/timoutconn_readv.go b/rpc/dataconn/timeoutconn/timoutconn_readv.go index c940573..b2ba354 100644 --- a/rpc/dataconn/timeoutconn/timoutconn_readv.go +++ b/rpc/dataconn/timeoutconn/timoutconn_readv.go @@ -29,8 +29,7 @@ func buildIovecs(buffers net.Buffers) (totalLen int64, vecs []syscall.Iovec) { return totalLen, vecs } -func (c Conn) readv(buffers net.Buffers) (n int64, err error) { - +func (c *Conn) readv(buffers net.Buffers) (n int64, err error) { scc, ok := c.Wire.(SyscallConner) if !ok { return c.readvFallback(buffers) @@ -62,7 +61,7 @@ func (c Conn) readv(buffers net.Buffers) (n int64, err error) { return n, nil } -func (c Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n int64, err error) { +func (c *Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n int64, err error) { rawReadErr := rawConn.Read(func(fd uintptr) (done bool) { // iovecs, n and err must not be shadowed!