conn: introduce new package that splits out the Bind and Endpoint types

The sticky socket code stays in the device package for now,
as it reaches deeply into the peer list.

This is the first step in an effort to split some code out of
the very busy device package.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:
David Crawshaw 2019-11-07 11:13:05 -05:00 committed by Jason A. Donenfeld
parent 6aefb61355
commit 203554620d
15 changed files with 562 additions and 452 deletions

View File

@ -3,11 +3,10 @@
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package conn
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"unsafe" "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@ -18,17 +17,13 @@ const (
sockoptIPV6_UNICAST_IF = 31 sockoptIPV6_UNICAST_IF = 31
) )
func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
bytes := make([]byte, 4) bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, interfaceIndex) binary.BigEndian.PutUint32(bytes, interfaceIndex)
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
if device.net.bind == nil { sysconn, err := bind.ipv4.SyscallConn()
return errors.New("Bind is not yet initialized")
}
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
if err != nil { if err != nil {
return err return err
} }
@ -41,12 +36,12 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bo
if err != nil { if err != nil {
return err return err
} }
device.net.bind.(*nativeBind).blackhole4 = blackhole bind.blackhole4 = blackhole
return nil return nil
} }
func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn() sysconn, err := bind.ipv6.SyscallConn()
if err != nil { if err != nil {
return err return err
} }
@ -59,6 +54,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bo
if err != nil { if err != nil {
return err return err
} }
device.net.bind.(*nativeBind).blackhole6 = blackhole bind.blackhole6 = blackhole
return nil return nil
} }

101
conn/conn.go Normal file
View File

@ -0,0 +1,101 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
// Package conn implements WireGuard's network connections.
package conn
import (
"errors"
"net"
"strings"
)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
type Bind interface {
// LastMark reports the last mark set for this Bind.
LastMark() uint32
// SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error
// ReceiveIPv6 reads an IPv6 UDP packet into b.
//
// It reports the number of bytes read, n,
// the packet source address ep,
// and any error.
ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error)
// ReceiveIPv4 reads an IPv4 UDP packet into b.
//
// It reports the number of bytes read, n,
// the packet source address ep,
// and any error.
ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
// Send writes a packet b to address ep.
Send(b []byte, ep Endpoint) error
// Close closes the Bind connection.
Close() error
}
// CreateBind creates a Bind bound to a port.
//
// The value actualPort reports the actual port number the Bind
// object gets bound to.
func CreateBind(port uint16) (b Bind, actualPort uint16, err error) {
return createBind(port)
}
// BindToInterface is implemented by Bind objects that support being
// tied to a single network interface.
type BindToInterface interface {
BindToInterface4(interfaceIndex uint32, blackhole bool) error
BindToInterface6(interfaceIndex uint32, blackhole bool) error
}
// An Endpoint maintains the source/destination caching for a peer.
//
// dst : the remote address of a peer ("endpoint" in uapi terminology)
// src : the local address from which datagrams originate going to the peer
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP
SrcIP() net.IP
}
func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address
host, _, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
// trying to make sure with a small sanity test that this is a real IP address and
// not something that's likely to incur DNS lookups.
host = host[:i]
}
if ip := net.ParseIP(host); ip == nil {
return nil, errors.New("Failed to parse IP address: " + host)
}
// parse address and port
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, err
}
ip4 := addr.IP.To4()
if ip4 != nil {
addr.IP = ip4
}
return addr, err
}

View File

@ -5,7 +5,7 @@
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package conn
import ( import (
"net" "net"
@ -67,16 +67,13 @@ func (e *NativeEndpoint) SrcToString() string {
} }
func listenNet(network string, port int) (*net.UDPConn, int, error) { func listenNet(network string, port int) (*net.UDPConn, int, error) {
// listen
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
// retrieve port // Retrieve port.
// TODO(crawshaw): under what circumstances is this necessary?
laddr := conn.LocalAddr() laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr( uaddr, err := net.ResolveUDPAddr(
laddr.Network(), laddr.Network(),
@ -100,7 +97,7 @@ func extractErrno(err error) error {
return syscallErr.Err return syscallErr.Err
} }
func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { func createBind(uport uint16) (Bind, uint16, error) {
var err error var err error
var bind nativeBind var bind nativeBind
@ -135,6 +132,8 @@ func (bind *nativeBind) Close() error {
return err2 return err2
} }
func (bind *nativeBind) LastMark() uint32 { return 0 }
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
if bind.ipv4 == nil { if bind.ipv4 == nil {
return 0, nil, syscall.EAFNOSUPPORT return 0, nil, syscall.EAFNOSUPPORT

View File

@ -3,18 +3,9 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code:
* https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
*/ */
package device package conn
import ( import (
"errors" "errors"
@ -25,7 +16,6 @@ import (
"unsafe" "unsafe"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
) )
const ( const (
@ -33,8 +23,8 @@ const (
) )
type IPv4Source struct { type IPv4Source struct {
src [4]byte Src [4]byte
ifindex int32 Ifindex int32
} }
type IPv6Source struct { type IPv6Source struct {
@ -49,6 +39,10 @@ type NativeEndpoint struct {
isV6 bool isV6 bool
} }
func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() }
func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 }
func (endpoint *NativeEndpoint) src4() *IPv4Source { func (endpoint *NativeEndpoint) src4() *IPv4Source {
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
} }
@ -66,11 +60,9 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
} }
type nativeBind struct { type nativeBind struct {
sock4 int sock4 int
sock6 int sock6 int
netlinkSock int lastMark uint32
netlinkCancel *rwcancel.RWCancel
lastMark uint32
} }
var _ Endpoint = (*NativeEndpoint)(nil) var _ Endpoint = (*NativeEndpoint)(nil)
@ -111,59 +103,25 @@ func CreateEndpoint(s string) (Endpoint, error) {
return nil, errors.New("Invalid IP address") return nil, errors.New("Invalid IP address")
} }
func createNetlinkRouteSocket() (int, error) { func createBind(port uint16) (Bind, uint16, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: unix.RTMGRP_IPV4_ROUTE,
}
err = unix.Bind(sock, saddr)
if err != nil {
unix.Close(sock)
return -1, err
}
return sock, nil
}
func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
var err error var err error
var bind nativeBind var bind nativeBind
var newPort uint16 var newPort uint16
bind.netlinkSock, err = createNetlinkRouteSocket() // Attempt ipv6 bind, update port if successful.
if err != nil {
return nil, 0, err
}
bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
if err != nil {
unix.Close(bind.netlinkSock)
return nil, 0, err
}
go bind.routineRouteListener(device)
// attempt ipv6 bind, update port if successful
bind.sock6, newPort, err = create6(port) bind.sock6, newPort, err = create6(port)
if err != nil { if err != nil {
if err != syscall.EAFNOSUPPORT { if err != syscall.EAFNOSUPPORT {
bind.netlinkCancel.Cancel()
return nil, 0, err return nil, 0, err
} }
} else { } else {
port = newPort port = newPort
} }
// attempt ipv4 bind, update port if successful // Attempt ipv4 bind, update port if successful.
bind.sock4, newPort, err = create4(port) bind.sock4, newPort, err = create4(port)
if err != nil { if err != nil {
if err != syscall.EAFNOSUPPORT { if err != syscall.EAFNOSUPPORT {
bind.netlinkCancel.Cancel()
unix.Close(bind.sock6) unix.Close(bind.sock6)
return nil, 0, err return nil, 0, err
} }
@ -178,6 +136,10 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
return &bind, port, nil return &bind, port, nil
} }
func (bind *nativeBind) LastMark() uint32 {
return bind.lastMark
}
func (bind *nativeBind) SetMark(value uint32) error { func (bind *nativeBind) SetMark(value uint32) error {
if bind.sock6 != -1 { if bind.sock6 != -1 {
err := unix.SetsockoptInt( err := unix.SetsockoptInt(
@ -216,22 +178,18 @@ func closeUnblock(fd int) error {
} }
func (bind *nativeBind) Close() error { func (bind *nativeBind) Close() error {
var err1, err2, err3 error var err1, err2 error
if bind.sock6 != -1 { if bind.sock6 != -1 {
err1 = closeUnblock(bind.sock6) err1 = closeUnblock(bind.sock6)
} }
if bind.sock4 != -1 { if bind.sock4 != -1 {
err2 = closeUnblock(bind.sock4) err2 = closeUnblock(bind.sock4)
} }
err3 = bind.netlinkCancel.Cancel()
if err1 != nil { if err1 != nil {
return err1 return err1
} }
if err2 != nil { return err2
return err2
}
return err3
} }
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
@ -278,10 +236,10 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
func (end *NativeEndpoint) SrcIP() net.IP { func (end *NativeEndpoint) SrcIP() net.IP {
if !end.isV6 { if !end.isV6 {
return net.IPv4( return net.IPv4(
end.src4().src[0], end.src4().Src[0],
end.src4().src[1], end.src4().Src[1],
end.src4().src[2], end.src4().Src[2],
end.src4().src[3], end.src4().Src[3],
) )
} else { } else {
return end.src6().src[:] return end.src6().src[:]
@ -478,8 +436,8 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
}, },
unix.Inet4Pktinfo{ unix.Inet4Pktinfo{
Spec_dst: end.src4().src, Spec_dst: end.src4().Src,
Ifindex: end.src4().ifindex, Ifindex: end.src4().Ifindex,
}, },
} }
@ -573,8 +531,8 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
if cmsg.cmsghdr.Level == unix.IPPROTO_IP && if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO && cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
end.src4().src = cmsg.pktinfo.Spec_dst end.src4().Src = cmsg.pktinfo.Spec_dst
end.src4().ifindex = cmsg.pktinfo.Ifindex end.src4().Ifindex = cmsg.pktinfo.Ifindex
} }
return size, nil return size, nil
@ -611,156 +569,3 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
return size, nil return size, nil
} }
func (bind *nativeBind) routineRouteListener(device *Device) {
type peerEndpointPtr struct {
peer *Peer
endpoint *Endpoint
}
var reqPeer map[uint32]peerEndpointPtr
var reqPeerLock sync.Mutex
defer unix.Close(bind.netlinkSock)
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !bind.netlinkCancel.ReadyRead() {
return
}
}
if err != nil {
return
}
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if uint(hdr.Len) > uint(len(remain)) {
break
}
switch hdr.Type {
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
if uint(len(remain)) < uint(hdr.Len) {
break
}
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
for {
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
break
}
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
break
}
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
reqPeerLock.Lock()
if reqPeer == nil {
reqPeerLock.Unlock()
break
}
pePtr, ok := reqPeer[hdr.Seq]
reqPeerLock.Unlock()
if !ok {
break
}
pePtr.peer.Lock()
if &pePtr.peer.endpoint != pePtr.endpoint {
pePtr.peer.Unlock()
break
}
if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
pePtr.peer.Unlock()
break
}
pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
pePtr.peer.Unlock()
}
attr = attr[attrhdr.Len:]
}
}
break
}
reqPeerLock.Lock()
reqPeer = make(map[uint32]peerEndpointPtr)
reqPeerLock.Unlock()
go func() {
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
peer.RLock()
if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
peer.RUnlock()
continue
}
if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
peer.RUnlock()
break
}
nlmsg := struct {
hdr unix.NlMsghdr
msg unix.RtMsg
dsthdr unix.RtAttr
dst [4]byte
srchdr unix.RtAttr
src [4]byte
markhdr unix.RtAttr
mark uint32
}{
unix.NlMsghdr{
Type: uint16(unix.RTM_GETROUTE),
Flags: unix.NLM_F_REQUEST,
Seq: i,
},
unix.RtMsg{
Family: unix.AF_INET,
Dst_len: 32,
Src_len: 32,
},
unix.RtAttr{
Len: 8,
Type: unix.RTA_DST,
},
peer.endpoint.(*NativeEndpoint).dst4().Addr,
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
peer.endpoint.(*NativeEndpoint).src4().src,
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
},
uint32(bind.lastMark),
}
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
endpoint: &peer.endpoint,
}
reqPeerLock.Unlock()
peer.RUnlock()
i++
_, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
break
}
}
device.peers.RUnlock()
}()
}
remain = remain[hdr.Len:]
}
}
}

View File

@ -5,7 +5,7 @@
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package conn
func (bind *nativeBind) SetMark(mark uint32) error { func (bind *nativeBind) SetMark(mark uint32) error {
return nil return nil

View File

@ -5,7 +5,7 @@
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/ */
package device package conn
import ( import (
"runtime" "runtime"

View File

@ -5,11 +5,15 @@
package device package device
import "errors" import (
"errors"
"golang.zx2c4.com/wireguard/conn"
)
type DummyDatagram struct { type DummyDatagram struct {
msg []byte msg []byte
endpoint Endpoint endpoint conn.Endpoint
world bool // better type world bool // better type
} }
@ -25,7 +29,7 @@ func (b *DummyBind) SetMark(v uint32) error {
return nil return nil
} }
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in6 datagram, ok := <-b.in6
if !ok { if !ok {
return 0, nil, errors.New("closed") return 0, nil, errors.New("closed")
@ -34,7 +38,7 @@ func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
return len(datagram.msg), datagram.endpoint, nil return len(datagram.msg), datagram.endpoint, nil
} }
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
datagram, ok := <-b.in4 datagram, ok := <-b.in4
if !ok { if !ok {
return 0, nil, errors.New("closed") return 0, nil, errors.New("closed")
@ -50,6 +54,6 @@ func (b *DummyBind) Close() error {
return nil return nil
} }
func (b *DummyBind) Send(buff []byte, end Endpoint) error { func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
return nil return nil
} }

36
device/bindsocketshim.go Normal file
View File

@ -0,0 +1,36 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
"golang.zx2c4.com/wireguard/conn"
)
// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn.
func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
if device.net.bind == nil {
return errors.New("Bind is not yet initialized")
}
if iface, ok := device.net.bind.(conn.BindToInterface); ok {
return iface.BindToInterface4(interfaceIndex, blackhole)
}
return nil
}
// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn.
func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
if device.net.bind == nil {
return errors.New("Bind is not yet initialized")
}
if iface, ok := device.net.bind.(conn.BindToInterface); ok {
return iface.BindToInterface6(interfaceIndex, blackhole)
}
return nil
}

View File

@ -1,187 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
"net"
"strings"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
const (
ConnRoutineNumber = 2
)
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
*/
type Bind interface {
SetMark(value uint32) error
ReceiveIPv6(buff []byte) (int, Endpoint, error)
ReceiveIPv4(buff []byte) (int, Endpoint, error)
Send(buff []byte, end Endpoint) error
Close() error
}
/* An Endpoint maintains the source/destination caching for a peer
*
* dst : the remote address of a peer ("endpoint" in uapi terminology)
* src : the local address from which datagrams originate going to the peer
*/
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP
SrcIP() net.IP
}
func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address
host, _, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
// trying to make sure with a small sanity test that this is a real IP address and
// not something that's likely to incur DNS lookups.
host = host[:i]
}
if ip := net.ParseIP(host); ip == nil {
return nil, errors.New("Failed to parse IP address: " + host)
}
// parse address and port
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, err
}
ip4 := addr.IP.To4()
if ip4 != nil {
addr.IP = ip4
}
return addr, err
}
func unsafeCloseBind(device *Device) error {
var err error
netc := &device.net
if netc.bind != nil {
err = netc.bind.Close()
netc.bind = nil
}
netc.stopping.Wait()
return err
}
func (device *Device) BindSetMark(mark uint32) error {
device.net.Lock()
defer device.net.Unlock()
// check if modified
if device.net.fwmark == mark {
return nil
}
// update fwmark on existing bind
device.net.fwmark = mark
if device.isUp.Get() && device.net.bind != nil {
if err := device.net.bind.SetMark(mark); err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
return nil
}
func (device *Device) BindUpdate() error {
device.net.Lock()
defer device.net.Unlock()
// close existing sockets
if err := unsafeCloseBind(device); err != nil {
return err
}
// open new sockets
if device.isUp.Get() {
// bind to new port
var err error
netc := &device.net
netc.bind, netc.port, err = CreateBind(netc.port, device)
if err != nil {
netc.bind = nil
netc.port = 0
return err
}
// set fwmark
if netc.fwmark != 0 {
err = netc.bind.SetMark(netc.fwmark)
if err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
// start receiving routines
device.net.starting.Add(ConnRoutineNumber)
device.net.stopping.Add(ConnRoutineNumber)
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
device.net.starting.Wait()
device.log.Debug.Println("UDP bind has been updated")
}
return nil
}
func (device *Device) BindClose() error {
device.net.Lock()
err := unsafeCloseBind(device)
device.net.Unlock()
return err
}

View File

@ -11,15 +11,14 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
const (
DeviceRoutineNumberPerCPU = 3
DeviceRoutineNumberAdditional = 2
)
type Device struct { type Device struct {
isUp AtomicBool // device is (going) up isUp AtomicBool // device is (going) up
isClosed AtomicBool // device is closed? (acting as guard) isClosed AtomicBool // device is closed? (acting as guard)
@ -39,9 +38,10 @@ type Device struct {
starting sync.WaitGroup starting sync.WaitGroup
stopping sync.WaitGroup stopping sync.WaitGroup
sync.RWMutex sync.RWMutex
bind Bind // bind interface bind conn.Bind // bind interface
port uint16 // listening port netlinkCancel *rwcancel.RWCancel
fwmark uint32 // mark value (0 = disabled) port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
} }
staticIdentity struct { staticIdentity struct {
@ -299,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
cpus := runtime.NumCPU() cpus := runtime.NumCPU()
device.state.starting.Wait() device.state.starting.Wait()
device.state.stopping.Wait() device.state.stopping.Wait()
device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
for i := 0; i < cpus; i += 1 { for i := 0; i < cpus; i += 1 {
device.state.starting.Add(3)
device.state.stopping.Add(3)
go device.RoutineEncryption() go device.RoutineEncryption()
go device.RoutineDecryption() go device.RoutineDecryption()
go device.RoutineHandshake() go device.RoutineHandshake()
} }
device.state.starting.Add(2)
device.state.stopping.Add(2)
go device.RoutineReadFromTUN() go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader() go device.RoutineTUNEventReader()
@ -413,3 +415,127 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
} }
device.peers.RUnlock() device.peers.RUnlock()
} }
func unsafeCloseBind(device *Device) error {
var err error
netc := &device.net
if netc.netlinkCancel != nil {
netc.netlinkCancel.Cancel()
}
if netc.bind != nil {
err = netc.bind.Close()
netc.bind = nil
}
netc.stopping.Wait()
return err
}
func (device *Device) BindSetMark(mark uint32) error {
device.net.Lock()
defer device.net.Unlock()
// check if modified
if device.net.fwmark == mark {
return nil
}
// update fwmark on existing bind
device.net.fwmark = mark
if device.isUp.Get() && device.net.bind != nil {
if err := device.net.bind.SetMark(mark); err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
return nil
}
func (device *Device) BindUpdate() error {
device.net.Lock()
defer device.net.Unlock()
// close existing sockets
if err := unsafeCloseBind(device); err != nil {
return err
}
// open new sockets
if device.isUp.Get() {
// bind to new port
var err error
netc := &device.net
netc.bind, netc.port, err = conn.CreateBind(netc.port)
if err != nil {
netc.bind = nil
netc.port = 0
return err
}
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
netc.bind = nil
netc.port = 0
return err
}
// set fwmark
if netc.fwmark != 0 {
err = netc.bind.SetMark(netc.fwmark)
if err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
// start receiving routines
device.net.starting.Add(2)
device.net.stopping.Add(2)
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
device.net.starting.Wait()
device.log.Debug.Println("UDP bind has been updated")
}
return nil
}
func (device *Device) BindClose() error {
device.net.Lock()
err := unsafeCloseBind(device)
device.net.Unlock()
return err
}

View File

@ -12,6 +12,8 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"golang.zx2c4.com/wireguard/conn"
) )
const ( const (
@ -24,7 +26,7 @@ type Peer struct {
keypairs Keypairs keypairs Keypairs
handshake Handshake handshake Handshake
device *Device device *Device
endpoint Endpoint endpoint conn.Endpoint
persistentKeepaliveInterval uint16 persistentKeepaliveInterval uint16
// These fields are accessed with atomic operations, which must be // These fields are accessed with atomic operations, which must be
@ -290,7 +292,7 @@ func (peer *Peer) Stop() {
var RoamingDisabled bool var RoamingDisabled bool
func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) { func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
if RoamingDisabled { if RoamingDisabled {
return return
} }

View File

@ -17,12 +17,13 @@ import (
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
) )
type QueueHandshakeElement struct { type QueueHandshakeElement struct {
msgType uint32 msgType uint32
packet []byte packet []byte
endpoint Endpoint endpoint conn.Endpoint
buffer *[MaxMessageSize]byte buffer *[MaxMessageSize]byte
} }
@ -33,7 +34,7 @@ type QueueInboundElement struct {
packet []byte packet []byte
counter uint64 counter uint64
keypair *Keypair keypair *Keypair
endpoint Endpoint endpoint conn.Endpoint
} }
func (elem *QueueInboundElement) Drop() { func (elem *QueueInboundElement) Drop() {
@ -90,7 +91,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for * Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately) * IPv4 and IPv6 (separately)
*/ */
func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) { func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
logDebug := device.log.Debug logDebug := device.log.Debug
defer func() { defer func() {
@ -108,7 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
var ( var (
err error err error
size int size int
endpoint Endpoint endpoint conn.Endpoint
) )
for { for {

12
device/sticky_default.go Normal file
View File

@ -0,0 +1,12 @@
// +build !linux
package device
import (
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
return nil, nil
}

215
device/sticky_linux.go Normal file
View File

@ -0,0 +1,215 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
*/
package device
import (
"sync"
"unsafe"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
netlinkSock, err := createNetlinkRouteSocket()
if err != nil {
return nil, err
}
netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
if err != nil {
unix.Close(netlinkSock)
return nil, err
}
go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
return netlinkCancel, nil
}
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
type peerEndpointPtr struct {
peer *Peer
endpoint *conn.Endpoint
}
var reqPeer map[uint32]peerEndpointPtr
var reqPeerLock sync.Mutex
defer unix.Close(netlinkSock)
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !netlinkCancel.ReadyRead() {
return
}
}
if err != nil {
return
}
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if uint(hdr.Len) > uint(len(remain)) {
break
}
switch hdr.Type {
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
if uint(len(remain)) < uint(hdr.Len) {
break
}
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
for {
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
break
}
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
break
}
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
reqPeerLock.Lock()
if reqPeer == nil {
reqPeerLock.Unlock()
break
}
pePtr, ok := reqPeer[hdr.Seq]
reqPeerLock.Unlock()
if !ok {
break
}
pePtr.peer.Lock()
if &pePtr.peer.endpoint != pePtr.endpoint {
pePtr.peer.Unlock()
break
}
if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx {
pePtr.peer.Unlock()
break
}
pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc()
pePtr.peer.Unlock()
}
attr = attr[attrhdr.Len:]
}
}
break
}
reqPeerLock.Lock()
reqPeer = make(map[uint32]peerEndpointPtr)
reqPeerLock.Unlock()
go func() {
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
peer.RLock()
if peer.endpoint == nil {
peer.RUnlock()
continue
}
nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint)
if nativeEP == nil {
peer.RUnlock()
continue
}
if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
peer.RUnlock()
break
}
nlmsg := struct {
hdr unix.NlMsghdr
msg unix.RtMsg
dsthdr unix.RtAttr
dst [4]byte
srchdr unix.RtAttr
src [4]byte
markhdr unix.RtAttr
mark uint32
}{
unix.NlMsghdr{
Type: uint16(unix.RTM_GETROUTE),
Flags: unix.NLM_F_REQUEST,
Seq: i,
},
unix.RtMsg{
Family: unix.AF_INET,
Dst_len: 32,
Src_len: 32,
},
unix.RtAttr{
Len: 8,
Type: unix.RTA_DST,
},
nativeEP.Dst4().Addr,
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
nativeEP.Src4().Src,
unix.RtAttr{
Len: 8,
Type: unix.RTA_MARK,
},
uint32(bind.LastMark()),
}
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
endpoint: &peer.endpoint,
}
reqPeerLock.Unlock()
peer.RUnlock()
i++
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
break
}
}
device.peers.RUnlock()
}()
}
remain = remain[hdr.Len:]
}
}
}
func createNetlinkRouteSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
}
err = unix.Bind(sock, saddr)
if err != nil {
unix.Close(sock)
return -1, err
}
return sock, nil
}

View File

@ -15,6 +15,7 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/ipc"
) )
@ -306,7 +307,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
err := func() error { err := func() error {
peer.Lock() peer.Lock()
defer peer.Unlock() defer peer.Unlock()
endpoint, err := CreateEndpoint(value) endpoint, err := conn.CreateEndpoint(value)
if err != nil { if err != nil {
return err return err
} }