mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-29 19:43:57 +01:00
2eeed55c18
This PR adds supports for the WireGuard userspace implementation using Bind interface from wireguard-go. The newly introduced ICEBind struct implements Bind with UDPMux-based structs from pion/ice to handle hole punching using ICE. The core implementation was taken from StdBind of wireguard-go. The result is a single WireGuard port that is used for host and server reflexive candidates. Relay candidates are still handled separately and will be integrated in the following PRs. ICEBind checks the incoming packets for being STUN or WireGuard ones and routes them to UDPMux (to handle hole punching) or to WireGuard respectively.
446 lines
11 KiB
Go
446 lines
11 KiB
Go
package bind
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/pion/ice/v2"
|
|
"github.com/pion/stun"
|
|
"github.com/pion/transport/v2/stdnet"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/pion/logging"
|
|
"github.com/pion/transport/v2"
|
|
)
|
|
|
|
/*
|
|
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements
|
|
*/
|
|
|
|
const receiveMTU = 8192
|
|
|
|
// UDPMuxDefault is an implementation of the interface
|
|
type UDPMuxDefault struct {
|
|
params UDPMuxParams
|
|
|
|
closedChan chan struct{}
|
|
closeOnce sync.Once
|
|
|
|
// connsIPv4 and connsIPv6 are maps of all udpMuxedConn indexed by ufrag|network|candidateType
|
|
connsIPv4, connsIPv6 map[string]*udpMuxedConn
|
|
|
|
addressMapMu sync.RWMutex
|
|
addressMap map[string][]*udpMuxedConn
|
|
|
|
// buffer pool to recycle buffers for net.UDPAddr encodes/decodes
|
|
pool *sync.Pool
|
|
|
|
mu sync.Mutex
|
|
|
|
// for UDP connection listen at unspecified address
|
|
localAddrsForUnspecified []net.Addr
|
|
}
|
|
|
|
const maxAddrSize = 512
|
|
|
|
// UDPMuxParams are parameters for UDPMux.
|
|
type UDPMuxParams struct {
|
|
Logger logging.LeveledLogger
|
|
UDPConn net.PacketConn
|
|
|
|
// Required for gathering local addresses
|
|
// in case a un UDPConn is passed which does not
|
|
// bind to a specific local address.
|
|
Net transport.Net
|
|
InterfaceFilter func(interfaceName string) bool
|
|
}
|
|
|
|
func localInterfaces(n transport.Net, interfaceFilter func(string) bool, ipFilter func(net.IP) bool, networkTypes []ice.NetworkType, includeLoopback bool) ([]net.IP, error) { //nolint:gocognit
|
|
ips := []net.IP{}
|
|
ifaces, err := n.Interfaces()
|
|
if err != nil {
|
|
return ips, err
|
|
}
|
|
|
|
var IPv4Requested, IPv6Requested bool
|
|
for _, typ := range networkTypes {
|
|
if typ.IsIPv4() {
|
|
IPv4Requested = true
|
|
}
|
|
|
|
if typ.IsIPv6() {
|
|
IPv6Requested = true
|
|
}
|
|
}
|
|
|
|
for _, iface := range ifaces {
|
|
if iface.Flags&net.FlagUp == 0 {
|
|
continue // interface down
|
|
}
|
|
if (iface.Flags&net.FlagLoopback != 0) && !includeLoopback {
|
|
continue // loopback interface
|
|
}
|
|
|
|
if interfaceFilter != nil && !interfaceFilter(iface.Name) {
|
|
continue
|
|
}
|
|
|
|
addrs, err := iface.Addrs()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
for _, addr := range addrs {
|
|
var ip net.IP
|
|
switch addr := addr.(type) {
|
|
case *net.IPNet:
|
|
ip = addr.IP
|
|
case *net.IPAddr:
|
|
ip = addr.IP
|
|
}
|
|
if ip == nil || (ip.IsLoopback() && !includeLoopback) {
|
|
continue
|
|
}
|
|
|
|
if ipv4 := ip.To4(); ipv4 == nil {
|
|
if !IPv6Requested {
|
|
continue
|
|
} else if !isSupportedIPv6(ip) {
|
|
continue
|
|
}
|
|
} else if !IPv4Requested {
|
|
continue
|
|
}
|
|
|
|
if ipFilter != nil && !ipFilter(ip) {
|
|
continue
|
|
}
|
|
|
|
ips = append(ips, ip)
|
|
}
|
|
}
|
|
return ips, nil
|
|
}
|
|
|
|
// The conditions of invalidation written below are defined in
|
|
// https://tools.ietf.org/html/rfc8445#section-5.1.1.1
|
|
func isSupportedIPv6(ip net.IP) bool {
|
|
if len(ip) != net.IPv6len ||
|
|
isZeros(ip[0:12]) || // !(IPv4-compatible IPv6)
|
|
ip[0] == 0xfe && ip[1]&0xc0 == 0xc0 || // !(IPv6 site-local unicast)
|
|
ip.IsLinkLocalUnicast() ||
|
|
ip.IsLinkLocalMulticast() {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func isZeros(ip net.IP) bool {
|
|
for i := 0; i < len(ip); i++ {
|
|
if ip[i] != 0 {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// NewUDPMuxDefault creates an implementation of UDPMux
|
|
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
|
if params.Logger == nil {
|
|
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
|
}
|
|
|
|
var localAddrsForUnspecified []net.Addr
|
|
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
|
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
|
|
} else if ok && addr.IP.IsUnspecified() {
|
|
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
|
// it will break the applications that are already using unspecified UDP connection
|
|
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
|
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
|
var networks []ice.NetworkType
|
|
switch {
|
|
case addr.IP.To4() != nil:
|
|
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
|
|
|
case addr.IP.To16() != nil:
|
|
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
|
|
|
default:
|
|
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
|
|
}
|
|
if len(networks) > 0 {
|
|
if params.Net == nil {
|
|
var err error
|
|
if params.Net, err = stdnet.NewNet(); err != nil {
|
|
params.Logger.Errorf("failed to get create network: %v", err)
|
|
}
|
|
}
|
|
|
|
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
|
|
if err == nil {
|
|
for _, ip := range ips {
|
|
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
|
}
|
|
} else {
|
|
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return &UDPMuxDefault{
|
|
addressMap: map[string][]*udpMuxedConn{},
|
|
params: params,
|
|
connsIPv4: make(map[string]*udpMuxedConn),
|
|
connsIPv6: make(map[string]*udpMuxedConn),
|
|
closedChan: make(chan struct{}, 1),
|
|
pool: &sync.Pool{
|
|
New: func() interface{} {
|
|
// big enough buffer to fit both packet and address
|
|
return newBufferHolder(receiveMTU + maxAddrSize)
|
|
},
|
|
},
|
|
localAddrsForUnspecified: localAddrsForUnspecified,
|
|
}
|
|
}
|
|
|
|
// LocalAddr returns the listening address of this UDPMuxDefault
|
|
func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
|
return m.params.UDPConn.LocalAddr()
|
|
}
|
|
|
|
// GetListenAddresses returns the list of addresses that this mux is listening on
|
|
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
|
if len(m.localAddrsForUnspecified) > 0 {
|
|
return m.localAddrsForUnspecified
|
|
}
|
|
|
|
return []net.Addr{m.LocalAddr()}
|
|
}
|
|
|
|
// GetConn returns a PacketConn given the connection's ufrag and network address
|
|
// creates the connection if an existing one can't be found
|
|
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
|
|
|
|
var isIPv6 bool
|
|
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
|
isIPv6 = true
|
|
}
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if m.IsClosed() {
|
|
return nil, io.ErrClosedPipe
|
|
}
|
|
|
|
if conn, ok := m.getConn(ufrag, isIPv6); ok {
|
|
return conn, nil
|
|
}
|
|
|
|
c := m.createMuxedConn(ufrag)
|
|
go func() {
|
|
<-c.CloseChannel()
|
|
m.RemoveConnByUfrag(ufrag)
|
|
}()
|
|
|
|
if isIPv6 {
|
|
m.connsIPv6[ufrag] = c
|
|
} else {
|
|
m.connsIPv4[ufrag] = c
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
// RemoveConnByUfrag stops and removes the muxed packet connection
|
|
func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
|
removedConns := make([]*udpMuxedConn, 0, 2)
|
|
|
|
// Keep lock section small to avoid deadlock with conn lock
|
|
m.mu.Lock()
|
|
if c, ok := m.connsIPv4[ufrag]; ok {
|
|
delete(m.connsIPv4, ufrag)
|
|
removedConns = append(removedConns, c)
|
|
}
|
|
if c, ok := m.connsIPv6[ufrag]; ok {
|
|
delete(m.connsIPv6, ufrag)
|
|
removedConns = append(removedConns, c)
|
|
}
|
|
m.mu.Unlock()
|
|
|
|
if len(removedConns) == 0 {
|
|
// No need to lock if no connection was found
|
|
return
|
|
}
|
|
|
|
m.addressMapMu.Lock()
|
|
defer m.addressMapMu.Unlock()
|
|
|
|
for _, c := range removedConns {
|
|
addresses := c.getAddresses()
|
|
for _, addr := range addresses {
|
|
if connList, ok := m.addressMap[addr]; ok {
|
|
var newList []*udpMuxedConn
|
|
for _, conn := range connList {
|
|
if conn.params.Key != ufrag {
|
|
newList = append(newList, conn)
|
|
}
|
|
}
|
|
m.addressMap[addr] = newList
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// IsClosed returns true if the mux had been closed
|
|
func (m *UDPMuxDefault) IsClosed() bool {
|
|
select {
|
|
case <-m.closedChan:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Close the mux, no further connections could be created
|
|
func (m *UDPMuxDefault) Close() error {
|
|
var err error
|
|
m.closeOnce.Do(func() {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
for _, c := range m.connsIPv4 {
|
|
_ = c.Close()
|
|
}
|
|
for _, c := range m.connsIPv6 {
|
|
_ = c.Close()
|
|
}
|
|
|
|
m.connsIPv4 = make(map[string]*udpMuxedConn)
|
|
m.connsIPv6 = make(map[string]*udpMuxedConn)
|
|
|
|
close(m.closedChan)
|
|
|
|
_ = m.params.UDPConn.Close()
|
|
})
|
|
return err
|
|
}
|
|
|
|
func (m *UDPMuxDefault) writeTo(buf []byte, rAddr net.Addr) (n int, err error) {
|
|
return m.params.UDPConn.WriteTo(buf, rAddr)
|
|
}
|
|
|
|
func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) {
|
|
if m.IsClosed() {
|
|
return
|
|
}
|
|
|
|
m.addressMapMu.Lock()
|
|
defer m.addressMapMu.Unlock()
|
|
|
|
existing, ok := m.addressMap[addr]
|
|
if !ok {
|
|
existing = []*udpMuxedConn{}
|
|
}
|
|
existing = append(existing, conn)
|
|
m.addressMap[addr] = existing
|
|
|
|
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
|
}
|
|
|
|
func (m *UDPMuxDefault) createMuxedConn(key string) *udpMuxedConn {
|
|
c := newUDPMuxedConn(&udpMuxedConnParams{
|
|
Mux: m,
|
|
Key: key,
|
|
AddrPool: m.pool,
|
|
LocalAddr: m.LocalAddr(),
|
|
Logger: m.params.Logger,
|
|
})
|
|
return c
|
|
}
|
|
|
|
// HandleSTUNMessage handles STUN packets and forwards them to underlying pion/ice library
|
|
func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
|
|
|
|
remoteAddr, ok := addr.(*net.UDPAddr)
|
|
if !ok {
|
|
return fmt.Errorf("underlying PacketConn did not return a UDPAddr")
|
|
}
|
|
|
|
// If we have already seen this address dispatch to the appropriate destination
|
|
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
|
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
|
// We will then forward STUN packets to each of these connections.
|
|
m.addressMapMu.Lock()
|
|
var destinationConnList []*udpMuxedConn
|
|
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
|
destinationConnList = append(destinationConnList, storedConns...)
|
|
}
|
|
m.addressMapMu.Unlock()
|
|
|
|
var isIPv6 bool
|
|
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
|
isIPv6 = true
|
|
}
|
|
|
|
// This block is needed to discover Peer Reflexive Candidates for which we don't know the Endpoint upfront.
|
|
// However, we can take a username attribute from the STUN message which contains ufrag.
|
|
// We can use ufrag to identify the destination conn to route packet to.
|
|
attr, stunAttrErr := msg.Get(stun.AttrUsername)
|
|
if stunAttrErr == nil {
|
|
ufrag := strings.Split(string(attr), ":")[0]
|
|
|
|
m.mu.Lock()
|
|
destinationConn := m.connsIPv4[ufrag]
|
|
if isIPv6 {
|
|
destinationConn = m.connsIPv6[ufrag]
|
|
}
|
|
|
|
if destinationConn != nil {
|
|
exists := false
|
|
for _, conn := range destinationConnList {
|
|
if conn.params.Key == destinationConn.params.Key {
|
|
exists = true
|
|
break
|
|
}
|
|
}
|
|
if !exists {
|
|
destinationConnList = append(destinationConnList, destinationConn)
|
|
}
|
|
}
|
|
m.mu.Unlock()
|
|
}
|
|
|
|
// Forward STUN packets to each destination connections even thought the STUN packet might not belong there.
|
|
// It will be discarded by the further ICE candidate logic if so.
|
|
for _, conn := range destinationConnList {
|
|
if err := conn.writePacket(msg.Raw, remoteAddr); err != nil {
|
|
log.Errorf("could not write packet: %v", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *UDPMuxDefault) getConn(ufrag string, isIPv6 bool) (val *udpMuxedConn, ok bool) {
|
|
if isIPv6 {
|
|
val, ok = m.connsIPv6[ufrag]
|
|
} else {
|
|
val, ok = m.connsIPv4[ufrag]
|
|
}
|
|
return
|
|
}
|
|
|
|
type bufferHolder struct {
|
|
buf []byte
|
|
}
|
|
|
|
func newBufferHolder(size int) *bufferHolder {
|
|
return &bufferHolder{
|
|
buf: make([]byte, size),
|
|
}
|
|
}
|