mirror of
https://github.com/zrepl/zrepl.git
synced 2024-11-25 01:44:43 +01:00
rpc: fix data race in timeoutconn
- `timeoutconn` handles state, yet calls to Read/Write make a copy of that state (non-pointer receiver) so any outbound calls will not have the state updated - Even without the copy issue, the renew methods can in edge cases set a new deadline _after_ DisableTimeouts have been called, consider the following racy behavior: 1. `renewReadDeadline` is called, checks `renewDeadlinesDisabled` (not disabled) 2. `DisableTimeouts` is called, sets `renewDeadlinesDisabled` 3. `DisableTimeouts` invokes `c.SetDeadline` 4. `renewReadDeadline` invokes `c.SetReadDeadline` To fix the above, the `Conn` receiver was made to be a pointer everywhere and access to renewDeadlinesDisabled is now guarded by an RWMutex instead of using atomics. closes #415
This commit is contained in:
parent
48be4032a2
commit
d118bcc717
@ -47,7 +47,7 @@ func (f *FrameHeader) Unmarshal(buf []byte) {
|
||||
|
||||
type Conn struct {
|
||||
readMtx, writeMtx sync.Mutex
|
||||
nc timeoutconn.Conn
|
||||
nc *timeoutconn.Conn
|
||||
readNextValid bool
|
||||
readNext FrameHeader
|
||||
nextReadErr error
|
||||
@ -55,7 +55,7 @@ type Conn struct {
|
||||
shutdown shutdownFSM
|
||||
}
|
||||
|
||||
func Wrap(nc timeoutconn.Conn) *Conn {
|
||||
func Wrap(nc *timeoutconn.Conn) *Conn {
|
||||
return &Conn{
|
||||
nc: nc,
|
||||
// ncBuf: bufio.NewReadWriter(bufio.NewReaderSize(nc, 1<<23), bufio.NewWriterSize(nc, 1<<23)),
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
@ -54,39 +54,60 @@ type Wire interface {
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
// immutable state
|
||||
|
||||
Wire
|
||||
renewDeadlinesDisabled int32
|
||||
idleTimeout time.Duration
|
||||
idleTimeout time.Duration
|
||||
|
||||
// mutable state (protected by mtx)
|
||||
|
||||
mtx sync.RWMutex
|
||||
renewDeadlinesDisabled bool
|
||||
}
|
||||
|
||||
func Wrap(conn Wire, idleTimeout time.Duration) Conn {
|
||||
return Conn{Wire: conn, idleTimeout: idleTimeout}
|
||||
func Wrap(conn Wire, idleTimeout time.Duration) *Conn {
|
||||
return &Conn{
|
||||
Wire: conn,
|
||||
idleTimeout: idleTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// DisableTimeouts disables the idle timeout behavior provided by this package.
|
||||
// Existing deadlines are cleared iff the call is the first call to this method.
|
||||
// Existing deadlines are cleared iff the call is the first call to this method
|
||||
// or if the previous call produced an error.
|
||||
func (c *Conn) DisableTimeouts() error {
|
||||
if atomic.CompareAndSwapInt32(&c.renewDeadlinesDisabled, 0, 1) {
|
||||
return c.SetDeadline(time.Time{})
|
||||
c.mtx.Lock()
|
||||
defer c.mtx.Unlock()
|
||||
if c.renewDeadlinesDisabled {
|
||||
return nil
|
||||
}
|
||||
err := c.SetDeadline(time.Time{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.renewDeadlinesDisabled = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Conn) renewReadDeadline() error {
|
||||
if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 {
|
||||
c.mtx.RLock()
|
||||
defer c.mtx.RUnlock()
|
||||
if c.renewDeadlinesDisabled {
|
||||
return nil
|
||||
}
|
||||
return c.SetReadDeadline(time.Now().Add(c.idleTimeout))
|
||||
}
|
||||
|
||||
func (c *Conn) RenewWriteDeadline() error {
|
||||
if atomic.LoadInt32(&c.renewDeadlinesDisabled) != 0 {
|
||||
c.mtx.RLock()
|
||||
defer c.mtx.RUnlock()
|
||||
if c.renewDeadlinesDisabled {
|
||||
return nil
|
||||
}
|
||||
return c.SetWriteDeadline(time.Now().Add(c.idleTimeout))
|
||||
}
|
||||
|
||||
func (c Conn) Read(p []byte) (n int, err error) {
|
||||
func (c *Conn) Read(p []byte) (n int, err error) {
|
||||
n = 0
|
||||
err = nil
|
||||
restart:
|
||||
@ -103,7 +124,7 @@ restart:
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c Conn) Write(p []byte) (n int, err error) {
|
||||
func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
n = 0
|
||||
restart:
|
||||
if err := c.RenewWriteDeadline(); err != nil {
|
||||
@ -123,7 +144,7 @@ restart:
|
||||
// but is guaranteed to use the writev system call if the wrapped Wire
|
||||
// support it.
|
||||
// Note the Conn does not support writev through io.Copy(aConn, aNetBuffers).
|
||||
func (c Conn) WritevFull(bufs net.Buffers) (n int64, err error) {
|
||||
func (c *Conn) WritevFull(bufs net.Buffers) (n int64, err error) {
|
||||
n = 0
|
||||
restart:
|
||||
if err := c.RenewWriteDeadline(); err != nil {
|
||||
@ -163,12 +184,12 @@ var _ SyscallConner = (*net.TCPConn)(nil)
|
||||
// If the connection returned io.EOF, the number of bytes written until
|
||||
// then + io.EOF is returned. This behavior is different to io.ReadFull
|
||||
// which returns io.ErrUnexpectedEOF.
|
||||
func (c Conn) ReadvFull(buffers net.Buffers) (n int64, err error) {
|
||||
func (c *Conn) ReadvFull(buffers net.Buffers) (n int64, err error) {
|
||||
return c.readv(buffers)
|
||||
}
|
||||
|
||||
// invoked by c.readv if readv system call cannot be used
|
||||
func (c Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) {
|
||||
func (c *Conn) readvFallback(nbuffers net.Buffers) (n int64, err error) {
|
||||
buffers := [][]byte(nbuffers)
|
||||
for i := range buffers {
|
||||
curBuf := buffers[i]
|
||||
|
@ -4,7 +4,7 @@ package timeoutconn
|
||||
|
||||
import "net"
|
||||
|
||||
func (c Conn) readv(buffers net.Buffers) (n int64, err error) {
|
||||
func (c *Conn) readv(buffers net.Buffers) (n int64, err error) {
|
||||
// Go does not expose the SYS_READV symbol for Solaris / Illumos - do they have it?
|
||||
// Anyhow, use the fallback
|
||||
return c.readvFallback(buffers)
|
||||
|
@ -29,8 +29,7 @@ func buildIovecs(buffers net.Buffers) (totalLen int64, vecs []syscall.Iovec) {
|
||||
return totalLen, vecs
|
||||
}
|
||||
|
||||
func (c Conn) readv(buffers net.Buffers) (n int64, err error) {
|
||||
|
||||
func (c *Conn) readv(buffers net.Buffers) (n int64, err error) {
|
||||
scc, ok := c.Wire.(SyscallConner)
|
||||
if !ok {
|
||||
return c.readvFallback(buffers)
|
||||
@ -62,7 +61,7 @@ func (c Conn) readv(buffers net.Buffers) (n int64, err error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (c Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n int64, err error) {
|
||||
func (c *Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n int64, err error) {
|
||||
rawReadErr := rawConn.Read(func(fd uintptr) (done bool) {
|
||||
// iovecs, n and err must not be shadowed!
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user