mirror of
https://github.com/netbirdio/netbird.git
synced 2024-11-22 08:03:30 +01:00
Rollback new routing functionality (#1805)
This commit is contained in:
parent
1d1d057e7d
commit
9f32ccd453
3
.github/workflows/golang-test-darwin.yml
vendored
3
.github/workflows/golang-test-darwin.yml
vendored
@ -32,9 +32,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
macos-go-
|
macos-go-
|
||||||
|
|
||||||
- name: Install libpcap
|
|
||||||
run: brew install libpcap
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
|
10
.github/workflows/golang-test-linux.yml
vendored
10
.github/workflows/golang-test-linux.yml
vendored
@ -36,11 +36,7 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
@ -71,7 +67,7 @@ jobs:
|
|||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
@ -86,7 +82,7 @@ jobs:
|
|||||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
||||||
|
|
||||||
- name: Generate RouteManager Test bin
|
- name: Generate RouteManager Test bin
|
||||||
run: CGO_ENABLED=1 go test -c -o routemanager-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/...
|
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager/...
|
||||||
|
|
||||||
- name: Generate nftables Manager Test bin
|
- name: Generate nftables Manager Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
||||||
|
2
.github/workflows/golang-test-windows.yml
vendored
2
.github/workflows/golang-test-windows.yml
vendored
@ -46,7 +46,7 @@ jobs:
|
|||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 5m -p 1 ./... > test-out.txt 2>&1"
|
||||||
- name: test output
|
- name: test output
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
run: Get-Content test-out.txt
|
run: Get-Content test-out.txt
|
||||||
|
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@ -40,7 +40,7 @@ jobs:
|
|||||||
cache: false
|
cache: false
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest'
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v3
|
uses: golangci/golangci-lint-action@v3
|
||||||
with:
|
with:
|
||||||
|
@ -94,9 +94,6 @@ type Engine struct {
|
|||||||
// peerConns is a map that holds all the peers that are known to this peer
|
// peerConns is a map that holds all the peers that are known to this peer
|
||||||
peerConns map[string]*peer.Conn
|
peerConns map[string]*peer.Conn
|
||||||
|
|
||||||
beforePeerHook peer.BeforeAddPeerHookFunc
|
|
||||||
afterPeerHook peer.AfterRemovePeerHookFunc
|
|
||||||
|
|
||||||
// rpManager is a Rosenpass manager
|
// rpManager is a Rosenpass manager
|
||||||
rpManager *rosenpass.Manager
|
rpManager *rosenpass.Manager
|
||||||
|
|
||||||
@ -264,14 +261,6 @@ func (e *Engine) Start() error {
|
|||||||
e.dnsServer = dnsServer
|
e.dnsServer = dnsServer
|
||||||
|
|
||||||
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
|
e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.wgInterface, e.statusRecorder, initialRoutes)
|
||||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to initialize route manager: %s", err)
|
|
||||||
} else {
|
|
||||||
e.beforePeerHook = beforePeerHook
|
|
||||||
e.afterPeerHook = afterPeerHook
|
|
||||||
}
|
|
||||||
|
|
||||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||||
|
|
||||||
err = e.wgInterfaceCreate()
|
err = e.wgInterfaceCreate()
|
||||||
@ -821,11 +810,6 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
|||||||
}
|
}
|
||||||
e.peerConns[peerKey] = conn
|
e.peerConns[peerKey] = conn
|
||||||
|
|
||||||
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
|
||||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
|
||||||
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
err = e.statusRecorder.AddPeer(peerKey, peerConfig.Fqdn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
log.Warnf("error adding peer %s to status recorder, got error: %v", peerKey, err)
|
||||||
|
@ -20,7 +20,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/iface/bind"
|
"github.com/netbirdio/netbird/iface/bind"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -101,9 +100,6 @@ type IceCredentials struct {
|
|||||||
Pwd string
|
Pwd string
|
||||||
}
|
}
|
||||||
|
|
||||||
type BeforeAddPeerHookFunc func(connID nbnet.ConnectionID, IP net.IP) error
|
|
||||||
type AfterRemovePeerHookFunc func(connID nbnet.ConnectionID) error
|
|
||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
config ConnConfig
|
config ConnConfig
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
@ -142,10 +138,6 @@ type Conn struct {
|
|||||||
|
|
||||||
remoteEndpoint *net.UDPAddr
|
remoteEndpoint *net.UDPAddr
|
||||||
remoteConn *ice.Conn
|
remoteConn *ice.Conn
|
||||||
|
|
||||||
connID nbnet.ConnectionID
|
|
||||||
beforeAddPeerHooks []BeforeAddPeerHookFunc
|
|
||||||
afterRemovePeerHooks []AfterRemovePeerHookFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// meta holds meta information about a connection
|
// meta holds meta information about a connection
|
||||||
@ -401,14 +393,6 @@ func isRelayCandidate(candidate ice.Candidate) bool {
|
|||||||
return candidate.Type() == ice.CandidateTypeRelay
|
return candidate.Type() == ice.CandidateTypeRelay
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) AddBeforeAddPeerHook(hook BeforeAddPeerHookFunc) {
|
|
||||||
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *Conn) AddAfterRemovePeerHook(hook AfterRemovePeerHookFunc) {
|
|
||||||
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
// configureConnection starts proxying traffic from/to local Wireguard and sets connection status to StatusConnected
|
||||||
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) {
|
func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, remoteRosenpassPubKey []byte, remoteRosenpassAddr string) (net.Addr, error) {
|
||||||
conn.mu.Lock()
|
conn.mu.Lock()
|
||||||
@ -437,13 +421,6 @@ func (conn *Conn) configureConnection(remoteConn net.Conn, remoteWgPort int, rem
|
|||||||
conn.remoteEndpoint = endpointUdpAddr
|
conn.remoteEndpoint = endpointUdpAddr
|
||||||
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
log.Debugf("Conn resolved IP for %s: %s", endpoint, endpointUdpAddr.IP)
|
||||||
|
|
||||||
conn.connID = nbnet.GenerateConnID()
|
|
||||||
for _, hook := range conn.beforeAddPeerHooks {
|
|
||||||
if err := hook(conn.connID, endpointUdpAddr.IP); err != nil {
|
|
||||||
log.Errorf("Before add peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
err = conn.config.WgConfig.WgInterface.UpdatePeer(conn.config.WgConfig.RemoteKey, conn.config.WgConfig.AllowedIps, defaultWgKeepAlive, endpointUdpAddr, conn.config.WgConfig.PreSharedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if conn.wgProxy != nil {
|
if conn.wgProxy != nil {
|
||||||
@ -534,15 +511,6 @@ func (conn *Conn) cleanup() error {
|
|||||||
// todo: is it problem if we try to remove a peer what is never existed?
|
// todo: is it problem if we try to remove a peer what is never existed?
|
||||||
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
err3 = conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
|
||||||
|
|
||||||
if conn.connID != "" {
|
|
||||||
for _, hook := range conn.afterRemovePeerHooks {
|
|
||||||
if err := hook(conn.connID); err != nil {
|
|
||||||
log.Errorf("After remove peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.connID = ""
|
|
||||||
|
|
||||||
if conn.notifyDisconnected != nil {
|
if conn.notifyDisconnected != nil {
|
||||||
conn.notifyDisconnected()
|
conn.notifyDisconnected()
|
||||||
conn.notifyDisconnected = nil
|
conn.notifyDisconnected = nil
|
||||||
|
@ -12,7 +12,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ProbeResult holds the info about the result of a relay probe request
|
// ProbeResult holds the info about the result of a relay probe request
|
||||||
@ -96,13 +95,15 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error)
|
|||||||
switch uri.Proto {
|
switch uri.Proto {
|
||||||
case stun.ProtoTypeUDP:
|
case stun.ProtoTypeUDP:
|
||||||
var err error
|
var err error
|
||||||
conn, err = nbnet.NewListener().ListenPacket(ctx, "udp", "")
|
listener := &net.ListenConfig{}
|
||||||
|
conn, err = listener.ListenPacket(ctx, "udp", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
probeErr = fmt.Errorf("listen: %w", err)
|
probeErr = fmt.Errorf("listen: %w", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
case stun.ProtoTypeTCP:
|
case stun.ProtoTypeTCP:
|
||||||
tcpConn, err := nbnet.NewDialer().DialContext(ctx, "tcp", turnServerAddr)
|
dialer := &net.Dialer{}
|
||||||
|
tcpConn, err := dialer.DialContext(ctx, "tcp", turnServerAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
probeErr = fmt.Errorf("dial: %w", err)
|
probeErr = fmt.Errorf("dial: %w", err)
|
||||||
return
|
return
|
||||||
|
@ -41,7 +41,6 @@ type clientNetwork struct {
|
|||||||
|
|
||||||
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
|
func newClientNetworkWatcher(ctx context.Context, wgInterface *iface.WGIface, statusRecorder *peer.Status, network netip.Prefix) *clientNetwork {
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
|
||||||
client := &clientNetwork{
|
client := &clientNetwork{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
stop: cancel,
|
stop: cancel,
|
||||||
@ -73,18 +72,6 @@ func (c *clientNetwork) getRouterPeerStatuses() map[string]routerPeerStatus {
|
|||||||
return routePeerStatuses
|
return routePeerStatuses
|
||||||
}
|
}
|
||||||
|
|
||||||
// getBestRouteFromStatuses determines the most optimal route from the available routes
|
|
||||||
// within a clientNetwork, taking into account peer connection status, route metrics, and
|
|
||||||
// preference for non-relayed and direct connections.
|
|
||||||
//
|
|
||||||
// It follows these prioritization rules:
|
|
||||||
// * Connected peers: Only routes with connected peers are considered.
|
|
||||||
// * Metric: Routes with lower metrics (better) are prioritized.
|
|
||||||
// * Non-relayed: Routes without relays are preferred.
|
|
||||||
// * Direct connections: Routes with direct peer connections are favored.
|
|
||||||
// * Stability: In case of equal scores, the currently active route (if any) is maintained.
|
|
||||||
//
|
|
||||||
// It returns the ID of the selected optimal route.
|
|
||||||
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
func (c *clientNetwork) getBestRouteFromStatuses(routePeerStatuses map[string]routerPeerStatus) string {
|
||||||
chosen := ""
|
chosen := ""
|
||||||
chosenScore := 0
|
chosenScore := 0
|
||||||
@ -171,7 +158,7 @@ func (c *clientNetwork) startPeersStatusChangeWatcher() {
|
|||||||
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
||||||
state, err := c.statusRecorder.GetPeer(peerKey)
|
state, err := c.statusRecorder.GetPeer(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get peer state: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(state.Routes, c.network.String())
|
delete(state.Routes, c.network.String())
|
||||||
@ -185,7 +172,7 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
|||||||
|
|
||||||
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
|
err = c.wgInterface.RemoveAllowedIP(peerKey, c.network.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("remove allowed IP %s removed for peer %s, err: %v",
|
return fmt.Errorf("couldn't remove allowed IP %s removed for peer %s, err: %v",
|
||||||
c.network, c.chosenRoute.Peer, err)
|
c.network, c.chosenRoute.Peer, err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -193,26 +180,30 @@ func (c *clientNetwork) removeRouteFromWireguardPeer(peerKey string) error {
|
|||||||
|
|
||||||
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
func (c *clientNetwork) removeRouteFromPeerAndSystem() error {
|
||||||
if c.chosenRoute != nil {
|
if c.chosenRoute != nil {
|
||||||
if err := removeVPNRoute(c.network, c.wgInterface.Name()); err != nil {
|
err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
|
||||||
return fmt.Errorf("remove route %s from system, err: %v", c.network, err)
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
err = removeFromRouteTableIfNonSystem(c.network, c.wgInterface.Address().IP.String())
|
||||||
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("remove route: %v", err)
|
return fmt.Errorf("couldn't remove route %s from system, err: %v",
|
||||||
|
c.network, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
routerPeerStatuses := c.getRouterPeerStatuses()
|
routerPeerStatuses := c.getRouterPeerStatuses()
|
||||||
|
|
||||||
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
|
chosen := c.getBestRouteFromStatuses(routerPeerStatuses)
|
||||||
|
|
||||||
// If no route is chosen, remove the route from the peer and system
|
|
||||||
if chosen == "" {
|
if chosen == "" {
|
||||||
if err := c.removeRouteFromPeerAndSystem(); err != nil {
|
err = c.removeRouteFromPeerAndSystem()
|
||||||
return fmt.Errorf("remove route from peer and system: %v", err)
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.chosenRoute = nil
|
c.chosenRoute = nil
|
||||||
@ -220,7 +211,6 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the chosen route is the same as the current route, do nothing
|
|
||||||
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
|
if c.chosenRoute != nil && c.chosenRoute.ID == chosen {
|
||||||
if c.chosenRoute.IsEqual(c.routes[chosen]) {
|
if c.chosenRoute.IsEqual(c.routes[chosen]) {
|
||||||
return nil
|
return nil
|
||||||
@ -228,13 +218,13 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if c.chosenRoute != nil {
|
if c.chosenRoute != nil {
|
||||||
// If a previous route exists, remove it from the peer
|
err = c.removeRouteFromWireguardPeer(c.chosenRoute.Peer)
|
||||||
if err := c.removeRouteFromWireguardPeer(c.chosenRoute.Peer); err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("remove route from peer: %v", err)
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// otherwise add the route to the system
|
err = addToRouteTableIfNoExists(c.network, c.wgInterface.Address().IP.String())
|
||||||
if err := addVPNRoute(c.network, c.wgInterface.Name()); err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
return fmt.Errorf("route %s couldn't be added for peer %s, err: %v",
|
||||||
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
c.network.String(), c.wgInterface.Address().IP.String(), err)
|
||||||
}
|
}
|
||||||
@ -255,7 +245,8 @@ func (c *clientNetwork) recalculateRouteAndUpdatePeerAndSystem() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String()); err != nil {
|
err = c.wgInterface.AddAllowedIP(c.chosenRoute.Peer, c.network.String())
|
||||||
|
if err != nil {
|
||||||
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
|
log.Errorf("couldn't add allowed IP %s added for peer %s, err: %v",
|
||||||
c.network, c.chosenRoute.Peer, err)
|
c.network, c.chosenRoute.Peer, err)
|
||||||
}
|
}
|
||||||
@ -296,21 +287,21 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
|||||||
log.Debugf("stopping watcher for network %s", c.network)
|
log.Debugf("stopping watcher for network %s", c.network)
|
||||||
err := c.removeRouteFromPeerAndSystem()
|
err := c.removeRouteFromPeerAndSystem()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Couldn't remove route from peer and system for network %s: %v", c.network, err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
case <-c.peerStateUpdate:
|
case <-c.peerStateUpdate:
|
||||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Couldn't recalculate route and update peer and system: %v", err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
case update := <-c.routeUpdate:
|
case update := <-c.routeUpdate:
|
||||||
if update.updateSerial < c.updateSerial {
|
if update.updateSerial < c.updateSerial {
|
||||||
log.Warnf("Received a routes update with smaller serial number, ignoring it")
|
log.Warnf("received a routes update with smaller serial number, ignoring it")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Received a new client network route update for %s", c.network)
|
log.Debugf("received a new client network route update for %s", c.network)
|
||||||
|
|
||||||
c.handleUpdate(update)
|
c.handleUpdate(update)
|
||||||
|
|
||||||
@ -318,7 +309,7 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() {
|
|||||||
|
|
||||||
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
err := c.recalculateRouteAndUpdatePeerAndSystem()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Couldn't recalculate route and update peer and system for network %s: %v", c.network, err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.startPeersStatusChangeWatcher()
|
c.startPeersStatusChangeWatcher()
|
||||||
|
@ -2,10 +2,6 @@ package routemanager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"net/url"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@ -19,14 +15,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
var defaultv4 = netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
|
||||||
|
|
||||||
// nolint:unused
|
|
||||||
var defaultv6 = netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
|
||||||
|
|
||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error)
|
|
||||||
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error
|
||||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||||
InitialRouteRange() []string
|
InitialRouteRange() []string
|
||||||
@ -66,24 +56,6 @@ func NewManager(ctx context.Context, pubKey string, wgInterface *iface.WGIface,
|
|||||||
return dm
|
return dm
|
||||||
}
|
}
|
||||||
|
|
||||||
// Init sets up the routing
|
|
||||||
func (m *DefaultManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
|
||||||
if err := cleanupRouting(); err != nil {
|
|
||||||
log.Warnf("Failed cleaning up routing: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
mgmtAddress := m.statusRecorder.GetManagementState().URL
|
|
||||||
signalAddress := m.statusRecorder.GetSignalState().URL
|
|
||||||
ips := resolveURLsToIPs([]string{mgmtAddress, signalAddress})
|
|
||||||
|
|
||||||
beforePeerHook, afterPeerHook, err := setupRouting(ips, m.wgInterface)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
|
||||||
}
|
|
||||||
log.Info("Routing setup complete")
|
|
||||||
return beforePeerHook, afterPeerHook, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error {
|
||||||
var err error
|
var err error
|
||||||
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder)
|
||||||
@ -99,15 +71,9 @@ func (m *DefaultManager) Stop() {
|
|||||||
if m.serverRouter != nil {
|
if m.serverRouter != nil {
|
||||||
m.serverRouter.cleanUp()
|
m.serverRouter.cleanUp()
|
||||||
}
|
}
|
||||||
if err := cleanupRouting(); err != nil {
|
|
||||||
log.Errorf("Error cleaning up routing: %v", err)
|
|
||||||
} else {
|
|
||||||
log.Info("Routing cleanup complete")
|
|
||||||
}
|
|
||||||
m.ctx = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps
|
// UpdateRoutes compares received routes with existing routes and remove, update or add them to the client and server maps
|
||||||
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) error {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
@ -125,7 +91,7 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro
|
|||||||
if m.serverRouter != nil {
|
if m.serverRouter != nil {
|
||||||
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
err := m.serverRouter.updateRoutes(newServerRoutesMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("update routes: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,7 +156,11 @@ func (m *DefaultManager) classifiesRoutes(newRoutes []*route.Route) (map[string]
|
|||||||
for _, newRoute := range newRoutes {
|
for _, newRoute := range newRoutes {
|
||||||
networkID := route.GetHAUniqueID(newRoute)
|
networkID := route.GetHAUniqueID(newRoute)
|
||||||
if !ownNetworkIDs[networkID] {
|
if !ownNetworkIDs[networkID] {
|
||||||
if !isPrefixSupported(newRoute.Network) {
|
// if prefix is too small, lets assume is a possible default route which is not yet supported
|
||||||
|
// we skip this route management
|
||||||
|
if newRoute.Network.Bits() < minRangeBits {
|
||||||
|
log.Errorf("this agent version: %s, doesn't support default routes, received %s, skipping this route",
|
||||||
|
version.NetbirdVersion(), newRoute.Network)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
newClientRoutesIDMap[networkID] = append(newClientRoutesIDMap[networkID], newRoute)
|
||||||
@ -208,38 +178,3 @@ func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Rou
|
|||||||
}
|
}
|
||||||
return rs
|
return rs
|
||||||
}
|
}
|
||||||
|
|
||||||
func isPrefixSupported(prefix netip.Prefix) bool {
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "linux", "windows", "darwin":
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// If prefix is too small, lets assume it is a possible default prefix which is not yet supported
|
|
||||||
// we skip this prefix management
|
|
||||||
if prefix.Bits() <= minRangeBits {
|
|
||||||
log.Warnf("This agent version: %s, doesn't support default routes, received %s, skipping this prefix",
|
|
||||||
version.NetbirdVersion(), prefix)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// resolveURLsToIPs takes a slice of URLs, resolves them to IP addresses and returns a slice of IPs.
|
|
||||||
func resolveURLsToIPs(urls []string) []net.IP {
|
|
||||||
var ips []net.IP
|
|
||||||
for _, rawurl := range urls {
|
|
||||||
u, err := url.Parse(rawurl)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to parse url %s: %v", rawurl, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ipAddrs, err := net.LookupIP(u.Hostname())
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to resolve host %s: %v", u.Hostname(), err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ips = append(ips, ipAddrs...)
|
|
||||||
}
|
|
||||||
return ips
|
|
||||||
}
|
|
||||||
|
@ -35,7 +35,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
removeSrvRouter bool
|
removeSrvRouter bool
|
||||||
serverRoutesExpected int
|
serverRoutesExpected int
|
||||||
clientNetworkWatchersExpected int
|
clientNetworkWatchersExpected int
|
||||||
clientNetworkWatchersExpectedAllowed int
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Should create 2 client networks",
|
name: "Should create 2 client networks",
|
||||||
@ -203,7 +202,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
inputSerial: 1,
|
inputSerial: 1,
|
||||||
clientNetworkWatchersExpected: 0,
|
clientNetworkWatchersExpected: 0,
|
||||||
clientNetworkWatchersExpectedAllowed: 1,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Remove 1 Client Route",
|
name: "Remove 1 Client Route",
|
||||||
@ -417,10 +415,6 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
statusRecorder := peer.NewRecorder("https://mgm")
|
statusRecorder := peer.NewRecorder("https://mgm")
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
routeManager := NewManager(ctx, localPeerKey, wgInterface, statusRecorder, nil)
|
||||||
|
|
||||||
_, _, err = routeManager.Init()
|
|
||||||
|
|
||||||
require.NoError(t, err, "should init route manager")
|
|
||||||
defer routeManager.Stop()
|
defer routeManager.Stop()
|
||||||
|
|
||||||
if testCase.removeSrvRouter {
|
if testCase.removeSrvRouter {
|
||||||
@ -435,11 +429,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes)
|
||||||
require.NoError(t, err, "should update routes")
|
require.NoError(t, err, "should update routes")
|
||||||
|
|
||||||
expectedWatchers := testCase.clientNetworkWatchersExpected
|
require.Len(t, routeManager.clientNetworks, testCase.clientNetworkWatchersExpected, "client networks size should match")
|
||||||
if (runtime.GOOS == "linux" || runtime.GOOS == "windows" || runtime.GOOS == "darwin") && testCase.clientNetworkWatchersExpectedAllowed != 0 {
|
|
||||||
expectedWatchers = testCase.clientNetworkWatchersExpectedAllowed
|
|
||||||
}
|
|
||||||
require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match")
|
|
||||||
|
|
||||||
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
if runtime.GOOS == "linux" && routeManager.serverRouter != nil {
|
||||||
sr := routeManager.serverRouter.(*defaultServerRouter)
|
sr := routeManager.serverRouter.(*defaultServerRouter)
|
||||||
|
@ -6,7 +6,6 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
@ -17,10 +16,6 @@ type MockManager struct {
|
|||||||
StopFunc func()
|
StopFunc func()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) Init() (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
||||||
func (m *MockManager) InitialRouteRange() []string {
|
func (m *MockManager) InitialRouteRange() []string {
|
||||||
return nil
|
return nil
|
||||||
|
@ -1,126 +0,0 @@
|
|||||||
//go:build !android && !ios
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ref struct {
|
|
||||||
count int
|
|
||||||
nexthop netip.Addr
|
|
||||||
intf string
|
|
||||||
}
|
|
||||||
|
|
||||||
type RouteManager struct {
|
|
||||||
// refCountMap keeps track of the reference ref for prefixes
|
|
||||||
refCountMap map[netip.Prefix]ref
|
|
||||||
// prefixMap keeps track of the prefixes associated with a connection ID for removal
|
|
||||||
prefixMap map[nbnet.ConnectionID][]netip.Prefix
|
|
||||||
addRoute AddRouteFunc
|
|
||||||
removeRoute RemoveRouteFunc
|
|
||||||
mutex sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
type AddRouteFunc func(prefix netip.Prefix) (nexthop netip.Addr, intf string, err error)
|
|
||||||
type RemoveRouteFunc func(prefix netip.Prefix, nexthop netip.Addr, intf string) error
|
|
||||||
|
|
||||||
func NewRouteManager(addRoute AddRouteFunc, removeRoute RemoveRouteFunc) *RouteManager {
|
|
||||||
// TODO: read initial routing table into refCountMap
|
|
||||||
return &RouteManager{
|
|
||||||
refCountMap: map[netip.Prefix]ref{},
|
|
||||||
prefixMap: map[nbnet.ConnectionID][]netip.Prefix{},
|
|
||||||
addRoute: addRoute,
|
|
||||||
removeRoute: removeRoute,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *RouteManager) AddRouteRef(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
|
||||||
rm.mutex.Lock()
|
|
||||||
defer rm.mutex.Unlock()
|
|
||||||
|
|
||||||
ref := rm.refCountMap[prefix]
|
|
||||||
log.Debugf("Increasing route ref count %d for prefix %s", ref.count, prefix)
|
|
||||||
|
|
||||||
// Add route to the system, only if it's a new prefix
|
|
||||||
if ref.count == 0 {
|
|
||||||
log.Debugf("Adding route for prefix %s", prefix)
|
|
||||||
nexthop, intf, err := rm.addRoute(prefix)
|
|
||||||
if errors.Is(err, ErrRouteNotFound) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if errors.Is(err, ErrRouteNotAllowed) {
|
|
||||||
log.Debugf("Adding route for prefix %s: %s", prefix, err)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to add route for prefix %s: %w", prefix, err)
|
|
||||||
}
|
|
||||||
ref.nexthop = nexthop
|
|
||||||
ref.intf = intf
|
|
||||||
}
|
|
||||||
|
|
||||||
ref.count++
|
|
||||||
rm.refCountMap[prefix] = ref
|
|
||||||
rm.prefixMap[connID] = append(rm.prefixMap[connID], prefix)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rm *RouteManager) RemoveRouteRef(connID nbnet.ConnectionID) error {
|
|
||||||
rm.mutex.Lock()
|
|
||||||
defer rm.mutex.Unlock()
|
|
||||||
|
|
||||||
prefixes, ok := rm.prefixMap[connID]
|
|
||||||
if !ok {
|
|
||||||
log.Debugf("No prefixes found for connection ID %s", connID)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var result *multierror.Error
|
|
||||||
for _, prefix := range prefixes {
|
|
||||||
ref := rm.refCountMap[prefix]
|
|
||||||
log.Debugf("Decreasing route ref count %d for prefix %s", ref.count, prefix)
|
|
||||||
if ref.count == 1 {
|
|
||||||
log.Debugf("Removing route for prefix %s", prefix)
|
|
||||||
// TODO: don't fail if the route is not found
|
|
||||||
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
|
||||||
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
delete(rm.refCountMap, prefix)
|
|
||||||
} else {
|
|
||||||
ref.count--
|
|
||||||
rm.refCountMap[prefix] = ref
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete(rm.prefixMap, connID)
|
|
||||||
|
|
||||||
return result.ErrorOrNil()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flush removes all references and routes from the system
|
|
||||||
func (rm *RouteManager) Flush() error {
|
|
||||||
rm.mutex.Lock()
|
|
||||||
defer rm.mutex.Unlock()
|
|
||||||
|
|
||||||
var result *multierror.Error
|
|
||||||
for prefix := range rm.refCountMap {
|
|
||||||
log.Debugf("Removing route for prefix %s", prefix)
|
|
||||||
ref := rm.refCountMap[prefix]
|
|
||||||
if err := rm.removeRoute(prefix, ref.nexthop, ref.intf); err != nil {
|
|
||||||
result = multierror.Append(result, fmt.Errorf("remove route for prefix %s: %w", prefix, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rm.refCountMap = map[netip.Prefix]ref{}
|
|
||||||
rm.prefixMap = map[nbnet.ConnectionID][]netip.Prefix{}
|
|
||||||
|
|
||||||
return result.ErrorOrNil()
|
|
||||||
}
|
|
@ -4,7 +4,6 @@ package routemanager
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@ -49,7 +48,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
|
|||||||
oldRoute := m.routes[routeID]
|
oldRoute := m.routes[routeID]
|
||||||
err := m.removeFromServerNetwork(oldRoute)
|
err := m.removeFromServerNetwork(oldRoute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Unable to remove route id: %s, network %s, from server, got: %v",
|
log.Errorf("unable to remove route id: %s, network %s, from server, got: %v",
|
||||||
oldRoute.ID, oldRoute.Network, err)
|
oldRoute.ID, oldRoute.Network, err)
|
||||||
}
|
}
|
||||||
delete(m.routes, routeID)
|
delete(m.routes, routeID)
|
||||||
@ -63,7 +62,7 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
|
|||||||
|
|
||||||
err := m.addToServerNetwork(newRoute)
|
err := m.addToServerNetwork(newRoute)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Unable to add route %s from server, got: %v", newRoute.ID, err)
|
log.Errorf("unable to add route %s from server, got: %v", newRoute.ID, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
m.routes[id] = newRoute
|
m.routes[id] = newRoute
|
||||||
@ -82,22 +81,15 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[string]*route.Route) er
|
|||||||
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
|
func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
log.Infof("Not removing from server network because context is done")
|
log.Infof("not removing from server network because context is done")
|
||||||
return m.ctx.Err()
|
return m.ctx.Err()
|
||||||
default:
|
default:
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
|
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse prefix: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.firewall.RemoveRoutingRules(routerPair)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("remove routing rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
delete(m.routes, route.ID)
|
delete(m.routes, route.ID)
|
||||||
|
|
||||||
state := m.statusRecorder.GetLocalPeerState()
|
state := m.statusRecorder.GetLocalPeerState()
|
||||||
@ -111,22 +103,15 @@ func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error
|
|||||||
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error {
|
||||||
select {
|
select {
|
||||||
case <-m.ctx.Done():
|
case <-m.ctx.Done():
|
||||||
log.Infof("Not adding to server network because context is done")
|
log.Infof("not adding to server network because context is done")
|
||||||
return m.ctx.Err()
|
return m.ctx.Err()
|
||||||
default:
|
default:
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
|
err := m.firewall.InsertRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), route))
|
||||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), route)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse prefix: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.firewall.InsertRoutingRules(routerPair)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("insert routing rules: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.routes[route.ID] = route
|
m.routes[route.ID] = route
|
||||||
|
|
||||||
state := m.statusRecorder.GetLocalPeerState()
|
state := m.statusRecorder.GetLocalPeerState()
|
||||||
@ -144,33 +129,23 @@ func (m *defaultServerRouter) cleanUp() {
|
|||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
for _, r := range m.routes {
|
for _, r := range m.routes {
|
||||||
routerPair, err := routeToRouterPair(m.wgInterface.Address().Masked().String(), r)
|
err := m.firewall.RemoveRoutingRules(routeToRouterPair(m.wgInterface.Address().String(), r))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to convert route to router pair: %v", err)
|
log.Warnf("failed to remove clean up route: %s", r.ID)
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.firewall.RemoveRoutingRules(routerPair)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to remove cleanup route: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
state := m.statusRecorder.GetLocalPeerState()
|
state := m.statusRecorder.GetLocalPeerState()
|
||||||
state.Routes = nil
|
state.Routes = nil
|
||||||
m.statusRecorder.UpdateLocalPeerState(state)
|
m.statusRecorder.UpdateLocalPeerState(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
func routeToRouterPair(source string, route *route.Route) (firewall.RouterPair, error) {
|
|
||||||
parsed, err := netip.ParsePrefix(source)
|
|
||||||
if err != nil {
|
|
||||||
return firewall.RouterPair{}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func routeToRouterPair(source string, route *route.Route) firewall.RouterPair {
|
||||||
|
parsed := netip.MustParsePrefix(source).Masked()
|
||||||
return firewall.RouterPair{
|
return firewall.RouterPair{
|
||||||
ID: route.ID,
|
ID: route.ID,
|
||||||
Source: parsed.String(),
|
Source: parsed.String(),
|
||||||
Destination: route.Network.Masked().String(),
|
Destination: route.Network.Masked().String(),
|
||||||
Masquerade: route.Masquerade,
|
Masquerade: route.Masquerade,
|
||||||
}, nil
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,407 +0,0 @@
|
|||||||
//go:build !android && !ios
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
"github.com/libp2p/go-netroute"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
|
||||||
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
|
||||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
|
||||||
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
|
||||||
|
|
||||||
var ErrRouteNotFound = errors.New("route not found")
|
|
||||||
var ErrRouteNotAllowed = errors.New("route not allowed")
|
|
||||||
|
|
||||||
// TODO: fix: for default our wg address now appears as the default gw
|
|
||||||
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
|
||||||
addr := netip.IPv4Unspecified()
|
|
||||||
if prefix.Addr().Is6() {
|
|
||||||
addr = netip.IPv6Unspecified()
|
|
||||||
}
|
|
||||||
|
|
||||||
defaultGateway, _, err := getNextHop(addr)
|
|
||||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
|
||||||
return fmt.Errorf("get existing route gateway: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !prefix.Contains(defaultGateway) {
|
|
||||||
log.Debugf("Skipping adding a new route for gateway %s because it is not in the network %s", defaultGateway, prefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
gatewayPrefix := netip.PrefixFrom(defaultGateway, 32)
|
|
||||||
if defaultGateway.Is6() {
|
|
||||||
gatewayPrefix = netip.PrefixFrom(defaultGateway, 128)
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err := existsInRouteTable(gatewayPrefix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
log.Debugf("Skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var exitIntf string
|
|
||||||
gatewayHop, intf, err := getNextHop(defaultGateway)
|
|
||||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
|
||||||
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
|
||||||
}
|
|
||||||
if intf != nil {
|
|
||||||
exitIntf = intf.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
|
||||||
return addToRouteTable(gatewayPrefix, gatewayHop, exitIntf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getNextHop(ip netip.Addr) (netip.Addr, *net.Interface, error) {
|
|
||||||
r, err := netroute.New()
|
|
||||||
if err != nil {
|
|
||||||
return netip.Addr{}, nil, fmt.Errorf("new netroute: %w", err)
|
|
||||||
}
|
|
||||||
intf, gateway, preferredSrc, err := r.Route(ip.AsSlice())
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("Failed to get route for %s: %v", ip, err)
|
|
||||||
return netip.Addr{}, nil, ErrRouteNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Route for %s: interface %v, nexthop %v, preferred source %v", ip, intf, gateway, preferredSrc)
|
|
||||||
if gateway == nil {
|
|
||||||
if preferredSrc == nil {
|
|
||||||
return netip.Addr{}, nil, ErrRouteNotFound
|
|
||||||
}
|
|
||||||
log.Debugf("No next hop found for ip %s, using preferred source %s", ip, preferredSrc)
|
|
||||||
|
|
||||||
addr, ok := netip.AddrFromSlice(preferredSrc)
|
|
||||||
if !ok {
|
|
||||||
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", preferredSrc)
|
|
||||||
}
|
|
||||||
return addr.Unmap(), intf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
addr, ok := netip.AddrFromSlice(gateway)
|
|
||||||
if !ok {
|
|
||||||
return netip.Addr{}, nil, fmt.Errorf("failed to parse IP address: %s", gateway)
|
|
||||||
}
|
|
||||||
|
|
||||||
return addr.Unmap(), intf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
|
||||||
routes, err := getRoutesFromTable()
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("get routes from table: %w", err)
|
|
||||||
}
|
|
||||||
for _, tableRoute := range routes {
|
|
||||||
if tableRoute == prefix {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSubRange(prefix netip.Prefix) (bool, error) {
|
|
||||||
routes, err := getRoutesFromTable()
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("get routes from table: %w", err)
|
|
||||||
}
|
|
||||||
for _, tableRoute := range routes {
|
|
||||||
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addRouteToNonVPNIntf adds a new route to the routing table for the given prefix and returns the next hop and interface.
|
|
||||||
// If the next hop or interface is pointing to the VPN interface, it will return the initial values.
|
|
||||||
func addRouteToNonVPNIntf(
|
|
||||||
prefix netip.Prefix,
|
|
||||||
vpnIntf *iface.WGIface,
|
|
||||||
initialNextHop netip.Addr,
|
|
||||||
initialIntf *net.Interface,
|
|
||||||
) (netip.Addr, string, error) {
|
|
||||||
addr := prefix.Addr()
|
|
||||||
switch {
|
|
||||||
case addr.IsLoopback(),
|
|
||||||
addr.IsLinkLocalUnicast(),
|
|
||||||
addr.IsLinkLocalMulticast(),
|
|
||||||
addr.IsInterfaceLocalMulticast(),
|
|
||||||
addr.IsUnspecified(),
|
|
||||||
addr.IsMulticast():
|
|
||||||
|
|
||||||
return netip.Addr{}, "", ErrRouteNotAllowed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine the exit interface and next hop for the prefix, so we can add a specific route
|
|
||||||
nexthop, intf, err := getNextHop(addr)
|
|
||||||
if err != nil {
|
|
||||||
return netip.Addr{}, "", fmt.Errorf("get next hop: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop, prefix, intf)
|
|
||||||
exitNextHop := nexthop
|
|
||||||
var exitIntf string
|
|
||||||
if intf != nil {
|
|
||||||
exitIntf = intf.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
vpnAddr, ok := netip.AddrFromSlice(vpnIntf.Address().IP)
|
|
||||||
if !ok {
|
|
||||||
return netip.Addr{}, "", fmt.Errorf("failed to convert vpn address to netip.Addr")
|
|
||||||
}
|
|
||||||
|
|
||||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
|
||||||
if exitNextHop == vpnAddr || exitIntf == vpnIntf.Name() {
|
|
||||||
log.Debugf("Route for prefix %s is pointing to the VPN interface", prefix)
|
|
||||||
exitNextHop = initialNextHop
|
|
||||||
if initialIntf != nil {
|
|
||||||
exitIntf = initialIntf.Name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Adding a new route for prefix %s with next hop %s", prefix, exitNextHop)
|
|
||||||
if err := addToRouteTable(prefix, exitNextHop, exitIntf); err != nil {
|
|
||||||
return netip.Addr{}, "", fmt.Errorf("add route to table: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return exitNextHop, exitIntf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
|
||||||
// in two /1 prefixes to avoid replacing the existing default route
|
|
||||||
func genericAddVPNRoute(prefix netip.Prefix, intf string) error {
|
|
||||||
if prefix == defaultv4 {
|
|
||||||
if err := addToRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := addToRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
|
||||||
if err2 := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err2 != nil {
|
|
||||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: remove once IPv6 is supported on the interface
|
|
||||||
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
|
||||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
|
||||||
}
|
|
||||||
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
|
||||||
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
|
||||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
} else if prefix == defaultv6 {
|
|
||||||
if err := addToRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
|
||||||
return fmt.Errorf("add unreachable route split 1: %w", err)
|
|
||||||
}
|
|
||||||
if err := addToRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
|
||||||
if err2 := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err2 != nil {
|
|
||||||
log.Warnf("Failed to rollback route addition: %s", err2)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("add unreachable route split 2: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return addNonExistingRoute(prefix, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
// addNonExistingRoute adds a new route to the vpn interface if it doesn't exist in the current routing table
|
|
||||||
func addNonExistingRoute(prefix netip.Prefix, intf string) error {
|
|
||||||
ok, err := existsInRouteTable(prefix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("exists in route table: %w", err)
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
log.Warnf("Skipping adding a new route for network %s because it already exists", prefix)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = isSubRange(prefix)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("sub range: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
err := addRouteForCurrentDefaultGateway(prefix)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("Unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return addToRouteTable(prefix, netip.Addr{}, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
// genericRemoveVPNRoute removes the route from the vpn interface. If a default prefix is given,
|
|
||||||
// it will remove the split /1 prefixes
|
|
||||||
func genericRemoveVPNRoute(prefix netip.Prefix, intf string) error {
|
|
||||||
if prefix == defaultv4 {
|
|
||||||
var result *multierror.Error
|
|
||||||
if err := removeFromRouteTable(splitDefaultv4_1, netip.Addr{}, intf); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
if err := removeFromRouteTable(splitDefaultv4_2, netip.Addr{}, intf); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: remove once IPv6 is supported on the interface
|
|
||||||
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.ErrorOrNil()
|
|
||||||
} else if prefix == defaultv6 {
|
|
||||||
var result *multierror.Error
|
|
||||||
if err := removeFromRouteTable(splitDefaultv6_1, netip.Addr{}, intf); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
if err := removeFromRouteTable(splitDefaultv6_2, netip.Addr{}, intf); err != nil {
|
|
||||||
result = multierror.Append(result, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.ErrorOrNil()
|
|
||||||
}
|
|
||||||
|
|
||||||
return removeFromRouteTable(prefix, netip.Addr{}, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getPrefixFromIP(ip net.IP) (*netip.Prefix, error) {
|
|
||||||
addr, ok := netip.AddrFromSlice(ip)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("parse IP address: %s", ip)
|
|
||||||
}
|
|
||||||
addr = addr.Unmap()
|
|
||||||
|
|
||||||
var prefixLength int
|
|
||||||
switch {
|
|
||||||
case addr.Is4():
|
|
||||||
prefixLength = 32
|
|
||||||
case addr.Is6():
|
|
||||||
prefixLength = 128
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid IP address: %s", addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
prefix := netip.PrefixFrom(addr, prefixLength)
|
|
||||||
return &prefix, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupRoutingWithRouteManager(routeManager **RouteManager, initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
|
||||||
initialNextHopV4, initialIntfV4, err := getNextHop(netip.IPv4Unspecified())
|
|
||||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
|
||||||
log.Errorf("Unable to get initial v4 default next hop: %v", err)
|
|
||||||
}
|
|
||||||
initialNextHopV6, initialIntfV6, err := getNextHop(netip.IPv6Unspecified())
|
|
||||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
|
||||||
log.Errorf("Unable to get initial v6 default next hop: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
*routeManager = NewRouteManager(
|
|
||||||
func(prefix netip.Prefix) (netip.Addr, string, error) {
|
|
||||||
addr := prefix.Addr()
|
|
||||||
nexthop, intf := initialNextHopV4, initialIntfV4
|
|
||||||
if addr.Is6() {
|
|
||||||
nexthop, intf = initialNextHopV6, initialIntfV6
|
|
||||||
}
|
|
||||||
return addRouteToNonVPNIntf(prefix, wgIface, nexthop, intf)
|
|
||||||
},
|
|
||||||
removeFromRouteTable,
|
|
||||||
)
|
|
||||||
|
|
||||||
return setupHooks(*routeManager, initAddresses)
|
|
||||||
}
|
|
||||||
|
|
||||||
func cleanupRoutingWithRouteManager(routeManager *RouteManager) error {
|
|
||||||
if routeManager == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Remove hooks selectively
|
|
||||||
nbnet.RemoveDialerHooks()
|
|
||||||
nbnet.RemoveListenerHooks()
|
|
||||||
|
|
||||||
if err := routeManager.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush route manager: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupHooks(routeManager *RouteManager, initAddresses []net.IP) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
|
||||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
|
||||||
prefix, err := getPrefixFromIP(ip)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("convert ip to prefix: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := routeManager.AddRouteRef(connID, *prefix); err != nil {
|
|
||||||
return fmt.Errorf("adding route reference: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
afterHook := func(connID nbnet.ConnectionID) error {
|
|
||||||
if err := routeManager.RemoveRouteRef(connID); err != nil {
|
|
||||||
return fmt.Errorf("remove route reference: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, ip := range initAddresses {
|
|
||||||
if err := beforeHook("init", ip); err != nil {
|
|
||||||
log.Errorf("Failed to add route reference: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
nbnet.AddDialerHook(func(ctx context.Context, connID nbnet.ConnectionID, resolvedIPs []net.IPAddr) error {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
var result *multierror.Error
|
|
||||||
for _, ip := range resolvedIPs {
|
|
||||||
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
|
||||||
}
|
|
||||||
return result.ErrorOrNil()
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
|
||||||
return afterHook(connID)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddListenerWriteHook(func(connID nbnet.ConnectionID, ip *net.IPAddr, data []byte) error {
|
|
||||||
return beforeHook(connID, ip.IP)
|
|
||||||
})
|
|
||||||
|
|
||||||
nbnet.AddListenerCloseHook(func(connID nbnet.ConnectionID, conn net.PacketConn) error {
|
|
||||||
return afterHook(connID)
|
|
||||||
})
|
|
||||||
|
|
||||||
return beforeHook, afterHook, nil
|
|
||||||
}
|
|
@ -1,33 +1,13 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func cleanupRouting() error {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func enableIPForwarding() error {
|
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
|
||||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func addVPNRoute(netip.Prefix, string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeVPNRoute(netip.Prefix, string) error {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
//go:build darwin || dragonfly || freebsd || netbsd || openbsd
|
||||||
|
// +build darwin dragonfly freebsd netbsd openbsd
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
|
@ -1,61 +0,0 @@
|
|||||||
//go:build darwin && !ios
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
|
||||||
|
|
||||||
var routeManager *RouteManager
|
|
||||||
|
|
||||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
|
||||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
|
||||||
}
|
|
||||||
|
|
||||||
func cleanupRouting() error {
|
|
||||||
return cleanupRoutingWithRouteManager(routeManager)
|
|
||||||
}
|
|
||||||
|
|
||||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
|
||||||
return routeCmd("add", prefix, nexthop, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
|
||||||
return routeCmd("delete", prefix, nexthop, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func routeCmd(action string, prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
|
||||||
inet := "-inet"
|
|
||||||
if prefix.Addr().Is6() {
|
|
||||||
inet = "-inet6"
|
|
||||||
// Special case for IPv6 split default route, pointing to the wg interface fails
|
|
||||||
// TODO: Remove once we have IPv6 support on the interface
|
|
||||||
if prefix.Bits() == 1 {
|
|
||||||
intf = "lo0"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
args := []string{"-n", action, inet, prefix.String()}
|
|
||||||
if nexthop.IsValid() {
|
|
||||||
args = append(args, nexthop.Unmap().String())
|
|
||||||
} else if intf != "" {
|
|
||||||
args = append(args, "-interface", intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err := exec.Command("route", args...).CombinedOutput()
|
|
||||||
log.Tracef("route %s: %s", strings.Join(args, " "), out)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to %s route for %s: %w", action, prefix, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -1,100 +0,0 @@
|
|||||||
//go:build !ios
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"os/exec"
|
|
||||||
"regexp"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
var expectedVPNint = "utun100"
|
|
||||||
var expectedExternalInt = "lo0"
|
|
||||||
var expectedInternalInt = "lo0"
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
testCases = append(testCases, []testCase{
|
|
||||||
{
|
|
||||||
name: "To more specific route without custom dialer via vpn",
|
|
||||||
destination: "10.10.0.2:53",
|
|
||||||
expectedInterface: expectedVPNint,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "10.10.0.2", 53),
|
|
||||||
},
|
|
||||||
}...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR string) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
err := exec.Command("ifconfig", intf, "alias", ipAddressCIDR).Run()
|
|
||||||
require.NoError(t, err, "Failed to create loopback alias")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := exec.Command("ifconfig", intf, ipAddressCIDR, "-alias").Run()
|
|
||||||
assert.NoError(t, err, "Failed to remove loopback alias")
|
|
||||||
})
|
|
||||||
|
|
||||||
return "lo0"
|
|
||||||
}
|
|
||||||
|
|
||||||
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, _ string) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
var originalNexthop net.IP
|
|
||||||
if dstCIDR == "0.0.0.0/0" {
|
|
||||||
var err error
|
|
||||||
originalNexthop, err = fetchOriginalGateway()
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("Failed to fetch original gateway: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if output, err := exec.Command("route", "delete", "-net", dstCIDR).CombinedOutput(); err != nil {
|
|
||||||
t.Logf("Failed to delete route: %v, output: %s", err, output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
if originalNexthop != nil {
|
|
||||||
err := exec.Command("route", "add", "-net", dstCIDR, originalNexthop.String()).Run()
|
|
||||||
assert.NoError(t, err, "Failed to restore original route")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
err := exec.Command("route", "add", "-net", dstCIDR, gw.String()).Run()
|
|
||||||
require.NoError(t, err, "Failed to add route")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := exec.Command("route", "delete", "-net", dstCIDR).Run()
|
|
||||||
assert.NoError(t, err, "Failed to remove route")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchOriginalGateway() (net.IP, error) {
|
|
||||||
output, err := exec.Command("route", "-n", "get", "default").CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
matches := regexp.MustCompile(`gateway: (\S+)`).FindStringSubmatch(string(output))
|
|
||||||
if len(matches) == 0 {
|
|
||||||
return nil, fmt.Errorf("gateway not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return net.ParseIP(matches[1]), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
defaultDummy := createAndSetupDummyInterface(t, expectedExternalInt, "192.168.0.1/24")
|
|
||||||
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
|
|
||||||
|
|
||||||
otherDummy := createAndSetupDummyInterface(t, expectedInternalInt, "192.168.1.1/24")
|
|
||||||
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
|
|
||||||
}
|
|
@ -1,33 +1,15 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupRouting([]net.IP, *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
||||||
return nil, nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func cleanupRouting() error {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func enableIPForwarding() error {
|
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
|
||||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func addVPNRoute(netip.Prefix, string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeVPNRoute(netip.Prefix, string) error {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -3,342 +3,142 @@
|
|||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"unsafe"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
// Pulled from http://man7.org/linux/man-pages/man7/rtnetlink.7.html
|
||||||
// NetbirdVPNTableID is the ID of the custom routing table used by Netbird.
|
// See the section on RTM_NEWROUTE, specifically 'struct rtmsg'.
|
||||||
NetbirdVPNTableID = 0x1BD0
|
type routeInfoInMemory struct {
|
||||||
// NetbirdVPNTableName is the name of the custom routing table used by Netbird.
|
Family byte
|
||||||
NetbirdVPNTableName = "netbird"
|
DstLen byte
|
||||||
|
SrcLen byte
|
||||||
|
TOS byte
|
||||||
|
|
||||||
// rtTablesPath is the path to the file containing the routing table names.
|
Table byte
|
||||||
rtTablesPath = "/etc/iproute2/rt_tables"
|
Protocol byte
|
||||||
|
Scope byte
|
||||||
|
Type byte
|
||||||
|
|
||||||
// ipv4ForwardingPath is the path to the file containing the IP forwarding setting.
|
Flags uint32
|
||||||
ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
|
||||||
)
|
|
||||||
|
|
||||||
var ErrTableIDExists = errors.New("ID exists with different name")
|
|
||||||
|
|
||||||
var routeManager = &RouteManager{}
|
|
||||||
var isLegacy = os.Getenv("NB_USE_LEGACY_ROUTING") == "true"
|
|
||||||
|
|
||||||
type ruleParams struct {
|
|
||||||
fwmark int
|
|
||||||
tableID int
|
|
||||||
family int
|
|
||||||
priority int
|
|
||||||
invert bool
|
|
||||||
suppressPrefix int
|
|
||||||
description string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getSetupRules() []ruleParams {
|
const ipv4ForwardingPath = "/proc/sys/net/ipv4/ip_forward"
|
||||||
return []ruleParams{
|
|
||||||
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V4, -1, true, -1, "rule v4 netbird"},
|
|
||||||
{nbnet.NetbirdFwmark, NetbirdVPNTableID, netlink.FAMILY_V6, -1, true, -1, "rule v6 netbird"},
|
|
||||||
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V4, -1, false, 0, "rule with suppress prefixlen v4"},
|
|
||||||
{-1, syscall.RT_TABLE_MAIN, netlink.FAMILY_V6, -1, false, 0, "rule with suppress prefixlen v6"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupRouting establishes the routing configuration for the VPN, including essential rules
|
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
||||||
// to ensure proper traffic flow for management, locally configured routes, and VPN traffic.
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
//
|
|
||||||
// Rule 1 (Main Route Precedence): Safeguards locally installed routes by giving them precedence over
|
|
||||||
// potential routes received and configured for the VPN. This rule is skipped for the default route and routes
|
|
||||||
// that are not in the main table.
|
|
||||||
//
|
|
||||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
|
||||||
// This table is where a default route or other specific routes received from the management server are configured,
|
|
||||||
// enabling VPN connectivity.
|
|
||||||
//
|
|
||||||
// The rules are inserted in reverse order, as rules are added from the bottom up in the rule list.
|
|
||||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (_ peer.BeforeAddPeerHookFunc, _ peer.AfterRemovePeerHookFunc, err error) {
|
|
||||||
if isLegacy {
|
|
||||||
log.Infof("Using legacy routing setup")
|
|
||||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = addRoutingTableName(); err != nil {
|
|
||||||
log.Errorf("Error adding routing table name: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if cleanErr := cleanupRouting(); cleanErr != nil {
|
return err
|
||||||
log.Errorf("Error cleaning up routing: %v", cleanErr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
rules := getSetupRules()
|
|
||||||
for _, rule := range rules {
|
|
||||||
if err := addRule(rule); err != nil {
|
|
||||||
if errors.Is(err, syscall.EOPNOTSUPP) {
|
|
||||||
log.Warnf("Rule operations are not supported, falling back to the legacy routing setup")
|
|
||||||
isLegacy = true
|
|
||||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
|
||||||
}
|
|
||||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil, nil
|
addrMask := "/32"
|
||||||
|
if prefix.Addr().Unmap().Is6() {
|
||||||
|
addrMask = "/128"
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
ip, _, err := net.ParseCIDR(addr + addrMask)
|
||||||
// It systematically removes the three rules and any associated routing table entries to ensure a clean state.
|
if err != nil {
|
||||||
// The function uses error aggregation to report any errors encountered during the cleanup process.
|
return err
|
||||||
func cleanupRouting() error {
|
|
||||||
if isLegacy {
|
|
||||||
return cleanupRoutingWithRouteManager(routeManager)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var result *multierror.Error
|
route := &netlink.Route{
|
||||||
|
Scope: netlink.SCOPE_UNIVERSE,
|
||||||
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V4); err != nil {
|
Dst: ipNet,
|
||||||
result = multierror.Append(result, fmt.Errorf("flush routes v4: %w", err))
|
Gw: ip,
|
||||||
}
|
|
||||||
if err := flushRoutes(NetbirdVPNTableID, netlink.FAMILY_V6); err != nil {
|
|
||||||
result = multierror.Append(result, fmt.Errorf("flush routes v6: %w", err))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rules := getSetupRules()
|
err = netlink.RouteAdd(route)
|
||||||
for _, rule := range rules {
|
if err != nil {
|
||||||
if err := removeAllRules(rule); err != nil && !errors.Is(err, syscall.EOPNOTSUPP) {
|
return err
|
||||||
result = multierror.Append(result, fmt.Errorf("%s: %w", rule.description, err))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result.ErrorOrNil()
|
|
||||||
}
|
|
||||||
|
|
||||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
|
||||||
return addRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
|
||||||
return removeRoute(prefix, nexthop, intf, syscall.RT_TABLE_MAIN)
|
|
||||||
}
|
|
||||||
|
|
||||||
func addVPNRoute(prefix netip.Prefix, intf string) error {
|
|
||||||
if isLegacy {
|
|
||||||
return genericAddVPNRoute(prefix, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
// No need to check if routes exist as main table takes precedence over the VPN table via Rule 1
|
|
||||||
|
|
||||||
// TODO remove this once we have ipv6 support
|
|
||||||
if prefix == defaultv4 {
|
|
||||||
if err := addUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
|
||||||
return fmt.Errorf("add blackhole: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := addRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
|
||||||
return fmt.Errorf("add route: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeVPNRoute(prefix netip.Prefix, intf string) error {
|
func removeFromRouteTable(prefix netip.Prefix, addr string) error {
|
||||||
if isLegacy {
|
_, ipNet, err := net.ParseCIDR(prefix.String())
|
||||||
return genericRemoveVPNRoute(prefix, intf)
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO remove this once we have ipv6 support
|
addrMask := "/32"
|
||||||
if prefix == defaultv4 {
|
if prefix.Addr().Unmap().Is6() {
|
||||||
if err := removeUnreachableRoute(defaultv6, NetbirdVPNTableID); err != nil {
|
addrMask = "/128"
|
||||||
return fmt.Errorf("remove unreachable route: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ip, _, err := net.ParseCIDR(addr + addrMask)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
if err := removeRoute(prefix, netip.Addr{}, intf, NetbirdVPNTableID); err != nil {
|
|
||||||
return fmt.Errorf("remove route: %w", err)
|
route := &netlink.Route{
|
||||||
|
Scope: netlink.SCOPE_UNIVERSE,
|
||||||
|
Dst: ipNet,
|
||||||
|
Gw: ip,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = netlink.RouteDel(route)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||||
v4Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V4)
|
tab, err := syscall.NetlinkRIB(syscall.RTM_GETROUTE, syscall.AF_UNSPEC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get v4 routes: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
v6Routes, err := getRoutes(syscall.RT_TABLE_MAIN, netlink.FAMILY_V6)
|
msgs, err := syscall.ParseNetlinkMessage(tab)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get v6 routes: %w", err)
|
return nil, err
|
||||||
|
|
||||||
}
|
}
|
||||||
return append(v4Routes, v6Routes...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRoutes fetches routes from a specific routing table identified by tableID.
|
|
||||||
func getRoutes(tableID, family int) ([]netip.Prefix, error) {
|
|
||||||
var prefixList []netip.Prefix
|
var prefixList []netip.Prefix
|
||||||
|
loop:
|
||||||
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
|
for _, m := range msgs {
|
||||||
|
switch m.Header.Type {
|
||||||
|
case syscall.NLMSG_DONE:
|
||||||
|
break loop
|
||||||
|
case syscall.RTM_NEWROUTE:
|
||||||
|
rt := (*routeInfoInMemory)(unsafe.Pointer(&m.Data[0]))
|
||||||
|
msg := m
|
||||||
|
attrs, err := syscall.ParseNetlinkRouteAttr(&msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("list routes from table %d: %v", tableID, err)
|
return nil, err
|
||||||
|
}
|
||||||
|
if rt.Family != syscall.AF_INET {
|
||||||
|
continue loop
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, route := range routes {
|
for _, attr := range attrs {
|
||||||
if route.Dst != nil {
|
if attr.Attr.Type == syscall.RTA_DST {
|
||||||
addr, ok := netip.AddrFromSlice(route.Dst.IP)
|
addr, ok := netip.AddrFromSlice(attr.Value)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("parse route destination IP: %v", route.Dst.IP)
|
continue
|
||||||
|
}
|
||||||
|
mask := net.CIDRMask(int(rt.DstLen), len(attr.Value)*8)
|
||||||
|
cidr, _ := mask.Size()
|
||||||
|
routePrefix := netip.PrefixFrom(addr, cidr)
|
||||||
|
if routePrefix.IsValid() && routePrefix.Addr().Is4() {
|
||||||
|
prefixList = append(prefixList, routePrefix)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ones, _ := route.Dst.Mask.Size()
|
|
||||||
|
|
||||||
prefix := netip.PrefixFrom(addr, ones)
|
|
||||||
if prefix.IsValid() {
|
|
||||||
prefixList = append(prefixList, prefix)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return prefixList, nil
|
return prefixList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// addRoute adds a route to a specific routing table identified by tableID.
|
|
||||||
func addRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
|
|
||||||
route := &netlink.Route{
|
|
||||||
Scope: netlink.SCOPE_UNIVERSE,
|
|
||||||
Table: tableID,
|
|
||||||
Family: getAddressFamily(prefix),
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
|
||||||
}
|
|
||||||
route.Dst = ipNet
|
|
||||||
|
|
||||||
if err := addNextHop(addr, intf, route); err != nil {
|
|
||||||
return fmt.Errorf("add gateway and device: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
return fmt.Errorf("netlink add route: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addUnreachableRoute adds an unreachable route for the specified IP family and routing table.
|
|
||||||
// ipFamily should be netlink.FAMILY_V4 for IPv4 or netlink.FAMILY_V6 for IPv6.
|
|
||||||
// tableID specifies the routing table to which the unreachable route will be added.
|
|
||||||
func addUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
route := &netlink.Route{
|
|
||||||
Type: syscall.RTN_UNREACHABLE,
|
|
||||||
Table: tableID,
|
|
||||||
Family: getAddressFamily(prefix),
|
|
||||||
Dst: ipNet,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
return fmt.Errorf("netlink add unreachable route: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeUnreachableRoute(prefix netip.Prefix, tableID int) error {
|
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
route := &netlink.Route{
|
|
||||||
Type: syscall.RTN_UNREACHABLE,
|
|
||||||
Table: tableID,
|
|
||||||
Family: getAddressFamily(prefix),
|
|
||||||
Dst: ipNet,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
return fmt.Errorf("netlink remove unreachable route: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeRoute removes a route from a specific routing table identified by tableID.
|
|
||||||
func removeRoute(prefix netip.Prefix, addr netip.Addr, intf string, tableID int) error {
|
|
||||||
_, ipNet, err := net.ParseCIDR(prefix.String())
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse prefix %s: %w", prefix, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
route := &netlink.Route{
|
|
||||||
Scope: netlink.SCOPE_UNIVERSE,
|
|
||||||
Table: tableID,
|
|
||||||
Family: getAddressFamily(prefix),
|
|
||||||
Dst: ipNet,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := addNextHop(addr, intf, route); err != nil {
|
|
||||||
return fmt.Errorf("add gateway and device: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
return fmt.Errorf("netlink remove route: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func flushRoutes(tableID, family int) error {
|
|
||||||
routes, err := netlink.RouteListFiltered(family, &netlink.Route{Table: tableID}, netlink.RT_FILTER_TABLE)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("list routes from table %d: %w", tableID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var result *multierror.Error
|
|
||||||
for i := range routes {
|
|
||||||
route := routes[i]
|
|
||||||
// unreachable default routes don't come back with Dst set
|
|
||||||
if route.Gw == nil && route.Src == nil && route.Dst == nil {
|
|
||||||
if family == netlink.FAMILY_V4 {
|
|
||||||
routes[i].Dst = &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)}
|
|
||||||
} else {
|
|
||||||
routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.ErrorOrNil()
|
|
||||||
}
|
|
||||||
|
|
||||||
func enableIPForwarding() error {
|
func enableIPForwarding() error {
|
||||||
bytes, err := os.ReadFile(ipv4ForwardingPath)
|
bytes, err := os.ReadFile(ipv4ForwardingPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("read file %s: %w", ipv4ForwardingPath, err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if it is already enabled
|
// check if it is already enabled
|
||||||
@ -347,162 +147,5 @@ func enableIPForwarding() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint:gosec
|
return os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644) //nolint:gosec
|
||||||
if err := os.WriteFile(ipv4ForwardingPath, []byte("1"), 0644); err != nil {
|
|
||||||
return fmt.Errorf("write file %s: %w", ipv4ForwardingPath, err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// entryExists checks if the specified ID or name already exists in the rt_tables file
|
|
||||||
// and verifies if existing names start with "netbird_".
|
|
||||||
func entryExists(file *os.File, id int) (bool, error) {
|
|
||||||
if _, err := file.Seek(0, 0); err != nil {
|
|
||||||
return false, fmt.Errorf("seek rt_tables: %w", err)
|
|
||||||
}
|
|
||||||
scanner := bufio.NewScanner(file)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
var existingID int
|
|
||||||
var existingName string
|
|
||||||
if _, err := fmt.Sscanf(line, "%d %s\n", &existingID, &existingName); err == nil {
|
|
||||||
if existingID == id {
|
|
||||||
if existingName != NetbirdVPNTableName {
|
|
||||||
return true, ErrTableIDExists
|
|
||||||
}
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
return false, fmt.Errorf("scan rt_tables: %w", err)
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addRoutingTableName adds human-readable names for custom routing tables.
|
|
||||||
func addRoutingTableName() error {
|
|
||||||
file, err := os.Open(rtTablesPath)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return fmt.Errorf("open rt_tables: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := file.Close(); err != nil {
|
|
||||||
log.Errorf("Error closing rt_tables: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
exists, err := entryExists(file, NetbirdVPNTableID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("verify entry %d, %s: %w", NetbirdVPNTableID, NetbirdVPNTableName, err)
|
|
||||||
}
|
|
||||||
if exists {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reopen the file in append mode to add new entries
|
|
||||||
if err := file.Close(); err != nil {
|
|
||||||
log.Errorf("Error closing rt_tables before appending: %v", err)
|
|
||||||
}
|
|
||||||
file, err = os.OpenFile(rtTablesPath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0644)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("open rt_tables for appending: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := file.WriteString(fmt.Sprintf("\n%d\t%s\n", NetbirdVPNTableID, NetbirdVPNTableName)); err != nil {
|
|
||||||
return fmt.Errorf("append entry to rt_tables: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addRule adds a routing rule to a specific routing table identified by tableID.
|
|
||||||
func addRule(params ruleParams) error {
|
|
||||||
rule := netlink.NewRule()
|
|
||||||
rule.Table = params.tableID
|
|
||||||
rule.Mark = params.fwmark
|
|
||||||
rule.Family = params.family
|
|
||||||
rule.Priority = params.priority
|
|
||||||
rule.Invert = params.invert
|
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
|
||||||
|
|
||||||
if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
return fmt.Errorf("add routing rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeRule removes a routing rule from a specific routing table identified by tableID.
|
|
||||||
func removeRule(params ruleParams) error {
|
|
||||||
rule := netlink.NewRule()
|
|
||||||
rule.Table = params.tableID
|
|
||||||
rule.Mark = params.fwmark
|
|
||||||
rule.Family = params.family
|
|
||||||
rule.Invert = params.invert
|
|
||||||
rule.Priority = params.priority
|
|
||||||
rule.SuppressPrefixlen = params.suppressPrefix
|
|
||||||
|
|
||||||
if err := netlink.RuleDel(rule); err != nil {
|
|
||||||
return fmt.Errorf("remove routing rule: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeAllRules(params ruleParams) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
done <- ctx.Err()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := removeRule(params); err != nil {
|
|
||||||
if errors.Is(err, syscall.ENOENT) || errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
done <- nil
|
|
||||||
return
|
|
||||||
}
|
|
||||||
done <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case err := <-done:
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// addNextHop adds the gateway and device to the route.
|
|
||||||
func addNextHop(addr netip.Addr, intf string, route *netlink.Route) error {
|
|
||||||
if addr.IsValid() {
|
|
||||||
route.Gw = addr.AsSlice()
|
|
||||||
}
|
|
||||||
|
|
||||||
if intf != "" {
|
|
||||||
link, err := netlink.LinkByName(intf)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("set interface %s: %w", intf, err)
|
|
||||||
}
|
|
||||||
route.LinkIndex = link.Attrs().Index
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getAddressFamily(prefix netip.Prefix) int {
|
|
||||||
if prefix.Addr().Is4() {
|
|
||||||
return netlink.FAMILY_V4
|
|
||||||
}
|
|
||||||
return netlink.FAMILY_V6
|
|
||||||
}
|
}
|
||||||
|
@ -1,207 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/vishvananda/netlink"
|
|
||||||
)
|
|
||||||
|
|
||||||
var expectedVPNint = "wgtest0"
|
|
||||||
var expectedLoopbackInt = "lo"
|
|
||||||
var expectedExternalInt = "dummyext0"
|
|
||||||
var expectedInternalInt = "dummyint0"
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
testCases = append(testCases, []testCase{
|
|
||||||
{
|
|
||||||
name: "To more specific route without custom dialer via physical interface",
|
|
||||||
destination: "10.10.0.2:53",
|
|
||||||
expectedInterface: expectedInternalInt,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.10.0.2", 53),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To more specific route (local) without custom dialer via physical interface",
|
|
||||||
destination: "127.0.10.1:53",
|
|
||||||
expectedInterface: expectedLoopbackInt,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("127.0.0.1", 12345, "127.0.10.1", 53),
|
|
||||||
},
|
|
||||||
}...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEntryExists(t *testing.T) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
tempFilePath := fmt.Sprintf("%s/rt_tables", tempDir)
|
|
||||||
|
|
||||||
content := []string{
|
|
||||||
"1000 reserved",
|
|
||||||
fmt.Sprintf("%d %s", NetbirdVPNTableID, NetbirdVPNTableName),
|
|
||||||
"9999 other_table",
|
|
||||||
}
|
|
||||||
require.NoError(t, os.WriteFile(tempFilePath, []byte(strings.Join(content, "\n")), 0644))
|
|
||||||
|
|
||||||
file, err := os.Open(tempFilePath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() {
|
|
||||||
assert.NoError(t, file.Close())
|
|
||||||
}()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
id int
|
|
||||||
shouldExist bool
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "ExistsWithNetbirdPrefix",
|
|
||||||
id: 7120,
|
|
||||||
shouldExist: true,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ExistsWithDifferentName",
|
|
||||||
id: 1000,
|
|
||||||
shouldExist: true,
|
|
||||||
err: ErrTableIDExists,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "DoesNotExist",
|
|
||||||
id: 1234,
|
|
||||||
shouldExist: false,
|
|
||||||
err: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
exists, err := entryExists(file, tc.id)
|
|
||||||
if tc.err != nil {
|
|
||||||
assert.ErrorIs(t, err, tc.err)
|
|
||||||
} else {
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tc.shouldExist, exists)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
dummy := &netlink.Dummy{LinkAttrs: netlink.LinkAttrs{Name: interfaceName}}
|
|
||||||
err := netlink.LinkDel(dummy)
|
|
||||||
if err != nil && !errors.Is(err, syscall.EINVAL) {
|
|
||||||
t.Logf("Failed to delete dummy interface: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = netlink.LinkAdd(dummy)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = netlink.LinkSetUp(dummy)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
if ipAddressCIDR != "" {
|
|
||||||
addr, err := netlink.ParseAddr(ipAddressCIDR)
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = netlink.AddrAdd(dummy, addr)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := netlink.LinkDel(dummy)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
})
|
|
||||||
|
|
||||||
return dummy.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
func addDummyRoute(t *testing.T, dstCIDR string, gw net.IP, intf string) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
_, dstIPNet, err := net.ParseCIDR(dstCIDR)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Handle existing routes with metric 0
|
|
||||||
var originalNexthop net.IP
|
|
||||||
var originalLinkIndex int
|
|
||||||
if dstIPNet.String() == "0.0.0.0/0" {
|
|
||||||
var err error
|
|
||||||
originalNexthop, originalLinkIndex, err = fetchOriginalGateway(netlink.FAMILY_V4)
|
|
||||||
if err != nil && !errors.Is(err, ErrRouteNotFound) {
|
|
||||||
t.Logf("Failed to fetch original gateway: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if originalNexthop != nil {
|
|
||||||
err = netlink.RouteDel(&netlink.Route{Dst: dstIPNet, Priority: 0})
|
|
||||||
switch {
|
|
||||||
case err != nil && !errors.Is(err, syscall.ESRCH):
|
|
||||||
t.Logf("Failed to delete route: %v", err)
|
|
||||||
case err == nil:
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err := netlink.RouteAdd(&netlink.Route{Dst: dstIPNet, Gw: originalNexthop, LinkIndex: originalLinkIndex, Priority: 0})
|
|
||||||
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
|
||||||
t.Fatalf("Failed to add route: %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
t.Logf("Failed to delete route: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
link, err := netlink.LinkByName(intf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
linkIndex := link.Attrs().Index
|
|
||||||
|
|
||||||
route := &netlink.Route{
|
|
||||||
Dst: dstIPNet,
|
|
||||||
Gw: gw,
|
|
||||||
LinkIndex: linkIndex,
|
|
||||||
}
|
|
||||||
err = netlink.RouteDel(route)
|
|
||||||
if err != nil && !errors.Is(err, syscall.ESRCH) {
|
|
||||||
t.Logf("Failed to delete route: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = netlink.RouteAdd(route)
|
|
||||||
if err != nil && !errors.Is(err, syscall.EEXIST) {
|
|
||||||
t.Fatalf("Failed to add route: %v", err)
|
|
||||||
}
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchOriginalGateway(family int) (net.IP, int, error) {
|
|
||||||
routes, err := netlink.RouteList(nil, family)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, route := range routes {
|
|
||||||
if route.Dst == nil && route.Priority == 0 {
|
|
||||||
return route.Gw, route.LinkIndex, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, 0, ErrRouteNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
defaultDummy := createAndSetupDummyInterface(t, "dummyext0", "192.168.0.1/24")
|
|
||||||
addDummyRoute(t, "0.0.0.0/0", net.IPv4(192, 168, 0, 1), defaultDummy)
|
|
||||||
|
|
||||||
otherDummy := createAndSetupDummyInterface(t, "dummyint0", "192.168.1.1/24")
|
|
||||||
addDummyRoute(t, "10.0.0.0/8", net.IPv4(192, 168, 1, 1), otherDummy)
|
|
||||||
}
|
|
120
client/internal/routemanager/systemops_nonandroid.go
Normal file
120
client/internal/routemanager/systemops_nonandroid.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
//go:build !android && !ios
|
||||||
|
|
||||||
|
package routemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/libp2p/go-netroute"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errRouteNotFound = fmt.Errorf("route not found")
|
||||||
|
|
||||||
|
func addToRouteTableIfNoExists(prefix netip.Prefix, addr string) error {
|
||||||
|
ok, err := existsInRouteTable(prefix)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ok {
|
||||||
|
log.Warnf("skipping adding a new route for network %s because it already exists", prefix)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ok, err = isSubRange(prefix)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
err := addRouteForCurrentDefaultGateway(prefix)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("unable to add route for current default gateway route. Will proceed without it. error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return addToRouteTable(prefix, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func addRouteForCurrentDefaultGateway(prefix netip.Prefix) error {
|
||||||
|
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
||||||
|
if err != nil && err != errRouteNotFound {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := netip.MustParseAddr(defaultGateway.String())
|
||||||
|
|
||||||
|
if !prefix.Contains(addr) {
|
||||||
|
log.Debugf("skipping adding a new route for gateway %s because it is not in the network %s", addr, prefix)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
gatewayPrefix := netip.PrefixFrom(addr, 32)
|
||||||
|
|
||||||
|
ok, err := existsInRouteTable(gatewayPrefix)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to check if there is an existing route for gateway %s. error: %s", gatewayPrefix, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok {
|
||||||
|
log.Debugf("skipping adding a new route for gateway %s because it already exists", gatewayPrefix)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
gatewayHop, err := getExistingRIBRouteGateway(gatewayPrefix)
|
||||||
|
if err != nil && err != errRouteNotFound {
|
||||||
|
return fmt.Errorf("unable to get the next hop for the default gateway address. error: %s", err)
|
||||||
|
}
|
||||||
|
log.Debugf("adding a new route for gateway %s with next hop %s", gatewayPrefix, gatewayHop)
|
||||||
|
return addToRouteTable(gatewayPrefix, gatewayHop.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func existsInRouteTable(prefix netip.Prefix) (bool, error) {
|
||||||
|
routes, err := getRoutesFromTable()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
for _, tableRoute := range routes {
|
||||||
|
if tableRoute == prefix {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isSubRange(prefix netip.Prefix) (bool, error) {
|
||||||
|
routes, err := getRoutesFromTable()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
for _, tableRoute := range routes {
|
||||||
|
if tableRoute.Bits() > minRangeBits && tableRoute.Contains(prefix.Addr()) && tableRoute.Bits() < prefix.Bits() {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeFromRouteTableIfNonSystem(prefix netip.Prefix, addr string) error {
|
||||||
|
return removeFromRouteTable(prefix, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getExistingRIBRouteGateway(prefix netip.Prefix) (net.IP, error) {
|
||||||
|
r, err := netroute.New()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, gateway, preferredSrc, err := r.Route(prefix.Addr().AsSlice())
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("getting routes returned an error: %v", err)
|
||||||
|
return nil, errRouteNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
if gateway == nil {
|
||||||
|
return preferredSrc, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return gateway, nil
|
||||||
|
}
|
@ -1,32 +1,24 @@
|
|||||||
//go:build !android && !ios
|
//go:build !android
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
"github.com/netbirdio/netbird/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
type dialer interface {
|
|
||||||
Dial(network, address string) (net.Conn, error)
|
|
||||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddRemoveRoutes(t *testing.T) {
|
func TestAddRemoveRoutes(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@ -61,30 +53,27 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
|
|
||||||
err = wgInterface.Create()
|
err = wgInterface.Create()
|
||||||
require.NoError(t, err, "should create testing wireguard interface")
|
require.NoError(t, err, "should create testing wireguard interface")
|
||||||
_, _, err = setupRouting(nil, nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, cleanupRouting())
|
|
||||||
})
|
|
||||||
|
|
||||||
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
|
err = addToRouteTableIfNoExists(testCase.prefix, wgInterface.Address().IP.String())
|
||||||
require.NoError(t, err, "genericAddVPNRoute should not return err")
|
require.NoError(t, err, "addToRouteTableIfNoExists should not return err")
|
||||||
|
|
||||||
|
prefixGateway, err := getExistingRIBRouteGateway(testCase.prefix)
|
||||||
|
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
|
||||||
if testCase.shouldRouteToWireguard {
|
if testCase.shouldRouteToWireguard {
|
||||||
assertWGOutInterface(t, testCase.prefix, wgInterface, false)
|
require.Equal(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
||||||
} else {
|
} else {
|
||||||
assertWGOutInterface(t, testCase.prefix, wgInterface, true)
|
require.NotEqual(t, wgInterface.Address().IP.String(), prefixGateway.String(), "route should point to a different interface")
|
||||||
}
|
}
|
||||||
exists, err := existsInRouteTable(testCase.prefix)
|
exists, err := existsInRouteTable(testCase.prefix)
|
||||||
require.NoError(t, err, "existsInRouteTable should not return err")
|
require.NoError(t, err, "existsInRouteTable should not return err")
|
||||||
if exists && testCase.shouldRouteToWireguard {
|
if exists && testCase.shouldRouteToWireguard {
|
||||||
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
|
err = removeFromRouteTableIfNonSystem(testCase.prefix, wgInterface.Address().IP.String())
|
||||||
require.NoError(t, err, "genericRemoveVPNRoute should not return err")
|
require.NoError(t, err, "removeFromRouteTableIfNonSystem should not return err")
|
||||||
|
|
||||||
prefixGateway, _, err := getNextHop(testCase.prefix.Addr())
|
prefixGateway, err = getExistingRIBRouteGateway(testCase.prefix)
|
||||||
require.NoError(t, err, "getNextHop should not return err")
|
require.NoError(t, err, "getExistingRIBRouteGateway should not return err")
|
||||||
|
|
||||||
internetGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
internetGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if testCase.shouldBeRemoved {
|
if testCase.shouldBeRemoved {
|
||||||
@ -97,12 +86,12 @@ func TestAddRemoveRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetNextHop(t *testing.T) {
|
func TestGetExistingRIBRouteGateway(t *testing.T) {
|
||||||
gateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
gateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||||
}
|
}
|
||||||
if !gateway.IsValid() {
|
if gateway == nil {
|
||||||
t.Fatal("should return a gateway")
|
t.Fatal("should return a gateway")
|
||||||
}
|
}
|
||||||
addresses, err := net.InterfaceAddrs()
|
addresses, err := net.InterfaceAddrs()
|
||||||
@ -124,11 +113,11 @@ func TestGetNextHop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
localIP, _, err := getNextHop(testingPrefix.Addr())
|
localIP, err := getExistingRIBRouteGateway(testingPrefix)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error: ", err)
|
t.Fatal("shouldn't return error: ", err)
|
||||||
}
|
}
|
||||||
if !localIP.IsValid() {
|
if localIP == nil {
|
||||||
t.Fatal("should return a gateway for local network")
|
t.Fatal("should return a gateway for local network")
|
||||||
}
|
}
|
||||||
if localIP.String() == gateway.String() {
|
if localIP.String() == gateway.String() {
|
||||||
@ -139,8 +128,8 @@ func TestGetNextHop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAddExistAndRemoveRoute(t *testing.T) {
|
func TestAddExistAndRemoveRouteNonAndroid(t *testing.T) {
|
||||||
defaultGateway, _, err := getNextHop(netip.MustParseAddr("0.0.0.0"))
|
defaultGateway, err := getExistingRIBRouteGateway(netip.MustParsePrefix("0.0.0.0/0"))
|
||||||
t.Log("defaultGateway: ", defaultGateway)
|
t.Log("defaultGateway: ", defaultGateway)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
t.Fatal("shouldn't return error when fetching the gateway: ", err)
|
||||||
@ -200,14 +189,16 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
|||||||
err = wgInterface.Create()
|
err = wgInterface.Create()
|
||||||
require.NoError(t, err, "should create testing wireguard interface")
|
require.NoError(t, err, "should create testing wireguard interface")
|
||||||
|
|
||||||
|
MockAddr := wgInterface.Address().IP.String()
|
||||||
|
|
||||||
// Prepare the environment
|
// Prepare the environment
|
||||||
if testCase.preExistingPrefix.IsValid() {
|
if testCase.preExistingPrefix.IsValid() {
|
||||||
err := genericAddVPNRoute(testCase.preExistingPrefix, wgInterface.Name())
|
err := addToRouteTableIfNoExists(testCase.preExistingPrefix, MockAddr)
|
||||||
require.NoError(t, err, "should not return err when adding pre-existing route")
|
require.NoError(t, err, "should not return err when adding pre-existing route")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the route
|
// Add the route
|
||||||
err = genericAddVPNRoute(testCase.prefix, wgInterface.Name())
|
err = addToRouteTableIfNoExists(testCase.prefix, MockAddr)
|
||||||
require.NoError(t, err, "should not return err when adding route")
|
require.NoError(t, err, "should not return err when adding route")
|
||||||
|
|
||||||
if testCase.shouldAddRoute {
|
if testCase.shouldAddRoute {
|
||||||
@ -217,7 +208,7 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
|||||||
require.True(t, ok, "route should exist")
|
require.True(t, ok, "route should exist")
|
||||||
|
|
||||||
// remove route again if added
|
// remove route again if added
|
||||||
err = genericRemoveVPNRoute(testCase.prefix, wgInterface.Name())
|
err = removeFromRouteTableIfNonSystem(testCase.prefix, MockAddr)
|
||||||
require.NoError(t, err, "should not return err")
|
require.NoError(t, err, "should not return err")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -226,7 +217,6 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
|||||||
ok, err := existsInRouteTable(testCase.prefix)
|
ok, err := existsInRouteTable(testCase.prefix)
|
||||||
t.Log("Buffer string: ", buf.String())
|
t.Log("Buffer string: ", buf.String())
|
||||||
require.NoError(t, err, "should not return err")
|
require.NoError(t, err, "should not return err")
|
||||||
|
|
||||||
if !strings.Contains(buf.String(), "because it already exists") {
|
if !strings.Contains(buf.String(), "because it already exists") {
|
||||||
require.False(t, ok, "route should not exist")
|
require.False(t, ok, "route should not exist")
|
||||||
}
|
}
|
||||||
@ -234,6 +224,31 @@ func TestAddExistAndRemoveRoute(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExistsInRouteTable(t *testing.T) {
|
||||||
|
addresses, err := net.InterfaceAddrs()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var addressPrefixes []netip.Prefix
|
||||||
|
for _, address := range addresses {
|
||||||
|
p := netip.MustParsePrefix(address.String())
|
||||||
|
if p.Addr().Is4() {
|
||||||
|
addressPrefixes = append(addressPrefixes, p.Masked())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, prefix := range addressPrefixes {
|
||||||
|
exists, err := existsInRouteTable(prefix)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
t.Fatalf("address %s should exist in route table", prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestIsSubRange(t *testing.T) {
|
func TestIsSubRange(t *testing.T) {
|
||||||
addresses, err := net.InterfaceAddrs()
|
addresses, err := net.InterfaceAddrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -271,132 +286,3 @@ func TestIsSubRange(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExistsInRouteTable(t *testing.T) {
|
|
||||||
addresses, err := net.InterfaceAddrs()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when fetching interface addresses: ", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var addressPrefixes []netip.Prefix
|
|
||||||
for _, address := range addresses {
|
|
||||||
p := netip.MustParsePrefix(address.String())
|
|
||||||
if p.Addr().Is6() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Windows sometimes has hidden interface link local addrs that don't turn up on any interface
|
|
||||||
if runtime.GOOS == "windows" && p.Addr().IsLinkLocalUnicast() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Linux loopback 127/8 is in the local table, not in the main table and always takes precedence
|
|
||||||
if runtime.GOOS == "linux" && p.Addr().IsLoopback() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
addressPrefixes = append(addressPrefixes, p.Masked())
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, prefix := range addressPrefixes {
|
|
||||||
exists, err := existsInRouteTable(prefix)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("shouldn't return error when checking if address exists in route table: ", err)
|
|
||||||
}
|
|
||||||
if !exists {
|
|
||||||
t.Fatalf("address %s should exist in route table", prefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createWGInterface(t *testing.T, interfaceName, ipAddressCIDR string, listenPort int) *iface.WGIface {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
peerPrivateKey, err := wgtypes.GeneratePrivateKey()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
wgInterface, err := iface.NewWGIFace(interfaceName, ipAddressCIDR, listenPort, peerPrivateKey.String(), iface.DefaultMTU, newNet, nil)
|
|
||||||
require.NoError(t, err, "should create testing WireGuard interface")
|
|
||||||
|
|
||||||
err = wgInterface.Create()
|
|
||||||
require.NoError(t, err, "should create testing WireGuard interface")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
wgInterface.Close()
|
|
||||||
})
|
|
||||||
|
|
||||||
return wgInterface
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupTestEnv(t *testing.T) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
setupDummyInterfacesAndRoutes(t)
|
|
||||||
|
|
||||||
wgIface := createWGInterface(t, expectedVPNint, "100.64.0.1/24", 51820)
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, wgIface.Close())
|
|
||||||
})
|
|
||||||
|
|
||||||
_, _, err := setupRouting(nil, wgIface)
|
|
||||||
require.NoError(t, err, "setupRouting should not return err")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
assert.NoError(t, cleanupRouting())
|
|
||||||
})
|
|
||||||
|
|
||||||
// default route exists in main table and vpn table
|
|
||||||
err = addVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addVPNRoute should not return err")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err = removeVPNRoute(netip.MustParsePrefix("0.0.0.0/0"), wgIface.Name())
|
|
||||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
|
||||||
})
|
|
||||||
|
|
||||||
// 10.0.0.0/8 route exists in main table and vpn table
|
|
||||||
err = addVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addVPNRoute should not return err")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err = removeVPNRoute(netip.MustParsePrefix("10.0.0.0/8"), wgIface.Name())
|
|
||||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
|
||||||
})
|
|
||||||
|
|
||||||
// 10.10.0.0/24 more specific route exists in vpn table
|
|
||||||
err = addVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addVPNRoute should not return err")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err = removeVPNRoute(netip.MustParsePrefix("10.10.0.0/24"), wgIface.Name())
|
|
||||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
|
||||||
})
|
|
||||||
|
|
||||||
// 127.0.10.0/24 more specific route exists in vpn table
|
|
||||||
err = addVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addVPNRoute should not return err")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err = removeVPNRoute(netip.MustParsePrefix("127.0.10.0/24"), wgIface.Name())
|
|
||||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
|
||||||
})
|
|
||||||
|
|
||||||
// unique route in vpn table
|
|
||||||
err = addVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
|
|
||||||
require.NoError(t, err, "addVPNRoute should not return err")
|
|
||||||
t.Cleanup(func() {
|
|
||||||
err = removeVPNRoute(netip.MustParsePrefix("172.16.0.0/12"), wgIface.Name())
|
|
||||||
assert.NoError(t, err, "removeVPNRoute should not return err")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertWGOutInterface(t *testing.T, prefix netip.Prefix, wgIface *iface.WGIface, invert bool) {
|
|
||||||
t.Helper()
|
|
||||||
if runtime.GOOS == "linux" && prefix.Addr().IsLoopback() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
prefixGateway, _, err := getNextHop(prefix.Addr())
|
|
||||||
require.NoError(t, err, "getNextHop should not return err")
|
|
||||||
if invert {
|
|
||||||
assert.NotEqual(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should not point to wireguard interface IP")
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, wgIface.Address().IP.String(), prefixGateway.String(), "route should point to wireguard interface IP")
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,23 +1,41 @@
|
|||||||
//go:build !linux && !ios
|
//go:build !linux
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
func enableIPForwarding() error {
|
func addToRouteTable(prefix netip.Prefix, addr string) error {
|
||||||
log.Infof("Enable IP forwarding is not implemented on %s", runtime.GOOS)
|
cmd := exec.Command("route", "add", prefix.String(), addr)
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debugf(string(out))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func addVPNRoute(prefix netip.Prefix, intf string) error {
|
func removeFromRouteTable(prefix netip.Prefix, addr string) error {
|
||||||
return genericAddVPNRoute(prefix, intf)
|
args := []string{"delete", prefix.String()}
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
args = append(args, addr)
|
||||||
|
}
|
||||||
|
cmd := exec.Command("route", args...)
|
||||||
|
out, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debugf(string(out))
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeVPNRoute(prefix netip.Prefix, intf string) error {
|
func enableIPForwarding() error {
|
||||||
return genericRemoveVPNRoute(prefix, intf)
|
log.Infof("enable IP forwarding is not implemented on %s", runtime.GOOS)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1,234 +0,0 @@
|
|||||||
//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || netbsd || dragonfly
|
|
||||||
|
|
||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gopacket/gopacket"
|
|
||||||
"github.com/gopacket/gopacket/layers"
|
|
||||||
"github.com/gopacket/gopacket/pcap"
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PacketExpectation struct {
|
|
||||||
SrcIP net.IP
|
|
||||||
DstIP net.IP
|
|
||||||
SrcPort int
|
|
||||||
DstPort int
|
|
||||||
UDP bool
|
|
||||||
TCP bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type testCase struct {
|
|
||||||
name string
|
|
||||||
destination string
|
|
||||||
expectedInterface string
|
|
||||||
dialer dialer
|
|
||||||
expectedPacket PacketExpectation
|
|
||||||
}
|
|
||||||
|
|
||||||
var testCases = []testCase{
|
|
||||||
{
|
|
||||||
name: "To external host without custom dialer via vpn",
|
|
||||||
destination: "192.0.2.1:53",
|
|
||||||
expectedInterface: expectedVPNint,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "192.0.2.1", 53),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To external host with custom dialer via physical interface",
|
|
||||||
destination: "192.0.2.1:53",
|
|
||||||
expectedInterface: expectedExternalInt,
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "192.0.2.1", 53),
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To duplicate internal route with custom dialer via physical interface",
|
|
||||||
destination: "10.0.0.2:53",
|
|
||||||
expectedInterface: expectedInternalInt,
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
|
||||||
destination: "10.0.0.2:53",
|
|
||||||
expectedInterface: expectedInternalInt,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("192.168.1.1", 12345, "10.0.0.2", 53),
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To unique vpn route with custom dialer via physical interface",
|
|
||||||
destination: "172.16.0.2:53",
|
|
||||||
expectedInterface: expectedExternalInt,
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
expectedPacket: createPacketExpectation("192.168.0.1", 12345, "172.16.0.2", 53),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To unique vpn route without custom dialer via vpn",
|
|
||||||
destination: "172.16.0.2:53",
|
|
||||||
expectedInterface: expectedVPNint,
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
expectedPacket: createPacketExpectation("100.64.0.1", 12345, "172.16.0.2", 53),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouting(t *testing.T) {
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
setupTestEnv(t)
|
|
||||||
|
|
||||||
filter := createBPFFilter(tc.destination)
|
|
||||||
handle := startPacketCapture(t, tc.expectedInterface, filter)
|
|
||||||
|
|
||||||
sendTestPacket(t, tc.destination, tc.expectedPacket.SrcPort, tc.dialer)
|
|
||||||
|
|
||||||
packetSource := gopacket.NewPacketSource(handle, handle.LinkType())
|
|
||||||
packet, err := packetSource.NextPacket()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
verifyPacket(t, packet, tc.expectedPacket)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createPacketExpectation(srcIP string, srcPort int, dstIP string, dstPort int) PacketExpectation {
|
|
||||||
return PacketExpectation{
|
|
||||||
SrcIP: net.ParseIP(srcIP),
|
|
||||||
DstIP: net.ParseIP(dstIP),
|
|
||||||
SrcPort: srcPort,
|
|
||||||
DstPort: dstPort,
|
|
||||||
UDP: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func startPacketCapture(t *testing.T, intf, filter string) *pcap.Handle {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
inactive, err := pcap.NewInactiveHandle(intf)
|
|
||||||
require.NoError(t, err, "Failed to create inactive pcap handle")
|
|
||||||
defer inactive.CleanUp()
|
|
||||||
|
|
||||||
err = inactive.SetSnapLen(1600)
|
|
||||||
require.NoError(t, err, "Failed to set snap length on inactive handle")
|
|
||||||
|
|
||||||
err = inactive.SetTimeout(time.Second * 10)
|
|
||||||
require.NoError(t, err, "Failed to set timeout on inactive handle")
|
|
||||||
|
|
||||||
err = inactive.SetImmediateMode(true)
|
|
||||||
require.NoError(t, err, "Failed to set immediate mode on inactive handle")
|
|
||||||
|
|
||||||
handle, err := inactive.Activate()
|
|
||||||
require.NoError(t, err, "Failed to activate pcap handle")
|
|
||||||
t.Cleanup(handle.Close)
|
|
||||||
|
|
||||||
err = handle.SetBPFFilter(filter)
|
|
||||||
require.NoError(t, err, "Failed to set BPF filter")
|
|
||||||
|
|
||||||
return handle
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendTestPacket(t *testing.T, destination string, sourcePort int, dialer dialer) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
if dialer == nil {
|
|
||||||
dialer = &net.Dialer{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if sourcePort != 0 {
|
|
||||||
localUDPAddr := &net.UDPAddr{
|
|
||||||
IP: net.IPv4zero,
|
|
||||||
Port: sourcePort,
|
|
||||||
}
|
|
||||||
switch dialer := dialer.(type) {
|
|
||||||
case *nbnet.Dialer:
|
|
||||||
dialer.LocalAddr = localUDPAddr
|
|
||||||
case *net.Dialer:
|
|
||||||
dialer.LocalAddr = localUDPAddr
|
|
||||||
default:
|
|
||||||
t.Fatal("Unsupported dialer type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
msg.Id = dns.Id()
|
|
||||||
msg.RecursionDesired = true
|
|
||||||
msg.Question = []dns.Question{
|
|
||||||
{Name: "example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET},
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := dialer.Dial("udp", destination)
|
|
||||||
require.NoError(t, err, "Failed to dial UDP")
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
data, err := msg.Pack()
|
|
||||||
require.NoError(t, err, "Failed to pack DNS message")
|
|
||||||
|
|
||||||
_, err = conn.Write(data)
|
|
||||||
if err != nil {
|
|
||||||
if strings.Contains(err.Error(), "required key not available") {
|
|
||||||
t.Logf("Ignoring WireGuard key error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Fatalf("Failed to send DNS query: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createBPFFilter(destination string) string {
|
|
||||||
host, port, err := net.SplitHostPort(destination)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Sprintf("udp and dst host %s and dst port %s", host, port)
|
|
||||||
}
|
|
||||||
return "udp"
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyPacket(t *testing.T, packet gopacket.Packet, exp PacketExpectation) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
ipLayer := packet.Layer(layers.LayerTypeIPv4)
|
|
||||||
require.NotNil(t, ipLayer, "Expected IPv4 layer not found in packet")
|
|
||||||
|
|
||||||
ip, ok := ipLayer.(*layers.IPv4)
|
|
||||||
require.True(t, ok, "Failed to cast to IPv4 layer")
|
|
||||||
|
|
||||||
// Convert both source and destination IP addresses to 16-byte representation
|
|
||||||
expectedSrcIP := exp.SrcIP.To16()
|
|
||||||
actualSrcIP := ip.SrcIP.To16()
|
|
||||||
assert.Equal(t, expectedSrcIP, actualSrcIP, "Source IP mismatch")
|
|
||||||
|
|
||||||
expectedDstIP := exp.DstIP.To16()
|
|
||||||
actualDstIP := ip.DstIP.To16()
|
|
||||||
assert.Equal(t, expectedDstIP, actualDstIP, "Destination IP mismatch")
|
|
||||||
|
|
||||||
if exp.UDP {
|
|
||||||
udpLayer := packet.Layer(layers.LayerTypeUDP)
|
|
||||||
require.NotNil(t, udpLayer, "Expected UDP layer not found in packet")
|
|
||||||
|
|
||||||
udp, ok := udpLayer.(*layers.UDP)
|
|
||||||
require.True(t, ok, "Failed to cast to UDP layer")
|
|
||||||
|
|
||||||
assert.Equal(t, layers.UDPPort(exp.SrcPort), udp.SrcPort, "UDP source port mismatch")
|
|
||||||
assert.Equal(t, layers.UDPPort(exp.DstPort), udp.DstPort, "UDP destination port mismatch")
|
|
||||||
}
|
|
||||||
|
|
||||||
if exp.TCP {
|
|
||||||
tcpLayer := packet.Layer(layers.LayerTypeTCP)
|
|
||||||
require.NotNil(t, tcpLayer, "Expected TCP layer not found in packet")
|
|
||||||
|
|
||||||
tcp, ok := tcpLayer.(*layers.TCP)
|
|
||||||
require.True(t, ok, "Failed to cast to TCP layer")
|
|
||||||
|
|
||||||
assert.Equal(t, layers.TCPPort(exp.SrcPort), tcp.SrcPort, "TCP source port mismatch")
|
|
||||||
assert.Equal(t, layers.TCPPort(exp.DstPort), tcp.DstPort, "TCP destination port mismatch")
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,19 +1,13 @@
|
|||||||
//go:build windows
|
//go:build windows
|
||||||
|
// +build windows
|
||||||
|
|
||||||
package routemanager
|
package routemanager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/yusufpapurcu/wmi"
|
"github.com/yusufpapurcu/wmi"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/iface"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Win32_IP4RouteTable struct {
|
type Win32_IP4RouteTable struct {
|
||||||
@ -21,35 +15,23 @@ type Win32_IP4RouteTable struct {
|
|||||||
Mask string
|
Mask string
|
||||||
}
|
}
|
||||||
|
|
||||||
var routeManager *RouteManager
|
|
||||||
|
|
||||||
func setupRouting(initAddresses []net.IP, wgIface *iface.WGIface) (peer.BeforeAddPeerHookFunc, peer.AfterRemovePeerHookFunc, error) {
|
|
||||||
return setupRoutingWithRouteManager(&routeManager, initAddresses, wgIface)
|
|
||||||
}
|
|
||||||
|
|
||||||
func cleanupRouting() error {
|
|
||||||
return cleanupRoutingWithRouteManager(routeManager)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRoutesFromTable() ([]netip.Prefix, error) {
|
func getRoutesFromTable() ([]netip.Prefix, error) {
|
||||||
var routes []Win32_IP4RouteTable
|
var routes []Win32_IP4RouteTable
|
||||||
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
|
query := "SELECT Destination, Mask FROM Win32_IP4RouteTable"
|
||||||
|
|
||||||
err := wmi.Query(query, &routes)
|
err := wmi.Query(query, &routes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get routes: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var prefixList []netip.Prefix
|
var prefixList []netip.Prefix
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
addr, err := netip.ParseAddr(route.Destination)
|
addr, err := netip.ParseAddr(route.Destination)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Unable to parse route destination %s: %v", route.Destination, err)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
maskSlice := net.ParseIP(route.Mask).To4()
|
maskSlice := net.ParseIP(route.Mask).To4()
|
||||||
if maskSlice == nil {
|
if maskSlice == nil {
|
||||||
log.Warnf("Unable to parse route mask %s", route.Mask)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3])
|
mask := net.IPv4Mask(maskSlice[0], maskSlice[1], maskSlice[2], maskSlice[3])
|
||||||
@ -62,69 +44,3 @@ func getRoutesFromTable() ([]netip.Prefix, error) {
|
|||||||
}
|
}
|
||||||
return prefixList, nil
|
return prefixList, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func addRoutePowershell(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
|
||||||
destinationPrefix := prefix.String()
|
|
||||||
psCmd := "New-NetRoute"
|
|
||||||
|
|
||||||
addressFamily := "IPv4"
|
|
||||||
if prefix.Addr().Is6() {
|
|
||||||
addressFamily = "IPv6"
|
|
||||||
}
|
|
||||||
|
|
||||||
script := fmt.Sprintf(
|
|
||||||
`%s -AddressFamily "%s" -DestinationPrefix "%s" -InterfaceAlias "%s" -Confirm:$False -ErrorAction Stop`,
|
|
||||||
psCmd, addressFamily, destinationPrefix, intf,
|
|
||||||
)
|
|
||||||
|
|
||||||
if nexthop.IsValid() {
|
|
||||||
script = fmt.Sprintf(
|
|
||||||
`%s -NextHop "%s"`, script, nexthop,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err := exec.Command("powershell", "-Command", script).CombinedOutput()
|
|
||||||
log.Tracef("PowerShell add route: %s", string(out))
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("PowerShell add route: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func addRouteCmd(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
|
|
||||||
args := []string{"add", prefix.String(), nexthop.Unmap().String()}
|
|
||||||
|
|
||||||
out, err := exec.Command("route", args...).CombinedOutput()
|
|
||||||
|
|
||||||
log.Tracef("route %s output: %s", strings.Join(args, " "), out)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("route add: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func addToRouteTable(prefix netip.Prefix, nexthop netip.Addr, intf string) error {
|
|
||||||
// Powershell doesn't support adding routes without an interface but allows to add interface by name
|
|
||||||
if intf != "" {
|
|
||||||
return addRoutePowershell(prefix, nexthop, intf)
|
|
||||||
}
|
|
||||||
return addRouteCmd(prefix, nexthop, intf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeFromRouteTable(prefix netip.Prefix, nexthop netip.Addr, _ string) error {
|
|
||||||
args := []string{"delete", prefix.String()}
|
|
||||||
if nexthop.IsValid() {
|
|
||||||
args = append(args, nexthop.Unmap().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err := exec.Command("route", args...).CombinedOutput()
|
|
||||||
log.Tracef("route %s output: %s", strings.Join(args, " "), out)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("remove route: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
@ -1,289 +0,0 @@
|
|||||||
package routemanager
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"os/exec"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
var expectedExtInt = "Ethernet1"
|
|
||||||
|
|
||||||
type RouteInfo struct {
|
|
||||||
NextHop string `json:"nexthop"`
|
|
||||||
InterfaceAlias string `json:"interfacealias"`
|
|
||||||
RouteMetric int `json:"routemetric"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type FindNetRouteOutput struct {
|
|
||||||
IPAddress string `json:"IPAddress"`
|
|
||||||
InterfaceIndex int `json:"InterfaceIndex"`
|
|
||||||
InterfaceAlias string `json:"InterfaceAlias"`
|
|
||||||
AddressFamily int `json:"AddressFamily"`
|
|
||||||
NextHop string `json:"NextHop"`
|
|
||||||
DestinationPrefix string `json:"DestinationPrefix"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type testCase struct {
|
|
||||||
name string
|
|
||||||
destination string
|
|
||||||
expectedSourceIP string
|
|
||||||
expectedDestPrefix string
|
|
||||||
expectedNextHop string
|
|
||||||
expectedInterface string
|
|
||||||
dialer dialer
|
|
||||||
}
|
|
||||||
|
|
||||||
var expectedVPNint = "wgtest0"
|
|
||||||
|
|
||||||
var testCases = []testCase{
|
|
||||||
{
|
|
||||||
name: "To external host without custom dialer via vpn",
|
|
||||||
destination: "192.0.2.1:53",
|
|
||||||
expectedSourceIP: "100.64.0.1",
|
|
||||||
expectedDestPrefix: "128.0.0.0/1",
|
|
||||||
expectedNextHop: "0.0.0.0",
|
|
||||||
expectedInterface: "wgtest0",
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To external host with custom dialer via physical interface",
|
|
||||||
destination: "192.0.2.1:53",
|
|
||||||
expectedDestPrefix: "192.0.2.1/32",
|
|
||||||
expectedInterface: expectedExtInt,
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To duplicate internal route with custom dialer via physical interface",
|
|
||||||
destination: "10.0.0.2:53",
|
|
||||||
expectedDestPrefix: "10.0.0.2/32",
|
|
||||||
expectedInterface: expectedExtInt,
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To duplicate internal route without custom dialer via physical interface", // local route takes precedence
|
|
||||||
destination: "10.0.0.2:53",
|
|
||||||
expectedSourceIP: "10.0.0.1",
|
|
||||||
expectedDestPrefix: "10.0.0.0/8",
|
|
||||||
expectedNextHop: "0.0.0.0",
|
|
||||||
expectedInterface: "Loopback Pseudo-Interface 1",
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To unique vpn route with custom dialer via physical interface",
|
|
||||||
destination: "172.16.0.2:53",
|
|
||||||
expectedDestPrefix: "172.16.0.2/32",
|
|
||||||
expectedInterface: expectedExtInt,
|
|
||||||
dialer: nbnet.NewDialer(),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "To unique vpn route without custom dialer via vpn",
|
|
||||||
destination: "172.16.0.2:53",
|
|
||||||
expectedSourceIP: "100.64.0.1",
|
|
||||||
expectedDestPrefix: "172.16.0.0/12",
|
|
||||||
expectedNextHop: "0.0.0.0",
|
|
||||||
expectedInterface: "wgtest0",
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To more specific route without custom dialer via vpn interface",
|
|
||||||
destination: "10.10.0.2:53",
|
|
||||||
expectedSourceIP: "100.64.0.1",
|
|
||||||
expectedDestPrefix: "10.10.0.0/24",
|
|
||||||
expectedNextHop: "0.0.0.0",
|
|
||||||
expectedInterface: "wgtest0",
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
|
||||||
name: "To more specific route (local) without custom dialer via physical interface",
|
|
||||||
destination: "127.0.10.2:53",
|
|
||||||
expectedSourceIP: "10.0.0.1",
|
|
||||||
expectedDestPrefix: "127.0.0.0/8",
|
|
||||||
expectedNextHop: "0.0.0.0",
|
|
||||||
expectedInterface: "Loopback Pseudo-Interface 1",
|
|
||||||
dialer: &net.Dialer{},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRouting(t *testing.T) {
|
|
||||||
for _, tc := range testCases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
setupTestEnv(t)
|
|
||||||
|
|
||||||
route, err := fetchOriginalGateway()
|
|
||||||
require.NoError(t, err, "Failed to fetch original gateway")
|
|
||||||
ip, err := fetchInterfaceIP(route.InterfaceAlias)
|
|
||||||
require.NoError(t, err, "Failed to fetch interface IP")
|
|
||||||
|
|
||||||
output := testRoute(t, tc.destination, tc.dialer)
|
|
||||||
if tc.expectedInterface == expectedExtInt {
|
|
||||||
verifyOutput(t, output, ip, tc.expectedDestPrefix, route.NextHop, route.InterfaceAlias)
|
|
||||||
} else {
|
|
||||||
verifyOutput(t, output, tc.expectedSourceIP, tc.expectedDestPrefix, tc.expectedNextHop, tc.expectedInterface)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// fetchInterfaceIP fetches the IPv4 address of the specified interface.
|
|
||||||
func fetchInterfaceIP(interfaceAlias string) (string, error) {
|
|
||||||
script := fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Where-Object AddressFamily -eq 2 | Select-Object -ExpandProperty IPAddress`, interfaceAlias)
|
|
||||||
out, err := exec.Command("powershell", "-Command", script).Output()
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to execute Get-NetIPAddress: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := strings.TrimSpace(string(out))
|
|
||||||
return ip, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func testRoute(t *testing.T, destination string, dialer dialer) *FindNetRouteOutput {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := dialer.DialContext(ctx, "udp", destination)
|
|
||||||
require.NoError(t, err, "Failed to dial destination")
|
|
||||||
defer func() {
|
|
||||||
err := conn.Close()
|
|
||||||
assert.NoError(t, err, "Failed to close connection")
|
|
||||||
}()
|
|
||||||
|
|
||||||
host, _, err := net.SplitHostPort(destination)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
script := fmt.Sprintf(`Find-NetRoute -RemoteIPAddress "%s" | Select-Object -Property IPAddress, InterfaceIndex, InterfaceAlias, AddressFamily, NextHop, DestinationPrefix | ConvertTo-Json`, host)
|
|
||||||
|
|
||||||
out, err := exec.Command("powershell", "-Command", script).Output()
|
|
||||||
require.NoError(t, err, "Failed to execute Find-NetRoute")
|
|
||||||
|
|
||||||
var outputs []FindNetRouteOutput
|
|
||||||
err = json.Unmarshal(out, &outputs)
|
|
||||||
require.NoError(t, err, "Failed to parse JSON outputs from Find-NetRoute")
|
|
||||||
|
|
||||||
require.Greater(t, len(outputs), 0, "No route found for destination")
|
|
||||||
combinedOutput := combineOutputs(outputs)
|
|
||||||
|
|
||||||
return combinedOutput
|
|
||||||
}
|
|
||||||
|
|
||||||
func createAndSetupDummyInterface(t *testing.T, interfaceName, ipAddressCIDR string) string {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
ip, ipNet, err := net.ParseCIDR(ipAddressCIDR)
|
|
||||||
require.NoError(t, err)
|
|
||||||
subnetMaskSize, _ := ipNet.Mask.Size()
|
|
||||||
script := fmt.Sprintf(`New-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -PrefixLength %d -PolicyStore ActiveStore -Confirm:$False`, interfaceName, ip.String(), subnetMaskSize)
|
|
||||||
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
|
|
||||||
require.NoError(t, err, "Failed to assign IP address to loopback adapter")
|
|
||||||
|
|
||||||
// Wait for the IP address to be applied
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
err = waitForIPAddress(ctx, interfaceName, ip.String())
|
|
||||||
require.NoError(t, err, "IP address not applied within timeout")
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
script = fmt.Sprintf(`Remove-NetIPAddress -InterfaceAlias "%s" -IPAddress "%s" -Confirm:$False`, interfaceName, ip.String())
|
|
||||||
_, err = exec.Command("powershell", "-Command", script).CombinedOutput()
|
|
||||||
require.NoError(t, err, "Failed to remove IP address from loopback adapter")
|
|
||||||
})
|
|
||||||
|
|
||||||
return interfaceName
|
|
||||||
}
|
|
||||||
|
|
||||||
func fetchOriginalGateway() (*RouteInfo, error) {
|
|
||||||
cmd := exec.Command("powershell", "-Command", "Get-NetRoute -DestinationPrefix 0.0.0.0/0 | Select-Object NextHop, RouteMetric, InterfaceAlias | ConvertTo-Json")
|
|
||||||
output, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to execute Get-NetRoute: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var routeInfo RouteInfo
|
|
||||||
err = json.Unmarshal(output, &routeInfo)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to parse JSON output: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &routeInfo, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyOutput(t *testing.T, output *FindNetRouteOutput, sourceIP, destPrefix, nextHop, intf string) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
assert.Equal(t, sourceIP, output.IPAddress, "Source IP mismatch")
|
|
||||||
assert.Equal(t, destPrefix, output.DestinationPrefix, "Destination prefix mismatch")
|
|
||||||
assert.Equal(t, nextHop, output.NextHop, "Next hop mismatch")
|
|
||||||
assert.Equal(t, intf, output.InterfaceAlias, "Interface mismatch")
|
|
||||||
}
|
|
||||||
|
|
||||||
func waitForIPAddress(ctx context.Context, interfaceAlias, expectedIPAddress string) error {
|
|
||||||
ticker := time.NewTicker(1 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-ticker.C:
|
|
||||||
out, err := exec.Command("powershell", "-Command", fmt.Sprintf(`Get-NetIPAddress -InterfaceAlias "%s" | Select-Object -ExpandProperty IPAddress`, interfaceAlias)).CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ipAddresses := strings.Split(strings.TrimSpace(string(out)), "\n")
|
|
||||||
for _, ip := range ipAddresses {
|
|
||||||
if strings.TrimSpace(ip) == expectedIPAddress {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func combineOutputs(outputs []FindNetRouteOutput) *FindNetRouteOutput {
|
|
||||||
var combined FindNetRouteOutput
|
|
||||||
|
|
||||||
for _, output := range outputs {
|
|
||||||
if output.IPAddress != "" {
|
|
||||||
combined.IPAddress = output.IPAddress
|
|
||||||
}
|
|
||||||
if output.InterfaceIndex != 0 {
|
|
||||||
combined.InterfaceIndex = output.InterfaceIndex
|
|
||||||
}
|
|
||||||
if output.InterfaceAlias != "" {
|
|
||||||
combined.InterfaceAlias = output.InterfaceAlias
|
|
||||||
}
|
|
||||||
if output.AddressFamily != 0 {
|
|
||||||
combined.AddressFamily = output.AddressFamily
|
|
||||||
}
|
|
||||||
if output.NextHop != "" {
|
|
||||||
combined.NextHop = output.NextHop
|
|
||||||
}
|
|
||||||
if output.DestinationPrefix != "" {
|
|
||||||
combined.DestinationPrefix = output.DestinationPrefix
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &combined
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupDummyInterfacesAndRoutes(t *testing.T) {
|
|
||||||
t.Helper()
|
|
||||||
|
|
||||||
createAndSetupDummyInterface(t, "Loopback Pseudo-Interface 1", "10.0.0.1/8")
|
|
||||||
}
|
|
@ -1,24 +0,0 @@
|
|||||||
package stdnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Dial connects to the address on the named network.
|
|
||||||
func (n *Net) Dial(network, address string) (net.Conn, error) {
|
|
||||||
return nbnet.NewDialer().Dial(network, address)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialUDP connects to the address on the named UDP network.
|
|
||||||
func (n *Net) DialUDP(network string, laddr, raddr *net.UDPAddr) (transport.UDPConn, error) {
|
|
||||||
return nbnet.DialUDP(network, laddr, raddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialTCP connects to the address on the named TCP network.
|
|
||||||
func (n *Net) DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, error) {
|
|
||||||
return nbnet.DialTCP(network, laddr, raddr)
|
|
||||||
}
|
|
@ -1,20 +0,0 @@
|
|||||||
package stdnet
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/pion/transport/v3"
|
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ListenPacket listens for incoming packets on the given network and address.
|
|
||||||
func (n *Net) ListenPacket(network, address string) (net.PacketConn, error) {
|
|
||||||
return nbnet.NewListener().ListenPacket(context.Background(), network, address)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenUDP acts like ListenPacket for UDP networks.
|
|
||||||
func (n *Net) ListenUDP(network string, locAddr *net.UDPAddr) (transport.UDPConn, error) {
|
|
||||||
return nbnet.ListenUDP(network, locAddr)
|
|
||||||
}
|
|
@ -17,7 +17,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/ebpf"
|
"github.com/netbirdio/netbird/client/internal/ebpf"
|
||||||
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
ebpfMgr "github.com/netbirdio/netbird/client/internal/ebpf/manager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGEBPFProxy definition for proxy with EBPF support
|
// WGEBPFProxy definition for proxy with EBPF support
|
||||||
@ -68,7 +67,7 @@ func (p *WGEBPFProxy) Listen() error {
|
|||||||
IP: net.ParseIP("127.0.0.1"),
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := nbnet.ListenUDP("udp", &addr)
|
conn, err := net.ListenUDP("udp", &addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cErr := p.Free()
|
cErr := p.Free()
|
||||||
if cErr != nil {
|
if cErr != nil {
|
||||||
@ -229,12 +228,6 @@ func (p *WGEBPFProxy) prepareSenderRawSocket() (net.PacketConn, error) {
|
|||||||
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
return nil, fmt.Errorf("binding to lo interface failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set the fwmark on the socket.
|
|
||||||
err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, nbnet.NetbirdFwmark)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("setting fwmark failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert the file descriptor to a PacketConn.
|
// Convert the file descriptor to a PacketConn.
|
||||||
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
file := os.NewFile(uintptr(fd), fmt.Sprintf("fd %d", fd))
|
||||||
if file == nil {
|
if file == nil {
|
||||||
|
@ -6,8 +6,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGUserSpaceProxy proxies
|
// WGUserSpaceProxy proxies
|
||||||
@ -35,7 +33,7 @@ func (p *WGUserSpaceProxy) AddTurnConn(remoteConn net.Conn) (net.Addr, error) {
|
|||||||
p.remoteConn = remoteConn
|
p.remoteConn = remoteConn
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
p.localConn, err = nbnet.NewDialer().Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
p.localConn, err = net.Dial("udp", fmt.Sprintf(":%d", p.localWGListenPort))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed dialing to local Wireguard port %s", err)
|
log.Errorf("failed dialing to local Wireguard port %s", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
2
go.mod
2
go.mod
@ -48,7 +48,6 @@ require (
|
|||||||
github.com/google/gopacket v1.1.19
|
github.com/google/gopacket v1.1.19
|
||||||
github.com/google/martian/v3 v3.0.0
|
github.com/google/martian/v3 v3.0.0
|
||||||
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
github.com/google/nftables v0.0.0-20220808154552-2eca00135732
|
||||||
github.com/gopacket/gopacket v1.1.1
|
|
||||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
|
||||||
github.com/hashicorp/go-multierror v1.1.1
|
github.com/hashicorp/go-multierror v1.1.1
|
||||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||||
@ -124,6 +123,7 @@ require (
|
|||||||
github.com/google/s2a-go v0.1.4 // indirect
|
github.com/google/s2a-go v0.1.4 // indirect
|
||||||
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
|
github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect
|
||||||
github.com/googleapis/gax-go/v2 v2.10.0 // indirect
|
github.com/googleapis/gax-go/v2 v2.10.0 // indirect
|
||||||
|
github.com/gopacket/gopacket v1.1.1 // indirect
|
||||||
github.com/hashicorp/errwrap v1.0.0 // indirect
|
github.com/hashicorp/errwrap v1.0.0 // indirect
|
||||||
github.com/hashicorp/go-uuid v1.0.2 // indirect
|
github.com/hashicorp/go-uuid v1.0.2 // indirect
|
||||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||||
|
@ -10,8 +10,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl"
|
"golang.zx2c4.com/wireguard/wgctrl"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type wgKernelConfigurer struct {
|
type wgKernelConfigurer struct {
|
||||||
@ -31,7 +29,7 @@ func (c *wgKernelConfigurer) configureInterface(privateKey string, port int) err
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fwmark := nbnet.NetbirdFwmark
|
fwmark := 0
|
||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
PrivateKey: &key,
|
PrivateKey: &key,
|
||||||
ReplacePeers: true,
|
ReplacePeers: true,
|
||||||
|
@ -13,8 +13,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type wgUSPConfigurer struct {
|
type wgUSPConfigurer struct {
|
||||||
@ -39,7 +37,7 @@ func (c *wgUSPConfigurer) configureInterface(privateKey string, port int) error
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fwmark := getFwmark()
|
fwmark := 0
|
||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
PrivateKey: &key,
|
PrivateKey: &key,
|
||||||
ReplacePeers: true,
|
ReplacePeers: true,
|
||||||
@ -347,10 +345,3 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
}
|
}
|
||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
return nbnet.NetbirdFwmark
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
@ -24,7 +24,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/management/proto"
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
nbgrpc "github.com/netbirdio/netbird/util/grpc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const ConnectTimeout = 10 * time.Second
|
const ConnectTimeout = 10 * time.Second
|
||||||
@ -58,7 +57,6 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE
|
|||||||
mgmCtx,
|
mgmCtx,
|
||||||
addr,
|
addr,
|
||||||
transportOption,
|
transportOption,
|
||||||
nbgrpc.WithCustomDialer(),
|
|
||||||
grpc.WithBlock(),
|
grpc.WithBlock(),
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
|
@ -21,8 +21,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrSharedSockStopped indicates that shared socket has been stopped
|
// ErrSharedSockStopped indicates that shared socket has been stopped
|
||||||
@ -84,18 +82,10 @@ func Listen(port int, filter BPFFilter) (_ net.PacketConn, err error) {
|
|||||||
return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err)
|
return nil, fmt.Errorf("failed to create ipv4 raw socket: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = nbnet.SetSocketMark(rawSock.conn4); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to set SO_MARK on ipv4 socket: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var sockErr error
|
var sockErr error
|
||||||
rawSock.conn6, sockErr = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil)
|
rawSock.conn6, sockErr = socket.Socket(unix.AF_INET6, unix.SOCK_RAW, unix.IPPROTO_UDP, "raw_udp6", nil)
|
||||||
if sockErr != nil {
|
if sockErr != nil {
|
||||||
log.Errorf("Failed to create ipv6 raw socket: %v", err)
|
log.Errorf("Failed to create ipv6 raw socket: %v", err)
|
||||||
} else {
|
|
||||||
if err = nbnet.SetSocketMark(rawSock.conn6); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to set SO_MARK on ipv6 socket: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port))
|
ipv4Instructions, ipv6Instructions, err := filter.GetInstructions(uint32(rawSock.port))
|
||||||
|
@ -23,7 +23,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
"github.com/netbirdio/netbird/management/client"
|
"github.com/netbirdio/netbird/management/client"
|
||||||
"github.com/netbirdio/netbird/signal/proto"
|
"github.com/netbirdio/netbird/signal/proto"
|
||||||
nbgrpc "github.com/netbirdio/netbird/util/grpc"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnStateNotifier is a wrapper interface of the status recorder
|
// ConnStateNotifier is a wrapper interface of the status recorder
|
||||||
@ -77,7 +76,6 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
|
|||||||
sigCtx,
|
sigCtx,
|
||||||
addr,
|
addr,
|
||||||
transportOption,
|
transportOption,
|
||||||
nbgrpc.WithCustomDialer(),
|
|
||||||
grpc.WithBlock(),
|
grpc.WithBlock(),
|
||||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
|
@ -1,22 +0,0 @@
|
|||||||
package grpc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func WithCustomDialer() grpc.DialOption {
|
|
||||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
|
||||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to dial: %s", err)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
})
|
|
||||||
}
|
|
@ -1,21 +0,0 @@
|
|||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Dialer extends the standard net.Dialer with the ability to execute hooks before
|
|
||||||
// and after connections. This can be used to bypass the VPN for connections using this dialer.
|
|
||||||
type Dialer struct {
|
|
||||||
*net.Dialer
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewDialer returns a customized net.Dialer with overridden Control method
|
|
||||||
func NewDialer() *Dialer {
|
|
||||||
dialer := &Dialer{
|
|
||||||
Dialer: &net.Dialer{},
|
|
||||||
}
|
|
||||||
dialer.init()
|
|
||||||
|
|
||||||
return dialer
|
|
||||||
}
|
|
@ -1,163 +0,0 @@
|
|||||||
//go:build !android && !ios
|
|
||||||
|
|
||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
type DialerDialHookFunc func(ctx context.Context, connID ConnectionID, resolvedAddresses []net.IPAddr) error
|
|
||||||
type DialerCloseHookFunc func(connID ConnectionID, conn *net.Conn) error
|
|
||||||
|
|
||||||
var (
|
|
||||||
dialerDialHooksMutex sync.RWMutex
|
|
||||||
dialerDialHooks []DialerDialHookFunc
|
|
||||||
dialerCloseHooksMutex sync.RWMutex
|
|
||||||
dialerCloseHooks []DialerCloseHookFunc
|
|
||||||
)
|
|
||||||
|
|
||||||
// AddDialerHook allows adding a new hook to be executed before dialing.
|
|
||||||
func AddDialerHook(hook DialerDialHookFunc) {
|
|
||||||
dialerDialHooksMutex.Lock()
|
|
||||||
defer dialerDialHooksMutex.Unlock()
|
|
||||||
dialerDialHooks = append(dialerDialHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddDialerCloseHook allows adding a new hook to be executed on connection close.
|
|
||||||
func AddDialerCloseHook(hook DialerCloseHookFunc) {
|
|
||||||
dialerCloseHooksMutex.Lock()
|
|
||||||
defer dialerCloseHooksMutex.Unlock()
|
|
||||||
dialerCloseHooks = append(dialerCloseHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveDialerHook removes all dialer hooks.
|
|
||||||
func RemoveDialerHooks() {
|
|
||||||
dialerDialHooksMutex.Lock()
|
|
||||||
defer dialerDialHooksMutex.Unlock()
|
|
||||||
dialerDialHooks = nil
|
|
||||||
|
|
||||||
dialerCloseHooksMutex.Lock()
|
|
||||||
defer dialerCloseHooksMutex.Unlock()
|
|
||||||
dialerCloseHooks = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialContext wraps the net.Dialer's DialContext method to use the custom connection
|
|
||||||
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
var resolver *net.Resolver
|
|
||||||
if d.Resolver != nil {
|
|
||||||
resolver = d.Resolver
|
|
||||||
}
|
|
||||||
|
|
||||||
connID := GenerateConnID()
|
|
||||||
if dialerDialHooks != nil {
|
|
||||||
if err := calliDialerHooks(ctx, connID, address, resolver); err != nil {
|
|
||||||
log.Errorf("Failed to call dialer hooks: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("dial: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wrap the connection in Conn to handle Close with hooks
|
|
||||||
return &Conn{Conn: conn, ID: connID}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dial wraps the net.Dialer's Dial method to use the custom connection
|
|
||||||
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
|
|
||||||
return d.DialContext(context.Background(), network, address)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Conn wraps a net.Conn to override the Close method
|
|
||||||
type Conn struct {
|
|
||||||
net.Conn
|
|
||||||
ID ConnectionID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
|
||||||
func (c *Conn) Close() error {
|
|
||||||
err := c.Conn.Close()
|
|
||||||
|
|
||||||
dialerCloseHooksMutex.RLock()
|
|
||||||
defer dialerCloseHooksMutex.RUnlock()
|
|
||||||
|
|
||||||
for _, hook := range dialerCloseHooks {
|
|
||||||
if err := hook(c.ID, &c.Conn); err != nil {
|
|
||||||
log.Errorf("Error executing dialer close hook: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func calliDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error {
|
|
||||||
host, _, err := net.SplitHostPort(address)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("split host and port: %w", err)
|
|
||||||
}
|
|
||||||
ips, err := resolver.LookupIPAddr(ctx, host)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to resolve address %s: %w", address, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
|
||||||
|
|
||||||
var result *multierror.Error
|
|
||||||
|
|
||||||
dialerDialHooksMutex.RLock()
|
|
||||||
defer dialerDialHooksMutex.RUnlock()
|
|
||||||
for _, hook := range dialerDialHooks {
|
|
||||||
if err := hook(ctx, connID, ips); err != nil {
|
|
||||||
result = multierror.Append(result, fmt.Errorf("executing dial hook: %w", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result.ErrorOrNil()
|
|
||||||
}
|
|
||||||
|
|
||||||
func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) {
|
|
||||||
dialer := NewDialer()
|
|
||||||
dialer.LocalAddr = laddr
|
|
||||||
|
|
||||||
conn, err := dialer.Dial(network, raddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return udpConn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) {
|
|
||||||
dialer := NewDialer()
|
|
||||||
dialer.LocalAddr = laddr
|
|
||||||
|
|
||||||
conn, err := dialer.Dial(network, raddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tcpConn, nil
|
|
||||||
}
|
|
@ -1,12 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package net
|
|
||||||
|
|
||||||
import "syscall"
|
|
||||||
|
|
||||||
// init configures the net.Dialer Control function to set the fwmark on the socket
|
|
||||||
func (d *Dialer) init() {
|
|
||||||
d.Dialer.Control = func(_, _ string, c syscall.RawConn) error {
|
|
||||||
return SetRawSocketMark(c)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,6 +0,0 @@
|
|||||||
//go:build !linux || android
|
|
||||||
|
|
||||||
package net
|
|
||||||
|
|
||||||
func (d *Dialer) init() {
|
|
||||||
}
|
|
@ -1,21 +0,0 @@
|
|||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ListenerConfig extends the standard net.ListenConfig with the ability to execute hooks before
|
|
||||||
// responding via the socket and after closing. This can be used to bypass the VPN for listeners.
|
|
||||||
type ListenerConfig struct {
|
|
||||||
*net.ListenConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewListener creates a new ListenerConfig instance.
|
|
||||||
func NewListener() *ListenerConfig {
|
|
||||||
listener := &ListenerConfig{
|
|
||||||
ListenConfig: &net.ListenConfig{},
|
|
||||||
}
|
|
||||||
listener.init()
|
|
||||||
|
|
||||||
return listener
|
|
||||||
}
|
|
@ -1,163 +0,0 @@
|
|||||||
//go:build !android && !ios
|
|
||||||
|
|
||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ListenerWriteHookFunc defines the function signature for write hooks for PacketConn.
|
|
||||||
type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte) error
|
|
||||||
|
|
||||||
// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn.
|
|
||||||
type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error
|
|
||||||
|
|
||||||
var (
|
|
||||||
listenerWriteHooksMutex sync.RWMutex
|
|
||||||
listenerWriteHooks []ListenerWriteHookFunc
|
|
||||||
listenerCloseHooksMutex sync.RWMutex
|
|
||||||
listenerCloseHooks []ListenerCloseHookFunc
|
|
||||||
)
|
|
||||||
|
|
||||||
// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent.
|
|
||||||
func AddListenerWriteHook(hook ListenerWriteHookFunc) {
|
|
||||||
listenerWriteHooksMutex.Lock()
|
|
||||||
defer listenerWriteHooksMutex.Unlock()
|
|
||||||
listenerWriteHooks = append(listenerWriteHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddListenerCloseHook allows adding a new hook to be executed upon closing a UDP connection.
|
|
||||||
func AddListenerCloseHook(hook ListenerCloseHookFunc) {
|
|
||||||
listenerCloseHooksMutex.Lock()
|
|
||||||
defer listenerCloseHooksMutex.Unlock()
|
|
||||||
listenerCloseHooks = append(listenerCloseHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveListenerHooks removes all dialer hooks.
|
|
||||||
func RemoveListenerHooks() {
|
|
||||||
listenerWriteHooksMutex.Lock()
|
|
||||||
defer listenerWriteHooksMutex.Unlock()
|
|
||||||
listenerWriteHooks = nil
|
|
||||||
|
|
||||||
listenerCloseHooksMutex.Lock()
|
|
||||||
defer listenerCloseHooksMutex.Unlock()
|
|
||||||
listenerCloseHooks = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenPacket listens on the network address and returns a PacketConn
|
|
||||||
// which includes support for write hooks.
|
|
||||||
func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
|
|
||||||
pc, err := l.ListenConfig.ListenPacket(ctx, network, address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("listen packet: %w", err)
|
|
||||||
}
|
|
||||||
connID := GenerateConnID()
|
|
||||||
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PacketConn wraps net.PacketConn to override its WriteTo and Close methods to include hook functionality.
|
|
||||||
type PacketConn struct {
|
|
||||||
net.PacketConn
|
|
||||||
ID ConnectionID
|
|
||||||
seenAddrs *sync.Map
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
|
||||||
func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|
||||||
callWriteHooks(c.ID, c.seenAddrs, b, addr)
|
|
||||||
return c.PacketConn.WriteTo(b, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
|
|
||||||
func (c *PacketConn) Close() error {
|
|
||||||
c.seenAddrs = &sync.Map{}
|
|
||||||
return closeConn(c.ID, c.PacketConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UDPConn wraps net.UDPConn to override its WriteTo and Close methods to include hook functionality.
|
|
||||||
type UDPConn struct {
|
|
||||||
*net.UDPConn
|
|
||||||
ID ConnectionID
|
|
||||||
seenAddrs *sync.Map
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteTo writes a packet with payload b to addr, executing registered write hooks beforehand.
|
|
||||||
func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
|
||||||
callWriteHooks(c.ID, c.seenAddrs, b, addr)
|
|
||||||
return c.UDPConn.WriteTo(b, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
|
|
||||||
func (c *UDPConn) Close() error {
|
|
||||||
c.seenAddrs = &sync.Map{}
|
|
||||||
return closeConn(c.ID, c.UDPConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
|
|
||||||
// Lookup the address in the seenAddrs map to avoid calling the hooks for every write
|
|
||||||
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {
|
|
||||||
ipStr, _, splitErr := net.SplitHostPort(addr.String())
|
|
||||||
if splitErr != nil {
|
|
||||||
log.Errorf("Error splitting IP address and port: %v", splitErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ip, err := net.ResolveIPAddr("ip", ipStr)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Error resolving IP address: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Debugf("Listener resolved IP for %s: %s", addr, ip)
|
|
||||||
|
|
||||||
func() {
|
|
||||||
listenerWriteHooksMutex.RLock()
|
|
||||||
defer listenerWriteHooksMutex.RUnlock()
|
|
||||||
|
|
||||||
for _, hook := range listenerWriteHooks {
|
|
||||||
if err := hook(id, ip, b); err != nil {
|
|
||||||
log.Errorf("Error executing listener write hook: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func closeConn(id ConnectionID, conn net.PacketConn) error {
|
|
||||||
err := conn.Close()
|
|
||||||
|
|
||||||
listenerCloseHooksMutex.RLock()
|
|
||||||
defer listenerCloseHooksMutex.RUnlock()
|
|
||||||
|
|
||||||
for _, hook := range listenerCloseHooks {
|
|
||||||
if err := hook(id, conn); err != nil {
|
|
||||||
log.Errorf("Error executing listener close hook: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenUDP listens on the network address and returns a transport.UDPConn
|
|
||||||
// which includes support for write and close hooks.
|
|
||||||
func ListenUDP(network string, laddr *net.UDPAddr) (*UDPConn, error) {
|
|
||||||
conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("listen UDP: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
packetConn := conn.(*PacketConn)
|
|
||||||
udpConn, ok := packetConn.PacketConn.(*net.UDPConn)
|
|
||||||
if !ok {
|
|
||||||
if err := packetConn.Close(); err != nil {
|
|
||||||
log.Errorf("Failed to close connection: %v", err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil
|
|
||||||
}
|
|
@ -1,14 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
)
|
|
||||||
|
|
||||||
// init configures the net.ListenerConfig Control function to set the fwmark on the socket
|
|
||||||
func (l *ListenerConfig) init() {
|
|
||||||
l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error {
|
|
||||||
return SetRawSocketMark(c)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,11 +0,0 @@
|
|||||||
//go:build android || ios
|
|
||||||
|
|
||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func ListenUDP(network string, laddr *net.UDPAddr) (*net.UDPConn, error) {
|
|
||||||
return net.ListenUDP(network, laddr)
|
|
||||||
}
|
|
@ -1,6 +0,0 @@
|
|||||||
//go:build !linux || android
|
|
||||||
|
|
||||||
package net
|
|
||||||
|
|
||||||
func (l *ListenerConfig) init() {
|
|
||||||
}
|
|
@ -1,17 +0,0 @@
|
|||||||
package net
|
|
||||||
|
|
||||||
import "github.com/google/uuid"
|
|
||||||
|
|
||||||
const (
|
|
||||||
// NetbirdFwmark is the fwmark value used by Netbird via wireguard
|
|
||||||
NetbirdFwmark = 0x1BD00
|
|
||||||
)
|
|
||||||
|
|
||||||
// ConnectionID provides a globally unique identifier for network connections.
|
|
||||||
// It's used to track connections throughout their lifecycle so the close hook can correlate with the dial hook.
|
|
||||||
type ConnectionID string
|
|
||||||
|
|
||||||
// GenerateConnID generates a unique identifier for each connection.
|
|
||||||
func GenerateConnID() ConnectionID {
|
|
||||||
return ConnectionID(uuid.NewString())
|
|
||||||
}
|
|
@ -1,35 +0,0 @@
|
|||||||
//go:build !android
|
|
||||||
|
|
||||||
package net
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"syscall"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SetSocketMark sets the SO_MARK option on the given socket connection
|
|
||||||
func SetSocketMark(conn syscall.Conn) error {
|
|
||||||
sysconn, err := conn.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get raw conn: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return SetRawSocketMark(sysconn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func SetRawSocketMark(conn syscall.RawConn) error {
|
|
||||||
var setErr error
|
|
||||||
|
|
||||||
err := conn.Control(func(fd uintptr) {
|
|
||||||
setErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("control: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if setErr != nil {
|
|
||||||
return fmt.Errorf("set SO_MARK: %w", setErr)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user