diff --git a/rpc/dataconn/timeoutconn/timeoutconn.go b/rpc/dataconn/timeoutconn/timeoutconn.go index 3b57b79..44168d8 100644 --- a/rpc/dataconn/timeoutconn/timeoutconn.go +++ b/rpc/dataconn/timeoutconn/timeoutconn.go @@ -162,10 +162,14 @@ func buildIovecs(buffers net.Buffers) (totalLen int64, vecs []syscall.Iovec) { if len(buffers[i]) == 0 { continue } - vecs = append(vecs, syscall.Iovec{ + + v := syscall.Iovec{ Base: &buffers[i][0], - Len: uint64(len(buffers[i])), - }) + } + // syscall.Iovec.Len has platform-dependent size, thus use SetLen + v.SetLen(len(buffers[i])) + + vecs = append(vecs, v) } return totalLen, vecs } @@ -268,14 +272,22 @@ func (c Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n in if int(thisReadN) < 0 { panic("unexpected return value") } - n += int64(thisReadN) + n += int64(thisReadN) // TODO check overflow + // shift iovecs forward - for left := int64(thisReadN); left > 0; { - curVecNewLength := int64((*iovecs)[0].Len) - left // TODO assert conversion - if curVecNewLength <= 0 { - left -= int64((*iovecs)[0].Len) + for left := int(thisReadN); left > 0; { + // conversion to uint does not change value, see TestIovecLenFieldIsMachineUint, and left > 0 + thisIovecConsumedCompletely := uint((*iovecs)[0].Len) <= uint(left) + if thisIovecConsumedCompletely { + // Update left, cannot go below 0 due to + // a) definition of thisIovecConsumedCompletely + // b) left > 0 due to loop invariant + // Convertion .Len to int64 is thus also safe now, because it is < left < INT_MAX + left -= int((*iovecs)[0].Len) *iovecs = (*iovecs)[1:] } else { + // trim this iovec to remaining length + // NOTE: unsafe.Pointer safety rules // https://tip.golang.org/pkg/unsafe/#Pointer // (3) Conversion of a Pointer to a uintptr and back, with arithmetic. @@ -283,7 +295,9 @@ func (c Conn) doOneReadv(rawConn syscall.RawConn, iovecs *[]syscall.Iovec) (n in // Note that both conversions must appear in the same expression, // with only the intervening arithmetic between them: (*iovecs)[0].Base = (*byte)(unsafe.Pointer(uintptr(unsafe.Pointer((*iovecs)[0].Base)) + uintptr(left))) - (*iovecs)[0].Len = uint64(curVecNewLength) + curVecNewLength := uint((*iovecs)[0].Len) - uint(left) // casts to uint do not change value + (*iovecs)[0].SetLen(int(curVecNewLength)) // int and uint have the same size, no change of value + break // inner } } diff --git a/rpc/dataconn/timeoutconn/timeoutconn_test.go b/rpc/dataconn/timeoutconn/timeoutconn_test.go index b4cce04..ce3043a 100644 --- a/rpc/dataconn/timeoutconn/timeoutconn_test.go +++ b/rpc/dataconn/timeoutconn/timeoutconn_test.go @@ -5,8 +5,10 @@ import ( "io" "net" "sync" + "syscall" "testing" "time" + "unsafe" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -176,3 +178,14 @@ func TestNoPartialWritesDueToDeadline(t *testing.T) { require.True(t, ok) assert.True(t, netErr.Timeout()) } + +func TestIovecLenFieldIsMachineUint(t *testing.T) { + iov := syscall.Iovec{} + _ = iov // make linter happy (unsafe.Sizeof not recognized as usage) + size_t := unsafe.Sizeof(iov.Len) + if size_t != unsafe.Sizeof(uint(23)) { + t.Fatalf("expecting (struct iov)->Len to be sizeof(uint)") + } + // ssize_t is defined to be the signed version of size_t, + // so we know sizeof(ssize_t) == sizeof(int) +}