From d6ed9c037ed31a1b4a527b3e49f9a1aaafa8db74 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 21 Jul 2025 12:13:21 +0200 Subject: [PATCH] [client] Fix bind exclusion routes (#4154) --- client/iface/bind/control.go | 15 ++++ client/iface/bind/control_android.go | 12 --- client/iface/bind/ice_bind.go | 3 +- client/iface/bind/udp_mux.go | 25 +++--- client/iface/bind/udp_mux_generic.go | 21 +++++ client/iface/bind/udp_mux_ios.go | 7 ++ client/iface/bind/udp_mux_universal.go | 20 +++-- client/internal/engine.go | 14 +--- client/internal/engine_test.go | 2 +- client/internal/peer/conn.go | 52 ------------- .../routemanager/client/client_test.go | 2 +- client/internal/routemanager/manager.go | 13 ++-- client/internal/routemanager/manager_test.go | 2 +- client/internal/routemanager/mock.go | 5 +- .../routemanager/notifier/notifier_other.go | 2 +- .../routemanager/systemops/systemops.go | 5 ++ .../systemops/systemops_android.go | 5 +- .../systemops/systemops_generic.go | 76 ++++++++++++++----- .../systemops/systemops_generic_test.go | 6 +- .../routemanager/systemops/systemops_ios.go | 5 +- .../routemanager/systemops/systemops_linux.go | 6 +- .../routemanager/systemops/systemops_unix.go | 3 +- .../systemops/systemops_windows.go | 3 +- util/net/listener_listen.go | 67 ++++++++++++++-- util/net/listener_listen_ios.go | 10 +++ 25 files changed, 230 insertions(+), 151 deletions(-) create mode 100644 client/iface/bind/control.go delete mode 100644 client/iface/bind/control_android.go create mode 100644 client/iface/bind/udp_mux_generic.go create mode 100644 client/iface/bind/udp_mux_ios.go create mode 100644 util/net/listener_listen_ios.go diff --git a/client/iface/bind/control.go b/client/iface/bind/control.go new file mode 100644 index 000000000..89bddf12c --- /dev/null +++ b/client/iface/bind/control.go @@ -0,0 +1,15 @@ +package bind + +import ( + wireguard "golang.zx2c4.com/wireguard/conn" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go) +func init() { + listener := nbnet.NewListener() + if listener.ListenConfig.Control != nil { + *wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control) + } +} diff --git a/client/iface/bind/control_android.go b/client/iface/bind/control_android.go deleted file mode 100644 index b8a865e39..000000000 --- a/client/iface/bind/control_android.go +++ /dev/null @@ -1,12 +0,0 @@ -package bind - -import ( - wireguard "golang.zx2c4.com/wireguard/conn" - - nbnet "github.com/netbirdio/netbird/util/net" -) - -func init() { - // ControlFns is not thread safe and should only be modified during init. - *wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket) -} diff --git a/client/iface/bind/ice_bind.go b/client/iface/bind/ice_bind.go index bb7a27279..c3d5ef377 100644 --- a/client/iface/bind/ice_bind.go +++ b/client/iface/bind/ice_bind.go @@ -16,6 +16,7 @@ import ( wgConn "golang.zx2c4.com/wireguard/conn" "github.com/netbirdio/netbird/client/iface/wgaddr" + nbnet "github.com/netbirdio/netbird/util/net" ) type RecvMessage struct { @@ -153,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r s.udpMux = NewUniversalUDPMuxDefault( UniversalUDPMuxParams{ - UDPConn: conn, + UDPConn: nbnet.WrapUDPConn(conn), Net: s.transportNet, FilterFn: s.filterFn, WGAddress: s.address, diff --git a/client/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go index 0e58499aa..29e5d7937 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/bind/udp_mux.go @@ -296,14 +296,20 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) { return } - m.addressMapMu.Lock() - defer m.addressMapMu.Unlock() - + var allAddresses []string for _, c := range removedConns { addresses := c.getAddresses() - for _, addr := range addresses { - delete(m.addressMap, addr) - } + allAddresses = append(allAddresses, addresses...) + } + + m.addressMapMu.Lock() + for _, addr := range allAddresses { + delete(m.addressMap, addr) + } + m.addressMapMu.Unlock() + + for _, addr := range allAddresses { + m.notifyAddressRemoval(addr) } } @@ -351,14 +357,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string) } m.addressMapMu.Lock() - defer m.addressMapMu.Unlock() - existing, ok := m.addressMap[addr] if !ok { existing = []*udpMuxedConn{} } existing = append(existing, conn) m.addressMap[addr] = existing + m.addressMapMu.Unlock() log.Debugf("ICE: registered %s for %s", addr, conn.params.Key) } @@ -386,12 +391,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro // If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one // muxed connection - one for the SRFLX candidate and the other one for the HOST one. // We will then forward STUN packets to each of these connections. - m.addressMapMu.Lock() + m.addressMapMu.RLock() var destinationConnList []*udpMuxedConn if storedConns, ok := m.addressMap[addr.String()]; ok { destinationConnList = append(destinationConnList, storedConns...) } - m.addressMapMu.Unlock() + m.addressMapMu.RUnlock() var isIPv6 bool if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil { diff --git a/client/iface/bind/udp_mux_generic.go b/client/iface/bind/udp_mux_generic.go new file mode 100644 index 000000000..e42d25462 --- /dev/null +++ b/client/iface/bind/udp_mux_generic.go @@ -0,0 +1,21 @@ +//go:build !ios + +package bind + +import ( + nbnet "github.com/netbirdio/netbird/util/net" +) + +func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { + wrapped, ok := m.params.UDPConn.(*UDPConn) + if !ok { + return + } + + nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn) + if !ok { + return + } + + nbnetConn.RemoveAddress(addr) +} diff --git a/client/iface/bind/udp_mux_ios.go b/client/iface/bind/udp_mux_ios.go new file mode 100644 index 000000000..15e26d02f --- /dev/null +++ b/client/iface/bind/udp_mux_ios.go @@ -0,0 +1,7 @@ +//go:build ios + +package bind + +func (m *UDPMuxDefault) notifyAddressRemoval(addr string) { + // iOS doesn't support nbnet hooks, so this is a no-op +} \ No newline at end of file diff --git a/client/iface/bind/udp_mux_universal.go b/client/iface/bind/udp_mux_universal.go index 5cc634955..b755a7827 100644 --- a/client/iface/bind/udp_mux_universal.go +++ b/client/iface/bind/udp_mux_universal.go @@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef // wrap UDP connection, process server reflexive messages // before they are passed to the UDPMux connection handler (connWorker) - m.params.UDPConn = &udpConn{ + m.params.UDPConn = &UDPConn{ PacketConn: params.UDPConn, mux: m, logger: params.Logger, @@ -70,7 +70,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef address: params.WGAddress, } - // embed UDPMux udpMuxParams := UDPMuxParams{ Logger: params.Logger, UDPConn: m.params.UDPConn, @@ -114,8 +113,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) { } } -// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets -type udpConn struct { +// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets +type UDPConn struct { net.PacketConn mux *UniversalUDPMuxDefault logger logging.LeveledLogger @@ -125,7 +124,12 @@ type udpConn struct { address wgaddr.Address } -func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { +// GetPacketConn returns the underlying PacketConn +func (u *UDPConn) GetPacketConn() net.PacketConn { + return u.PacketConn +} + +func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { if u.filterFn == nil { return u.PacketConn.WriteTo(b, addr) } @@ -137,21 +141,21 @@ func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { return u.handleUncachedAddress(b, addr) } -func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) { +func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) { if isRouted { return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr) } return u.PacketConn.WriteTo(b, addr) } -func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) { +func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) { if err := u.performFilterCheck(addr); err != nil { return 0, err } return u.PacketConn.WriteTo(b, addr) } -func (u *udpConn) performFilterCheck(addr net.Addr) error { +func (u *UDPConn) performFilterCheck(addr net.Addr) error { host, err := getHostFromAddr(addr) if err != nil { log.Errorf("Failed to get host from address %s: %v", addr, err) diff --git a/client/internal/engine.go b/client/internal/engine.go index e9772b359..1abb8163d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -61,7 +61,6 @@ import ( signal "github.com/netbirdio/netbird/signal/client" sProto "github.com/netbirdio/netbird/signal/proto" "github.com/netbirdio/netbird/util" - nbnet "github.com/netbirdio/netbird/util/net" ) // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. @@ -138,9 +137,6 @@ type Engine struct { connMgr *ConnMgr - beforePeerHook nbnet.AddHookFunc - afterPeerHook nbnet.RemoveHookFunc - // rpManager is a Rosenpass manager rpManager *rosenpass.Manager @@ -409,12 +405,8 @@ func (e *Engine) Start() error { DisableClientRoutes: e.config.DisableClientRoutes, DisableServerRoutes: e.config.DisableServerRoutes, }) - beforePeerHook, afterPeerHook, err := e.routeManager.Init() - if err != nil { + if err := e.routeManager.Init(); err != nil { log.Errorf("Failed to initialize route manager: %s", err) - } else { - e.beforePeerHook = beforePeerHook - e.afterPeerHook = afterPeerHook } e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener) @@ -1261,10 +1253,6 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { return fmt.Errorf("peer already exists: %s", peerKey) } - if e.beforePeerHook != nil && e.afterPeerHook != nil { - conn.AddBeforeAddPeerHook(e.beforePeerHook) - conn.AddAfterRemovePeerHook(e.afterPeerHook) - } return nil } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index e75672ed1..f02138686 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -400,7 +400,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { StatusRecorder: engine.statusRecorder, RelayManager: relayMgr, }) - _, _, err = engine.routeManager.Init() + err = engine.routeManager.Init() require.NoError(t, err) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 7765bb51c..ddd90450d 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -26,7 +26,6 @@ import ( "github.com/netbirdio/netbird/client/internal/stdnet" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" - nbnet "github.com/netbirdio/netbird/util/net" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) @@ -106,10 +105,6 @@ type Conn struct { workerRelay *WorkerRelay wgWatcherWg sync.WaitGroup - connIDRelay nbnet.ConnectionID - connIDICE nbnet.ConnectionID - beforeAddPeerHooks []nbnet.AddHookFunc - afterRemovePeerHooks []nbnet.RemoveHookFunc // used to store the remote Rosenpass key for Relayed connection in case of connection update from ice rosenpassRemoteKey []byte @@ -267,8 +262,6 @@ func (conn *Conn) Close(signalToRemote bool) { conn.Log.Errorf("failed to remove wg endpoint: %v", err) } - conn.freeUpConnID() - if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil { conn.onDisconnected(conn.config.WgConfig.RemoteKey) } @@ -293,13 +286,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa conn.workerICE.OnRemoteCandidate(candidate, haRoutes) } -func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) { - conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook) -} -func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) { - conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook) -} - // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) { conn.onConnected = handler @@ -387,10 +373,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn ep = directEp } - if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil { - conn.Log.Errorf("Before add peer hook failed: %v", err) - } - conn.workerRelay.DisableWgWatcher() // todo consider to run conn.wgWatcherWg.Wait() here @@ -503,10 +485,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) { return } - if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil { - conn.Log.Errorf("Before add peer hook failed: %v", err) - } - wgProxy.Work() if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil { if err := wgProxy.CloseConn(); err != nil { @@ -707,36 +685,6 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) { return true } -func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error { - conn.connIDICE = nbnet.GenerateConnID() - for _, hook := range conn.beforeAddPeerHooks { - if err := hook(conn.connIDICE, ip); err != nil { - return err - } - } - return nil -} - -func (conn *Conn) freeUpConnID() { - if conn.connIDRelay != "" { - for _, hook := range conn.afterRemovePeerHooks { - if err := hook(conn.connIDRelay); err != nil { - conn.Log.Errorf("After remove peer hook failed: %v", err) - } - } - conn.connIDRelay = "" - } - - if conn.connIDICE != "" { - for _, hook := range conn.afterRemovePeerHooks { - if err := hook(conn.connIDICE); err != nil { - conn.Log.Errorf("After remove peer hook failed: %v", err) - } - } - conn.connIDICE = "" - } -} - func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) { conn.Log.Debugf("setup proxied WireGuard connection") udpAddr := &net.UDPAddr{ diff --git a/client/internal/routemanager/client/client_test.go b/client/internal/routemanager/client/client_test.go index ec8e0e944..850f6691f 100644 --- a/client/internal/routemanager/client/client_test.go +++ b/client/internal/routemanager/client/client_test.go @@ -812,7 +812,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) { } params := common.HandlerParams{ - Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, + Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")}, } // create new clientNetwork client := &Watcher{ diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index e0974ab2a..e51778811 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -44,7 +44,7 @@ import ( // Manager is a route manager interface type Manager interface { - Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) + Init() error UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap) TriggerSelection(route.HAMap) @@ -201,11 +201,11 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) { } // Init sets up the routing -func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (m *DefaultManager) Init() error { m.routeSelector = m.initSelector() if nbnet.CustomRoutingDisabled() || m.disableClientRoutes { - return nil, nil, nil + return nil } if err := m.sysOps.CleanupRouting(nil); err != nil { @@ -219,13 +219,12 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) ips := resolveURLsToIPs(initialAddresses) - beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager) - if err != nil { - return nil, nil, fmt.Errorf("setup routing: %w", err) + if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil { + return fmt.Errorf("setup routing: %w", err) } log.Info("Routing setup complete") - return beforePeerHook, afterPeerHook, nil + return nil } func (m *DefaultManager) initSelector() *routeselector.RouteSelector { diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 486ee080a..2f13c2134 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -430,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) { StatusRecorder: statusRecorder, }) - _, _, err = routeManager.Init() + err = routeManager.Init() require.NoError(t, err, "should init route manager") defer routeManager.Stop(nil) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 4e182f82c..be633c3fa 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -9,7 +9,6 @@ import ( "github.com/netbirdio/netbird/client/internal/routeselector" "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/route" - "github.com/netbirdio/netbird/util/net" ) // MockManager is the mock instance of a route manager @@ -23,8 +22,8 @@ type MockManager struct { StopFunc func(manager *statemanager.Manager) } -func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { - return nil, nil, nil +func (m *MockManager) Init() error { + return nil } // InitialRouteRange mock implementation of InitialRouteRange from Manager interface diff --git a/client/internal/routemanager/notifier/notifier_other.go b/client/internal/routemanager/notifier/notifier_other.go index 77045b839..0521e3dc2 100644 --- a/client/internal/routemanager/notifier/notifier_other.go +++ b/client/internal/routemanager/notifier/notifier_other.go @@ -33,4 +33,4 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { func (n *Notifier) GetInitialRouteRanges() []string { return []string{} -} \ No newline at end of file +} diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 106c520da..b91348e94 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -6,6 +6,7 @@ import ( "net/netip" "sync" "sync/atomic" + "time" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" @@ -56,6 +57,10 @@ type SysOps struct { // seq is an atomic counter for generating unique sequence numbers for route messages //nolint:unused // only used on BSD systems seq atomic.Uint32 + + localSubnetsCache []*net.IPNet + localSubnetsCacheMu sync.RWMutex + localSubnetsCacheTime time.Time } func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { diff --git a/client/internal/routemanager/systemops/systemops_android.go b/client/internal/routemanager/systemops/systemops_android.go index ca8aea3fb..a375ce832 100644 --- a/client/internal/routemanager/systemops/systemops_android.go +++ b/client/internal/routemanager/systemops/systemops_android.go @@ -10,11 +10,10 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { - return nil, nil, nil +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { + return nil } func (r *SysOps) CleanupRouting(*statemanager.Manager) error { diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index d223a27b2..128afa2a5 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -10,6 +10,7 @@ import ( "net/netip" "runtime" "strconv" + "time" "github.com/hashicorp/go-multierror" "github.com/libp2p/go-netroute" @@ -24,6 +25,8 @@ import ( nbnet "github.com/netbirdio/netbird/util/net" ) +const localSubnetsCacheTTL = 15 * time.Minute + var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1) var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) @@ -31,7 +34,7 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) var ErrRoutingIsSeparate = errors.New("routing is separate") -func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error { stateManager.RegisterState(&ShutdownState{}) initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified()) @@ -75,7 +78,10 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana r.refCounter = refCounter - return r.setupHooks(initAddresses, stateManager) + if err := r.setupHooks(initAddresses, stateManager); err != nil { + return fmt.Errorf("setup hooks: %w", err) + } + return nil } // updateState updates state on every change so it will be persisted regularly @@ -128,18 +134,14 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init return Nexthop{}, fmt.Errorf("get next hop: %w", err) } - log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP) - exitNextHop := Nexthop{ - IP: nexthop.IP, - Intf: nexthop.Intf, - } + log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.Intf) + exitNextHop := nexthop vpnAddr := vpnIntf.Address().IP // if next hop is the VPN address or the interface is the VPN interface, we should use the initial values if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() { log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop) - exitNextHop = initialNextHop } @@ -152,12 +154,37 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init } func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) { + r.localSubnetsCacheMu.RLock() + cacheAge := time.Since(r.localSubnetsCacheTime) + subnets := r.localSubnetsCache + r.localSubnetsCacheMu.RUnlock() + + if cacheAge > localSubnetsCacheTTL || subnets == nil { + r.localSubnetsCacheMu.Lock() + if time.Since(r.localSubnetsCacheTime) > localSubnetsCacheTTL || r.localSubnetsCache == nil { + r.refreshLocalSubnetsCache() + } + subnets = r.localSubnetsCache + r.localSubnetsCacheMu.Unlock() + } + + for _, subnet := range subnets { + if subnet.Contains(prefix.Addr().AsSlice()) { + return true, subnet + } + } + + return false, nil +} + +func (r *SysOps) refreshLocalSubnetsCache() { localInterfaces, err := net.Interfaces() if err != nil { log.Errorf("Failed to get local interfaces: %v", err) - return false, nil + return } + var newSubnets []*net.IPNet for _, intf := range localInterfaces { addrs, err := intf.Addrs() if err != nil { @@ -171,14 +198,12 @@ func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) log.Errorf("Failed to convert address to IPNet: %v", addr) continue } - - if ipnet.Contains(prefix.Addr().AsSlice()) { - return true, ipnet - } + newSubnets = append(newSubnets, ipnet) } } - return false, nil + r.localSubnetsCache = newSubnets + r.localSubnetsCacheTime = time.Now() } // genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix @@ -264,7 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) return r.removeFromRouteTable(prefix, nextHop) } -func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { prefix, err := util.GetPrefixFromIP(ip) if err != nil { @@ -289,9 +314,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M return nil } + var merr *multierror.Error + for _, ip := range initAddresses { if err := beforeHook("init", ip); err != nil { - log.Errorf("Failed to add route reference: %v", err) + merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err)) } } @@ -300,11 +327,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M return ctx.Err() } - var result *multierror.Error + var merr *multierror.Error for _, ip := range resolvedIPs { - result = multierror.Append(result, beforeHook(connID, ip.IP)) + merr = multierror.Append(merr, beforeHook(connID, ip.IP)) } - return nberrors.FormatErrorOrNil(result) + return nberrors.FormatErrorOrNil(merr) }) nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error { @@ -319,7 +346,16 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M return afterHook(connID) }) - return beforeHook, afterHook, nil + nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error { + if _, err := r.refCounter.Decrement(prefix); err != nil { + return fmt.Errorf("remove route reference: %w", err) + } + + r.updateState(stateManager) + return nil + }) + + return nberrors.FormatErrorOrNil(merr) } func GetNextHop(ip netip.Addr) (Nexthop, error) { diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 2a57e6044..c1c1182bc 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) r := NewSysOps(wgInterface, nil) - _, _, err := r.SetupRouting(nil, nil) + err := r.SetupRouting(nil, nil) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, r.CleanupRouting(nil)) @@ -341,7 +341,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) r := NewSysOps(wgInterface, nil) - _, _, err := r.SetupRouting(nil, nil) + err := r.SetupRouting(nil, nil) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, r.CleanupRouting(nil)) @@ -484,7 +484,7 @@ func setupTestEnv(t *testing.T) { }) r := NewSysOps(wgInterface, nil) - _, _, err := r.SetupRouting(nil, nil) + err := r.SetupRouting(nil, nil) require.NoError(t, err, "setupRouting should not return err") t.Cleanup(func() { assert.NoError(t, r.CleanupRouting(nil)) diff --git a/client/internal/routemanager/systemops/systemops_ios.go b/client/internal/routemanager/systemops/systemops_ios.go index bf06f3739..10356eae0 100644 --- a/client/internal/routemanager/systemops/systemops_ios.go +++ b/client/internal/routemanager/systemops/systemops_ios.go @@ -10,14 +10,13 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error { r.mu.Lock() defer r.mu.Unlock() r.prefixes = make(map[netip.Prefix]struct{}) - return nil, nil, nil + return nil } func (r *SysOps) CleanupRouting(*statemanager.Manager) error { diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index b48cfa242..711f1d758 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -72,7 +72,7 @@ func getSetupRules() []ruleParams { // Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table. // This table is where a default route or other specific routes received from the management server are configured, // enabling VPN connectivity. -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) { if !nbnet.AdvancedRouting() { log.Infof("Using legacy routing setup") return r.setupRefCounter(initAddresses, stateManager) @@ -89,7 +89,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager rules := getSetupRules() for _, rule := range rules { if err := addRule(rule); err != nil { - return nil, nil, fmt.Errorf("%s: %w", rule.description, err) + return fmt.Errorf("%s: %w", rule.description, err) } } @@ -104,7 +104,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager } originalSysctl = originalValues - return nil, nil, nil + return nil } // CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'. diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index 46e5ca915..f165f7779 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -18,10 +18,9 @@ import ( "golang.org/x/sys/unix" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { return r.setupRefCounter(initAddresses, stateManager) } diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index 11eaa435e..7afac9ae5 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -19,7 +19,6 @@ import ( "golang.org/x/sys/windows" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/util/net" ) const InfiniteLifetime = 0xffffffff @@ -137,7 +136,7 @@ const ( RouteDeleted ) -func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error { return r.setupRefCounter(initAddresses, stateManager) } diff --git a/util/net/listener_listen.go b/util/net/listener_listen.go index efffba40e..dc99fbd68 100644 --- a/util/net/listener_listen.go +++ b/util/net/listener_listen.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "net" + "net/netip" "sync" log "github.com/sirupsen/logrus" @@ -17,11 +18,16 @@ type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte // ListenerCloseHookFunc defines the function signature for close hooks for PacketConn. type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error +// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed. +type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error + var ( - listenerWriteHooksMutex sync.RWMutex - listenerWriteHooks []ListenerWriteHookFunc - listenerCloseHooksMutex sync.RWMutex - listenerCloseHooks []ListenerCloseHookFunc + listenerWriteHooksMutex sync.RWMutex + listenerWriteHooks []ListenerWriteHookFunc + listenerCloseHooksMutex sync.RWMutex + listenerCloseHooks []ListenerCloseHookFunc + listenerAddressRemoveHooksMutex sync.RWMutex + listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc ) // AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent. @@ -38,7 +44,14 @@ func AddListenerCloseHook(hook ListenerCloseHookFunc) { listenerCloseHooks = append(listenerCloseHooks, hook) } -// RemoveListenerHooks removes all dialer hooks. +// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed. +func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) { + listenerAddressRemoveHooksMutex.Lock() + defer listenerAddressRemoveHooksMutex.Unlock() + listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook) +} + +// RemoveListenerHooks removes all listener hooks. func RemoveListenerHooks() { listenerWriteHooksMutex.Lock() defer listenerWriteHooksMutex.Unlock() @@ -47,6 +60,10 @@ func RemoveListenerHooks() { listenerCloseHooksMutex.Lock() defer listenerCloseHooksMutex.Unlock() listenerCloseHooks = nil + + listenerAddressRemoveHooksMutex.Lock() + defer listenerAddressRemoveHooksMutex.Unlock() + listenerAddressRemoveHooks = nil } // ListenPacket listens on the network address and returns a PacketConn @@ -61,6 +78,7 @@ func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address stri return nil, fmt.Errorf("listen packet: %w", err) } connID := GenerateConnID() + return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil } @@ -102,6 +120,45 @@ func (c *UDPConn) Close() error { return closeConn(c.ID, c.UDPConn) } +// WrapUDPConn wraps an existing *net.UDPConn with nbnet functionality +func WrapUDPConn(conn *net.UDPConn) *UDPConn { + return &UDPConn{ + UDPConn: conn, + ID: GenerateConnID(), + seenAddrs: &sync.Map{}, + } +} + +// RemoveAddress removes an address from the seen cache and triggers removal hooks. +func (c *UDPConn) RemoveAddress(addr string) { + if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists { + return + } + + ipStr, _, err := net.SplitHostPort(addr) + if err != nil { + log.Errorf("Error splitting IP address and port: %v", err) + return + } + + ipAddr, err := netip.ParseAddr(ipStr) + if err != nil { + log.Errorf("Error parsing IP address %s: %v", ipStr, err) + return + } + + prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen()) + + listenerAddressRemoveHooksMutex.RLock() + defer listenerAddressRemoveHooksMutex.RUnlock() + + for _, hook := range listenerAddressRemoveHooks { + if err := hook(c.ID, prefix); err != nil { + log.Errorf("Error executing listener address remove hook: %v", err) + } + } +} + func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) { // Lookup the address in the seenAddrs map to avoid calling the hooks for every write if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded { diff --git a/util/net/listener_listen_ios.go b/util/net/listener_listen_ios.go new file mode 100644 index 000000000..3cbd2cd71 --- /dev/null +++ b/util/net/listener_listen_ios.go @@ -0,0 +1,10 @@ +package net + +import ( + "net" +) + +// WrapUDPConn on iOS just returns the original connection since iOS handles its own networking +func WrapUDPConn(conn *net.UDPConn) *net.UDPConn { + return conn +}