Create thread safe peer store (#3028)

Create thread safe peer store
This commit is contained in:
Zoltan Papp 2024-12-11 18:37:10 +01:00 committed by GitHub
parent da0a54c6d6
commit cb44454288
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 123 additions and 36 deletions

View File

@ -34,6 +34,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard"
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/routemanager"
@ -117,7 +118,7 @@ type Engine struct {
// mgmClient is a Management Service client
mgmClient mgm.Client
// peerConns is a map that holds all the peers that are known to this peer
peerConns map[string]*peer.Conn
peerStore *peerstore.Store
beforePeerHook nbnet.AddHookFunc
afterPeerHook nbnet.RemoveHookFunc
@ -231,7 +232,7 @@ func NewEngineWithProbes(
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
mgmClient: mgmClient,
relayManager: relayManager,
peerConns: make(map[string]*peer.Conn),
peerStore: peerstore.NewConnStore(),
syncMsgMux: &sync.Mutex{},
config: config,
mobileDep: mobileDep,
@ -383,7 +384,7 @@ func (e *Engine) Start() error {
initialRoutes,
e.stateManager,
dnsServer,
e.peerConns,
e.peerStore,
)
beforePeerHook, afterPeerHook, err := e.routeManager.Init()
if err != nil {
@ -462,8 +463,8 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
var modified []*mgmProto.RemotePeerConfig
for _, p := range peersUpdate {
peerPubKey := p.GetWgPubKey()
if peerConn, ok := e.peerConns[peerPubKey]; ok {
if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") {
if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok {
if allowedIPs != strings.Join(p.AllowedIps, ",") {
modified = append(modified, p)
continue
}
@ -494,17 +495,12 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service.
// It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method.
func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
currentPeers := make([]string, 0, len(e.peerConns))
for p := range e.peerConns {
currentPeers = append(currentPeers, p)
}
newPeers := make([]string, 0, len(peersUpdate))
for _, p := range peersUpdate {
newPeers = append(newPeers, p.GetWgPubKey())
}
toRemove := util.SliceDiff(currentPeers, newPeers)
toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers)
for _, p := range toRemove {
err := e.removePeer(p)
@ -518,7 +514,7 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
func (e *Engine) removeAllPeers() error {
log.Debugf("removing all peer connections")
for p := range e.peerConns {
for _, p := range e.peerStore.PeersPubKey() {
err := e.removePeer(p)
if err != nil {
return err
@ -542,9 +538,8 @@ func (e *Engine) removePeer(peerKey string) error {
}
}()
conn, exists := e.peerConns[peerKey]
conn, exists := e.peerStore.Remove(peerKey)
if exists {
delete(e.peerConns, peerKey)
conn.Close()
}
return nil
@ -983,12 +978,16 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error {
peerKey := peerConfig.GetWgPubKey()
peerIPs := peerConfig.GetAllowedIps()
if _, ok := e.peerConns[peerKey]; !ok {
if _, ok := e.peerStore.PeerConn(peerKey); !ok {
conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ","))
if err != nil {
return fmt.Errorf("create peer connection: %w", err)
}
e.peerConns[peerKey] = conn
if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok {
conn.Close()
return fmt.Errorf("peer already exists: %s", peerKey)
}
if e.beforePeerHook != nil && e.afterPeerHook != nil {
conn.AddBeforeAddPeerHook(e.beforePeerHook)
@ -1077,8 +1076,8 @@ func (e *Engine) receiveSignalEvents() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
conn := e.peerConns[msg.Key]
if conn == nil {
conn, ok := e.peerStore.PeerConn(msg.Key)
if !ok {
return fmt.Errorf("wrongly addressed message %s", msg.Key)
}
@ -1407,9 +1406,8 @@ func (e *Engine) receiveProbeEvents() {
go e.probes.WgProbe.Receive(e.ctx, func() bool {
log.Debug("received wg probe request")
for _, peer := range e.peerConns {
key := peer.GetKey()
wgStats, err := peer.WgConfig().WgInterface.GetStats(key)
for _, key := range e.peerStore.PeersPubKey() {
wgStats, err := e.wgInterface.GetStats(key)
if err != nil {
log.Debugf("failed to get wg stats for peer %s: %s", key, err)
}

View File

@ -0,0 +1,87 @@
package peerstore
import (
"net"
"sync"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/peer"
)
// Store is a thread-safe store for peer connections.
type Store struct {
peerConns map[string]*peer.Conn
peerConnsMu sync.RWMutex
}
func NewConnStore() *Store {
return &Store{
peerConns: make(map[string]*peer.Conn),
}
}
func (s *Store) AddPeerConn(pubKey string, conn *peer.Conn) bool {
s.peerConnsMu.Lock()
defer s.peerConnsMu.Unlock()
_, ok := s.peerConns[pubKey]
if ok {
return false
}
s.peerConns[pubKey] = conn
return true
}
func (s *Store) Remove(pubKey string) (*peer.Conn, bool) {
s.peerConnsMu.Lock()
defer s.peerConnsMu.Unlock()
p, ok := s.peerConns[pubKey]
if !ok {
return nil, false
}
delete(s.peerConns, pubKey)
return p, true
}
func (s *Store) AllowedIPs(pubKey string) (string, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return "", false
}
return p.WgConfig().AllowedIps, true
}
func (s *Store) AllowedIP(pubKey string) (net.IP, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return nil, false
}
return p.AllowedIP(), true
}
func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
p, ok := s.peerConns[pubKey]
if !ok {
return nil, false
}
return p, true
}
func (s *Store) PeersPubKey() []string {
s.peerConnsMu.RLock()
defer s.peerConnsMu.RUnlock()
return maps.Keys(s.peerConns)
}

View File

@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/client/iface"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
@ -65,7 +66,7 @@ func newClientNetworkWatcher(
routeRefCounter *refcounter.RouteRefCounter,
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
dnsServer nbdns.Server,
peerConns map[string]*peer.Conn,
peerStore *peerstore.Store,
) *clientNetwork {
ctx, cancel := context.WithCancel(ctx)
@ -86,7 +87,7 @@ func newClientNetworkWatcher(
statusRecorder,
wgInterface,
dnsServer,
peerConns,
peerStore,
),
}
return client
@ -398,7 +399,7 @@ func handlerFromRoute(
statusRecorder *peer.Status,
wgInterface iface.IWGIface,
dnsServer nbdns.Server,
peerConns map[string]*peer.Conn,
peerStore *peerstore.Store,
) RouteHandler {
if rt.IsDynamic() {
if useNewDNSRoute {
@ -408,7 +409,7 @@ func handlerFromRoute(
allowedIPsRefCounter,
statusRecorder,
dnsServer,
peerConns,
peerStore,
)
}
dns := nbdns.NewServiceViaMemory(wgInterface)

View File

@ -17,6 +17,7 @@ import (
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route"
@ -33,8 +34,7 @@ type DnsInterceptor struct {
dnsServer nbdns.Server
currentPeerKey string
interceptedDomains domainMap
peerConns map[string]*peer.Conn
// TODO: peerConns add lock to sync with engine
peerStore *peerstore.Store
}
func New(
@ -43,7 +43,7 @@ func New(
allowedIPsRefCounter *refcounter.AllowedIPsRefCounter,
statusRecorder *peer.Status,
dnsServer nbdns.Server,
peerConns map[string]*peer.Conn,
peerStore *peerstore.Store,
) *DnsInterceptor {
return &DnsInterceptor{
route: rt,
@ -52,7 +52,7 @@ func New(
statusRecorder: statusRecorder,
dnsServer: dnsServer,
interceptedDomains: make(domainMap),
peerConns: peerConns,
peerStore: peerStore,
}
}
@ -189,11 +189,11 @@ func (d *DnsInterceptor) getUpstreamIP() (net.IP, error) {
d.mu.RLock()
defer d.mu.RUnlock()
peerConn, exists := d.peerConns[d.currentPeerKey]
peerAllowedIP, exists := d.peerStore.AllowedIP(d.currentPeerKey)
if !exists {
return nil, fmt.Errorf("peer connection not found for key: %s", d.currentPeerKey)
}
return peerConn.AllowedIP(), nil
return peerAllowedIP, nil
}
func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error {

View File

@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peerstore"
"github.com/netbirdio/netbird/client/internal/routemanager/notifier"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
@ -67,7 +68,7 @@ type DefaultManager struct {
// clientRoutes is the most recent list of clientRoutes received from the Management Service
clientRoutes route.HAMap
dnsServer dns.Server
peerConns map[string]*peer.Conn
peerStore *peerstore.Store
}
func NewManager(
@ -80,7 +81,7 @@ func NewManager(
initialRoutes []*route.Route,
stateManager *statemanager.Manager,
dnsServer dns.Server,
peerConns map[string]*peer.Conn,
peerStore *peerstore.Store,
) *DefaultManager {
mCTX, cancel := context.WithCancel(ctx)
notifier := notifier.NewNotifier()
@ -99,7 +100,7 @@ func NewManager(
notifier: notifier,
stateManager: stateManager,
dnsServer: dnsServer,
peerConns: peerConns,
peerStore: peerStore,
}
dm.routeRefCounter = refcounter.New(
@ -316,7 +317,7 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
m.routeRefCounter,
m.allowedIPsRefCounter,
m.dnsServer,
m.peerConns,
m.peerStore,
)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
@ -346,7 +347,7 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
for id, routes := range networks {
clientNetworkWatcher, found := m.clientNetworks[id]
if !found {
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerConns)
clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter, m.dnsServer, m.peerStore)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.peersStateAndUpdateWatcher()
}