Rework sticky sockets

This commit is contained in:
Jason A. Donenfeld 2018-04-20 04:05:11 +02:00
parent f5c256affd
commit 5ba84696e2
3 changed files with 172 additions and 291 deletions

View File

@ -1,13 +1,18 @@
/* Copyright 2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. /* Copyright 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
* *
* This implements userspace semantics of "sticky sockets", modeled after * This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. * WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
*/ */
package main package main
import ( import (
"encoding/binary"
"errors" "errors"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"net" "net"
@ -15,15 +20,36 @@ import (
"unsafe" "unsafe"
) )
/* Supports source address caching type IPv4Source struct {
* src [4]byte
* Currently there is no way to achieve this within the net package: ifindex int32
* See e.g. https://github.com/golang/go/issues/17930 }
* So this code is remains platform dependent.
*/ type IPv6Source struct {
src [16]byte
//ifindex belongs in dst.ZoneId
}
type NativeEndpoint struct { type NativeEndpoint struct {
src unix.RawSockaddrInet6 dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
dst unix.RawSockaddrInet6 src [unsafe.Sizeof(IPv6Source{})]byte
isV6 bool
}
func (endpoint *NativeEndpoint) src4() *IPv4Source {
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *NativeEndpoint) src6() *IPv6Source {
return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
}
func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
} }
type NativeBind struct { type NativeBind struct {
@ -34,22 +60,6 @@ type NativeBind struct {
var _ Endpoint = (*NativeEndpoint)(nil) var _ Endpoint = (*NativeEndpoint)(nil)
var _ Bind = NativeBind{} var _ Bind = NativeBind{}
type IPv4Source struct {
src unix.RawSockaddrInet4
Ifindex int32
}
func htons(val uint16) uint16 {
var out [unsafe.Sizeof(val)]byte
binary.BigEndian.PutUint16(out[:], val)
return *((*uint16)(unsafe.Pointer(&out[0])))
}
func ntohs(val uint16) uint16 {
tmp := ((*[unsafe.Sizeof(val)]byte)(unsafe.Pointer(&val)))
return binary.BigEndian.Uint16((*tmp)[:])
}
func CreateEndpoint(s string) (Endpoint, error) { func CreateEndpoint(s string) (Endpoint, error) {
var end NativeEndpoint var end NativeEndpoint
addr, err := parseEndpoint(s) addr, err := parseEndpoint(s)
@ -59,10 +69,9 @@ func CreateEndpoint(s string) (Endpoint, error) {
ipv4 := addr.IP.To4() ipv4 := addr.IP.To4()
if ipv4 != nil { if ipv4 != nil {
dst := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst)) dst := end.dst4()
dst.Family = unix.AF_INET end.isV6 = false
dst.Port = htons(uint16(addr.Port)) dst.Port = addr.Port
dst.Zero = [8]byte{}
copy(dst.Addr[:], ipv4) copy(dst.Addr[:], ipv4)
end.ClearSrc() end.ClearSrc()
return &end, nil return &end, nil
@ -74,17 +83,16 @@ func CreateEndpoint(s string) (Endpoint, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
dst := &end.dst dst := end.dst6()
dst.Family = unix.AF_INET6 end.isV6 = true
dst.Port = htons(uint16(addr.Port)) dst.Port = addr.Port
dst.Flowinfo = 0 dst.ZoneId = zone
dst.Scope_id = zone
copy(dst.Addr[:], ipv6[:]) copy(dst.Addr[:], ipv6[:])
end.ClearSrc() end.ClearSrc()
return &end, nil return &end, nil
} }
return nil, errors.New("Failed to recognize IP address format") return nil, errors.New("Invalid IP address")
} }
func CreateBind(port uint16) (Bind, uint16, error) { func CreateBind(port uint16) (Bind, uint16, error) {
@ -160,86 +168,85 @@ func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
func (bind NativeBind) Send(buff []byte, end Endpoint) error { func (bind NativeBind) Send(buff []byte, end Endpoint) error {
nend := end.(*NativeEndpoint) nend := end.(*NativeEndpoint)
switch nend.dst.Family { if !nend.isV6 {
case unix.AF_INET6:
return send6(bind.sock6, nend, buff)
case unix.AF_INET:
return send4(bind.sock4, nend, buff) return send4(bind.sock4, nend, buff)
default: } else {
return errors.New("Unknown address family of destination") return send6(bind.sock6, nend, buff)
} }
} }
func sockaddrToString(addr unix.RawSockaddrInet6) string { func rawAddrToIP4(addr *unix.SockaddrInet4) net.IP {
var udpAddr net.UDPAddr
switch addr.Family {
case unix.AF_INET6:
udpAddr.Port = int(ntohs(addr.Port))
udpAddr.IP = addr.Addr[:]
return udpAddr.String()
case unix.AF_INET:
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
udpAddr.Port = int(ntohs(ptr.Port))
udpAddr.IP = net.IPv4(
ptr.Addr[0],
ptr.Addr[1],
ptr.Addr[2],
ptr.Addr[3],
)
return udpAddr.String()
default:
return "<unknown address family>"
}
}
func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP {
switch addr.Family {
case unix.AF_INET6:
return addr.Addr[:]
case unix.AF_INET:
ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
return net.IPv4( return net.IPv4(
ptr.Addr[0], addr.Addr[0],
ptr.Addr[1], addr.Addr[1],
ptr.Addr[2], addr.Addr[2],
ptr.Addr[3], addr.Addr[3],
) )
default: }
return nil
} func rawAddrToIP6(addr *unix.SockaddrInet6) net.IP {
return addr.Addr[:]
} }
func (end *NativeEndpoint) SrcIP() net.IP { func (end *NativeEndpoint) SrcIP() net.IP {
return rawAddrToIP(end.src) if !end.isV6 {
return net.IPv4(
end.src4().src[0],
end.src4().src[1],
end.src4().src[2],
end.src4().src[3],
)
} else {
return end.src6().src[:]
}
} }
func (end *NativeEndpoint) DstIP() net.IP { func (end *NativeEndpoint) DstIP() net.IP {
return rawAddrToIP(end.dst) if !end.isV6 {
return net.IPv4(
end.dst4().Addr[0],
end.dst4().Addr[1],
end.dst4().Addr[2],
end.dst4().Addr[3],
)
} else {
return end.dst6().Addr[:]
}
} }
func (end *NativeEndpoint) DstToBytes() []byte { func (end *NativeEndpoint) DstToBytes() []byte {
ptr := unsafe.Pointer(&end.src) if !end.isV6 {
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr) return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
return arr[:] } else {
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
}
} }
func (end *NativeEndpoint) SrcToString() string { func (end *NativeEndpoint) SrcToString() string {
return sockaddrToString(end.src) return end.SrcIP().String()
} }
func (end *NativeEndpoint) DstToString() string { func (end *NativeEndpoint) DstToString() string {
return sockaddrToString(end.dst) var udpAddr net.UDPAddr
udpAddr.IP = end.DstIP()
if !end.isV6 {
udpAddr.Port = end.dst4().Port
} else {
udpAddr.Port = end.dst6().Port
}
return udpAddr.String()
} }
func (end *NativeEndpoint) ClearDst() { func (end *NativeEndpoint) ClearDst() {
end.dst = unix.RawSockaddrInet6{} for i := range end.dst {
end.dst[i] = 0
}
} }
func (end *NativeEndpoint) ClearSrc() { func (end *NativeEndpoint) ClearSrc() {
end.src = unix.RawSockaddrInet6{} for i := range end.src {
end.src[i] = 0
}
} }
func zoneToUint32(zone string) (uint32, error) { func zoneToUint32(zone string) (uint32, error) {
@ -295,6 +302,7 @@ func create4(port uint16) (int, uint16, error) {
return unix.Bind(fd, &addr) return unix.Bind(fd, &addr)
}(); err != nil { }(); err != nil {
unix.Close(fd) unix.Close(fd)
return -1, 0, err
} }
return fd, uint16(addr.Port), err return fd, uint16(addr.Port), err
@ -353,71 +361,16 @@ func create6(port uint16) (int, uint16, error) {
}(); err != nil { }(); err != nil {
unix.Close(fd) unix.Close(fd)
return -1, 0, err
} }
return fd, uint16(addr.Port), err return fd, uint16(addr.Port), err
} }
func send6(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet6Pktinfo{
Addr: end.src.Addr,
Ifindex: end.src.Scope_id,
},
}
msghdr := unix.Msghdr{
Iov: &iovec,
Iovlen: 1,
Name: (*byte)(unsafe.Pointer(&end.dst)),
Namelen: unix.SizeofSockaddrInet6,
Control: (*byte)(unsafe.Pointer(&cmsg)),
}
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
_, _, errno := sendmsg(sock, &msghdr, 0)
if errno == 0 {
return nil
}
// clear src and retry
if errno == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
_, _, errno = sendmsg(sock, &msghdr, 0)
}
return errno
}
func send4(sock int, end *NativeEndpoint, buff []byte) error { func send4(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header // construct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
src4 := (*IPv4Source)(unsafe.Pointer(&end.src))
cmsg := struct { cmsg := struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo pktinfo unix.Inet4Pktinfo
@ -428,65 +381,86 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
}, },
unix.Inet4Pktinfo{ unix.Inet4Pktinfo{
Spec_dst: src4.src.Addr, Spec_dst: end.src4().src,
Ifindex: src4.Ifindex, Ifindex: end.src4().ifindex,
}, },
} }
msghdr := unix.Msghdr{ _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
Iov: &iovec,
Iovlen: 1,
Name: (*byte)(unsafe.Pointer(&end.dst)),
Namelen: unix.SizeofSockaddrInet4,
Control: (*byte)(unsafe.Pointer(&cmsg)),
Flags: 0,
}
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
_, _, errno := sendmsg(sock, &msghdr, 0) if err == nil {
// clear source and try again
if errno == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
_, _, errno = sendmsg(sock, &msghdr, 0)
}
// errno = 0 is still an error instance
if errno == 0 {
return nil return nil
} }
return errno // clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
}
return err
}
func send6(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet6Pktinfo{
Addr: end.src6().src,
Ifindex: end.dst6().ZoneId,
},
}
if cmsg.pktinfo.Addr == [16]byte{} {
cmsg.pktinfo.Ifindex = 0
}
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
if err == nil {
return nil
}
// clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
}
return err
} }
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header // contruct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
var cmsg struct { var cmsg struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo pktinfo unix.Inet4Pktinfo
} }
var msghdr unix.Msghdr size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
msghdr.Iov = &iovec
msghdr.Iovlen = 1
msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
msghdr.Namelen = unix.SizeofSockaddrInet4
msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
size, _, errno := recvmsg(sock, &msghdr, 0) if err != nil {
return 0, err
}
end.isV6 = false
if errno != 0 { if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
return 0, errno *end.dst4() = *newDst4
} }
// update source cache // update source cache
@ -494,40 +468,31 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
if cmsg.cmsghdr.Level == unix.IPPROTO_IP && if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO && cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
src4 := (*IPv4Source)(unsafe.Pointer(&end.src)) end.src4().src = cmsg.pktinfo.Spec_dst
src4.src.Family = unix.AF_INET end.src4().ifindex = cmsg.pktinfo.Ifindex
src4.src.Addr = cmsg.pktinfo.Spec_dst
src4.Ifindex = cmsg.pktinfo.Ifindex
} }
return int(size), nil return size, nil
} }
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header // contruct message header
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
var cmsg struct { var cmsg struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo pktinfo unix.Inet6Pktinfo
} }
var msg unix.Msghdr size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
msg.Iov = &iovec
msg.Iovlen = 1
msg.Name = (*byte)(unsafe.Pointer(&end.dst))
msg.Namelen = uint32(unix.SizeofSockaddrInet6)
msg.Control = (*byte)(unsafe.Pointer(&cmsg))
msg.SetControllen(int(unsafe.Sizeof(cmsg)))
size, _, errno := recvmsg(sock, &msg, 0) if err != nil {
return 0, err
}
end.isV6 = true
if errno != 0 { if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
return 0, errno *end.dst6() = *newDst6
} }
// update source cache // update source cache
@ -535,10 +500,9 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
end.src.Family = unix.AF_INET6 end.src6().src = cmsg.pktinfo.Addr
end.src.Addr = cmsg.pktinfo.Addr end.dst6().ZoneId = cmsg.pktinfo.Ifindex
end.src.Scope_id = cmsg.pktinfo.Ifindex
} }
return int(size), nil return size, nil
} }

View File

@ -1,30 +0,0 @@
// +build linux,!386
/* Copyright 2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*/
package main
import (
"golang.org/x/sys/unix"
"syscall"
"unsafe"
)
func sendmsg(fd int, msghdr *unix.Msghdr, flags int) (uintptr, uintptr, syscall.Errno) {
return unix.Syscall(
unix.SYS_SENDMSG,
uintptr(fd),
uintptr(unsafe.Pointer(msghdr)),
uintptr(flags),
)
}
func recvmsg(fd int, msghdr *unix.Msghdr, flags int) (uintptr, uintptr, syscall.Errno) {
return unix.Syscall(
unix.SYS_RECVMSG,
uintptr(fd),
uintptr(unsafe.Pointer(msghdr)),
uintptr(flags),
)
}

View File

@ -1,53 +0,0 @@
// +build linux,386
/* Copyright 2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
*/
package main
import (
"golang.org/x/sys/unix"
"syscall"
"unsafe"
)
const (
_SENDMSG = 16
_RECVMSG = 17
)
func sendmsg(fd int, msghdr *unix.Msghdr, flags int) (uintptr, uintptr, syscall.Errno) {
args := struct {
fd uintptr
msghdr uintptr
flags uintptr
}{
uintptr(fd),
uintptr(unsafe.Pointer(msghdr)),
uintptr(flags),
}
return unix.Syscall(
unix.SYS_SOCKETCALL,
_SENDMSG,
uintptr(unsafe.Pointer(&args)),
0,
)
}
func recvmsg(fd int, msghdr *unix.Msghdr, flags int) (uintptr, uintptr, syscall.Errno) {
args := struct {
fd uintptr
msghdr uintptr
flags uintptr
}{
uintptr(fd),
uintptr(unsafe.Pointer(msghdr)),
uintptr(flags),
}
return unix.Syscall(
unix.SYS_SOCKETCALL,
_RECVMSG,
uintptr(unsafe.Pointer(&args)),
0,
)
}