mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-25 17:43:38 +01:00
fd67892cb4
Refactor the flat code structure
370 lines
11 KiB
Go
370 lines
11 KiB
Go
package bind
|
|
|
|
/*
|
|
Most of this code was copied from https://github.com/pion/ice and modified to fulfill NetBird's requirements.
|
|
*/
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"sync"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/pion/logging"
|
|
"github.com/pion/stun/v2"
|
|
"github.com/pion/transport/v3"
|
|
)
|
|
|
|
// FilterFn is a function that filters out candidates based on the address.
|
|
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
|
|
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
|
|
|
|
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
|
|
// It then passes packets to the UDPMux that does the actual connection muxing.
|
|
type UniversalUDPMuxDefault struct {
|
|
*UDPMuxDefault
|
|
params UniversalUDPMuxParams
|
|
|
|
// since we have a shared socket, for srflx candidates it makes sense to have a shared mapped address across all the agents
|
|
// stun.XORMappedAddress indexed by the STUN server addr
|
|
xorMappedMap map[string]*xorMapped
|
|
}
|
|
|
|
// UniversalUDPMuxParams are parameters for UniversalUDPMux server reflexive.
|
|
type UniversalUDPMuxParams struct {
|
|
Logger logging.LeveledLogger
|
|
UDPConn net.PacketConn
|
|
XORMappedAddrCacheTTL time.Duration
|
|
Net transport.Net
|
|
FilterFn FilterFn
|
|
}
|
|
|
|
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
|
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
|
|
if params.Logger == nil {
|
|
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
|
}
|
|
if params.XORMappedAddrCacheTTL == 0 {
|
|
params.XORMappedAddrCacheTTL = time.Second * 25
|
|
}
|
|
|
|
m := &UniversalUDPMuxDefault{
|
|
params: params,
|
|
xorMappedMap: make(map[string]*xorMapped),
|
|
}
|
|
|
|
// wrap UDP connection, process server reflexive messages
|
|
// before they are passed to the UDPMux connection handler (connWorker)
|
|
m.params.UDPConn = &udpConn{
|
|
PacketConn: params.UDPConn,
|
|
mux: m,
|
|
logger: params.Logger,
|
|
filterFn: params.FilterFn,
|
|
}
|
|
|
|
// embed UDPMux
|
|
udpMuxParams := UDPMuxParams{
|
|
Logger: params.Logger,
|
|
UDPConn: m.params.UDPConn,
|
|
Net: m.params.Net,
|
|
}
|
|
m.UDPMuxDefault = NewUDPMuxDefault(udpMuxParams)
|
|
|
|
return m
|
|
}
|
|
|
|
// ReadFromConn reads from the m.params.UDPConn provided upon the creation. It expects STUN packets only, however, will
|
|
// just ignore other packets printing an warning message.
|
|
// It is a blocking method, consider running in a go routine.
|
|
func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
|
|
buf := make([]byte, 1500)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
log.Debugf("stopped reading from the UDPConn due to finished context")
|
|
return
|
|
default:
|
|
n, a, err := m.params.UDPConn.ReadFrom(buf)
|
|
if err != nil {
|
|
log.Errorf("error while reading packet: %s", err)
|
|
continue
|
|
}
|
|
msg := &stun.Message{
|
|
Raw: append([]byte{}, buf[:n]...),
|
|
}
|
|
err = msg.Decode()
|
|
if err != nil {
|
|
log.Warnf("error while parsing STUN message. The packet doesn't seem to be a STUN packet: %s", err)
|
|
continue
|
|
}
|
|
|
|
err = m.HandleSTUNMessage(msg, a)
|
|
if err != nil {
|
|
log.Errorf("error while handling STUn message: %s", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
|
type udpConn struct {
|
|
net.PacketConn
|
|
mux *UniversalUDPMuxDefault
|
|
logger logging.LeveledLogger
|
|
filterFn FilterFn
|
|
// TODO: reset cache on route changes
|
|
addrCache sync.Map
|
|
}
|
|
|
|
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
|
if u.filterFn == nil {
|
|
return u.PacketConn.WriteTo(b, addr)
|
|
}
|
|
|
|
if isRouted, found := u.addrCache.Load(addr.String()); found {
|
|
return u.handleCachedAddress(isRouted.(bool), b, addr)
|
|
}
|
|
|
|
return u.handleUncachedAddress(b, addr)
|
|
}
|
|
|
|
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
|
if isRouted {
|
|
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
|
}
|
|
return u.PacketConn.WriteTo(b, addr)
|
|
}
|
|
|
|
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
|
if err := u.performFilterCheck(addr); err != nil {
|
|
return 0, err
|
|
}
|
|
return u.PacketConn.WriteTo(b, addr)
|
|
}
|
|
|
|
func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
|
host, err := getHostFromAddr(addr)
|
|
if err != nil {
|
|
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
|
return nil
|
|
}
|
|
|
|
a, err := netip.ParseAddr(host)
|
|
if err != nil {
|
|
log.Errorf("Failed to parse address %s: %v", addr, err)
|
|
return nil
|
|
}
|
|
|
|
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
|
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
|
} else {
|
|
u.addrCache.Store(addr.String(), isRouted)
|
|
if isRouted {
|
|
// Extra log, as the error only shows up with ICE logging enabled
|
|
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
|
|
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getHostFromAddr(addr net.Addr) (string, error) {
|
|
host, _, err := net.SplitHostPort(addr.String())
|
|
return host, err
|
|
}
|
|
|
|
// GetSharedConn returns the shared udp conn
|
|
func (m *UniversalUDPMuxDefault) GetSharedConn() net.PacketConn {
|
|
return m.params.UDPConn
|
|
}
|
|
|
|
// GetListenAddresses returns the listen addr of this UDP
|
|
func (m *UniversalUDPMuxDefault) GetListenAddresses() []net.Addr {
|
|
return []net.Addr{m.LocalAddr()}
|
|
}
|
|
|
|
// GetRelayedAddr creates relayed connection to the given TURN service and returns the relayed addr.
|
|
// Not implemented yet.
|
|
func (m *UniversalUDPMuxDefault) GetRelayedAddr(turnAddr net.Addr, deadline time.Duration) (*net.Addr, error) {
|
|
return nil, fmt.Errorf("not implemented yet")
|
|
}
|
|
|
|
// GetConnForURL add uniques to the muxed connection by concatenating ufrag and URL (e.g. STUN URL) to be able to support multiple STUN/TURN servers
|
|
// and return a unique connection per server.
|
|
func (m *UniversalUDPMuxDefault) GetConnForURL(ufrag string, url string, addr net.Addr) (net.PacketConn, error) {
|
|
return m.UDPMuxDefault.GetConn(fmt.Sprintf("%s%s", ufrag, url), addr)
|
|
}
|
|
|
|
// HandleSTUNMessage discovers STUN packets that carry a XOR mapped address from a STUN server.
|
|
// All other STUN packets will be forwarded to the UDPMux
|
|
func (m *UniversalUDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) error {
|
|
|
|
udpAddr, ok := addr.(*net.UDPAddr)
|
|
if !ok {
|
|
// message about this err will be logged in the UDPMux
|
|
return nil
|
|
}
|
|
|
|
if m.isXORMappedResponse(msg, udpAddr.String()) {
|
|
err := m.handleXORMappedResponse(udpAddr, msg)
|
|
if err != nil {
|
|
log.Debugf("%s: %v", fmt.Errorf("failed to get XOR-MAPPED-ADDRESS response"), err)
|
|
return nil
|
|
}
|
|
return nil
|
|
}
|
|
return m.UDPMuxDefault.HandleSTUNMessage(msg, addr)
|
|
}
|
|
|
|
// isXORMappedResponse indicates whether the message is a XORMappedAddress and is coming from the known STUN server.
|
|
func (m *UniversalUDPMuxDefault) isXORMappedResponse(msg *stun.Message, stunAddr string) bool {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
// check first if it is a STUN server address because remote peer can also send similar messages but as a BindingSuccess
|
|
_, ok := m.xorMappedMap[stunAddr]
|
|
_, err := msg.Get(stun.AttrXORMappedAddress)
|
|
return err == nil && ok
|
|
}
|
|
|
|
// handleXORMappedResponse parses response from the STUN server, extracts XORMappedAddress attribute
|
|
// and set the mapped address for the server
|
|
func (m *UniversalUDPMuxDefault) handleXORMappedResponse(stunAddr *net.UDPAddr, msg *stun.Message) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
mappedAddr, ok := m.xorMappedMap[stunAddr.String()]
|
|
if !ok {
|
|
return fmt.Errorf("no XOR address mapping")
|
|
}
|
|
|
|
var addr stun.XORMappedAddress
|
|
if err := addr.GetFrom(msg); err != nil {
|
|
return err
|
|
}
|
|
|
|
m.xorMappedMap[stunAddr.String()] = mappedAddr
|
|
mappedAddr.SetAddr(&addr)
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetXORMappedAddr returns *stun.XORMappedAddress if already present for a given STUN server.
|
|
// Makes a STUN binding request to discover mapped address otherwise.
|
|
// Blocks until the stun.XORMappedAddress has been discovered or deadline.
|
|
// Method is safe for concurrent use.
|
|
func (m *UniversalUDPMuxDefault) GetXORMappedAddr(serverAddr net.Addr, deadline time.Duration) (*stun.XORMappedAddress, error) {
|
|
m.mu.Lock()
|
|
mappedAddr, ok := m.xorMappedMap[serverAddr.String()]
|
|
// if we already have a mapping for this STUN server (address already received)
|
|
// and if it is not too old we return it without making a new request to STUN server
|
|
if ok {
|
|
if mappedAddr.expired() {
|
|
mappedAddr.closeWaiters()
|
|
delete(m.xorMappedMap, serverAddr.String())
|
|
ok = false
|
|
} else if mappedAddr.pending() {
|
|
ok = false
|
|
}
|
|
}
|
|
m.mu.Unlock()
|
|
if ok {
|
|
return mappedAddr.addr, nil
|
|
}
|
|
|
|
// otherwise, make a STUN request to discover the address
|
|
// or wait for already sent request to complete
|
|
waitAddrReceived, err := m.sendSTUN(serverAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s: %s", "failed to send STUN packet", err)
|
|
}
|
|
|
|
// block until response was handled by the connWorker routine and XORMappedAddress was updated
|
|
select {
|
|
case <-waitAddrReceived:
|
|
// when channel closed, addr was obtained
|
|
var addr *stun.XORMappedAddress
|
|
m.mu.Lock()
|
|
// A very odd case that mappedAddr is nil.
|
|
// Can happen when the deadline property is larger than params.XORMappedAddrCacheTTL.
|
|
// Or when we don't receive a response to our m.sendSTUN request (the response is handled asynchronously) and
|
|
// the XORMapped expires meanwhile triggering a closure of the waitAddrReceived channel.
|
|
// We protect the code from panic here.
|
|
if mappedAddr, ok := m.xorMappedMap[serverAddr.String()]; ok {
|
|
addr = mappedAddr.addr
|
|
}
|
|
m.mu.Unlock()
|
|
if addr == nil {
|
|
return nil, fmt.Errorf("no XOR address mapping")
|
|
}
|
|
return addr, nil
|
|
case <-time.After(deadline):
|
|
return nil, fmt.Errorf("timeout while waiting for XORMappedAddr")
|
|
}
|
|
}
|
|
|
|
// sendSTUN sends a STUN request via UDP conn.
|
|
//
|
|
// The returned channel is closed when the STUN response has been received.
|
|
// Method is safe for concurrent use.
|
|
func (m *UniversalUDPMuxDefault) sendSTUN(serverAddr net.Addr) (chan struct{}, error) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// if record present in the map, we already sent a STUN request,
|
|
// just wait when waitAddrReceived will be closed
|
|
addrMap, ok := m.xorMappedMap[serverAddr.String()]
|
|
if !ok {
|
|
addrMap = &xorMapped{
|
|
expiresAt: time.Now().Add(m.params.XORMappedAddrCacheTTL),
|
|
waitAddrReceived: make(chan struct{}),
|
|
}
|
|
m.xorMappedMap[serverAddr.String()] = addrMap
|
|
}
|
|
|
|
req, err := stun.Build(stun.BindingRequest, stun.TransactionID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if _, err = m.params.UDPConn.WriteTo(req.Raw, serverAddr); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return addrMap.waitAddrReceived, nil
|
|
}
|
|
|
|
type xorMapped struct {
|
|
addr *stun.XORMappedAddress
|
|
waitAddrReceived chan struct{}
|
|
expiresAt time.Time
|
|
}
|
|
|
|
func (a *xorMapped) closeWaiters() {
|
|
select {
|
|
case <-a.waitAddrReceived:
|
|
// notify was close, ok, that means we received duplicate response
|
|
// just exit
|
|
break
|
|
default:
|
|
// notify that twe have a new addr
|
|
close(a.waitAddrReceived)
|
|
}
|
|
}
|
|
|
|
func (a *xorMapped) pending() bool {
|
|
return a.addr == nil
|
|
}
|
|
|
|
func (a *xorMapped) expired() bool {
|
|
return a.expiresAt.Before(time.Now())
|
|
}
|
|
|
|
func (a *xorMapped) SetAddr(addr *stun.XORMappedAddress) {
|
|
a.addr = addr
|
|
a.closeWaiters()
|
|
}
|