Create thread safe peer store (#3028)

Create thread safe peer store
This commit is contained in:
Zoltan Papp 2024-12-11 18:37:10 +01:00 committed by GitHub
parent da0a54c6d6
commit cb44454288
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 123 additions and 36 deletions

View File

@ -34,6 +34,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager"
@ -117,7 +118,7 @@ type Engine struct {
// mgmClient is a Management Service client // mgmClient is a Management Service client
mgmClient mgm.Client mgmClient mgm.Client
// peerConns is a map that holds all the peers that are known to this peer // peerConns is a map that holds all the peers that are known to this peer
peerConns map[string]*peer.Conn peerStore *peerstore.Store
beforePeerHook nbnet.AddHookFunc beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc afterPeerHook nbnet.RemoveHookFunc
@ -231,7 +232,7 @@ func NewEngineWithProbes(
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey), signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
mgmClient: mgmClient, mgmClient: mgmClient,
relayManager: relayManager, relayManager: relayManager,
peerConns: make(map[string]*peer.Conn), peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{}, syncMsgMux: &sync.Mutex{},
config: config, config: config,
mobileDep: mobileDep, mobileDep: mobileDep,
@ -383,7 +384,7 @@ func (e *Engine) Start() error {
initialRoutes, initialRoutes,
e.stateManager, e.stateManager,
dnsServer, dnsServer,
e.peerConns, e.peerStore,
) )
beforePeerHook, afterPeerHook, err := e.routeManager.Init() beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil { if err != nil {
@ -462,8 +463,8 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
var modified []*mgmProto.RemotePeerConfig var modified []*mgmProto.RemotePeerConfig
for _, p := range peersUpdate { for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey() peerPubKey := p.GetWgPubKey()
if peerConn, ok := e.peerConns[peerPubKey]; ok { if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok {
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") { if allowedIPs != strings.Join(p.AllowedIps, ",") {
modified = append(modified, p) modified = append(modified, p)
continue continue
} }
@ -494,17 +495,12 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service. // removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method. // It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
currentPeers := make([]string, 0, len(e.peerConns))
for p := range e.peerConns {
currentPeers = append(currentPeers, p)
}
newPeers := make([]string, 0, len(peersUpdate)) newPeers := make([]string, 0, len(peersUpdate))
for _, p := range peersUpdate { for _, p := range peersUpdate {
newPeers = append(newPeers, p.GetWgPubKey()) newPeers = append(newPeers, p.GetWgPubKey())
} }
toRemove := util.SliceDiff(currentPeers, newPeers) toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
for _, p := range toRemove { for _, p := range toRemove {
err := e.removePeer(p) err := e.removePeer(p)
@ -518,7 +514,7 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
func (e *Engine) removeAllPeers() error { func (e *Engine) removeAllPeers() error {
log.Debugf("removing all peer connections") log.Debugf("removing all peer connections")
for p := range e.peerConns { for _, p := range e.peerStore.PeersPubKey() {
err := e.removePeer(p) err := e.removePeer(p)
if err != nil { if err != nil {
return err return err
@ -542,9 +538,8 @@ func (e *Engine) removePeer(peerKey string) error {
} }
}() }()
conn, exists := e.peerConns[peerKey] conn, exists := e.peerStore.Remove(peerKey)
if exists { if exists {
delete(e.peerConns, peerKey)
conn.Close() conn.Close()
} }
return nil return nil
@ -983,12 +978,16 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
peerKey := peerConfig.GetWgPubKey() peerKey := peerConfig.GetWgPubKey()
peerIPs := peerConfig.GetAllowedIps() peerIPs := peerConfig.GetAllowedIps()
if _, ok := e.peerConns[peerKey]; !ok { if _, ok := e.peerStore.PeerConn(peerKey); !ok {
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
if err != nil { if err != nil {
return fmt.Errorf("create peer connection: %w", err) return fmt.Errorf("create peer connection: %w", err)
} }
e.peerConns[peerKey] = conn
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
conn.Close()
return fmt.Errorf("peer already exists: %s", peerKey)
}
if e.beforePeerHook != nil && e.afterPeerHook != nil { if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook) conn.AddBeforeAddPeerHook(e.beforePeerHook)
@ -1077,8 +1076,8 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()
conn := e.peerConns[msg.Key] conn, ok := e.peerStore.PeerConn(msg.Key)
if conn == nil { if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key) return fmt.Errorf("wrongly addressed message %s", msg.Key)
} }
@ -1407,9 +1406,8 @@ func (e *Engine) receiveProbeEvents() {
go e.probes.WgProbe.Receive(e.ctx, func() bool { go e.probes.WgProbe.Receive(e.ctx, func() bool {
log.Debug("received wg probe request") log.Debug("received wg probe request")
for _, peer := range e.peerConns { for _, key := range e.peerStore.PeersPubKey() {
key := peer.GetKey() wgStats, err := e.wgInterface.GetStats(key)
wgStats, err := peer.WgConfig().WgInterface.GetStats(key)
if err != nil { if err != nil {
log.Debugf("failed to get wg stats for peer %s: %s", key, err) log.Debugf("failed to get wg stats for peer %s: %s", key, err)
} }

View File

@ -0,0 +1,87 @@
package peerstore
import (
"net"
"sync"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/peer"
)
// Store is a thread-safe store for peer connections.
type Store struct {
peerConns map[string]*peer.Conn
peerConnsMu sync.RWMutex
}
func NewConnStore() *Store {
return &Store{
peerConns: make(map[string]*peer.Conn),
}
}
func (s *Store) AddPeerConn(pubKey string, conn *peer.Conn) bool {
s.peerConnsMu.Lock()
defer s.peerConnsMu.Unlock()
_, ok := s.peerConns[pubKey]
if ok {
return false
}
s.peerConns[pubKey] = conn
return true
}
func (s *Store) Remove(pubKey string) (*peer.Conn, bool) {
s.peerConnsMu.Lock()
defer s.peerConnsMu.Unlock()
p, ok := s.peerConns[pubKey]
if !ok {
return nil, false
}
delete(s.peerConns, pubKey)
return p, true
}
func (s *Store) AllowedIPs(pubKey string) (string, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return "", false
}
return p.WgConfig().AllowedIps, true
}
func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return nil, false
}
return p.AllowedIP(), true
}
func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return nil, false
}
return p, true
}
func (s *Store) PeersPubKey() []string {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
return maps.Keys(s.peerConns)
}

View File

@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor" "github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
@ -65,7 +66,7 @@ func newClientNetworkWatcher(
routeRefCounter *refcounter.RouteRefCounter, routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsServer nbdns.Server, dnsServer nbdns.Server,
peerConns map[string]*peer.Conn, peerStore *peerstore.Store,
) *clientNetwork { ) *clientNetwork {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
@ -86,7 +87,7 @@ func newClientNetworkWatcher(
statusRecorder, statusRecorder,
wgInterface, wgInterface,
dnsServer, dnsServer,
peerConns, peerStore,
), ),
} }
return client return client
@ -398,7 +399,7 @@ func handlerFromRoute(
statusRecorder *peer.Status, statusRecorder *peer.Status,
wgInterface iface.IWGIface, wgInterface iface.IWGIface,
dnsServer nbdns.Server, dnsServer nbdns.Server,
peerConns map[string]*peer.Conn, peerStore *peerstore.Store,
) RouteHandler { ) RouteHandler {
if rt.IsDynamic() { if rt.IsDynamic() {
if useNewDNSRoute { if useNewDNSRoute {
@ -408,7 +409,7 @@ func handlerFromRoute(
allowedIPsRefCounter, allowedIPsRefCounter,
statusRecorder, statusRecorder,
dnsServer, dnsServer,
peerConns, peerStore,
) )
} }
dns := nbdns.NewServiceViaMemory(wgInterface) dns := nbdns.NewServiceViaMemory(wgInterface)

View File

@ -17,6 +17,7 @@ import (
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@ -33,8 +34,7 @@ type DnsInterceptor struct {
dnsServer nbdns.Server dnsServer nbdns.Server
currentPeerKey string currentPeerKey string
interceptedDomains domainMap interceptedDomains domainMap
peerConns map[string]*peer.Conn peerStore *peerstore.Store
// TODO: peerConns add lock to sync with engine
} }
func New( func New(
@ -43,7 +43,7 @@ func New(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status, statusRecorder *peer.Status,
dnsServer nbdns.Server, dnsServer nbdns.Server,
peerConns map[string]*peer.Conn, peerStore *peerstore.Store,
) *DnsInterceptor { ) *DnsInterceptor {
return &DnsInterceptor{ return &DnsInterceptor{
route: rt, route: rt,
@ -52,7 +52,7 @@ func New(
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
dnsServer: dnsServer, dnsServer: dnsServer,
interceptedDomains: make(domainMap), interceptedDomains: make(domainMap),
peerConns: peerConns, peerStore: peerStore,
} }
} }
@ -189,11 +189,11 @@ func (d *DnsInterceptor) getUpstreamIP() (net.IP, error) {
d.mu.RLock() d.mu.RLock()
defer d.mu.RUnlock() defer d.mu.RUnlock()
peerConn, exists := d.peerConns[d.currentPeerKey] peerAllowedIP, exists := d.peerStore.AllowedIP(d.currentPeerKey)
if !exists { if !exists {
return nil, fmt.Errorf("peer connection not found for key: %s", d.currentPeerKey) return nil, fmt.Errorf("peer connection not found for key: %s", d.currentPeerKey)
} }
return peerConn.AllowedIP(), nil return peerAllowedIP, nil
} }
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {

View File

@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
@ -67,7 +68,7 @@ type DefaultManager struct {
// clientRoutes is the most recent list of clientRoutes received from the Management Service // clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap clientRoutes route.HAMap
dnsServer dns.Server dnsServer dns.Server
peerConns map[string]*peer.Conn peerStore *peerstore.Store
} }
func NewManager( func NewManager(
@ -80,7 +81,7 @@ func NewManager(
initialRoutes []*route.Route, initialRoutes []*route.Route,
stateManager *statemanager.Manager, stateManager *statemanager.Manager,
dnsServer dns.Server, dnsServer dns.Server,
peerConns map[string]*peer.Conn, peerStore *peerstore.Store,
) *DefaultManager { ) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx) mCTX, cancel := context.WithCancel(ctx)
notifier := notifier.NewNotifier() notifier := notifier.NewNotifier()
@ -99,7 +100,7 @@ func NewManager(
notifier: notifier, notifier: notifier,
stateManager: stateManager, stateManager: stateManager,
dnsServer: dnsServer, dnsServer: dnsServer,
peerConns: peerConns, peerStore: peerStore,
} }
dm.routeRefCounter = refcounter.New( dm.routeRefCounter = refcounter.New(
@ -316,7 +317,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
m.routeRefCounter, m.routeRefCounter,
m.allowedIPsRefCounter, m.allowedIPsRefCounter,
m.dnsServer, m.dnsServer,
m.peerConns, m.peerStore,
) )
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
@ -346,7 +347,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks { for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id] clientNetworkWatcher, found := m.clientNetworks[id]
if !found { if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerConns) clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerStore)
m.clientNetworks[id] = clientNetworkWatcher m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher() go clientNetworkWatcher.peersStateAndUpdateWatcher()
} }