diff --git a/client/internal/engine.go b/client/internal/engine.go index c36b110d3..ca580bd32 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -34,6 +34,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -117,7 +118,7 @@ type Engine struct { // mgmClient is a Management Service client mgmClient mgm.Client // peerConns is a map that holds all the peers that are known to this peer - peerConns map[string]*peer.Conn + peerStore *peerstore.Store beforePeerHook nbnet.AddHookFunc afterPeerHook nbnet.RemoveHookFunc @@ -231,7 +232,7 @@ func NewEngineWithProbes( signaler: peer.NewSignaler(signalClient, config.WgPrivateKey), mgmClient: mgmClient, relayManager: relayManager, - peerConns: make(map[string]*peer.Conn), + peerStore: peerstore.NewConnStore(), syncMsgMux: &sync.Mutex{}, config: config, mobileDep: mobileDep, @@ -383,7 +384,7 @@ func (e *Engine) Start() error { initialRoutes, e.stateManager, dnsServer, - e.peerConns, + e.peerStore, ) beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err != nil { @@ -462,8 +463,8 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { var modified []*mgmProto.RemotePeerConfig for _, p := range peersUpdate { peerPubKey := p.GetWgPubKey() - if peerConn, ok := e.peerConns[peerPubKey]; ok { - if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") { + if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok { + if allowedIPs != strings.Join(p.AllowedIps, ",") { modified = append(modified, p) continue } @@ -494,17 +495,12 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { // removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service. // It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method. func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error { - currentPeers := make([]string, 0, len(e.peerConns)) - for p := range e.peerConns { - currentPeers = append(currentPeers, p) - } - newPeers := make([]string, 0, len(peersUpdate)) for _, p := range peersUpdate { newPeers = append(newPeers, p.GetWgPubKey()) } - toRemove := util.SliceDiff(currentPeers, newPeers) + toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers) for _, p := range toRemove { err := e.removePeer(p) @@ -518,7 +514,7 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) removeAllPeers() error { log.Debugf("removing all peer connections") - for p := range e.peerConns { + for _, p := range e.peerStore.PeersPubKey() { err := e.removePeer(p) if err != nil { return err @@ -542,9 +538,8 @@ func (e *Engine) removePeer(peerKey string) error { } }() - conn, exists := e.peerConns[peerKey] + conn, exists := e.peerStore.Remove(peerKey) if exists { - delete(e.peerConns, peerKey) conn.Close() } return nil @@ -983,12 +978,16 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { peerKey := peerConfig.GetWgPubKey() peerIPs := peerConfig.GetAllowedIps() - if _, ok := e.peerConns[peerKey]; !ok { + if _, ok := e.peerStore.PeerConn(peerKey); !ok { conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) if err != nil { return fmt.Errorf("create peer connection: %w", err) } - e.peerConns[peerKey] = conn + + if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok { + conn.Close() + return fmt.Errorf("peer already exists: %s", peerKey) + } if e.beforePeerHook != nil && e.afterPeerHook != nil { conn.AddBeforeAddPeerHook(e.beforePeerHook) @@ -1077,8 +1076,8 @@ func (e *Engine) receiveSignalEvents() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - conn := e.peerConns[msg.Key] - if conn == nil { + conn, ok := e.peerStore.PeerConn(msg.Key) + if !ok { return fmt.Errorf("wrongly addressed message %s", msg.Key) } @@ -1407,9 +1406,8 @@ func (e *Engine) receiveProbeEvents() { go e.probes.WgProbe.Receive(e.ctx, func() bool { log.Debug("received wg probe request") - for _, peer := range e.peerConns { - key := peer.GetKey() - wgStats, err := peer.WgConfig().WgInterface.GetStats(key) + for _, key := range e.peerStore.PeersPubKey() { + wgStats, err := e.wgInterface.GetStats(key) if err != nil { log.Debugf("failed to get wg stats for peer %s: %s", key, err) } diff --git a/client/internal/peerstore/store.go b/client/internal/peerstore/store.go new file mode 100644 index 000000000..6b3385ff5 --- /dev/null +++ b/client/internal/peerstore/store.go @@ -0,0 +1,87 @@ +package peerstore + +import ( + "net" + "sync" + + "golang.org/x/exp/maps" + + "github.com/netbirdio/netbird/client/internal/peer" +) + +// Store is a thread-safe store for peer connections. +type Store struct { + peerConns map[string]*peer.Conn + peerConnsMu sync.RWMutex +} + +func NewConnStore() *Store { + return &Store{ + peerConns: make(map[string]*peer.Conn), + } +} + +func (s *Store) AddPeerConn(pubKey string, conn *peer.Conn) bool { + s.peerConnsMu.Lock() + defer s.peerConnsMu.Unlock() + + _, ok := s.peerConns[pubKey] + if ok { + return false + } + + s.peerConns[pubKey] = conn + return true +} + +func (s *Store) Remove(pubKey string) (*peer.Conn, bool) { + s.peerConnsMu.Lock() + defer s.peerConnsMu.Unlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + delete(s.peerConns, pubKey) + return p, true +} + +func (s *Store) AllowedIPs(pubKey string) (string, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return "", false + } + return p.WgConfig().AllowedIps, true +} + +func (s *Store) AllowedIP(pubKey string) (net.IP, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + return p.AllowedIP(), true +} + +func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + return p, true +} + +func (s *Store) PeersPubKey() []string { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + return maps.Keys(s.peerConns) +} diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 4a5cc50f3..b7fc5b15d 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/client/iface" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" @@ -65,7 +66,7 @@ func newClientNetworkWatcher( routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsServer nbdns.Server, - peerConns map[string]*peer.Conn, + peerStore *peerstore.Store, ) *clientNetwork { ctx, cancel := context.WithCancel(ctx) @@ -86,7 +87,7 @@ func newClientNetworkWatcher( statusRecorder, wgInterface, dnsServer, - peerConns, + peerStore, ), } return client @@ -398,7 +399,7 @@ func handlerFromRoute( statusRecorder *peer.Status, wgInterface iface.IWGIface, dnsServer nbdns.Server, - peerConns map[string]*peer.Conn, + peerStore *peerstore.Store, ) RouteHandler { if rt.IsDynamic() { if useNewDNSRoute { @@ -408,7 +409,7 @@ func handlerFromRoute( allowedIPsRefCounter, statusRecorder, dnsServer, - peerConns, + peerStore, ) } dns := nbdns.NewServiceViaMemory(wgInterface) diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 7e67dbc68..991ceda95 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -17,6 +17,7 @@ import ( nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/route" @@ -33,8 +34,7 @@ type DnsInterceptor struct { dnsServer nbdns.Server currentPeerKey string interceptedDomains domainMap - peerConns map[string]*peer.Conn - // TODO: peerConns add lock to sync with engine + peerStore *peerstore.Store } func New( @@ -43,7 +43,7 @@ func New( allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, statusRecorder *peer.Status, dnsServer nbdns.Server, - peerConns map[string]*peer.Conn, + peerStore *peerstore.Store, ) *DnsInterceptor { return &DnsInterceptor{ route: rt, @@ -52,7 +52,7 @@ func New( statusRecorder: statusRecorder, dnsServer: dnsServer, interceptedDomains: make(domainMap), - peerConns: peerConns, + peerStore: peerStore, } } @@ -189,11 +189,11 @@ func (d *DnsInterceptor) getUpstreamIP() (net.IP, error) { d.mu.RLock() defer d.mu.RUnlock() - peerConn, exists := d.peerConns[d.currentPeerKey] + peerAllowedIP, exists := d.peerStore.AllowedIP(d.currentPeerKey) if !exists { return nil, fmt.Errorf("peer connection not found for key: %s", d.currentPeerKey) } - return peerConn.AllowedIP(), nil + return peerAllowedIP, nil } func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index ab2ce0361..30899bc1d 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" @@ -67,7 +68,7 @@ type DefaultManager struct { // clientRoutes is the most recent list of clientRoutes received from the Management Service clientRoutes route.HAMap dnsServer dns.Server - peerConns map[string]*peer.Conn + peerStore *peerstore.Store } func NewManager( @@ -80,7 +81,7 @@ func NewManager( initialRoutes []*route.Route, stateManager *statemanager.Manager, dnsServer dns.Server, - peerConns map[string]*peer.Conn, + peerStore *peerstore.Store, ) *DefaultManager { mCTX, cancel := context.WithCancel(ctx) notifier := notifier.NewNotifier() @@ -99,7 +100,7 @@ func NewManager( notifier: notifier, stateManager: stateManager, dnsServer: dnsServer, - peerConns: peerConns, + peerStore: peerStore, } dm.routeRefCounter = refcounter.New( @@ -316,7 +317,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, - m.peerConns, + m.peerStore, ) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() @@ -346,7 +347,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout for id, routes := range networks { clientNetworkWatcher, found := m.clientNetworks[id] if !found { - clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerConns) + clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerStore) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() }