Process routes before peers (#2105)

This commit is contained in:
Viktor Liu 2024-06-19 12:12:11 +02:00 committed by GitHub
parent b679404618
commit 61bc092458
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 165 additions and 93 deletions

View File

@ -265,7 +265,7 @@ func TestUpdateDNSServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil) wgIface, err := iface.NewWGIFace(fmt.Sprintf("utun230%d", n), fmt.Sprintf("100.66.100.%d/32", n+1), 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -343,7 +343,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
} }
privKey, _ := wgtypes.GeneratePrivateKey() privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil) wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.1/32", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil { if err != nil {
t.Errorf("build interface wireguard: %v", err) t.Errorf("build interface wireguard: %v", err)
return return
@ -801,7 +801,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
} }
privKey, _ := wgtypes.GeneratePrivateKey() privKey, _ := wgtypes.GeneratePrivateKey()
wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil) wgIface, err := iface.NewWGIFace("utun2301", "100.66.100.2/24", 33100, privKey.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil { if err != nil {
t.Fatalf("build interface wireguard: %v", err) t.Fatalf("build interface wireguard: %v", err)
return nil, err return nil, err

View File

@ -29,6 +29,7 @@ import (
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/client/internal/wgproxy"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
@ -735,6 +736,20 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
return nil return nil
} }
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers()) e.updateOfflinePeers(networkMap.GetOfflinePeers())
@ -776,19 +791,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
} }
} }
} }
protoRoutes := networkMap.GetRoutes()
if protoRoutes == nil {
protoRoutes = []*mgmProto.Route{}
}
_, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes))
if err != nil {
log.Errorf("failed to update clientRoutes, err: %v", err)
}
e.clientRoutesMu.Lock()
e.clientRoutes = clientRoutes
e.clientRoutesMu.Unlock()
protoDNSConfig := networkMap.GetDNSConfig() protoDNSConfig := networkMap.GetDNSConfig()
if protoDNSConfig == nil { if protoDNSConfig == nil {
@ -1287,7 +1289,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
default: default:
} }
return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs) return iface.NewWGIFace(e.config.WgIfaceName, e.config.WgAddr, e.config.WgPort, e.config.WgPrivateKey.String(), iface.DefaultMTU, transportNet, mArgs, e.addrViaRoutes)
} }
func (e *Engine) wgInterfaceCreate() (err error) { func (e *Engine) wgInterfaceCreate() (err error) {
@ -1485,6 +1487,21 @@ func (e *Engine) startNetworkMonitor() {
}() }()
} }
func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) {
var vpnRoutes []netip.Prefix
for _, routes := range e.GetClientRoutes() {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
return true, prefix, nil
}
return false, netip.Prefix{}, nil
}
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool { return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {

View File

@ -217,7 +217,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil) engine.wgInterface, err = iface.NewWGIFace("utun102", "100.64.0.1/24", engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -574,7 +574,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, engine.config.WgPort, key.String(), iface.DefaultMTU, newNet, nil, nil)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
input := struct { input := struct {
inputSerial uint64 inputSerial uint64
@ -745,7 +745,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil) engine.wgInterface, err = iface.NewWGIFace(wgIfaceName, wgAddr, 33100, key.String(), iface.DefaultMTU, newNet, nil, nil)
assert.NoError(t, err, "shouldn't return error") assert.NoError(t, err, "shouldn't return error")
mockRouteManager := &routemanager.MockManager{ mockRouteManager := &routemanager.MockManager{

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
@ -15,7 +14,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/internal/wgproxy" "github.com/netbirdio/netbird/client/internal/wgproxy"
"github.com/netbirdio/netbird/iface" "github.com/netbirdio/netbird/iface"
@ -763,10 +761,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
return return
} }
if candidateViaRoutes(candidate, haRoutes) {
return
}
err := conn.agent.AddRemoteCandidate(candidate) err := conn.agent.AddRemoteCandidate(candidate)
if err != nil { if err != nil {
log.Errorf("error while handling remote candidate from peer %s", conn.config.Key) log.Errorf("error while handling remote candidate from peer %s", conn.config.Key)
@ -797,25 +791,3 @@ func extraSrflxCandidate(candidate ice.Candidate) (*ice.CandidateServerReflexive
RelPort: relatedAdd.Port, RelPort: relatedAdd.Port,
}) })
} }
func candidateViaRoutes(candidate ice.Candidate, clientRoutes route.HAMap) bool {
var vpnRoutes []netip.Prefix
for _, routes := range clientRoutes {
if len(routes) > 0 && routes[0] != nil {
vpnRoutes = append(vpnRoutes, routes[0].Network)
}
}
addr, err := netip.ParseAddr(candidate.Address())
if err != nil {
log.Errorf("Failed to parse IP address %s: %v", candidate.Address(), err)
return false
}
if isVpn, prefix := systemops.IsAddrRouted(addr, vpnRoutes); isVpn {
log.Debugf("Ignoring candidate [%s], its address is routed to network %s", candidate.String(), prefix)
return true
}
return false
}

View File

@ -407,7 +407,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun43%d", n), "100.65.65.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()

View File

@ -154,7 +154,8 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf *iface.WGIfac
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() { if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix) log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
exitNextHop = initialNextHop exitNextHop = initialNextHop
} }

View File

@ -61,7 +61,7 @@ func TestAddRemoveRoutes(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()
@ -213,7 +213,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) wgInterface, err := iface.NewWGIFace(fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WGIface interface") require.NoError(t, err, "should create testing WGIface interface")
defer wgInterface.Close() defer wgInterface.Close()
@ -345,7 +345,7 @@ func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listen
newNet, err := stdnet.NewNet() newNet, err := stdnet.NewNet()
require.NoError(t, err) require.NoError(t, err)
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil) wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil, nil)
require.NoError(t, err, "should create testing WireGuard interface") require.NoError(t, err, "should create testing WireGuard interface")
err = wgInterface.Create() err = wgInterface.Create()

View File

@ -28,11 +28,14 @@ type ICEBind struct {
transportNet transport.Net transportNet transport.Net
udpMux *UniversalUDPMuxDefault udpMux *UniversalUDPMuxDefault
filterFn FilterFn
} }
func NewICEBind(transportNet transport.Net) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
ib := &ICEBind{ ib := &ICEBind{
transportNet: transportNet, transportNet: transportNet,
filterFn: filterFn,
} }
rc := receiverCreator{ rc := receiverCreator{
@ -59,8 +62,9 @@ func (s *ICEBind) createIPv4ReceiverFn(ipv4MsgsPool *sync.Pool, pc *ipv4.PacketC
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ UniversalUDPMuxParams{
UDPConn: conn, UDPConn: conn,
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn,
}, },
) )
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {

View File

@ -8,6 +8,8 @@ import (
"context" "context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -17,6 +19,10 @@ import (
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
) )
// FilterFn is a function that filters out candidates based on the address.
// If it returns true, the address is to be filtered. It also returns the prefix of matching route.
type FilterFn func(address netip.Addr) (bool, netip.Prefix, error)
// UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn // UniversalUDPMuxDefault handles STUN and TURN servers packets by wrapping the original UDPConn
// It then passes packets to the UDPMux that does the actual connection muxing. // It then passes packets to the UDPMux that does the actual connection muxing.
type UniversalUDPMuxDefault struct { type UniversalUDPMuxDefault struct {
@ -34,6 +40,7 @@ type UniversalUDPMuxParams struct {
UDPConn net.PacketConn UDPConn net.PacketConn
XORMappedAddrCacheTTL time.Duration XORMappedAddrCacheTTL time.Duration
Net transport.Net Net transport.Net
FilterFn FilterFn
} }
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@ -56,6 +63,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
PacketConn: params.UDPConn, PacketConn: params.UDPConn,
mux: m, mux: m,
logger: params.Logger, logger: params.Logger,
filterFn: params.FilterFn,
} }
// embed UDPMux // embed UDPMux
@ -105,8 +113,68 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets // udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
type udpConn struct { type udpConn struct {
net.PacketConn net.PacketConn
mux *UniversalUDPMuxDefault mux *UniversalUDPMuxDefault
logger logging.LeveledLogger logger logging.LeveledLogger
filterFn FilterFn
// TODO: reset cache on route changes
addrCache sync.Map
}
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
if u.filterFn == nil {
return u.PacketConn.WriteTo(b, addr)
}
if isRouted, found := u.addrCache.Load(addr.String()); found {
return u.handleCachedAddress(isRouted.(bool), b, addr)
}
return u.handleUncachedAddress(b, addr)
}
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
if isRouted {
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
if err := u.performFilterCheck(addr); err != nil {
return 0, err
}
return u.PacketConn.WriteTo(b, addr)
}
func (u *udpConn) performFilterCheck(addr net.Addr) error {
host, err := getHostFromAddr(addr)
if err != nil {
log.Errorf("Failed to get host from address %s: %v", addr, err)
return nil
}
a, err := netip.ParseAddr(host)
if err != nil {
log.Errorf("Failed to parse address %s: %v", addr, err)
return nil
}
if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}
return nil
}
func getHostFromAddr(addr net.Addr) (string, error) {
host, _, err := net.SplitHostPort(addr.String())
return host, err
} }
// GetSharedConn returns the shared udp conn // GetSharedConn returns the shared udp conn

View File

@ -4,17 +4,19 @@ import (
"fmt" "fmt"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address) wgAddress, err := parseWGAddress(address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
wgIFace := &WGIface{ wgIFace := &WGIface{
tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter), tun: newTunDevice(wgAddress, wgPort, wgPrivKey, mtu, transportNet, args.TunAdapter, filterFn),
userspaceBind: true, userspaceBind: true,
} }
return wgIFace, nil return wgIFace, nil

View File

@ -7,11 +7,12 @@ import (
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/iface/netstack" "github.com/netbirdio/netbird/iface/netstack"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, _ *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address) wgAddress, err := parseWGAddress(address)
if err != nil { if err != nil {
return nil, err return nil, err
@ -22,11 +23,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
} }
if netstack.IsEnabled() { if netstack.IsEnabled() {
wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr()) wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil return wgIFace, nil
} }
wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil return wgIFace, nil
} }

View File

@ -6,16 +6,18 @@ import (
"fmt" "fmt"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address) wgAddress, err := parseWGAddress(address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
wgIFace := &WGIface{ wgIFace := &WGIface{
tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd), tun: newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, transportNet, args.TunFd, filterFn),
userspaceBind: true, userspaceBind: true,
} }
return wgIFace, nil return wgIFace, nil

View File

@ -41,7 +41,7 @@ func TestWGIface_UpdateAddr(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil) iface, err := NewWGIFace(ifaceName, addr, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -114,7 +114,7 @@ func Test_CreateInterface(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil) iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -149,7 +149,7 @@ func Test_Close(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil) iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -182,7 +182,7 @@ func Test_ConfigureInterface(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil) iface, err := NewWGIFace(ifaceName, wgIP, wgPort, key, DefaultMTU, newNet, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -230,7 +230,7 @@ func Test_UpdatePeer(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil) iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -291,7 +291,7 @@ func Test_RemovePeer(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil) iface, err := NewWGIFace(ifaceName, wgIP, 33100, key, DefaultMTU, newNet, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -345,7 +345,7 @@ func Test_ConnectPeers(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil) iface1, err := NewWGIFace(peer1ifaceName, peer1wgIP, peer1wgPort, peer1Key.String(), DefaultMTU, newNet, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -368,7 +368,7 @@ func Test_ConnectPeers(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil) iface2, err := NewWGIFace(peer2ifaceName, peer2wgIP, peer2wgPort, peer2Key.String(), DefaultMTU, newNet, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -8,11 +8,12 @@ import (
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/iface/netstack" "github.com/netbirdio/netbird/iface/netstack"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address) wgAddress, err := parseWGAddress(address)
if err != nil { if err != nil {
return nil, err return nil, err
@ -22,7 +23,7 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
// move the kernel/usp/netstack preference evaluation to upper layer // move the kernel/usp/netstack preference evaluation to upper layer
if netstack.IsEnabled() { if netstack.IsEnabled() {
wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr()) wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
return wgIFace, nil return wgIFace, nil
} }
@ -36,7 +37,7 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
if !tunModuleIsLoaded() { if !tunModuleIsLoaded() {
return nil, fmt.Errorf("couldn't check or load tun module") return nil, fmt.Errorf("couldn't check or load tun module")
} }
wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) wgIFace.tun = newTunUSPDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, nil)
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
return wgIFace, nil return wgIFace, nil
} }

View File

@ -5,11 +5,12 @@ import (
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/iface/bind"
"github.com/netbirdio/netbird/iface/netstack" "github.com/netbirdio/netbird/iface/netstack"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments) (*WGIface, error) { func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string, mtu int, transportNet transport.Net, args *MobileIFaceArguments, filterFn bind.FilterFn) (*WGIface, error) {
wgAddress, err := parseWGAddress(address) wgAddress, err := parseWGAddress(address)
if err != nil { if err != nil {
return nil, err return nil, err
@ -20,11 +21,11 @@ func NewWGIFace(iFaceName string, address string, wgPort int, wgPrivKey string,
} }
if netstack.IsEnabled() { if netstack.IsEnabled() {
wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr()) wgIFace.tun = newTunNetstackDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, netstack.ListenAddr(), filterFn)
return wgIFace, nil return wgIFace, nil
} }
wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet) wgIFace.tun = newTunDevice(iFaceName, wgAddress, wgPort, wgPrivKey, mtu, transportNet, filterFn)
return wgIFace, nil return wgIFace, nil
} }

View File

@ -31,13 +31,13 @@ type wgTunDevice struct {
configurer wgConfigurer configurer wgConfigurer
} }
func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter) wgTunDevice { func newTunDevice(address WGAddress, port int, key string, mtu int, transportNet transport.Net, tunAdapter TunAdapter, filterFn bind.FilterFn) wgTunDevice {
return wgTunDevice{ return wgTunDevice{
address: address, address: address,
port: port, port: port,
key: key, key: key,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet), iceBind: bind.NewICEBind(transportNet, filterFn),
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
} }
} }

View File

@ -27,14 +27,14 @@ type tunDevice struct {
configurer wgConfigurer configurer wgConfigurer
} }
func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice { func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
return &tunDevice{ return &tunDevice{
name: name, name: name,
address: address, address: address,
port: port, port: port,
key: key, key: key,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet), iceBind: bind.NewICEBind(transportNet, filterFn),
} }
} }

View File

@ -29,13 +29,13 @@ type tunDevice struct {
configurer wgConfigurer configurer wgConfigurer
} }
func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int) *tunDevice { func newTunDevice(name string, address WGAddress, port int, key string, transportNet transport.Net, tunFd int, filterFn bind.FilterFn) *tunDevice {
return &tunDevice{ return &tunDevice{
name: name, name: name,
address: address, address: address,
port: port, port: port,
key: key, key: key,
iceBind: bind.NewICEBind(transportNet), iceBind: bind.NewICEBind(transportNet, filterFn),
tunFd: tunFd, tunFd: tunFd,
} }
} }

View File

@ -27,6 +27,8 @@ type tunKernelDevice struct {
link *wgLink link *wgLink
udpMuxConn net.PacketConn udpMuxConn net.PacketConn
udpMux *bind.UniversalUDPMuxDefault udpMux *bind.UniversalUDPMuxDefault
filterFn bind.FilterFn
} }
func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice { func newTunDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) wgTunDevice {
@ -96,8 +98,9 @@ func (t *tunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
bindParams := bind.UniversalUDPMuxParams{ bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock, UDPConn: rawSock,
Net: t.transportNet, Net: t.transportNet,
FilterFn: t.filterFn,
} }
mux := bind.NewUniversalUDPMuxDefault(bindParams) mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx) go mux.ReadFromConn(t.ctx)

View File

@ -30,7 +30,7 @@ type tunNetstackDevice struct {
configurer wgConfigurer configurer wgConfigurer
} }
func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string) wgTunDevice { func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net, listenAddress string, filterFn bind.FilterFn) wgTunDevice {
return &tunNetstackDevice{ return &tunNetstackDevice{
name: name, name: name,
address: address, address: address,
@ -38,7 +38,7 @@ func newTunNetstackDevice(name string, address WGAddress, wgPort int, key string
key: key, key: key,
mtu: mtu, mtu: mtu,
listenAddress: listenAddress, listenAddress: listenAddress,
iceBind: bind.NewICEBind(transportNet), iceBind: bind.NewICEBind(transportNet, filterFn),
} }
} }

View File

@ -29,7 +29,7 @@ type tunUSPDevice struct {
configurer wgConfigurer configurer wgConfigurer
} }
func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice { func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
log.Infof("using userspace bind mode") log.Infof("using userspace bind mode")
checkUser() checkUser()
@ -40,7 +40,7 @@ func newTunUSPDevice(name string, address WGAddress, port int, key string, mtu i
port: port, port: port,
key: key, key: key,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet), iceBind: bind.NewICEBind(transportNet, filterFn),
} }
} }

View File

@ -29,14 +29,14 @@ type tunDevice struct {
configurer wgConfigurer configurer wgConfigurer
} }
func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net) wgTunDevice { func newTunDevice(name string, address WGAddress, port int, key string, mtu int, transportNet transport.Net, filterFn bind.FilterFn) wgTunDevice {
return &tunDevice{ return &tunDevice{
name: name, name: name,
address: address, address: address,
port: port, port: port,
key: key, key: key,
mtu: mtu, mtu: mtu,
iceBind: bind.NewICEBind(transportNet), iceBind: bind.NewICEBind(transportNet, filterFn),
} }
} }