mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-10 15:48:29 +02:00
[client] Fix bind exclusion routes (#4154)
This commit is contained in:
15
client/iface/bind/control.go
Normal file
15
client/iface/bind/control.go
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: This is most likely obsolete since the control fns should be called by the wrapped udpconn (ice_bind.go)
|
||||||
|
func init() {
|
||||||
|
listener := nbnet.NewListener()
|
||||||
|
if listener.ListenConfig.Control != nil {
|
||||||
|
*wireguard.ControlFns = append(*wireguard.ControlFns, listener.ListenConfig.Control)
|
||||||
|
}
|
||||||
|
}
|
@ -1,12 +0,0 @@
|
|||||||
package bind
|
|
||||||
|
|
||||||
import (
|
|
||||||
wireguard "golang.zx2c4.com/wireguard/conn"
|
|
||||||
|
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// ControlFns is not thread safe and should only be modified during init.
|
|
||||||
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
|
||||||
}
|
|
@ -16,6 +16,7 @@ import (
|
|||||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RecvMessage struct {
|
type RecvMessage struct {
|
||||||
@ -153,7 +154,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
|
|
||||||
s.udpMux = NewUniversalUDPMuxDefault(
|
s.udpMux = NewUniversalUDPMuxDefault(
|
||||||
UniversalUDPMuxParams{
|
UniversalUDPMuxParams{
|
||||||
UDPConn: conn,
|
UDPConn: nbnet.WrapUDPConn(conn),
|
||||||
Net: s.transportNet,
|
Net: s.transportNet,
|
||||||
FilterFn: s.filterFn,
|
FilterFn: s.filterFn,
|
||||||
WGAddress: s.address,
|
WGAddress: s.address,
|
||||||
|
@ -296,14 +296,20 @@ func (m *UDPMuxDefault) RemoveConnByUfrag(ufrag string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m.addressMapMu.Lock()
|
var allAddresses []string
|
||||||
defer m.addressMapMu.Unlock()
|
|
||||||
|
|
||||||
for _, c := range removedConns {
|
for _, c := range removedConns {
|
||||||
addresses := c.getAddresses()
|
addresses := c.getAddresses()
|
||||||
for _, addr := range addresses {
|
allAddresses = append(allAddresses, addresses...)
|
||||||
delete(m.addressMap, addr)
|
}
|
||||||
}
|
|
||||||
|
m.addressMapMu.Lock()
|
||||||
|
for _, addr := range allAddresses {
|
||||||
|
delete(m.addressMap, addr)
|
||||||
|
}
|
||||||
|
m.addressMapMu.Unlock()
|
||||||
|
|
||||||
|
for _, addr := range allAddresses {
|
||||||
|
m.notifyAddressRemoval(addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -351,14 +357,13 @@ func (m *UDPMuxDefault) registerConnForAddress(conn *udpMuxedConn, addr string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.addressMapMu.Lock()
|
m.addressMapMu.Lock()
|
||||||
defer m.addressMapMu.Unlock()
|
|
||||||
|
|
||||||
existing, ok := m.addressMap[addr]
|
existing, ok := m.addressMap[addr]
|
||||||
if !ok {
|
if !ok {
|
||||||
existing = []*udpMuxedConn{}
|
existing = []*udpMuxedConn{}
|
||||||
}
|
}
|
||||||
existing = append(existing, conn)
|
existing = append(existing, conn)
|
||||||
m.addressMap[addr] = existing
|
m.addressMap[addr] = existing
|
||||||
|
m.addressMapMu.Unlock()
|
||||||
|
|
||||||
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
log.Debugf("ICE: registered %s for %s", addr, conn.params.Key)
|
||||||
}
|
}
|
||||||
@ -386,12 +391,12 @@ func (m *UDPMuxDefault) HandleSTUNMessage(msg *stun.Message, addr net.Addr) erro
|
|||||||
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
// If you are using the same socket for the Host and SRFLX candidates, it might be that there are more than one
|
||||||
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
// muxed connection - one for the SRFLX candidate and the other one for the HOST one.
|
||||||
// We will then forward STUN packets to each of these connections.
|
// We will then forward STUN packets to each of these connections.
|
||||||
m.addressMapMu.Lock()
|
m.addressMapMu.RLock()
|
||||||
var destinationConnList []*udpMuxedConn
|
var destinationConnList []*udpMuxedConn
|
||||||
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
if storedConns, ok := m.addressMap[addr.String()]; ok {
|
||||||
destinationConnList = append(destinationConnList, storedConns...)
|
destinationConnList = append(destinationConnList, storedConns...)
|
||||||
}
|
}
|
||||||
m.addressMapMu.Unlock()
|
m.addressMapMu.RUnlock()
|
||||||
|
|
||||||
var isIPv6 bool
|
var isIPv6 bool
|
||||||
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
if udpAddr, _ := addr.(*net.UDPAddr); udpAddr != nil && udpAddr.IP.To4() == nil {
|
||||||
|
21
client/iface/bind/udp_mux_generic.go
Normal file
21
client/iface/bind/udp_mux_generic.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||||
|
wrapped, ok := m.params.UDPConn.(*UDPConn)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
nbnetConn, ok := wrapped.GetPacketConn().(*nbnet.UDPConn)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
nbnetConn.RemoveAddress(addr)
|
||||||
|
}
|
7
client/iface/bind/udp_mux_ios.go
Normal file
7
client/iface/bind/udp_mux_ios.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
//go:build ios
|
||||||
|
|
||||||
|
package bind
|
||||||
|
|
||||||
|
func (m *UDPMuxDefault) notifyAddressRemoval(addr string) {
|
||||||
|
// iOS doesn't support nbnet hooks, so this is a no-op
|
||||||
|
}
|
@ -62,7 +62,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
|||||||
|
|
||||||
// wrap UDP connection, process server reflexive messages
|
// wrap UDP connection, process server reflexive messages
|
||||||
// before they are passed to the UDPMux connection handler (connWorker)
|
// before they are passed to the UDPMux connection handler (connWorker)
|
||||||
m.params.UDPConn = &udpConn{
|
m.params.UDPConn = &UDPConn{
|
||||||
PacketConn: params.UDPConn,
|
PacketConn: params.UDPConn,
|
||||||
mux: m,
|
mux: m,
|
||||||
logger: params.Logger,
|
logger: params.Logger,
|
||||||
@ -70,7 +70,6 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
|||||||
address: params.WGAddress,
|
address: params.WGAddress,
|
||||||
}
|
}
|
||||||
|
|
||||||
// embed UDPMux
|
|
||||||
udpMuxParams := UDPMuxParams{
|
udpMuxParams := UDPMuxParams{
|
||||||
Logger: params.Logger,
|
Logger: params.Logger,
|
||||||
UDPConn: m.params.UDPConn,
|
UDPConn: m.params.UDPConn,
|
||||||
@ -114,8 +113,8 @@ func (m *UniversalUDPMuxDefault) ReadFromConn(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// udpConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
// UDPConn is a wrapper around UDPMux conn that overrides ReadFrom and handles STUN/TURN packets
|
||||||
type udpConn struct {
|
type UDPConn struct {
|
||||||
net.PacketConn
|
net.PacketConn
|
||||||
mux *UniversalUDPMuxDefault
|
mux *UniversalUDPMuxDefault
|
||||||
logger logging.LeveledLogger
|
logger logging.LeveledLogger
|
||||||
@ -125,7 +124,12 @@ type udpConn struct {
|
|||||||
address wgaddr.Address
|
address wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
// GetPacketConn returns the underlying PacketConn
|
||||||
|
func (u *UDPConn) GetPacketConn() net.PacketConn {
|
||||||
|
return u.PacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||||
if u.filterFn == nil {
|
if u.filterFn == nil {
|
||||||
return u.PacketConn.WriteTo(b, addr)
|
return u.PacketConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
@ -137,21 +141,21 @@ func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
|||||||
return u.handleUncachedAddress(b, addr)
|
return u.handleUncachedAddress(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
func (u *UDPConn) handleCachedAddress(isRouted bool, b []byte, addr net.Addr) (int, error) {
|
||||||
if isRouted {
|
if isRouted {
|
||||||
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
return 0, fmt.Errorf("address %s is part of a routed network, refusing to write", addr)
|
||||||
}
|
}
|
||||||
return u.PacketConn.WriteTo(b, addr)
|
return u.PacketConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
func (u *UDPConn) handleUncachedAddress(b []byte, addr net.Addr) (int, error) {
|
||||||
if err := u.performFilterCheck(addr); err != nil {
|
if err := u.performFilterCheck(addr); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return u.PacketConn.WriteTo(b, addr)
|
return u.PacketConn.WriteTo(b, addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||||
host, err := getHostFromAddr(addr)
|
host, err := getHostFromAddr(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
log.Errorf("Failed to get host from address %s: %v", addr, err)
|
||||||
|
@ -61,7 +61,6 @@ import (
|
|||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
sProto "github.com/netbirdio/netbird/signal/proto"
|
sProto "github.com/netbirdio/netbird/signal/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||||
@ -138,9 +137,6 @@ type Engine struct {
|
|||||||
|
|
||||||
connMgr *ConnMgr
|
connMgr *ConnMgr
|
||||||
|
|
||||||
beforePeerHook nbnet.AddHookFunc
|
|
||||||
afterPeerHook nbnet.RemoveHookFunc
|
|
||||||
|
|
||||||
// rpManager is a Rosenpass manager
|
// rpManager is a Rosenpass manager
|
||||||
rpManager *rosenpass.Manager
|
rpManager *rosenpass.Manager
|
||||||
|
|
||||||
@ -409,12 +405,8 @@ func (e *Engine) Start() error {
|
|||||||
DisableClientRoutes: e.config.DisableClientRoutes,
|
DisableClientRoutes: e.config.DisableClientRoutes,
|
||||||
DisableServerRoutes: e.config.DisableServerRoutes,
|
DisableServerRoutes: e.config.DisableServerRoutes,
|
||||||
})
|
})
|
||||||
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
|
if err := e.routeManager.Init(); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to initialize route manager: %s", err)
|
log.Errorf("Failed to initialize route manager: %s", err)
|
||||||
} else {
|
|
||||||
e.beforePeerHook = beforePeerHook
|
|
||||||
e.afterPeerHook = afterPeerHook
|
|
||||||
}
|
}
|
||||||
|
|
||||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||||
@ -1261,10 +1253,6 @@ func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
|
|||||||
return fmt.Errorf("peer already exists: %s", peerKey)
|
return fmt.Errorf("peer already exists: %s", peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.beforePeerHook != nil && e.afterPeerHook != nil {
|
|
||||||
conn.AddBeforeAddPeerHook(e.beforePeerHook)
|
|
||||||
conn.AddAfterRemovePeerHook(e.afterPeerHook)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -400,7 +400,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
StatusRecorder: engine.statusRecorder,
|
StatusRecorder: engine.statusRecorder,
|
||||||
RelayManager: relayMgr,
|
RelayManager: relayMgr,
|
||||||
})
|
})
|
||||||
_, _, err = engine.routeManager.Init()
|
err = engine.routeManager.Init()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
|
||||||
|
@ -26,7 +26,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
relayClient "github.com/netbirdio/netbird/relay/client"
|
relayClient "github.com/netbirdio/netbird/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -106,10 +105,6 @@ type Conn struct {
|
|||||||
workerRelay *WorkerRelay
|
workerRelay *WorkerRelay
|
||||||
wgWatcherWg sync.WaitGroup
|
wgWatcherWg sync.WaitGroup
|
||||||
|
|
||||||
connIDRelay nbnet.ConnectionID
|
|
||||||
connIDICE nbnet.ConnectionID
|
|
||||||
beforeAddPeerHooks []nbnet.AddHookFunc
|
|
||||||
afterRemovePeerHooks []nbnet.RemoveHookFunc
|
|
||||||
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
|
// used to store the remote Rosenpass key for Relayed connection in case of connection update from ice
|
||||||
rosenpassRemoteKey []byte
|
rosenpassRemoteKey []byte
|
||||||
|
|
||||||
@ -267,8 +262,6 @@ func (conn *Conn) Close(signalToRemote bool) {
|
|||||||
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
conn.Log.Errorf("failed to remove wg endpoint: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.freeUpConnID()
|
|
||||||
|
|
||||||
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
|
if conn.evalStatus() == StatusConnected && conn.onDisconnected != nil {
|
||||||
conn.onDisconnected(conn.config.WgConfig.RemoteKey)
|
conn.onDisconnected(conn.config.WgConfig.RemoteKey)
|
||||||
}
|
}
|
||||||
@ -293,13 +286,6 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMa
|
|||||||
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) AddBeforeAddPeerHook(hook nbnet.AddHookFunc) {
|
|
||||||
conn.beforeAddPeerHooks = append(conn.beforeAddPeerHooks, hook)
|
|
||||||
}
|
|
||||||
func (conn *Conn) AddAfterRemovePeerHook(hook nbnet.RemoveHookFunc) {
|
|
||||||
conn.afterRemovePeerHooks = append(conn.afterRemovePeerHooks, hook)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
|
||||||
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
|
func (conn *Conn) SetOnConnected(handler func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string)) {
|
||||||
conn.onConnected = handler
|
conn.onConnected = handler
|
||||||
@ -387,10 +373,6 @@ func (conn *Conn) onICEConnectionIsReady(priority conntype.ConnPriority, iceConn
|
|||||||
ep = directEp
|
ep = directEp
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.runBeforeAddPeerHooks(ep.IP); err != nil {
|
|
||||||
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.workerRelay.DisableWgWatcher()
|
conn.workerRelay.DisableWgWatcher()
|
||||||
// todo consider to run conn.wgWatcherWg.Wait() here
|
// todo consider to run conn.wgWatcherWg.Wait() here
|
||||||
|
|
||||||
@ -503,10 +485,6 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := conn.runBeforeAddPeerHooks(wgProxy.EndpointAddr().IP); err != nil {
|
|
||||||
conn.Log.Errorf("Before add peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
wgProxy.Work()
|
wgProxy.Work()
|
||||||
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr(), rci.rosenpassPubKey); err != nil {
|
||||||
if err := wgProxy.CloseConn(); err != nil {
|
if err := wgProxy.CloseConn(); err != nil {
|
||||||
@ -707,36 +685,6 @@ func (conn *Conn) isConnectedOnAllWay() (connected bool) {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (conn *Conn) runBeforeAddPeerHooks(ip net.IP) error {
|
|
||||||
conn.connIDICE = nbnet.GenerateConnID()
|
|
||||||
for _, hook := range conn.beforeAddPeerHooks {
|
|
||||||
if err := hook(conn.connIDICE, ip); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *Conn) freeUpConnID() {
|
|
||||||
if conn.connIDRelay != "" {
|
|
||||||
for _, hook := range conn.afterRemovePeerHooks {
|
|
||||||
if err := hook(conn.connIDRelay); err != nil {
|
|
||||||
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.connIDRelay = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
if conn.connIDICE != "" {
|
|
||||||
for _, hook := range conn.afterRemovePeerHooks {
|
|
||||||
if err := hook(conn.connIDICE); err != nil {
|
|
||||||
conn.Log.Errorf("After remove peer hook failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn.connIDICE = ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
|
||||||
conn.Log.Debugf("setup proxied WireGuard connection")
|
conn.Log.Debugf("setup proxied WireGuard connection")
|
||||||
udpAddr := &net.UDPAddr{
|
udpAddr := &net.UDPAddr{
|
||||||
|
@ -812,7 +812,7 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
params := common.HandlerParams{
|
params := common.HandlerParams{
|
||||||
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
|
Route: &route.Route{Network: netip.MustParsePrefix("192.168.0.0/24")},
|
||||||
}
|
}
|
||||||
// create new clientNetwork
|
// create new clientNetwork
|
||||||
client := &Watcher{
|
client := &Watcher{
|
||||||
|
@ -44,7 +44,7 @@ import (
|
|||||||
|
|
||||||
// Manager is a route manager interface
|
// Manager is a route manager interface
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
Init() error
|
||||||
UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
UpdateRoutes(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||||
ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
ClassifyRoutes(newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
||||||
TriggerSelection(route.HAMap)
|
TriggerSelection(route.HAMap)
|
||||||
@ -201,11 +201,11 @@ func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Init sets up the routing
|
// Init sets up the routing
|
||||||
func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (m *DefaultManager) Init() error {
|
||||||
m.routeSelector = m.initSelector()
|
m.routeSelector = m.initSelector()
|
||||||
|
|
||||||
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
|
if nbnet.CustomRoutingDisabled() || m.disableClientRoutes {
|
||||||
return nil, nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.sysOps.CleanupRouting(nil); err != nil {
|
if err := m.sysOps.CleanupRouting(nil); err != nil {
|
||||||
@ -219,13 +219,12 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error)
|
|||||||
|
|
||||||
ips := resolveURLsToIPs(initialAddresses)
|
ips := resolveURLsToIPs(initialAddresses)
|
||||||
|
|
||||||
beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager)
|
if err := m.sysOps.SetupRouting(ips, m.stateManager); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("setup routing: %w", err)
|
||||||
return nil, nil, fmt.Errorf("setup routing: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("Routing setup complete")
|
log.Info("Routing setup complete")
|
||||||
return beforePeerHook, afterPeerHook, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
func (m *DefaultManager) initSelector() *routeselector.RouteSelector {
|
||||||
|
@ -430,7 +430,7 @@ func TestManagerUpdateRoutes(t *testing.T) {
|
|||||||
StatusRecorder: statusRecorder,
|
StatusRecorder: statusRecorder,
|
||||||
})
|
})
|
||||||
|
|
||||||
_, _, err = routeManager.Init()
|
err = routeManager.Init()
|
||||||
|
|
||||||
require.NoError(t, err, "should init route manager")
|
require.NoError(t, err, "should init route manager")
|
||||||
defer routeManager.Stop(nil)
|
defer routeManager.Stop(nil)
|
||||||
|
@ -9,7 +9,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routeselector"
|
"github.com/netbirdio/netbird/client/internal/routeselector"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockManager is the mock instance of a route manager
|
// MockManager is the mock instance of a route manager
|
||||||
@ -23,8 +22,8 @@ type MockManager struct {
|
|||||||
StopFunc func(manager *statemanager.Manager)
|
StopFunc func(manager *statemanager.Manager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) {
|
func (m *MockManager) Init() error {
|
||||||
return nil, nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
// InitialRouteRange mock implementation of InitialRouteRange from Manager interface
|
||||||
|
@ -33,4 +33,4 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
|||||||
|
|
||||||
func (n *Notifier) GetInitialRouteRanges() []string {
|
func (n *Notifier) GetInitialRouteRanges() []string {
|
||||||
return []string{}
|
return []string{}
|
||||||
}
|
}
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
|
||||||
@ -56,6 +57,10 @@ type SysOps struct {
|
|||||||
// seq is an atomic counter for generating unique sequence numbers for route messages
|
// seq is an atomic counter for generating unique sequence numbers for route messages
|
||||||
//nolint:unused // only used on BSD systems
|
//nolint:unused // only used on BSD systems
|
||||||
seq atomic.Uint32
|
seq atomic.Uint32
|
||||||
|
|
||||||
|
localSubnetsCache []*net.IPNet
|
||||||
|
localSubnetsCacheMu sync.RWMutex
|
||||||
|
localSubnetsCacheTime time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps {
|
||||||
|
@ -10,11 +10,10 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||||
return nil, nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||||
|
@ -10,6 +10,7 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
"github.com/libp2p/go-netroute"
|
"github.com/libp2p/go-netroute"
|
||||||
@ -24,6 +25,8 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const localSubnetsCacheTTL = 15 * time.Minute
|
||||||
|
|
||||||
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
var splitDefaultv4_1 = netip.PrefixFrom(netip.IPv4Unspecified(), 1)
|
||||||
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||||
@ -31,7 +34,7 @@ var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
|||||||
|
|
||||||
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
||||||
|
|
||||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
stateManager.RegisterState(&ShutdownState{})
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|
||||||
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
initialNextHopV4, err := GetNextHop(netip.IPv4Unspecified())
|
||||||
@ -75,7 +78,10 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
|
|||||||
|
|
||||||
r.refCounter = refCounter
|
r.refCounter = refCounter
|
||||||
|
|
||||||
return r.setupHooks(initAddresses, stateManager)
|
if err := r.setupHooks(initAddresses, stateManager); err != nil {
|
||||||
|
return fmt.Errorf("setup hooks: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateState updates state on every change so it will be persisted regularly
|
// updateState updates state on every change so it will be persisted regularly
|
||||||
@ -128,18 +134,14 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
|
|||||||
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
|
return Nexthop{}, fmt.Errorf("get next hop: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.IP)
|
log.Debugf("Found next hop %s for prefix %s with interface %v", nexthop.IP, prefix, nexthop.Intf)
|
||||||
exitNextHop := Nexthop{
|
exitNextHop := nexthop
|
||||||
IP: nexthop.IP,
|
|
||||||
Intf: nexthop.Intf,
|
|
||||||
}
|
|
||||||
|
|
||||||
vpnAddr := vpnIntf.Address().IP
|
vpnAddr := vpnIntf.Address().IP
|
||||||
|
|
||||||
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
// if next hop is the VPN address or the interface is the VPN interface, we should use the initial values
|
||||||
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
if exitNextHop.IP == vpnAddr || exitNextHop.Intf != nil && exitNextHop.Intf.Name == vpnIntf.Name() {
|
||||||
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
|
log.Debugf("Route for prefix %s is pointing to the VPN interface, using initial next hop %v", prefix, initialNextHop)
|
||||||
|
|
||||||
exitNextHop = initialNextHop
|
exitNextHop = initialNextHop
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -152,12 +154,37 @@ func (r *SysOps) addRouteToNonVPNIntf(prefix netip.Prefix, vpnIntf wgIface, init
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
|
func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet) {
|
||||||
|
r.localSubnetsCacheMu.RLock()
|
||||||
|
cacheAge := time.Since(r.localSubnetsCacheTime)
|
||||||
|
subnets := r.localSubnetsCache
|
||||||
|
r.localSubnetsCacheMu.RUnlock()
|
||||||
|
|
||||||
|
if cacheAge > localSubnetsCacheTTL || subnets == nil {
|
||||||
|
r.localSubnetsCacheMu.Lock()
|
||||||
|
if time.Since(r.localSubnetsCacheTime) > localSubnetsCacheTTL || r.localSubnetsCache == nil {
|
||||||
|
r.refreshLocalSubnetsCache()
|
||||||
|
}
|
||||||
|
subnets = r.localSubnetsCache
|
||||||
|
r.localSubnetsCacheMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, subnet := range subnets {
|
||||||
|
if subnet.Contains(prefix.Addr().AsSlice()) {
|
||||||
|
return true, subnet
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *SysOps) refreshLocalSubnetsCache() {
|
||||||
localInterfaces, err := net.Interfaces()
|
localInterfaces, err := net.Interfaces()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to get local interfaces: %v", err)
|
log.Errorf("Failed to get local interfaces: %v", err)
|
||||||
return false, nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var newSubnets []*net.IPNet
|
||||||
for _, intf := range localInterfaces {
|
for _, intf := range localInterfaces {
|
||||||
addrs, err := intf.Addrs()
|
addrs, err := intf.Addrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -171,14 +198,12 @@ func (r *SysOps) isPrefixInLocalSubnets(prefix netip.Prefix) (bool, *net.IPNet)
|
|||||||
log.Errorf("Failed to convert address to IPNet: %v", addr)
|
log.Errorf("Failed to convert address to IPNet: %v", addr)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
newSubnets = append(newSubnets, ipnet)
|
||||||
if ipnet.Contains(prefix.Addr().AsSlice()) {
|
|
||||||
return true, ipnet
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false, nil
|
r.localSubnetsCache = newSubnets
|
||||||
|
r.localSubnetsCacheTime = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
// genericAddVPNRoute adds a new route to the vpn interface, it splits the default prefix
|
||||||
@ -264,7 +289,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
|
|||||||
return r.removeFromRouteTable(prefix, nextHop)
|
return r.removeFromRouteTable(prefix, nextHop)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
|
||||||
prefix, err := util.GetPrefixFromIP(ip)
|
prefix, err := util.GetPrefixFromIP(ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -289,9 +314,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
for _, ip := range initAddresses {
|
for _, ip := range initAddresses {
|
||||||
if err := beforeHook("init", ip); err != nil {
|
if err := beforeHook("init", ip); err != nil {
|
||||||
log.Errorf("Failed to add route reference: %v", err)
|
merr = multierror.Append(merr, fmt.Errorf("add initial route for %s: %w", ip, err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -300,11 +327,11 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
|||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
var result *multierror.Error
|
var merr *multierror.Error
|
||||||
for _, ip := range resolvedIPs {
|
for _, ip := range resolvedIPs {
|
||||||
result = multierror.Append(result, beforeHook(connID, ip.IP))
|
merr = multierror.Append(merr, beforeHook(connID, ip.IP))
|
||||||
}
|
}
|
||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
})
|
})
|
||||||
|
|
||||||
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
nbnet.AddDialerCloseHook(func(connID nbnet.ConnectionID, conn *net.Conn) error {
|
||||||
@ -319,7 +346,16 @@ func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.M
|
|||||||
return afterHook(connID)
|
return afterHook(connID)
|
||||||
})
|
})
|
||||||
|
|
||||||
return beforeHook, afterHook, nil
|
nbnet.AddListenerAddressRemoveHook(func(connID nbnet.ConnectionID, prefix netip.Prefix) error {
|
||||||
|
if _, err := r.refCounter.Decrement(prefix); err != nil {
|
||||||
|
return fmt.Errorf("remove route reference: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState(stateManager)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
func GetNextHop(ip netip.Addr) (Nexthop, error) {
|
||||||
|
@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) {
|
|||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
_, _, err := r.SetupRouting(nil, nil)
|
err := r.SetupRouting(nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil))
|
||||||
@ -341,7 +341,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) {
|
|||||||
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n)
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
_, _, err := r.SetupRouting(nil, nil)
|
err := r.SetupRouting(nil, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil))
|
||||||
@ -484,7 +484,7 @@ func setupTestEnv(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
r := NewSysOps(wgInterface, nil)
|
r := NewSysOps(wgInterface, nil)
|
||||||
_, _, err := r.SetupRouting(nil, nil)
|
err := r.SetupRouting(nil, nil)
|
||||||
require.NoError(t, err, "setupRouting should not return err")
|
require.NoError(t, err, "setupRouting should not return err")
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
assert.NoError(t, r.CleanupRouting(nil))
|
assert.NoError(t, r.CleanupRouting(nil))
|
||||||
|
@ -10,14 +10,13 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) SetupRouting([]net.IP, *statemanager.Manager) error {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
r.prefixes = make(map[netip.Prefix]struct{})
|
r.prefixes = make(map[netip.Prefix]struct{})
|
||||||
return nil, nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
func (r *SysOps) CleanupRouting(*statemanager.Manager) error {
|
||||||
|
@ -72,7 +72,7 @@ func getSetupRules() []ruleParams {
|
|||||||
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
// Rule 2 (VPN Traffic Routing): Directs all remaining traffic to the 'NetbirdVPNTableID' custom routing table.
|
||||||
// This table is where a default route or other specific routes received from the management server are configured,
|
// This table is where a default route or other specific routes received from the management server are configured,
|
||||||
// enabling VPN connectivity.
|
// enabling VPN connectivity.
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (_ nbnet.AddHookFunc, _ nbnet.RemoveHookFunc, err error) {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (err error) {
|
||||||
if !nbnet.AdvancedRouting() {
|
if !nbnet.AdvancedRouting() {
|
||||||
log.Infof("Using legacy routing setup")
|
log.Infof("Using legacy routing setup")
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
@ -89,7 +89,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
rules := getSetupRules()
|
rules := getSetupRules()
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := addRule(rule); err != nil {
|
if err := addRule(rule); err != nil {
|
||||||
return nil, nil, fmt.Errorf("%s: %w", rule.description, err)
|
return fmt.Errorf("%s: %w", rule.description, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
|
|||||||
}
|
}
|
||||||
originalSysctl = originalValues
|
originalSysctl = originalValues
|
||||||
|
|
||||||
return nil, nil, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
// CleanupRouting performs a thorough cleanup of the routing configuration established by 'setupRouting'.
|
||||||
|
@ -18,10 +18,9 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ import (
|
|||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const InfiniteLifetime = 0xffffffff
|
const InfiniteLifetime = 0xffffffff
|
||||||
@ -137,7 +136,7 @@ const (
|
|||||||
RouteDeleted
|
RouteDeleted
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
|
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||||
return r.setupRefCounter(initAddresses, stateManager)
|
return r.setupRefCounter(initAddresses, stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -17,11 +18,16 @@ type ListenerWriteHookFunc func(connID ConnectionID, ip *net.IPAddr, data []byte
|
|||||||
// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn.
|
// ListenerCloseHookFunc defines the function signature for close hooks for PacketConn.
|
||||||
type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error
|
type ListenerCloseHookFunc func(connID ConnectionID, conn net.PacketConn) error
|
||||||
|
|
||||||
|
// ListenerAddressRemoveHookFunc defines the function signature for hooks called when addresses are removed.
|
||||||
|
type ListenerAddressRemoveHookFunc func(connID ConnectionID, prefix netip.Prefix) error
|
||||||
|
|
||||||
var (
|
var (
|
||||||
listenerWriteHooksMutex sync.RWMutex
|
listenerWriteHooksMutex sync.RWMutex
|
||||||
listenerWriteHooks []ListenerWriteHookFunc
|
listenerWriteHooks []ListenerWriteHookFunc
|
||||||
listenerCloseHooksMutex sync.RWMutex
|
listenerCloseHooksMutex sync.RWMutex
|
||||||
listenerCloseHooks []ListenerCloseHookFunc
|
listenerCloseHooks []ListenerCloseHookFunc
|
||||||
|
listenerAddressRemoveHooksMutex sync.RWMutex
|
||||||
|
listenerAddressRemoveHooks []ListenerAddressRemoveHookFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent.
|
// AddListenerWriteHook allows adding a new write hook to be executed before a UDP packet is sent.
|
||||||
@ -38,7 +44,14 @@ func AddListenerCloseHook(hook ListenerCloseHookFunc) {
|
|||||||
listenerCloseHooks = append(listenerCloseHooks, hook)
|
listenerCloseHooks = append(listenerCloseHooks, hook)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveListenerHooks removes all dialer hooks.
|
// AddListenerAddressRemoveHook allows adding a new hook to be executed when an address is removed.
|
||||||
|
func AddListenerAddressRemoveHook(hook ListenerAddressRemoveHookFunc) {
|
||||||
|
listenerAddressRemoveHooksMutex.Lock()
|
||||||
|
defer listenerAddressRemoveHooksMutex.Unlock()
|
||||||
|
listenerAddressRemoveHooks = append(listenerAddressRemoveHooks, hook)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveListenerHooks removes all listener hooks.
|
||||||
func RemoveListenerHooks() {
|
func RemoveListenerHooks() {
|
||||||
listenerWriteHooksMutex.Lock()
|
listenerWriteHooksMutex.Lock()
|
||||||
defer listenerWriteHooksMutex.Unlock()
|
defer listenerWriteHooksMutex.Unlock()
|
||||||
@ -47,6 +60,10 @@ func RemoveListenerHooks() {
|
|||||||
listenerCloseHooksMutex.Lock()
|
listenerCloseHooksMutex.Lock()
|
||||||
defer listenerCloseHooksMutex.Unlock()
|
defer listenerCloseHooksMutex.Unlock()
|
||||||
listenerCloseHooks = nil
|
listenerCloseHooks = nil
|
||||||
|
|
||||||
|
listenerAddressRemoveHooksMutex.Lock()
|
||||||
|
defer listenerAddressRemoveHooksMutex.Unlock()
|
||||||
|
listenerAddressRemoveHooks = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenPacket listens on the network address and returns a PacketConn
|
// ListenPacket listens on the network address and returns a PacketConn
|
||||||
@ -61,6 +78,7 @@ func (l *ListenerConfig) ListenPacket(ctx context.Context, network, address stri
|
|||||||
return nil, fmt.Errorf("listen packet: %w", err)
|
return nil, fmt.Errorf("listen packet: %w", err)
|
||||||
}
|
}
|
||||||
connID := GenerateConnID()
|
connID := GenerateConnID()
|
||||||
|
|
||||||
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
|
return &PacketConn{PacketConn: pc, ID: connID, seenAddrs: &sync.Map{}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -102,6 +120,45 @@ func (c *UDPConn) Close() error {
|
|||||||
return closeConn(c.ID, c.UDPConn)
|
return closeConn(c.ID, c.UDPConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WrapUDPConn wraps an existing *net.UDPConn with nbnet functionality
|
||||||
|
func WrapUDPConn(conn *net.UDPConn) *UDPConn {
|
||||||
|
return &UDPConn{
|
||||||
|
UDPConn: conn,
|
||||||
|
ID: GenerateConnID(),
|
||||||
|
seenAddrs: &sync.Map{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAddress removes an address from the seen cache and triggers removal hooks.
|
||||||
|
func (c *UDPConn) RemoveAddress(addr string) {
|
||||||
|
if _, exists := c.seenAddrs.LoadAndDelete(addr); !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ipStr, _, err := net.SplitHostPort(addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Error splitting IP address and port: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ipAddr, err := netip.ParseAddr(ipStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Error parsing IP address %s: %v", ipStr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix := netip.PrefixFrom(ipAddr, ipAddr.BitLen())
|
||||||
|
|
||||||
|
listenerAddressRemoveHooksMutex.RLock()
|
||||||
|
defer listenerAddressRemoveHooksMutex.RUnlock()
|
||||||
|
|
||||||
|
for _, hook := range listenerAddressRemoveHooks {
|
||||||
|
if err := hook(c.ID, prefix); err != nil {
|
||||||
|
log.Errorf("Error executing listener address remove hook: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
|
func callWriteHooks(id ConnectionID, seenAddrs *sync.Map, b []byte, addr net.Addr) {
|
||||||
// Lookup the address in the seenAddrs map to avoid calling the hooks for every write
|
// Lookup the address in the seenAddrs map to avoid calling the hooks for every write
|
||||||
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {
|
if _, loaded := seenAddrs.LoadOrStore(addr.String(), true); !loaded {
|
||||||
|
10
util/net/listener_listen_ios.go
Normal file
10
util/net/listener_listen_ios.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
package net
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WrapUDPConn on iOS just returns the original connection since iOS handles its own networking
|
||||||
|
func WrapUDPConn(conn *net.UDPConn) *net.UDPConn {
|
||||||
|
return conn
|
||||||
|
}
|
Reference in New Issue
Block a user