[management] Improve mgmt sync performance (#2363)

This commit is contained in:
Viktor Liu 2024-08-07 10:52:31 +02:00 committed by GitHub
parent 54d896846b
commit ac0d5ff9f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 1005 additions and 174 deletions

View File

@ -11,6 +11,7 @@ import (
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
@ -71,6 +72,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
t.Helper()
lis, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
@ -88,7 +90,11 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
return nil, nil
}
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
if err != nil {
t.Fatal(err)
}

View File

@ -36,6 +36,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
signal "github.com/netbirdio/netbird/signal/client"
"github.com/netbirdio/netbird/signal/proto"
@ -1069,7 +1070,11 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
if err != nil {
return nil, "", err
}

View File

@ -19,6 +19,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/signal/proto"
signalServer "github.com/netbirdio/netbird/signal/server"
)
@ -120,7 +121,11 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
if err != nil {
return nil, "", err
}

View File

@ -9,7 +9,10 @@ import (
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/client/system"
@ -71,7 +74,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
if err != nil {
t.Fatal(err)
}

View File

@ -190,7 +190,7 @@ var (
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
}
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator)
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics)
if err != nil {
return fmt.Errorf("failed to build default manager: %v", err)
}

View File

@ -18,6 +18,8 @@ import (
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
"github.com/hashicorp/go-multierror"
"github.com/miekg/dns"
gocache "github.com/patrickmn/go-cache"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
@ -37,6 +39,7 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
)
@ -170,6 +173,8 @@ type DefaultAccountManager struct {
userDeleteFromIDPEnabled bool
integratedPeerValidator integrated_validator.IntegratedValidator
metrics telemetry.AppMetrics
}
// Settings represents Account settings structure that can be modified via API and Dashboard
@ -401,8 +406,16 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
return a.Groups[groupID]
}
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap {
// GetPeerNetworkMap returns the networkmap for the given peer ID.
func (a *Account) GetPeerNetworkMap(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
validatedPeersMap map[string]struct{},
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
@ -438,7 +451,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
if dnsManagementStatus {
var zones []nbdns.CustomZone
peersCustomZone := getPeersCustomZone(ctx, a, dnsDomain)
if peersCustomZone.Domain != "" {
zones = append(zones, peersCustomZone)
}
@ -446,7 +459,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
}
return &NetworkMap{
nm := &NetworkMap{
Peers: peersToConnect,
Network: a.Network.Copy(),
Routes: routesUpdate,
@ -454,6 +467,60 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
}
if metrics != nil {
objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
}
return nm
}
func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone {
var merr *multierror.Error
if dnsDomain == "" {
log.WithContext(ctx).Error("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
}
customZone := nbdns.CustomZone{
Domain: dns.Fqdn(dnsDomain),
Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)),
}
domainSuffix := "." + dnsDomain
var sb strings.Builder
for _, peer := range a.Peers {
if peer.DNSLabel == "" {
merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name))
continue
}
sb.Grow(len(peer.DNSLabel) + len(domainSuffix))
sb.WriteString(peer.DNSLabel)
sb.WriteString(domainSuffix)
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
Name: sb.String(),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: defaultTTL,
RData: peer.IP.String(),
})
sb.Reset()
}
go func() {
if merr != nil {
log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr)
}
}()
return customZone
}
// GetExpiredPeers returns peers that have been expired
@ -871,10 +938,18 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
}
// BuildManager creates a new DefaultAccountManager with a provided Store
func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation,
func BuildManager(
ctx context.Context,
store Store,
peersUpdateManager *PeersUpdateManager,
idpManager idp.Manager,
singleAccountModeDomain string,
dnsDomain string,
eventStore activity.Store,
geo *geolocation.Geolocation,
userDeleteFromIDPEnabled bool,
integratedPeerValidator integrated_validator.IntegratedValidator,
metrics telemetry.AppMetrics,
) (*DefaultAccountManager, error) {
am := &DefaultAccountManager{
Store: store,
@ -889,6 +964,7 @@ func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpd
peerLoginExpiry: NewDefaultScheduler(),
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
integratedPeerValidator: integratedPeerValidator,
metrics: metrics,
}
allAccounts := store.GetAllAccounts(ctx)
// enable single account mode only if configured by user and number of existing accounts is not grater than 1

View File

@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
)
@ -410,7 +411,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
validatedPeers[p] = struct{}{}
}
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers)
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
}
@ -2293,7 +2295,13 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
})
}
func createManager(t *testing.T) (*DefaultAccountManager, error) {
type TB interface {
Cleanup(func())
Helper()
TempDir() string
}
func createManager(t TB) (*DefaultAccountManager, error) {
t.Helper()
store, err := createStore(t)
@ -2302,7 +2310,12 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
}
eventStore := &activity.InMemoryEventStore{}
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
return nil, err
}
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
if err != nil {
return nil, err
}
@ -2310,7 +2323,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
return manager, nil
}
func createStore(t *testing.T) (Store, error) {
func createStore(t TB) (Store, error) {
t.Helper()
dataDir := t.TempDir()
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)

View File

@ -4,8 +4,8 @@ import (
"context"
"fmt"
"strconv"
"sync"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
@ -17,6 +17,50 @@ import (
const defaultTTL = 300
// DNSConfigCache is a thread-safe cache for DNS configuration components
type DNSConfigCache struct {
CustomZones sync.Map
NameServerGroups sync.Map
}
// GetCustomZone retrieves a cached custom zone
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
if c == nil {
return nil, false
}
if value, ok := c.CustomZones.Load(key); ok {
return value.(*proto.CustomZone), true
}
return nil, false
}
// SetCustomZone stores a custom zone in the cache
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
if c == nil {
return
}
c.CustomZones.Store(key, value)
}
// GetNameServerGroup retrieves a cached name server group
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
if c == nil {
return nil, false
}
if value, ok := c.NameServerGroups.Load(key); ok {
return value.(*proto.NameServerGroup), true
}
return nil, false
}
// SetNameServerGroup stores a name server group in the cache
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
if c == nil {
return
}
c.NameServerGroups.Store(key, value)
}
type lookupMap map[string]struct{}
// DNSSettings defines dns settings at the account level
@ -113,69 +157,73 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return nil
}
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable}
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
ServiceEnable: update.ServiceEnable,
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
}
for _, zone := range update.CustomZones {
protoZone := &proto.CustomZone{Domain: zone.Domain}
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
cacheKey := zone.Domain
if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
} else {
protoZone := convertToProtoCustomZone(zone)
cache.SetCustomZone(cacheKey, protoZone)
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
}
for _, nsGroup := range update.NameServerGroups {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
cacheKey := nsGroup.ID
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
} else {
protoGroup := convertToProtoNameServerGroup(nsGroup)
cache.SetNameServerGroup(cacheKey, protoGroup)
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
for _, ns := range nsGroup.NameServers {
protoNS := &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
}
protoGroup.NameServers = append(protoGroup.NameServers, protoNS)
}
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
}
return protoUpdate
}
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
if dnsDomain == "" {
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
return nbdns.CustomZone{}
// Helper function to convert nbdns.CustomZone to proto.CustomZone
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
protoZone := &proto.CustomZone{
Domain: zone.Domain,
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
}
customZone := nbdns.CustomZone{
Domain: dns.Fqdn(dnsDomain),
}
for _, peer := range account.Peers {
if peer.DNSLabel == "" {
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
continue
}
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
Name: dns.Fqdn(peer.DNSLabel + "." + dnsDomain),
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: defaultTTL,
RData: peer.IP.String(),
for _, record := range zone.Records {
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
Name: record.Name,
Type: int64(record.Type),
Class: record.Class,
TTL: int64(record.TTL),
RData: record.RData,
})
}
return protoZone
}
return customZone
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
protoGroup := &proto.NameServerGroup{
Primary: nsGroup.Primary,
Domains: nsGroup.Domains,
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
}
for _, ns := range nsGroup.NameServers {
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
IP: ns.IP.String(),
Port: int64(ns.Port),
NSType: int64(ns.NSType),
})
}
return protoGroup
}
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {

View File

@ -2,9 +2,14 @@ package server
import (
"context"
"fmt"
"net/netip"
"reflect"
"testing"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/dns"
@ -195,7 +200,11 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics)
}
func createDNSStore(t *testing.T) (Store, error) {
@ -320,3 +329,150 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
return am.Store.GetAccount(context.Background(), account.Id)
}
func generateTestData(size int) nbdns.Config {
config := nbdns.Config{
ServiceEnable: true,
CustomZones: make([]nbdns.CustomZone, size),
NameServerGroups: make([]*nbdns.NameServerGroup, size),
}
for i := 0; i < size; i++ {
config.CustomZones[i] = nbdns.CustomZone{
Domain: fmt.Sprintf("domain%d.com", i),
Records: []nbdns.SimpleRecord{
{
Name: fmt.Sprintf("record%d", i),
Type: 1,
Class: "IN",
TTL: 3600,
RData: "192.168.1.1",
},
},
}
config.NameServerGroups[i] = &nbdns.NameServerGroup{
ID: fmt.Sprintf("group%d", i),
Primary: i == 0,
Domains: []string{fmt.Sprintf("domain%d.com", i)},
SearchDomainsEnabled: true,
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
Port: 53,
NSType: 1,
},
},
}
}
return config
}
func BenchmarkToProtocolDNSConfig(b *testing.B) {
sizes := []int{10, 100, 1000}
for _, size := range sizes {
testData := generateTestData(size)
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
cache := &DNSConfigCache{}
b.ResetTimer()
for i := 0; i < b.N; i++ {
toProtocolDNSConfig(testData, cache)
}
})
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := &DNSConfigCache{}
toProtocolDNSConfig(testData, cache)
}
})
}
}
func TestToProtocolDNSConfigWithCache(t *testing.T) {
var cache DNSConfigCache
// Create two different configs
config1 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.com",
Records: []nbdns.SimpleRecord{
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group1",
Name: "Group 1",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
},
},
},
}
config2 := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "example.org",
Records: []nbdns.SimpleRecord{
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
ID: "group2",
Name: "Group 2",
NameServers: []nbdns.NameServer{
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
},
},
},
}
// First run with config1
result1 := toProtocolDNSConfig(config1, &cache)
// Second run with config2
result2 := toProtocolDNSConfig(config2, &cache)
// Third run with config1 again
result3 := toProtocolDNSConfig(config1, &cache)
// Verify that result1 and result3 are identical
if !reflect.DeepEqual(result1, result3) {
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
}
// Verify that result2 is different from result1 and result3
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
t.Errorf("Results should be different for different inputs")
}
// Verify that the cache contains elements from both configs
if _, exists := cache.GetCustomZone("example.com"); !exists {
t.Errorf("Cache should contain custom zone for example.com")
}
if _, exists := cache.GetCustomZone("example.org"); !exists {
t.Errorf("Cache should contain custom zone for example.org")
}
if _, exists := cache.GetNameServerGroup("group1"); !exists {
t.Errorf("Cache should contain name server group 'group1'")
}
if _, exists := cache.GetNameServerGroup("group2"); !exists {
t.Errorf("Cache should contain name server group 'group2'")
}
}

View File

@ -533,53 +533,46 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe
}
}
func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
remotePeers := []*proto.RemotePeerConfig{}
for _, rPeer := range peers {
fqdn := rPeer.FQDN(dnsName)
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: fqdn,
})
}
return remotePeers
}
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
wtConfig := toWiretrusteeConfig(config, turnCredentials)
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName)
routesUpdate := toProtocolRoutes(networkMap.Routes)
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
return &proto.SyncResponse{
WiretrusteeConfig: wtConfig,
PeerConfig: pConfig,
RemotePeers: remotePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
response := &proto.SyncResponse{
WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
PeerConfig: pConfig,
RemotePeers: remotePeers,
OfflinePeers: offlinePeers,
RemotePeersIsEmpty: len(remotePeers) == 0,
Routes: routesUpdate,
DNSConfig: dnsUpdate,
FirewallRules: firewallRules,
FirewallRulesIsEmpty: len(firewallRules) == 0,
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache),
},
Checks: toProtocolChecks(ctx, checks),
}
response.NetworkMap.PeerConfig = response.PeerConfig
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName)
response.RemotePeers = allPeers
response.NetworkMap.RemotePeers = allPeers
response.RemotePeersIsEmpty = len(allPeers) == 0
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
response.NetworkMap.FirewallRules = firewallRules
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
return response
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{
WgPubKey: rPeer.Key,
AllowedIps: []string{rPeer.IP.String() + "/32"},
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
Fqdn: rPeer.FQDN(dnsName),
})
}
return dst
}
// IsHealthy indicates whether the service is healthy
@ -597,7 +590,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
} else {
turnCredentials = nil
}
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil)
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
if err != nil {

View File

@ -71,7 +71,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
return
}
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
@ -115,7 +116,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
util.WriteError(ctx, fmt.Errorf("internal error"), w)
return
}
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
_, valid := validPeers[peer.ID]
@ -194,9 +197,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
}
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
}
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
@ -210,16 +211,6 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, respBody)
}
func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) {
validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
if err != nil {
return 0, err
}
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
return len(netMap.Peers) + len(netMap.OfflinePeers), nil
}
func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
for _, peer := range respBody {
_, ok := approvedPeersMap[peer.Id]

View File

@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/formatter"
mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
)
@ -419,8 +420,12 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccoun
ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{})
eventStore, nil, false, MocIntegratedValidator{}, metrics)
if err != nil {
return nil, nil, "", err
}

View File

@ -26,6 +26,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/util"
)
@ -541,8 +542,13 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
log.Fatalf("failed creating metrics: %v", err)
}
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
if err != nil {
log.Fatalf("failed creating a manager: %v", err)
}

View File

@ -11,6 +11,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
@ -762,7 +763,11 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
}
func createNSStore(t *testing.T) (Store, error) {

View File

@ -5,6 +5,7 @@ import (
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/rs/xid"
@ -322,7 +323,8 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
if err != nil {
return nil, err
}
return account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validatedPeers), nil
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil
}
// GetPeerNetwork returns the Network for a given peer
@ -535,7 +537,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
}
postureChecks := am.getPeerPostureChecks(account, peer)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, am.dnsDomain, approvedPeersMap)
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
}
@ -591,7 +594,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
}
postureChecks = am.getPeerPostureChecks(account, peer)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
}
// LoginPeer logs in or registers a peer.
@ -738,7 +742,8 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
}
postureChecks = am.getPeerPostureChecks(account, peer)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
}
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
@ -914,22 +919,45 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco
// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
start := time.Now()
defer func() {
if am.metrics != nil {
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start))
}
}()
peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
log.WithContext(ctx).Errorf("failed send out updates to peers, failed to validate peer: %v", err)
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
return
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
dnsCache := &DNSConfigCache{}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
for _, peer := range peers {
if !am.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
continue
}
postureChecks := am.getPeerPostureChecks(account, peer)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(ctx, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
postureChecks := am.getPeerPostureChecks(account, p)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
}(peer)
}
wg.Wait()
}

View File

@ -1,7 +1,6 @@
package peer
import (
"fmt"
"net"
"net/netip"
"slices"
@ -241,7 +240,7 @@ func (p *Peer) FQDN(dnsDomain string) string {
if dnsDomain == "" {
return ""
}
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
return p.DNSLabel + "." + dnsDomain
}
// EventMeta returns activity event meta related to the peer

View File

@ -0,0 +1,31 @@
package peer
import (
"fmt"
"testing"
)
// FQDNOld is the original implementation for benchmarking purposes
func (p *Peer) FQDNOld(dnsDomain string) string {
if dnsDomain == "" {
return ""
}
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
}
func BenchmarkFQDN(b *testing.B) {
p := &Peer{DNSLabel: "test-peer"}
dnsDomain := "example.com"
b.Run("Old", func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.FQDNOld(dnsDomain)
}
})
b.Run("New", func(b *testing.B) {
for i := 0; i < b.N; i++ {
p.FQDN(dnsDomain)
}
})
}

View File

@ -2,15 +2,26 @@ package server
import (
"context"
"fmt"
"io"
"net"
"net/netip"
"os"
"testing"
"time"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
nbroute "github.com/netbirdio/netbird/route"
)
func TestPeer_LoginExpired(t *testing.T) {
@ -633,3 +644,354 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
}
}
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
b.Helper()
manager, err := createManager(b)
if err != nil {
return nil, "", "", err
}
accountID := "test_account"
adminUser := "account_creator"
regularUser := "regular_user"
account := newAccountWithId(context.Background(), accountID, adminUser, "")
account.Users[regularUser] = &User{
Id: regularUser,
Role: UserRoleUser,
}
// Create peers
for i := 0; i < peers; i++ {
peerKey, _ := wgtypes.GeneratePrivateKey()
peer := &nbpeer.Peer{
ID: fmt.Sprintf("peer-%d", i),
DNSLabel: fmt.Sprintf("peer-%d", i),
Key: peerKey.PublicKey().String(),
IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)),
Status: &nbpeer.PeerStatus{},
UserID: regularUser,
}
account.Peers[peer.ID] = peer
}
// Create groups and policies
account.Policies = make([]*Policy, 0, groups)
for i := 0; i < groups; i++ {
groupID := fmt.Sprintf("group-%d", i)
group := &nbgroup.Group{
ID: groupID,
Name: fmt.Sprintf("Group %d", i),
}
for j := 0; j < peers/groups; j++ {
peerIndex := i*(peers/groups) + j
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
}
account.Groups[groupID] = group
// Create a policy for this group
policy := &Policy{
ID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Policy for Group %d", i),
Enabled: true,
Rules: []*PolicyRule{
{
ID: fmt.Sprintf("rule-%d", i),
Name: fmt.Sprintf("Rule for Group %d", i),
Enabled: true,
Sources: []string{groupID},
Destinations: []string{groupID},
Bidirectional: true,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
},
},
}
account.Policies = append(account.Policies, policy)
}
account.PostureChecks = []*posture.Checks{
{
ID: "PostureChecksAll",
Name: "All",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
},
}
err = manager.Store.SaveAccount(context.Background(), account)
if err != nil {
return nil, "", "", err
}
return manager, accountID, regularUser, nil
}
func BenchmarkGetPeers(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
}{
{"Small", 50, 5},
{"Medium", 500, 10},
{"Large", 5000, 20},
{"Small single", 50, 1},
{"Medium single", 500, 1},
{"Large 5", 5000, 5},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := manager.GetPeers(context.Background(), accountID, userID)
if err != nil {
b.Fatalf("GetPeers failed: %v", err)
}
}
})
}
}
func BenchmarkUpdateAccountPeers(b *testing.B) {
benchCases := []struct {
name string
peers int
groups int
}{
{"Small", 50, 5},
{"Medium", 500, 10},
{"Large", 5000, 20},
{"Small single", 50, 1},
{"Medium single", 500, 1},
{"Large 5", 5000, 5},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
if err != nil {
b.Fatalf("Failed to setup test account manager: %v", err)
}
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
}
manager.peersUpdateManager.peerChannels = peerChannels
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
manager.updateAccountPeers(ctx, account)
}
duration := time.Since(start)
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op")
b.ReportMetric(0, "ns/op")
})
}
}
func TestToSyncResponse(t *testing.T) {
_, ipnet, err := net.ParseCIDR("192.168.1.0/24")
if err != nil {
t.Fatal(err)
}
domainList, err := domain.FromStringList([]string{"example.com"})
if err != nil {
t.Fatal(err)
}
config := &Config{
Signal: &Host{
Proto: "https",
URI: "signal.uri",
Username: "",
Password: "",
},
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
TURNConfig: &TURNConfig{
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
},
}
peer := &nbpeer.Peer{
IP: net.ParseIP("192.168.1.1"),
SSHEnabled: true,
Key: "peer-key",
DNSLabel: "peer1",
SSHKey: "peer1-ssh-key",
}
turnCredentials := &TURNCredentials{
Username: "turn-user",
Password: "turn-pass",
}
networkMap := &NetworkMap{
Network: &Network{Net: *ipnet, Serial: 1000},
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
Routes: []*nbroute.Route{
{
ID: "route1",
Network: netip.MustParsePrefix("10.0.0.0/24"),
Domains: domainList,
KeepRoute: true,
NetID: "route1",
Peer: "peer1",
NetworkType: 1,
Masquerade: true,
Metric: 9999,
Enabled: true,
},
},
DNSConfig: nbdns.Config{
ServiceEnable: true,
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
Primary: true,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
{
ID: "ns1",
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
Groups: []string{"group1"},
Primary: true,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
},
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
},
FirewallRules: []*FirewallRule{
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
},
}
dnsName := "example.com"
checks := []*posture.Checks{
{
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{{LinuxPath: "/usr/bin/netbird"}},
},
},
},
}
dnsCache := &DNSConfigCache{}
response := toSyncResponse(context.Background(), config, peer, turnCredentials, networkMap, dnsName, checks, dnsCache)
assert.NotNil(t, response)
// assert peer config
assert.Equal(t, "192.168.1.1/24", response.PeerConfig.Address)
assert.Equal(t, "peer1.example.com", response.PeerConfig.Fqdn)
assert.Equal(t, true, response.PeerConfig.SshConfig.SshEnabled)
// assert wiretrustee config
assert.Equal(t, "signal.uri", response.WiretrusteeConfig.Signal.Uri)
assert.Equal(t, proto.HostConfig_HTTPS, response.WiretrusteeConfig.Signal.GetProtocol())
assert.Equal(t, "stun.uri", response.WiretrusteeConfig.Stuns[0].Uri)
assert.Equal(t, "turn.uri", response.WiretrusteeConfig.Turns[0].HostConfig.GetUri())
assert.Equal(t, "turn-user", response.WiretrusteeConfig.Turns[0].User)
assert.Equal(t, "turn-pass", response.WiretrusteeConfig.Turns[0].Password)
// assert RemotePeers
assert.Equal(t, 1, len(response.RemotePeers))
assert.Equal(t, "192.168.1.2/32", response.RemotePeers[0].AllowedIps[0])
assert.Equal(t, "peer2-key", response.RemotePeers[0].WgPubKey)
assert.Equal(t, "peer2.example.com", response.RemotePeers[0].GetFqdn())
assert.Equal(t, false, response.RemotePeers[0].GetSshConfig().GetSshEnabled())
assert.Equal(t, []byte("peer2-ssh-key"), response.RemotePeers[0].GetSshConfig().GetSshPubKey())
// assert network map
assert.Equal(t, uint64(1000), response.NetworkMap.Serial)
assert.Equal(t, "192.168.1.1/24", response.NetworkMap.PeerConfig.Address)
assert.Equal(t, "peer1.example.com", response.NetworkMap.PeerConfig.Fqdn)
assert.Equal(t, true, response.NetworkMap.PeerConfig.SshConfig.SshEnabled)
// assert network map RemotePeers
assert.Equal(t, 1, len(response.NetworkMap.RemotePeers))
assert.Equal(t, "192.168.1.2/32", response.NetworkMap.RemotePeers[0].AllowedIps[0])
assert.Equal(t, "peer2-key", response.NetworkMap.RemotePeers[0].WgPubKey)
assert.Equal(t, "peer2.example.com", response.NetworkMap.RemotePeers[0].GetFqdn())
assert.Equal(t, []byte("peer2-ssh-key"), response.NetworkMap.RemotePeers[0].GetSshConfig().GetSshPubKey())
// assert network map OfflinePeers
assert.Equal(t, 1, len(response.NetworkMap.OfflinePeers))
assert.Equal(t, "192.168.1.3/32", response.NetworkMap.OfflinePeers[0].AllowedIps[0])
assert.Equal(t, "peer3-key", response.NetworkMap.OfflinePeers[0].WgPubKey)
assert.Equal(t, "peer3.example.com", response.NetworkMap.OfflinePeers[0].GetFqdn())
assert.Equal(t, []byte("peer3-ssh-key"), response.NetworkMap.OfflinePeers[0].GetSshConfig().GetSshPubKey())
// assert network map Routes
assert.Equal(t, 1, len(response.NetworkMap.Routes))
assert.Equal(t, "10.0.0.0/24", response.NetworkMap.Routes[0].Network)
assert.Equal(t, "route1", response.NetworkMap.Routes[0].ID)
assert.Equal(t, "peer1", response.NetworkMap.Routes[0].Peer)
assert.Equal(t, "example.com", response.NetworkMap.Routes[0].Domains[0])
assert.Equal(t, true, response.NetworkMap.Routes[0].KeepRoute)
assert.Equal(t, true, response.NetworkMap.Routes[0].Masquerade)
assert.Equal(t, int64(9999), response.NetworkMap.Routes[0].Metric)
assert.Equal(t, int64(1), response.NetworkMap.Routes[0].NetworkType)
assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID)
// assert network map DNSConfig
assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable)
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones))
assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups))
// assert network map DNSConfig.CustomZones
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Domain)
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones[0].Records))
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Name)
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Type)
assert.Equal(t, "IN", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Class)
assert.Equal(t, int64(60), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].TTL)
assert.Equal(t, "100.64.0.1", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData)
// assert network map DNSConfig.NameServerGroups
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].Primary)
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].SearchDomainsEnabled)
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.NameServerGroups[0].Domains[0])
assert.Equal(t, "8.8.8.8", response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetIP())
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetNSType())
assert.Equal(t, int64(53), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetPort())
// assert network map Firewall
assert.Equal(t, 1, len(response.NetworkMap.FirewallRules))
assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP)
assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction)
assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol)
assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port)
// assert posture checks
assert.Equal(t, 1, len(response.Checks))
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
}

View File

@ -213,7 +213,6 @@ type FirewallRule struct {
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
for _, policy := range a.Policies {
if !policy.Enabled {
@ -225,8 +224,8 @@ func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string,
continue
}
sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap)
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
if rule.Bidirectional {
if peerInSources {
@ -290,8 +289,8 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
fr.PeerIP = "0.0.0.0"
}
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ","))
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
if _, ok := rulesExists[ruleID]; ok {
continue
}
@ -491,23 +490,23 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
//
// Important: Posture checks are applicable only to source group peers,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups {
group, ok := account.Groups[g]
group, ok := a.Groups[g]
if !ok {
continue
}
for _, p := range group.Peers {
peer, ok := account.Peers[p]
peer, ok := a.Peers[p]
if !ok || peer == nil {
continue
}
// validate the peer based on policy posture checks applied
isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
@ -535,7 +534,7 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
}
for _, postureChecksID := range sourcePostureChecksID {
postureChecks := getPostureChecks(a, postureChecksID)
postureChecks := a.getPostureChecks(postureChecksID)
if postureChecks == nil {
continue
}
@ -553,8 +552,8 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
return true
}
func getPostureChecks(account *Account, postureChecksID string) *posture.Checks {
for _, postureChecks := range account.PostureChecks {
func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
for _, postureChecks := range a.PostureChecks {
if postureChecks.ID == postureChecksID {
return postureChecks
}

View File

@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
)
@ -1233,7 +1234,11 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
}
func createRouterStore(t *testing.T) (Store, error) {

View File

@ -0,0 +1,69 @@
package telemetry
import (
"context"
"time"
"go.opentelemetry.io/otel/metric"
)
// AccountManagerMetrics represents all metrics related to the AccountManager
type AccountManagerMetrics struct {
ctx context.Context
updateAccountPeersDurationMs metric.Float64Histogram
getPeerNetworkMapDurationMs metric.Float64Histogram
networkMapObjectCount metric.Int64Histogram
}
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*AccountManagerMetrics, error) {
updateAccountPeersDurationMs, err := meter.Float64Histogram("management.account.update.account.peers.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithExplicitBucketBoundaries(
0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 30000,
))
if err != nil {
return nil, err
}
getPeerNetworkMapDurationMs, err := meter.Float64Histogram("management.account.get.peer.network.map.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithExplicitBucketBoundaries(
0.1, 0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000,
))
if err != nil {
return nil, err
}
networkMapObjectCount, err := meter.Int64Histogram("management.account.network.map.object.count",
metric.WithUnit("objects"),
metric.WithExplicitBucketBoundaries(
50, 100, 200, 500, 1000, 2500, 5000, 10000,
))
if err != nil {
return nil, err
}
return &AccountManagerMetrics{
ctx: ctx,
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
networkMapObjectCount: networkMapObjectCount,
}, nil
}
// CountUpdateAccountPeersDuration counts the duration of updating account peers
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) {
metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
}
// CountGetPeerNetworkMapDuration counts the duration of getting the peer network map
func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration time.Duration) {
metrics.getPeerNetworkMapDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
}
// CountNetworkMapObjects counts the number of network map objects
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
metrics.networkMapObjectCount.Record(metrics.ctx, count)
}

View File

@ -20,14 +20,15 @@ const defaultEndpoint = "/metrics"
// MockAppMetrics mocks the AppMetrics interface
type MockAppMetrics struct {
GetMeterFunc func() metric2.Meter
CloseFunc func() error
ExposeFunc func(ctx context.Context, port int, endpoint string) error
IDPMetricsFunc func() *IDPMetrics
HTTPMiddlewareFunc func() *HTTPMiddleware
GRPCMetricsFunc func() *GRPCMetrics
StoreMetricsFunc func() *StoreMetrics
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
GetMeterFunc func() metric2.Meter
CloseFunc func() error
ExposeFunc func(ctx context.Context, port int, endpoint string) error
IDPMetricsFunc func() *IDPMetrics
HTTPMiddlewareFunc func() *HTTPMiddleware
GRPCMetricsFunc func() *GRPCMetrics
StoreMetricsFunc func() *StoreMetrics
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
AddAccountManagerMetricsFunc func() *AccountManagerMetrics
}
// GetMeter mocks the GetMeter function of the AppMetrics interface
@ -94,6 +95,14 @@ func (mock *MockAppMetrics) UpdateChannelMetrics() *UpdateChannelMetrics {
return nil
}
// AccountManagerMetrics mocks the MockAppMetrics function of the AccountManagerMetrics interface
func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
if mock.AddAccountManagerMetricsFunc != nil {
return mock.AddAccountManagerMetricsFunc()
}
return nil
}
// AppMetrics is metrics interface
type AppMetrics interface {
GetMeter() metric2.Meter
@ -104,19 +113,21 @@ type AppMetrics interface {
GRPCMetrics() *GRPCMetrics
StoreMetrics() *StoreMetrics
UpdateChannelMetrics() *UpdateChannelMetrics
AccountManagerMetrics() *AccountManagerMetrics
}
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
type defaultAppMetrics struct {
// Meter can be used by different application parts to create counters and measure things
Meter metric2.Meter
listener net.Listener
ctx context.Context
idpMetrics *IDPMetrics
httpMiddleware *HTTPMiddleware
grpcMetrics *GRPCMetrics
storeMetrics *StoreMetrics
updateChannelMetrics *UpdateChannelMetrics
Meter metric2.Meter
listener net.Listener
ctx context.Context
idpMetrics *IDPMetrics
httpMiddleware *HTTPMiddleware
grpcMetrics *GRPCMetrics
storeMetrics *StoreMetrics
updateChannelMetrics *UpdateChannelMetrics
accountManagerMetrics *AccountManagerMetrics
}
// IDPMetrics returns metrics for the idp package
@ -144,6 +155,11 @@ func (appMetrics *defaultAppMetrics) UpdateChannelMetrics() *UpdateChannelMetric
return appMetrics.updateChannelMetrics
}
// AccountManagerMetrics returns metrics for the account manager
func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
return appMetrics.accountManagerMetrics
}
// Close stop application metrics HTTP handler and closes listener.
func (appMetrics *defaultAppMetrics) Close() error {
if appMetrics.listener == nil {
@ -220,13 +236,19 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
return nil, err
}
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
if err != nil {
return nil, err
}
return &defaultAppMetrics{
Meter: meter,
ctx: ctx,
idpMetrics: idpMetrics,
httpMiddleware: middleware,
grpcMetrics: grpcMetrics,
storeMetrics: storeMetrics,
updateChannelMetrics: updateChannelMetrics,
Meter: meter,
ctx: ctx,
idpMetrics: idpMetrics,
httpMiddleware: middleware,
grpcMetrics: grpcMetrics,
storeMetrics: storeMetrics,
updateChannelMetrics: updateChannelMetrics,
accountManagerMetrics: accountManagerMetrics,
}, nil
}