mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 01:38:41 +02:00
[management] Improve mgmt sync performance (#2363)
This commit is contained in:
parent
54d896846b
commit
ac0d5ff9f3
@ -11,6 +11,7 @@ import (
|
|||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
|
||||||
@ -71,6 +72,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
|
|
||||||
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
|
func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Listener) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", ":0")
|
lis, err := net.Listen("tcp", ":0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -88,7 +90,11 @@ func startManagement(t *testing.T, config *mgmt.Config) (*grpc.Server, net.Liste
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
iv, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv)
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -36,6 +36,7 @@ import (
|
|||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
signal "github.com/netbirdio/netbird/signal/client"
|
signal "github.com/netbirdio/netbird/signal/client"
|
||||||
"github.com/netbirdio/netbird/signal/proto"
|
"github.com/netbirdio/netbird/signal/proto"
|
||||||
@ -1069,7 +1070,11 @@ func startManagement(t *testing.T, dataDir string) (*grpc.Server, string, error)
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
@ -19,6 +19,7 @@ import (
|
|||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/signal/proto"
|
"github.com/netbirdio/netbird/signal/proto"
|
||||||
signalServer "github.com/netbirdio/netbird/signal/server"
|
signalServer "github.com/netbirdio/netbird/signal/server"
|
||||||
)
|
)
|
||||||
@ -120,7 +121,11 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,10 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
|
||||||
@ -71,7 +74,11 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
peersUpdateManager := mgmt.NewPeersUpdateManager(nil)
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
ia, _ := integrations.NewIntegratedValidator(context.Background(), eventStore)
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia)
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -190,7 +190,7 @@ var (
|
|||||||
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
|
return fmt.Errorf("failed to initialize integrated peer validator: %v", err)
|
||||||
}
|
}
|
||||||
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
accountManager, err := server.BuildManager(ctx, store, peersUpdateManager, idpManager, mgmtSingleAccModeDomain,
|
||||||
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator)
|
dnsDomain, eventStore, geo, userDeleteFromIDPEnabled, integratedPeerValidator, appMetrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to build default manager: %v", err)
|
return fmt.Errorf("failed to build default manager: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,8 @@ import (
|
|||||||
|
|
||||||
"github.com/eko/gocache/v3/cache"
|
"github.com/eko/gocache/v3/cache"
|
||||||
cacheStore "github.com/eko/gocache/v3/store"
|
cacheStore "github.com/eko/gocache/v3/store"
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"github.com/miekg/dns"
|
||||||
gocache "github.com/patrickmn/go-cache"
|
gocache "github.com/patrickmn/go-cache"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -37,6 +39,7 @@ import (
|
|||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/status"
|
"github.com/netbirdio/netbird/management/server/status"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -170,6 +173,8 @@ type DefaultAccountManager struct {
|
|||||||
userDeleteFromIDPEnabled bool
|
userDeleteFromIDPEnabled bool
|
||||||
|
|
||||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||||
|
|
||||||
|
metrics telemetry.AppMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
// Settings represents Account settings structure that can be modified via API and Dashboard
|
// Settings represents Account settings structure that can be modified via API and Dashboard
|
||||||
@ -401,8 +406,16 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
|
|||||||
return a.Groups[groupID]
|
return a.Groups[groupID]
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
|
// GetPeerNetworkMap returns the networkmap for the given peer ID.
|
||||||
func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap {
|
func (a *Account) GetPeerNetworkMap(
|
||||||
|
ctx context.Context,
|
||||||
|
peerID string,
|
||||||
|
peersCustomZone nbdns.CustomZone,
|
||||||
|
validatedPeersMap map[string]struct{},
|
||||||
|
metrics *telemetry.AccountManagerMetrics,
|
||||||
|
) *NetworkMap {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
peer := a.Peers[peerID]
|
peer := a.Peers[peerID]
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
return &NetworkMap{
|
return &NetworkMap{
|
||||||
@ -438,7 +451,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
|
|||||||
|
|
||||||
if dnsManagementStatus {
|
if dnsManagementStatus {
|
||||||
var zones []nbdns.CustomZone
|
var zones []nbdns.CustomZone
|
||||||
peersCustomZone := getPeersCustomZone(ctx, a, dnsDomain)
|
|
||||||
if peersCustomZone.Domain != "" {
|
if peersCustomZone.Domain != "" {
|
||||||
zones = append(zones, peersCustomZone)
|
zones = append(zones, peersCustomZone)
|
||||||
}
|
}
|
||||||
@ -446,7 +459,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
|
|||||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &NetworkMap{
|
nm := &NetworkMap{
|
||||||
Peers: peersToConnect,
|
Peers: peersToConnect,
|
||||||
Network: a.Network.Copy(),
|
Network: a.Network.Copy(),
|
||||||
Routes: routesUpdate,
|
Routes: routesUpdate,
|
||||||
@ -454,6 +467,60 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
|
|||||||
OfflinePeers: expiredPeers,
|
OfflinePeers: expiredPeers,
|
||||||
FirewallRules: firewallRules,
|
FirewallRules: firewallRules,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if metrics != nil {
|
||||||
|
objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules))
|
||||||
|
metrics.CountNetworkMapObjects(objectCount)
|
||||||
|
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nm
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone {
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if dnsDomain == "" {
|
||||||
|
log.WithContext(ctx).Error("no dns domain is set, returning empty zone")
|
||||||
|
return nbdns.CustomZone{}
|
||||||
|
}
|
||||||
|
|
||||||
|
customZone := nbdns.CustomZone{
|
||||||
|
Domain: dns.Fqdn(dnsDomain),
|
||||||
|
Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)),
|
||||||
|
}
|
||||||
|
|
||||||
|
domainSuffix := "." + dnsDomain
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
for _, peer := range a.Peers {
|
||||||
|
if peer.DNSLabel == "" {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
sb.Grow(len(peer.DNSLabel) + len(domainSuffix))
|
||||||
|
sb.WriteString(peer.DNSLabel)
|
||||||
|
sb.WriteString(domainSuffix)
|
||||||
|
|
||||||
|
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
|
||||||
|
Name: sb.String(),
|
||||||
|
Type: int(dns.TypeA),
|
||||||
|
Class: nbdns.DefaultClass,
|
||||||
|
TTL: defaultTTL,
|
||||||
|
RData: peer.IP.String(),
|
||||||
|
})
|
||||||
|
|
||||||
|
sb.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if merr != nil {
|
||||||
|
log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return customZone
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetExpiredPeers returns peers that have been expired
|
// GetExpiredPeers returns peers that have been expired
|
||||||
@ -871,10 +938,18 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||||
func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
func BuildManager(
|
||||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation,
|
ctx context.Context,
|
||||||
|
store Store,
|
||||||
|
peersUpdateManager *PeersUpdateManager,
|
||||||
|
idpManager idp.Manager,
|
||||||
|
singleAccountModeDomain string,
|
||||||
|
dnsDomain string,
|
||||||
|
eventStore activity.Store,
|
||||||
|
geo *geolocation.Geolocation,
|
||||||
userDeleteFromIDPEnabled bool,
|
userDeleteFromIDPEnabled bool,
|
||||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||||
|
metrics telemetry.AppMetrics,
|
||||||
) (*DefaultAccountManager, error) {
|
) (*DefaultAccountManager, error) {
|
||||||
am := &DefaultAccountManager{
|
am := &DefaultAccountManager{
|
||||||
Store: store,
|
Store: store,
|
||||||
@ -889,6 +964,7 @@ func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpd
|
|||||||
peerLoginExpiry: NewDefaultScheduler(),
|
peerLoginExpiry: NewDefaultScheduler(),
|
||||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||||
integratedPeerValidator: integratedPeerValidator,
|
integratedPeerValidator: integratedPeerValidator,
|
||||||
|
metrics: metrics,
|
||||||
}
|
}
|
||||||
allAccounts := store.GetAllAccounts(ctx)
|
allAccounts := store.GetAllAccounts(ctx)
|
||||||
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
||||||
|
@ -24,6 +24,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -410,7 +411,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
|||||||
validatedPeers[p] = struct{}{}
|
validatedPeers[p] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers)
|
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||||
|
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
|
||||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||||
}
|
}
|
||||||
@ -2293,7 +2295,13 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
type TB interface {
|
||||||
|
Cleanup(func())
|
||||||
|
Helper()
|
||||||
|
TempDir() string
|
||||||
|
}
|
||||||
|
|
||||||
|
func createManager(t TB) (*DefaultAccountManager, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
store, err := createStore(t)
|
store, err := createStore(t)
|
||||||
@ -2302,7 +2310,12 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
}
|
}
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
|
|
||||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -2310,7 +2323,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
return manager, nil
|
return manager, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createStore(t *testing.T) (Store, error) {
|
func createStore(t TB) (Store, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
dataDir := t.TempDir()
|
dataDir := t.TempDir()
|
||||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
||||||
|
@ -4,8 +4,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
@ -17,6 +17,50 @@ import (
|
|||||||
|
|
||||||
const defaultTTL = 300
|
const defaultTTL = 300
|
||||||
|
|
||||||
|
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||||
|
type DNSConfigCache struct {
|
||||||
|
CustomZones sync.Map
|
||||||
|
NameServerGroups sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCustomZone retrieves a cached custom zone
|
||||||
|
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if value, ok := c.CustomZones.Load(key); ok {
|
||||||
|
return value.(*proto.CustomZone), true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetCustomZone stores a custom zone in the cache
|
||||||
|
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.CustomZones.Store(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNameServerGroup retrieves a cached name server group
|
||||||
|
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if value, ok := c.NameServerGroups.Load(key); ok {
|
||||||
|
return value.(*proto.NameServerGroup), true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNameServerGroup stores a name server group in the cache
|
||||||
|
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.NameServerGroups.Store(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
type lookupMap map[string]struct{}
|
type lookupMap map[string]struct{}
|
||||||
|
|
||||||
// DNSSettings defines dns settings at the account level
|
// DNSSettings defines dns settings at the account level
|
||||||
@ -113,69 +157,73 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
|
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||||
protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable}
|
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
||||||
|
protoUpdate := &proto.DNSConfig{
|
||||||
|
ServiceEnable: update.ServiceEnable,
|
||||||
|
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||||
|
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||||
|
}
|
||||||
|
|
||||||
for _, zone := range update.CustomZones {
|
for _, zone := range update.CustomZones {
|
||||||
protoZone := &proto.CustomZone{Domain: zone.Domain}
|
cacheKey := zone.Domain
|
||||||
for _, record := range zone.Records {
|
if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
|
||||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
|
||||||
Name: record.Name,
|
} else {
|
||||||
Type: int64(record.Type),
|
protoZone := convertToProtoCustomZone(zone)
|
||||||
Class: record.Class,
|
cache.SetCustomZone(cacheKey, protoZone)
|
||||||
TTL: int64(record.TTL),
|
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||||
RData: record.RData,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, nsGroup := range update.NameServerGroups {
|
for _, nsGroup := range update.NameServerGroups {
|
||||||
protoGroup := &proto.NameServerGroup{
|
cacheKey := nsGroup.ID
|
||||||
Primary: nsGroup.Primary,
|
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
||||||
Domains: nsGroup.Domains,
|
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
} else {
|
||||||
|
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
||||||
|
cache.SetNameServerGroup(cacheKey, protoGroup)
|
||||||
|
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||||
}
|
}
|
||||||
for _, ns := range nsGroup.NameServers {
|
|
||||||
protoNS := &proto.NameServer{
|
|
||||||
IP: ns.IP.String(),
|
|
||||||
Port: int64(ns.Port),
|
|
||||||
NSType: int64(ns.NSType),
|
|
||||||
}
|
|
||||||
protoGroup.NameServers = append(protoGroup.NameServers, protoNS)
|
|
||||||
}
|
|
||||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return protoUpdate
|
return protoUpdate
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
|
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||||
if dnsDomain == "" {
|
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||||
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
|
protoZone := &proto.CustomZone{
|
||||||
return nbdns.CustomZone{}
|
Domain: zone.Domain,
|
||||||
|
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||||
}
|
}
|
||||||
|
for _, record := range zone.Records {
|
||||||
customZone := nbdns.CustomZone{
|
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||||
Domain: dns.Fqdn(dnsDomain),
|
Name: record.Name,
|
||||||
}
|
Type: int64(record.Type),
|
||||||
|
Class: record.Class,
|
||||||
for _, peer := range account.Peers {
|
TTL: int64(record.TTL),
|
||||||
if peer.DNSLabel == "" {
|
RData: record.RData,
|
||||||
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
|
|
||||||
Name: dns.Fqdn(peer.DNSLabel + "." + dnsDomain),
|
|
||||||
Type: int(dns.TypeA),
|
|
||||||
Class: nbdns.DefaultClass,
|
|
||||||
TTL: defaultTTL,
|
|
||||||
RData: peer.IP.String(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
return protoZone
|
||||||
|
}
|
||||||
|
|
||||||
return customZone
|
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
||||||
|
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||||
|
protoGroup := &proto.NameServerGroup{
|
||||||
|
Primary: nsGroup.Primary,
|
||||||
|
Domains: nsGroup.Domains,
|
||||||
|
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||||
|
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||||
|
}
|
||||||
|
for _, ns := range nsGroup.NameServers {
|
||||||
|
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||||
|
IP: ns.IP.String(),
|
||||||
|
Port: int64(ns.Port),
|
||||||
|
NSType: int64(ns.NSType),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return protoGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
||||||
|
@ -2,9 +2,14 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/dns"
|
"github.com/netbirdio/netbird/dns"
|
||||||
@ -195,7 +200,11 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createDNSStore(t *testing.T) (Store, error) {
|
func createDNSStore(t *testing.T) (Store, error) {
|
||||||
@ -320,3 +329,150 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
|||||||
|
|
||||||
return am.Store.GetAccount(context.Background(), account.Id)
|
return am.Store.GetAccount(context.Background(), account.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateTestData(size int) nbdns.Config {
|
||||||
|
config := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: make([]nbdns.CustomZone, size),
|
||||||
|
NameServerGroups: make([]*nbdns.NameServerGroup, size),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < size; i++ {
|
||||||
|
config.CustomZones[i] = nbdns.CustomZone{
|
||||||
|
Domain: fmt.Sprintf("domain%d.com", i),
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: fmt.Sprintf("record%d", i),
|
||||||
|
Type: 1,
|
||||||
|
Class: "IN",
|
||||||
|
TTL: 3600,
|
||||||
|
RData: "192.168.1.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config.NameServerGroups[i] = &nbdns.NameServerGroup{
|
||||||
|
ID: fmt.Sprintf("group%d", i),
|
||||||
|
Primary: i == 0,
|
||||||
|
Domains: []string{fmt.Sprintf("domain%d.com", i)},
|
||||||
|
SearchDomainsEnabled: true,
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
Port: 53,
|
||||||
|
NSType: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||||
|
sizes := []int{10, 100, 1000}
|
||||||
|
|
||||||
|
for _, size := range sizes {
|
||||||
|
testData := generateTestData(size)
|
||||||
|
|
||||||
|
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
|
||||||
|
cache := &DNSConfigCache{}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
toProtocolDNSConfig(testData, cache)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
cache := &DNSConfigCache{}
|
||||||
|
toProtocolDNSConfig(testData, cache)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||||
|
var cache DNSConfigCache
|
||||||
|
|
||||||
|
// Create two different configs
|
||||||
|
config1 := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "example.com",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
ID: "group1",
|
||||||
|
Name: "Group 1",
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
config2 := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "example.org",
|
||||||
|
Records: []nbdns.SimpleRecord{
|
||||||
|
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
ID: "group2",
|
||||||
|
Name: "Group 2",
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// First run with config1
|
||||||
|
result1 := toProtocolDNSConfig(config1, &cache)
|
||||||
|
|
||||||
|
// Second run with config2
|
||||||
|
result2 := toProtocolDNSConfig(config2, &cache)
|
||||||
|
|
||||||
|
// Third run with config1 again
|
||||||
|
result3 := toProtocolDNSConfig(config1, &cache)
|
||||||
|
|
||||||
|
// Verify that result1 and result3 are identical
|
||||||
|
if !reflect.DeepEqual(result1, result3) {
|
||||||
|
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that result2 is different from result1 and result3
|
||||||
|
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
|
||||||
|
t.Errorf("Results should be different for different inputs")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the cache contains elements from both configs
|
||||||
|
if _, exists := cache.GetCustomZone("example.com"); !exists {
|
||||||
|
t.Errorf("Cache should contain custom zone for example.com")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := cache.GetCustomZone("example.org"); !exists {
|
||||||
|
t.Errorf("Cache should contain custom zone for example.org")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
||||||
|
t.Errorf("Cache should contain name server group 'group1'")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := cache.GetNameServerGroup("group2"); !exists {
|
||||||
|
t.Errorf("Cache should contain name server group 'group2'")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -533,53 +533,46 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
|
||||||
remotePeers := []*proto.RemotePeerConfig{}
|
response := &proto.SyncResponse{
|
||||||
for _, rPeer := range peers {
|
WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials),
|
||||||
fqdn := rPeer.FQDN(dnsName)
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName),
|
||||||
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
|
|
||||||
WgPubKey: rPeer.Key,
|
|
||||||
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
|
|
||||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
|
||||||
Fqdn: fqdn,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return remotePeers
|
|
||||||
}
|
|
||||||
|
|
||||||
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNCredentials, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
|
|
||||||
wtConfig := toWiretrusteeConfig(config, turnCredentials)
|
|
||||||
|
|
||||||
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
|
|
||||||
|
|
||||||
remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName)
|
|
||||||
|
|
||||||
routesUpdate := toProtocolRoutes(networkMap.Routes)
|
|
||||||
|
|
||||||
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
|
|
||||||
|
|
||||||
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
|
|
||||||
|
|
||||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
|
||||||
|
|
||||||
return &proto.SyncResponse{
|
|
||||||
WiretrusteeConfig: wtConfig,
|
|
||||||
PeerConfig: pConfig,
|
|
||||||
RemotePeers: remotePeers,
|
|
||||||
RemotePeersIsEmpty: len(remotePeers) == 0,
|
|
||||||
NetworkMap: &proto.NetworkMap{
|
NetworkMap: &proto.NetworkMap{
|
||||||
Serial: networkMap.Network.CurrentSerial(),
|
Serial: networkMap.Network.CurrentSerial(),
|
||||||
PeerConfig: pConfig,
|
Routes: toProtocolRoutes(networkMap.Routes),
|
||||||
RemotePeers: remotePeers,
|
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache),
|
||||||
OfflinePeers: offlinePeers,
|
|
||||||
RemotePeersIsEmpty: len(remotePeers) == 0,
|
|
||||||
Routes: routesUpdate,
|
|
||||||
DNSConfig: dnsUpdate,
|
|
||||||
FirewallRules: firewallRules,
|
|
||||||
FirewallRulesIsEmpty: len(firewallRules) == 0,
|
|
||||||
},
|
},
|
||||||
Checks: toProtocolChecks(ctx, checks),
|
Checks: toProtocolChecks(ctx, checks),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||||
|
|
||||||
|
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||||
|
allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName)
|
||||||
|
response.RemotePeers = allPeers
|
||||||
|
response.NetworkMap.RemotePeers = allPeers
|
||||||
|
response.RemotePeersIsEmpty = len(allPeers) == 0
|
||||||
|
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||||
|
|
||||||
|
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
|
||||||
|
|
||||||
|
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||||
|
response.NetworkMap.FirewallRules = firewallRules
|
||||||
|
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||||
|
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||||
|
for _, rPeer := range peers {
|
||||||
|
dst = append(dst, &proto.RemotePeerConfig{
|
||||||
|
WgPubKey: rPeer.Key,
|
||||||
|
AllowedIps: []string{rPeer.IP.String() + "/32"},
|
||||||
|
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||||
|
Fqdn: rPeer.FQDN(dnsName),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsHealthy indicates whether the service is healthy
|
// IsHealthy indicates whether the service is healthy
|
||||||
@ -597,7 +590,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
|||||||
} else {
|
} else {
|
||||||
turnCredentials = nil
|
turnCredentials = nil
|
||||||
}
|
}
|
||||||
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
|
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil)
|
||||||
|
|
||||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -71,7 +71,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
|
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
|
||||||
|
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
|
||||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||||
|
|
||||||
_, valid := validPeers[peer.ID]
|
_, valid := validPeers[peer.ID]
|
||||||
@ -115,7 +116,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
|||||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
|
|
||||||
|
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
|
||||||
|
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
|
||||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||||
|
|
||||||
_, valid := validPeers[peer.ID]
|
_, valid := validPeers[peer.ID]
|
||||||
@ -194,9 +197,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||||
|
|
||||||
accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID)
|
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||||
|
|
||||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||||
@ -210,16 +211,6 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, respBody)
|
util.WriteJSONObject(r.Context(), w, respBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) {
|
|
||||||
validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
|
|
||||||
return len(netMap.Peers) + len(netMap.OfflinePeers), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
|
func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
|
||||||
for _, peer := range respBody {
|
for _, peer := range respBody {
|
||||||
_, ok := approvedPeersMap[peer.Id]
|
_, ok := approvedPeersMap[peer.Id]
|
||||||
|
@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -419,8 +420,12 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccoun
|
|||||||
|
|
||||||
ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
|
ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||||
eventStore, nil, false, MocIntegratedValidator{})
|
eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, "", err
|
return nil, nil, "", err
|
||||||
}
|
}
|
||||||
|
@ -26,6 +26,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/group"
|
"github.com/netbirdio/netbird/management/server/group"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -541,8 +542,13 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
|||||||
|
|
||||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
|
||||||
eventStore, nil, false, MocIntegratedValidator{})
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed creating metrics: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed creating a manager: %v", err)
|
log.Fatalf("failed creating a manager: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@ -762,7 +763,11 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNSStore(t *testing.T) (Store, error) {
|
func createNSStore(t *testing.T) (Store, error) {
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
@ -322,7 +323,8 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validatedPeers), nil
|
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||||
|
return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPeerNetwork returns the Network for a given peer
|
// GetPeerNetwork returns the Network for a given peer
|
||||||
@ -535,7 +537,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
|
|
||||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
postureChecks := am.getPeerPostureChecks(account, peer)
|
||||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, am.dnsDomain, approvedPeersMap)
|
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||||
|
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||||
return newPeer, networkMap, postureChecks, nil
|
return newPeer, networkMap, postureChecks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -591,7 +594,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
|||||||
}
|
}
|
||||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||||
|
|
||||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil
|
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||||
|
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginPeer logs in or registers a peer.
|
// LoginPeer logs in or registers a peer.
|
||||||
@ -738,7 +742,8 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
|||||||
}
|
}
|
||||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||||
|
|
||||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
|
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||||
|
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
|
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
|
||||||
@ -914,22 +919,45 @@ func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Acco
|
|||||||
// updateAccountPeers updates all peers that belong to an account.
|
// updateAccountPeers updates all peers that belong to an account.
|
||||||
// Should be called when changes have to be synced to peers.
|
// Should be called when changes have to be synced to peers.
|
||||||
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
|
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
|
||||||
|
start := time.Now()
|
||||||
|
defer func() {
|
||||||
|
if am.metrics != nil {
|
||||||
|
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
peers := account.GetPeers()
|
peers := account.GetPeers()
|
||||||
|
|
||||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed send out updates to peers, failed to validate peer: %v", err)
|
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
semaphore := make(chan struct{}, 10)
|
||||||
|
|
||||||
|
dnsCache := &DNSConfigCache{}
|
||||||
|
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||||
|
|
||||||
for _, peer := range peers {
|
for _, peer := range peers {
|
||||||
if !am.peersUpdateManager.HasChannel(peer.ID) {
|
if !am.peersUpdateManager.HasChannel(peer.ID) {
|
||||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
wg.Add(1)
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap)
|
semaphore <- struct{}{}
|
||||||
update := toSyncResponse(ctx, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
|
go func(p *nbpeer.Peer) {
|
||||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
|
defer wg.Done()
|
||||||
|
defer func() { <-semaphore }()
|
||||||
|
|
||||||
|
postureChecks := am.getPeerPostureChecks(account, p)
|
||||||
|
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||||
|
update := toSyncResponse(ctx, nil, p, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
||||||
|
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
|
||||||
|
}(peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
package peer
|
package peer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
@ -241,7 +240,7 @@ func (p *Peer) FQDN(dnsDomain string) string {
|
|||||||
if dnsDomain == "" {
|
if dnsDomain == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
|
return p.DNSLabel + "." + dnsDomain
|
||||||
}
|
}
|
||||||
|
|
||||||
// EventMeta returns activity event meta related to the peer
|
// EventMeta returns activity event meta related to the peer
|
||||||
|
31
management/server/peer/peer_test.go
Normal file
31
management/server/peer/peer_test.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FQDNOld is the original implementation for benchmarking purposes
|
||||||
|
func (p *Peer) FQDNOld(dnsDomain string) string {
|
||||||
|
if dnsDomain == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkFQDN(b *testing.B) {
|
||||||
|
p := &Peer{DNSLabel: "test-peer"}
|
||||||
|
dnsDomain := "example.com"
|
||||||
|
|
||||||
|
b.Run("Old", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
p.FQDNOld(dnsDomain)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("New", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
p.FQDN(dnsDomain)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
@ -2,15 +2,26 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
"github.com/netbirdio/netbird/management/proto"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
nbroute "github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPeer_LoginExpired(t *testing.T) {
|
func TestPeer_LoginExpired(t *testing.T) {
|
||||||
@ -633,3 +644,354 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
manager, err := createManager(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID := "test_account"
|
||||||
|
adminUser := "account_creator"
|
||||||
|
regularUser := "regular_user"
|
||||||
|
|
||||||
|
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||||
|
account.Users[regularUser] = &User{
|
||||||
|
Id: regularUser,
|
||||||
|
Role: UserRoleUser,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create peers
|
||||||
|
for i := 0; i < peers; i++ {
|
||||||
|
peerKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
ID: fmt.Sprintf("peer-%d", i),
|
||||||
|
DNSLabel: fmt.Sprintf("peer-%d", i),
|
||||||
|
Key: peerKey.PublicKey().String(),
|
||||||
|
IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)),
|
||||||
|
Status: &nbpeer.PeerStatus{},
|
||||||
|
UserID: regularUser,
|
||||||
|
}
|
||||||
|
account.Peers[peer.ID] = peer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create groups and policies
|
||||||
|
account.Policies = make([]*Policy, 0, groups)
|
||||||
|
for i := 0; i < groups; i++ {
|
||||||
|
groupID := fmt.Sprintf("group-%d", i)
|
||||||
|
group := &nbgroup.Group{
|
||||||
|
ID: groupID,
|
||||||
|
Name: fmt.Sprintf("Group %d", i),
|
||||||
|
}
|
||||||
|
for j := 0; j < peers/groups; j++ {
|
||||||
|
peerIndex := i*(peers/groups) + j
|
||||||
|
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
|
||||||
|
}
|
||||||
|
account.Groups[groupID] = group
|
||||||
|
|
||||||
|
// Create a policy for this group
|
||||||
|
policy := &Policy{
|
||||||
|
ID: fmt.Sprintf("policy-%d", i),
|
||||||
|
Name: fmt.Sprintf("Policy for Group %d", i),
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: fmt.Sprintf("rule-%d", i),
|
||||||
|
Name: fmt.Sprintf("Rule for Group %d", i),
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{groupID},
|
||||||
|
Destinations: []string{groupID},
|
||||||
|
Bidirectional: true,
|
||||||
|
Protocol: PolicyRuleProtocolALL,
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
account.Policies = append(account.Policies, policy)
|
||||||
|
}
|
||||||
|
|
||||||
|
account.PostureChecks = []*posture.Checks{
|
||||||
|
{
|
||||||
|
ID: "PostureChecksAll",
|
||||||
|
Name: "All",
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
|
MinVersion: "0.0.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.Store.SaveAccount(context.Background(), account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return manager, accountID, regularUser, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkGetPeers(b *testing.B) {
|
||||||
|
benchCases := []struct {
|
||||||
|
name string
|
||||||
|
peers int
|
||||||
|
groups int
|
||||||
|
}{
|
||||||
|
{"Small", 50, 5},
|
||||||
|
{"Medium", 500, 10},
|
||||||
|
{"Large", 5000, 20},
|
||||||
|
{"Small single", 50, 1},
|
||||||
|
{"Medium single", 500, 1},
|
||||||
|
{"Large 5", 5000, 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
b.Run(bc.name, func(b *testing.B) {
|
||||||
|
manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, err := manager.GetPeers(context.Background(), accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("GetPeers failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||||
|
benchCases := []struct {
|
||||||
|
name string
|
||||||
|
peers int
|
||||||
|
groups int
|
||||||
|
}{
|
||||||
|
{"Small", 50, 5},
|
||||||
|
{"Medium", 500, 10},
|
||||||
|
{"Large", 5000, 20},
|
||||||
|
{"Small single", 50, 1},
|
||||||
|
{"Medium single", 500, 1},
|
||||||
|
{"Large 5", 5000, 5},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
for _, bc := range benchCases {
|
||||||
|
b.Run(bc.name, func(b *testing.B) {
|
||||||
|
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to get account: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peerChannels := make(map[string]chan *UpdateMessage)
|
||||||
|
|
||||||
|
for peerID := range account.Peers {
|
||||||
|
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.peersUpdateManager.peerChannels = peerChannels
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.updateAccountPeers(ctx, account)
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op")
|
||||||
|
b.ReportMetric(0, "ns/op")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToSyncResponse(t *testing.T) {
|
||||||
|
_, ipnet, err := net.ParseCIDR("192.168.1.0/24")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
domainList, err := domain.FromStringList([]string{"example.com"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &Config{
|
||||||
|
Signal: &Host{
|
||||||
|
Proto: "https",
|
||||||
|
URI: "signal.uri",
|
||||||
|
Username: "",
|
||||||
|
Password: "",
|
||||||
|
},
|
||||||
|
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
|
||||||
|
TURNConfig: &TURNConfig{
|
||||||
|
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
SSHEnabled: true,
|
||||||
|
Key: "peer-key",
|
||||||
|
DNSLabel: "peer1",
|
||||||
|
SSHKey: "peer1-ssh-key",
|
||||||
|
}
|
||||||
|
turnCredentials := &TURNCredentials{
|
||||||
|
Username: "turn-user",
|
||||||
|
Password: "turn-pass",
|
||||||
|
}
|
||||||
|
networkMap := &NetworkMap{
|
||||||
|
Network: &Network{Net: *ipnet, Serial: 1000},
|
||||||
|
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
|
||||||
|
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
|
||||||
|
Routes: []*nbroute.Route{
|
||||||
|
{
|
||||||
|
ID: "route1",
|
||||||
|
Network: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
Domains: domainList,
|
||||||
|
KeepRoute: true,
|
||||||
|
NetID: "route1",
|
||||||
|
Peer: "peer1",
|
||||||
|
NetworkType: 1,
|
||||||
|
Masquerade: true,
|
||||||
|
Metric: 9999,
|
||||||
|
Enabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
DNSConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: []nbdns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: nbdns.DefaultDNSPort,
|
||||||
|
}},
|
||||||
|
Primary: true,
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
Enabled: true,
|
||||||
|
SearchDomainsEnabled: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "ns1",
|
||||||
|
NameServers: []nbdns.NameServer{{
|
||||||
|
IP: netip.MustParseAddr("1.1.1.1"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: nbdns.DefaultDNSPort,
|
||||||
|
}},
|
||||||
|
Groups: []string{"group1"},
|
||||||
|
Primary: true,
|
||||||
|
Domains: []string{"example.com"},
|
||||||
|
Enabled: true,
|
||||||
|
SearchDomainsEnabled: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
|
||||||
|
},
|
||||||
|
FirewallRules: []*FirewallRule{
|
||||||
|
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dnsName := "example.com"
|
||||||
|
checks := []*posture.Checks{
|
||||||
|
{
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
ProcessCheck: &posture.ProcessCheck{
|
||||||
|
Processes: []posture.Process{{LinuxPath: "/usr/bin/netbird"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dnsCache := &DNSConfigCache{}
|
||||||
|
|
||||||
|
response := toSyncResponse(context.Background(), config, peer, turnCredentials, networkMap, dnsName, checks, dnsCache)
|
||||||
|
|
||||||
|
assert.NotNil(t, response)
|
||||||
|
// assert peer config
|
||||||
|
assert.Equal(t, "192.168.1.1/24", response.PeerConfig.Address)
|
||||||
|
assert.Equal(t, "peer1.example.com", response.PeerConfig.Fqdn)
|
||||||
|
assert.Equal(t, true, response.PeerConfig.SshConfig.SshEnabled)
|
||||||
|
// assert wiretrustee config
|
||||||
|
assert.Equal(t, "signal.uri", response.WiretrusteeConfig.Signal.Uri)
|
||||||
|
assert.Equal(t, proto.HostConfig_HTTPS, response.WiretrusteeConfig.Signal.GetProtocol())
|
||||||
|
assert.Equal(t, "stun.uri", response.WiretrusteeConfig.Stuns[0].Uri)
|
||||||
|
assert.Equal(t, "turn.uri", response.WiretrusteeConfig.Turns[0].HostConfig.GetUri())
|
||||||
|
assert.Equal(t, "turn-user", response.WiretrusteeConfig.Turns[0].User)
|
||||||
|
assert.Equal(t, "turn-pass", response.WiretrusteeConfig.Turns[0].Password)
|
||||||
|
// assert RemotePeers
|
||||||
|
assert.Equal(t, 1, len(response.RemotePeers))
|
||||||
|
assert.Equal(t, "192.168.1.2/32", response.RemotePeers[0].AllowedIps[0])
|
||||||
|
assert.Equal(t, "peer2-key", response.RemotePeers[0].WgPubKey)
|
||||||
|
assert.Equal(t, "peer2.example.com", response.RemotePeers[0].GetFqdn())
|
||||||
|
assert.Equal(t, false, response.RemotePeers[0].GetSshConfig().GetSshEnabled())
|
||||||
|
assert.Equal(t, []byte("peer2-ssh-key"), response.RemotePeers[0].GetSshConfig().GetSshPubKey())
|
||||||
|
// assert network map
|
||||||
|
assert.Equal(t, uint64(1000), response.NetworkMap.Serial)
|
||||||
|
assert.Equal(t, "192.168.1.1/24", response.NetworkMap.PeerConfig.Address)
|
||||||
|
assert.Equal(t, "peer1.example.com", response.NetworkMap.PeerConfig.Fqdn)
|
||||||
|
assert.Equal(t, true, response.NetworkMap.PeerConfig.SshConfig.SshEnabled)
|
||||||
|
// assert network map RemotePeers
|
||||||
|
assert.Equal(t, 1, len(response.NetworkMap.RemotePeers))
|
||||||
|
assert.Equal(t, "192.168.1.2/32", response.NetworkMap.RemotePeers[0].AllowedIps[0])
|
||||||
|
assert.Equal(t, "peer2-key", response.NetworkMap.RemotePeers[0].WgPubKey)
|
||||||
|
assert.Equal(t, "peer2.example.com", response.NetworkMap.RemotePeers[0].GetFqdn())
|
||||||
|
assert.Equal(t, []byte("peer2-ssh-key"), response.NetworkMap.RemotePeers[0].GetSshConfig().GetSshPubKey())
|
||||||
|
// assert network map OfflinePeers
|
||||||
|
assert.Equal(t, 1, len(response.NetworkMap.OfflinePeers))
|
||||||
|
assert.Equal(t, "192.168.1.3/32", response.NetworkMap.OfflinePeers[0].AllowedIps[0])
|
||||||
|
assert.Equal(t, "peer3-key", response.NetworkMap.OfflinePeers[0].WgPubKey)
|
||||||
|
assert.Equal(t, "peer3.example.com", response.NetworkMap.OfflinePeers[0].GetFqdn())
|
||||||
|
assert.Equal(t, []byte("peer3-ssh-key"), response.NetworkMap.OfflinePeers[0].GetSshConfig().GetSshPubKey())
|
||||||
|
// assert network map Routes
|
||||||
|
assert.Equal(t, 1, len(response.NetworkMap.Routes))
|
||||||
|
assert.Equal(t, "10.0.0.0/24", response.NetworkMap.Routes[0].Network)
|
||||||
|
assert.Equal(t, "route1", response.NetworkMap.Routes[0].ID)
|
||||||
|
assert.Equal(t, "peer1", response.NetworkMap.Routes[0].Peer)
|
||||||
|
assert.Equal(t, "example.com", response.NetworkMap.Routes[0].Domains[0])
|
||||||
|
assert.Equal(t, true, response.NetworkMap.Routes[0].KeepRoute)
|
||||||
|
assert.Equal(t, true, response.NetworkMap.Routes[0].Masquerade)
|
||||||
|
assert.Equal(t, int64(9999), response.NetworkMap.Routes[0].Metric)
|
||||||
|
assert.Equal(t, int64(1), response.NetworkMap.Routes[0].NetworkType)
|
||||||
|
assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID)
|
||||||
|
// assert network map DNSConfig
|
||||||
|
assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable)
|
||||||
|
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones))
|
||||||
|
assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups))
|
||||||
|
// assert network map DNSConfig.CustomZones
|
||||||
|
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Domain)
|
||||||
|
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones[0].Records))
|
||||||
|
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Name)
|
||||||
|
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Type)
|
||||||
|
assert.Equal(t, "IN", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Class)
|
||||||
|
assert.Equal(t, int64(60), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].TTL)
|
||||||
|
assert.Equal(t, "100.64.0.1", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData)
|
||||||
|
// assert network map DNSConfig.NameServerGroups
|
||||||
|
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].Primary)
|
||||||
|
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].SearchDomainsEnabled)
|
||||||
|
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.NameServerGroups[0].Domains[0])
|
||||||
|
assert.Equal(t, "8.8.8.8", response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetIP())
|
||||||
|
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetNSType())
|
||||||
|
assert.Equal(t, int64(53), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetPort())
|
||||||
|
// assert network map Firewall
|
||||||
|
assert.Equal(t, 1, len(response.NetworkMap.FirewallRules))
|
||||||
|
assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP)
|
||||||
|
assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction)
|
||||||
|
assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
|
||||||
|
assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol)
|
||||||
|
assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port)
|
||||||
|
// assert posture checks
|
||||||
|
assert.Equal(t, 1, len(response.Checks))
|
||||||
|
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
|
||||||
|
}
|
||||||
|
@ -213,7 +213,6 @@ type FirewallRule struct {
|
|||||||
//
|
//
|
||||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||||
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||||
|
|
||||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
|
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
|
||||||
for _, policy := range a.Policies {
|
for _, policy := range a.Policies {
|
||||||
if !policy.Enabled {
|
if !policy.Enabled {
|
||||||
@ -225,8 +224,8 @@ func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string,
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
|
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
|
||||||
destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap)
|
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
|
||||||
|
|
||||||
if rule.Bidirectional {
|
if rule.Bidirectional {
|
||||||
if peerInSources {
|
if peerInSources {
|
||||||
@ -290,8 +289,8 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
|||||||
fr.PeerIP = "0.0.0.0"
|
fr.PeerIP = "0.0.0.0"
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ","))
|
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
|
||||||
if _, ok := rulesExists[ruleID]; ok {
|
if _, ok := rulesExists[ruleID]; ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -491,23 +490,23 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
|||||||
//
|
//
|
||||||
// Important: Posture checks are applicable only to source group peers,
|
// Important: Posture checks are applicable only to source group peers,
|
||||||
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
||||||
func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||||
peerInGroups := false
|
peerInGroups := false
|
||||||
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
|
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
|
||||||
for _, g := range groups {
|
for _, g := range groups {
|
||||||
group, ok := account.Groups[g]
|
group, ok := a.Groups[g]
|
||||||
if !ok {
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, p := range group.Peers {
|
for _, p := range group.Peers {
|
||||||
peer, ok := account.Peers[p]
|
peer, ok := a.Peers[p]
|
||||||
if !ok || peer == nil {
|
if !ok || peer == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// validate the peer based on policy posture checks applied
|
// validate the peer based on policy posture checks applied
|
||||||
isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||||
if !isValid {
|
if !isValid {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -535,7 +534,7 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, postureChecksID := range sourcePostureChecksID {
|
for _, postureChecksID := range sourcePostureChecksID {
|
||||||
postureChecks := getPostureChecks(a, postureChecksID)
|
postureChecks := a.getPostureChecks(postureChecksID)
|
||||||
if postureChecks == nil {
|
if postureChecks == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -553,8 +552,8 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPostureChecks(account *Account, postureChecksID string) *posture.Checks {
|
func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
||||||
for _, postureChecks := range account.PostureChecks {
|
for _, postureChecks := range a.PostureChecks {
|
||||||
if postureChecks.ID == postureChecksID {
|
if postureChecks.ID == postureChecksID {
|
||||||
return postureChecks
|
return postureChecks
|
||||||
}
|
}
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1233,7 +1234,11 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
eventStore := &activity.InMemoryEventStore{}
|
eventStore := &activity.InMemoryEventStore{}
|
||||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createRouterStore(t *testing.T) (Store, error) {
|
func createRouterStore(t *testing.T) (Store, error) {
|
||||||
|
69
management/server/telemetry/accountmanager_metrics.go
Normal file
69
management/server/telemetry/accountmanager_metrics.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package telemetry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AccountManagerMetrics represents all metrics related to the AccountManager
|
||||||
|
type AccountManagerMetrics struct {
|
||||||
|
ctx context.Context
|
||||||
|
updateAccountPeersDurationMs metric.Float64Histogram
|
||||||
|
getPeerNetworkMapDurationMs metric.Float64Histogram
|
||||||
|
networkMapObjectCount metric.Int64Histogram
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
|
||||||
|
func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*AccountManagerMetrics, error) {
|
||||||
|
updateAccountPeersDurationMs, err := meter.Float64Histogram("management.account.update.account.peers.duration.ms",
|
||||||
|
metric.WithUnit("milliseconds"),
|
||||||
|
metric.WithExplicitBucketBoundaries(
|
||||||
|
0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 30000,
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
getPeerNetworkMapDurationMs, err := meter.Float64Histogram("management.account.get.peer.network.map.duration.ms",
|
||||||
|
metric.WithUnit("milliseconds"),
|
||||||
|
metric.WithExplicitBucketBoundaries(
|
||||||
|
0.1, 0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000,
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
networkMapObjectCount, err := meter.Int64Histogram("management.account.network.map.object.count",
|
||||||
|
metric.WithUnit("objects"),
|
||||||
|
metric.WithExplicitBucketBoundaries(
|
||||||
|
50, 100, 200, 500, 1000, 2500, 5000, 10000,
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AccountManagerMetrics{
|
||||||
|
ctx: ctx,
|
||||||
|
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
|
||||||
|
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
|
||||||
|
networkMapObjectCount: networkMapObjectCount,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountUpdateAccountPeersDuration counts the duration of updating account peers
|
||||||
|
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) {
|
||||||
|
metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountGetPeerNetworkMapDuration counts the duration of getting the peer network map
|
||||||
|
func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration time.Duration) {
|
||||||
|
metrics.getPeerNetworkMapDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountNetworkMapObjects counts the number of network map objects
|
||||||
|
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
|
||||||
|
metrics.networkMapObjectCount.Record(metrics.ctx, count)
|
||||||
|
}
|
@ -20,14 +20,15 @@ const defaultEndpoint = "/metrics"
|
|||||||
|
|
||||||
// MockAppMetrics mocks the AppMetrics interface
|
// MockAppMetrics mocks the AppMetrics interface
|
||||||
type MockAppMetrics struct {
|
type MockAppMetrics struct {
|
||||||
GetMeterFunc func() metric2.Meter
|
GetMeterFunc func() metric2.Meter
|
||||||
CloseFunc func() error
|
CloseFunc func() error
|
||||||
ExposeFunc func(ctx context.Context, port int, endpoint string) error
|
ExposeFunc func(ctx context.Context, port int, endpoint string) error
|
||||||
IDPMetricsFunc func() *IDPMetrics
|
IDPMetricsFunc func() *IDPMetrics
|
||||||
HTTPMiddlewareFunc func() *HTTPMiddleware
|
HTTPMiddlewareFunc func() *HTTPMiddleware
|
||||||
GRPCMetricsFunc func() *GRPCMetrics
|
GRPCMetricsFunc func() *GRPCMetrics
|
||||||
StoreMetricsFunc func() *StoreMetrics
|
StoreMetricsFunc func() *StoreMetrics
|
||||||
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
|
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
|
||||||
|
AddAccountManagerMetricsFunc func() *AccountManagerMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMeter mocks the GetMeter function of the AppMetrics interface
|
// GetMeter mocks the GetMeter function of the AppMetrics interface
|
||||||
@ -94,6 +95,14 @@ func (mock *MockAppMetrics) UpdateChannelMetrics() *UpdateChannelMetrics {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountManagerMetrics mocks the MockAppMetrics function of the AccountManagerMetrics interface
|
||||||
|
func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
|
||||||
|
if mock.AddAccountManagerMetricsFunc != nil {
|
||||||
|
return mock.AddAccountManagerMetricsFunc()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AppMetrics is metrics interface
|
// AppMetrics is metrics interface
|
||||||
type AppMetrics interface {
|
type AppMetrics interface {
|
||||||
GetMeter() metric2.Meter
|
GetMeter() metric2.Meter
|
||||||
@ -104,19 +113,21 @@ type AppMetrics interface {
|
|||||||
GRPCMetrics() *GRPCMetrics
|
GRPCMetrics() *GRPCMetrics
|
||||||
StoreMetrics() *StoreMetrics
|
StoreMetrics() *StoreMetrics
|
||||||
UpdateChannelMetrics() *UpdateChannelMetrics
|
UpdateChannelMetrics() *UpdateChannelMetrics
|
||||||
|
AccountManagerMetrics() *AccountManagerMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
|
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
|
||||||
type defaultAppMetrics struct {
|
type defaultAppMetrics struct {
|
||||||
// Meter can be used by different application parts to create counters and measure things
|
// Meter can be used by different application parts to create counters and measure things
|
||||||
Meter metric2.Meter
|
Meter metric2.Meter
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
idpMetrics *IDPMetrics
|
idpMetrics *IDPMetrics
|
||||||
httpMiddleware *HTTPMiddleware
|
httpMiddleware *HTTPMiddleware
|
||||||
grpcMetrics *GRPCMetrics
|
grpcMetrics *GRPCMetrics
|
||||||
storeMetrics *StoreMetrics
|
storeMetrics *StoreMetrics
|
||||||
updateChannelMetrics *UpdateChannelMetrics
|
updateChannelMetrics *UpdateChannelMetrics
|
||||||
|
accountManagerMetrics *AccountManagerMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
// IDPMetrics returns metrics for the idp package
|
// IDPMetrics returns metrics for the idp package
|
||||||
@ -144,6 +155,11 @@ func (appMetrics *defaultAppMetrics) UpdateChannelMetrics() *UpdateChannelMetric
|
|||||||
return appMetrics.updateChannelMetrics
|
return appMetrics.updateChannelMetrics
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AccountManagerMetrics returns metrics for the account manager
|
||||||
|
func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
|
||||||
|
return appMetrics.accountManagerMetrics
|
||||||
|
}
|
||||||
|
|
||||||
// Close stop application metrics HTTP handler and closes listener.
|
// Close stop application metrics HTTP handler and closes listener.
|
||||||
func (appMetrics *defaultAppMetrics) Close() error {
|
func (appMetrics *defaultAppMetrics) Close() error {
|
||||||
if appMetrics.listener == nil {
|
if appMetrics.listener == nil {
|
||||||
@ -220,13 +236,19 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return &defaultAppMetrics{
|
return &defaultAppMetrics{
|
||||||
Meter: meter,
|
Meter: meter,
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
idpMetrics: idpMetrics,
|
idpMetrics: idpMetrics,
|
||||||
httpMiddleware: middleware,
|
httpMiddleware: middleware,
|
||||||
grpcMetrics: grpcMetrics,
|
grpcMetrics: grpcMetrics,
|
||||||
storeMetrics: storeMetrics,
|
storeMetrics: storeMetrics,
|
||||||
updateChannelMetrics: updateChannelMetrics,
|
updateChannelMetrics: updateChannelMetrics,
|
||||||
|
accountManagerMetrics: accountManagerMetrics,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user