Merge branch 'main' into userspace-router

This commit is contained in:
Viktor Liu 2025-01-02 16:25:04 +01:00
commit c3c6afa37b
46 changed files with 2891 additions and 330 deletions

View File

@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd ignore_words_list: erro,clienta,hastable,iif,groupd,testin
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1 only_warn: 1
golangci: golangci:

View File

@ -197,7 +197,7 @@ func (m *Manager) AllowNetbird() error {
} }
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(
net.ParseIP("0.0.0.0"), net.IP{0, 0, 0, 0},
"all", "all",
nil, nil,
nil, nil,

View File

@ -68,17 +68,16 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
pattern = strings.ToLower(dns.Fqdn(pattern))
origPattern := pattern origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.") isWildcard := strings.HasPrefix(pattern, "*.")
if isWildcard { if isWildcard {
pattern = pattern[2:] pattern = pattern[2:]
} }
pattern = dns.Fqdn(pattern)
origPattern = dns.Fqdn(origPattern)
// First remove any existing handler with same original pattern and priority // First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority { if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil { if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop() c.handlers[i].StopHandler.stop()
} }
@ -126,10 +125,10 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
pattern = dns.Fqdn(pattern) pattern = dns.Fqdn(pattern)
// Find and remove handlers matching both original pattern and priority // Find and remove handlers matching both original pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i] entry := c.handlers[i]
if entry.OrigPattern == pattern && entry.Priority == priority { if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
if entry.StopHandler != nil { if entry.StopHandler != nil {
entry.StopHandler.stop() entry.StopHandler.stop()
} }
@ -144,9 +143,9 @@ func (c *HandlerChain) HasHandlers(pattern string) bool {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
pattern = dns.Fqdn(pattern) pattern = strings.ToLower(dns.Fqdn(pattern))
for _, entry := range c.handlers { for _, entry := range c.handlers {
if entry.Pattern == pattern { if strings.EqualFold(entry.Pattern, pattern) {
return true return true
} }
} }
@ -158,7 +157,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
qname := r.Question[0].Name qname := strings.ToLower(r.Question[0].Name)
log.Tracef("handling DNS request for domain=%s", qname) log.Tracef("handling DNS request for domain=%s", qname)
c.mu.RLock() c.mu.RLock()
@ -187,9 +186,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants subdomain matching, allow suffix match // If handler wants subdomain matching, allow suffix match
// Otherwise require exact match // Otherwise require exact match
if entry.MatchSubdomains { if entry.MatchSubdomains {
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern) matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else { } else {
matched = qname == entry.Pattern matched = strings.EqualFold(qname, entry.Pattern)
} }
} }

View File

@ -507,5 +507,173 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 4: Remove last handler // Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault) chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
assert.False(t, chain.HasHandlers(testDomain)) assert.False(t, chain.HasHandlers(testDomain))
} }
func TestHandlerChain_CaseSensitivity(t *testing.T) {
tests := []struct {
name string
scenario string
addHandlers []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}
query string
expectedCalls int
}{
{
name: "case insensitive exact match",
scenario: "handler registered lowercase, query uppercase",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"example.com.", nbdns.PriorityDefault, false, true},
},
query: "EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "case insensitive wildcard match",
scenario: "handler registered mixed case wildcard, query different case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"*.Example.Com.", nbdns.PriorityDefault, false, true},
},
query: "sub.EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "multiple handlers different case same domain",
scenario: "second handler should replace first despite case difference",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityDefault, false, true},
},
query: "ExAmPlE.cOm.",
expectedCalls: 1,
},
{
name: "subdomain matching case insensitive",
scenario: "handler with MatchSubdomains true should match regardless of case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"example.com.", nbdns.PriorityDefault, true, true},
},
query: "SUB.EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "root zone case insensitive",
scenario: "root zone handler should match regardless of case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{".", nbdns.PriorityDefault, false, true},
},
query: "EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "multiple handlers different priority",
scenario: "should call higher priority handler despite case differences",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityMatchDomain, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
},
query: "example.com.",
expectedCalls: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlerCalls := make(map[string]bool) // track which patterns were called
// Add handlers according to test case
for _, h := range tt.addHandlers {
var handler dns.Handler
pattern := h.pattern // capture pattern for closure
if h.subdomains {
subHandler := &nbdns.MockSubdomainHandler{
Subdomains: true,
}
if h.shouldMatch {
subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
handlerCalls[pattern] = true
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
handler = subHandler
} else {
mockHandler := &nbdns.MockHandler{}
if h.shouldMatch {
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
handlerCalls[pattern] = true
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
handler = mockHandler
}
chain.AddHandler(pattern, handler, h.priority, nil)
}
// Execute request
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
chain.ServeDNS(&mockResponseWriter{}, r)
// Verify each handler was called exactly as expected
for _, h := range tt.addHandlers {
wasCalled := handlerCalls[h.pattern]
assert.Equal(t, h.shouldMatch, wasCalled,
"Handler for pattern %q was %s when it should%s have been",
h.pattern,
map[bool]string{true: "called", false: "not called"}[wasCalled],
map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch])
}
// Verify total number of calls
assert.Equal(t, tt.expectedCalls, len(handlerCalls),
"Wrong number of total handler calls")
})
}
}

View File

@ -83,7 +83,7 @@ func (h *Manager) allowDNSFirewall() error {
IsRange: false, IsRange: false,
Values: []int{ListenPort}, Values: []int{ListenPort},
} }
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
if err != nil { if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err) log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err return err

View File

@ -410,13 +410,9 @@ func (e *Engine) Start() error {
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
if err != nil { if err != nil {
log.Errorf("failed creating firewall manager: %s", err) log.Errorf("failed creating firewall manager: %s", err)
} } else if e.firewall != nil {
if err := e.initFirewall(err); err != nil {
if e.firewall != nil && e.firewall.IsServerRouteSupported() { return err
err = e.routeManager.EnableServerRouter(e.firewall)
if err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
} }
} }
@ -459,6 +455,41 @@ func (e *Engine) Start() error {
return nil return nil
} }
func (e *Engine) initFirewall(error) error {
if e.firewall.IsServerRouteSupported() {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
}
}
if e.rpManager == nil || !e.config.RosenpassEnabled {
return nil
}
rosenpassPort := e.rpManager.GetAddress().Port
port := manager.Port{Values: []int{rosenpassPort}}
// this rule is static and will be torn down on engine down by the firewall manager
if _, err := e.firewall.AddPeerFiltering(
net.IP{0, 0, 0, 0},
manager.ProtocolUDP,
nil,
&port,
manager.RuleDirectionIN,
manager.ActionAccept,
"",
"",
); err != nil {
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
return nil
}
log.Infof("rosenpass interface traffic allowed on port %d", rosenpassPort)
return nil
}
// modifyPeers updates peers that have been modified (e.g. IP address has been changed). // modifyPeers updates peers that have been modified (e.g. IP address has been changed).
// It closes the existing connection, removes it from the peerConns map, and creates a new one. // It closes the existing connection, removes it from the peerConns map, and creates a new one.
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {

View File

@ -42,7 +42,7 @@ import (
nbContext "github.com/netbirdio/netbird/management/server/context" nbContext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/groups"
httpapi "github.com/netbirdio/netbird/management/server/http" nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
@ -281,7 +281,7 @@ var (
routersManager := routers.NewManager(store, permissionsManager, accountManager) routersManager := routers.NewManager(store, permissionsManager, accountManager)
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager) networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
if err != nil { if err != nil {
return fmt.Errorf("failed creating HTTP API handler: %v", err) return fmt.Errorf("failed creating HTTP API handler: %v", err)
} }

View File

@ -161,7 +161,7 @@ type DefaultAccountManager struct {
externalCacheManager ExternalCacheManager externalCacheManager ExternalCacheManager
ctx context.Context ctx context.Context
eventStore activity.Store eventStore activity.Store
geo *geolocation.Geolocation geo geolocation.Geolocation
requestBuffer *AccountRequestBuffer requestBuffer *AccountRequestBuffer
@ -244,7 +244,7 @@ func BuildManager(
singleAccountModeDomain string, singleAccountModeDomain string,
dnsDomain string, dnsDomain string,
eventStore activity.Store, eventStore activity.Store,
geo *geolocation.Geolocation, geo geolocation.Geolocation,
userDeleteFromIDPEnabled bool, userDeleteFromIDPEnabled bool,
integratedPeerValidator integrated_validator.IntegratedValidator, integratedPeerValidator integrated_validator.IntegratedValidator,
metrics telemetry.AppMetrics, metrics telemetry.AppMetrics,
@ -1252,6 +1252,12 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled. // and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
if claim, exists := claims.Raw[jwtclaims.IsToken]; exists {
if isToken, ok := claim.(bool); ok && isToken {
return nil
}
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err

View File

@ -27,7 +27,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/jwtclaims"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
@ -38,47 +37,6 @@ import (
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
type MocIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
}
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
}
return update, false, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
for _, peer := range peers {
validatedPeers[peer.ID] = struct{}{}
}
return validatedPeers, nil
}
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil
}
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
}
func (MocIntegratedValidator) Stop(_ context.Context) {
}
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) { func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) {
t.Helper() t.Helper()
peer := &nbpeer.Peer{ peer := &nbpeer.Peer{
@ -2729,6 +2687,19 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
t.Run("skip sync for token auth type", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
})
t.Run("empty jwt groups", func(t *testing.T) { t.Run("empty jwt groups", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{ claims := jwtclaims.AuthorizationClaims{
UserId: "user1", UserId: "user1",
@ -2822,7 +2793,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.Len(t, user.AutoGroups, 1, "new group should be added") assert.Len(t, user.AutoGroups, 1, "new group should be added")
}) })
t.Run("remove all JWT groups", func(t *testing.T) { t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{ claims := jwtclaims.AuthorizationClaims{
UserId: "user1", UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{}}, Raw: jwt.MapClaims{"groups": []interface{}{}},
@ -2833,7 +2804,20 @@ func TestAccount_SetJWTGroups(t *testing.T) {
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
})
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user2",
Raw: jwt.MapClaims{},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
}) })
} }
@ -3037,9 +3021,9 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
minMsPerOpCICD float64 minMsPerOpCICD float64
maxMsPerOpCICD float64 maxMsPerOpCICD float64
}{ }{
{"Small", 50, 5, 1, 3, 3, 10}, {"Small", 50, 5, 1, 3, 3, 11},
{"Medium", 500, 100, 7, 13, 10, 70}, {"Medium", 500, 100, 7, 13, 10, 70},
{"Large", 5000, 200, 65, 80, 60, 200}, {"Large", 5000, 200, 65, 80, 60, 220},
{"Small single", 50, 10, 1, 3, 3, 70}, {"Small single", 50, 10, 1, 3, 3, 70},
{"Medium single", 500, 10, 7, 13, 10, 26}, {"Medium single", 500, 10, 7, 13, 10, 26},
{"Large 5", 5000, 15, 65, 80, 60, 200}, {"Large 5", 5000, 15, 65, 80, 60, 200},
@ -3179,7 +3163,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
maxMsPerOpCICD float64 maxMsPerOpCICD float64
}{ }{
{"Small", 50, 5, 107, 120, 107, 160}, {"Small", 50, 5, 107, 120, 107, 160},
{"Medium", 500, 100, 105, 140, 105, 190}, {"Medium", 500, 100, 105, 140, 105, 220},
{"Large", 5000, 200, 180, 220, 180, 350}, {"Large", 5000, 200, 180, 220, 180, 350},
{"Small single", 50, 10, 107, 120, 105, 160}, {"Small single", 50, 10, 107, 120, 105, 160},
{"Medium single", 500, 10, 105, 140, 105, 170}, {"Medium single", 500, 10, 105, 140, 105, 170},

View File

@ -14,7 +14,14 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
type Geolocation struct { type Geolocation interface {
Lookup(ip net.IP) (*Record, error)
GetAllCountries() ([]Country, error)
GetCitiesByCountry(countryISOCode string) ([]City, error)
Stop() error
}
type geolocationImpl struct {
mmdbPath string mmdbPath string
mux sync.RWMutex mux sync.RWMutex
db *maxminddb.Reader db *maxminddb.Reader
@ -54,7 +61,7 @@ const (
geonamesdbPattern = "geonames_*.db" geonamesdbPattern = "geonames_*.db"
) )
func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geolocation, error) { func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (Geolocation, error) {
mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern) mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern)
mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate) mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate)
if err != nil { if err != nil {
@ -86,7 +93,7 @@ func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geol
return nil, err return nil, err
} }
geo := &Geolocation{ geo := &geolocationImpl{
mmdbPath: mmdbPath, mmdbPath: mmdbPath,
mux: sync.RWMutex{}, mux: sync.RWMutex{},
db: db, db: db,
@ -113,7 +120,7 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) {
return db, nil return db, nil
} }
func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) { func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) {
gl.mux.RLock() gl.mux.RLock()
defer gl.mux.RUnlock() defer gl.mux.RUnlock()
@ -127,7 +134,7 @@ func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
} }
// GetAllCountries retrieves a list of all countries. // GetAllCountries retrieves a list of all countries.
func (gl *Geolocation) GetAllCountries() ([]Country, error) { func (gl *geolocationImpl) GetAllCountries() ([]Country, error) {
allCountries, err := gl.locationDB.GetAllCountries() allCountries, err := gl.locationDB.GetAllCountries()
if err != nil { if err != nil {
return nil, err return nil, err
@ -143,7 +150,7 @@ func (gl *Geolocation) GetAllCountries() ([]Country, error) {
} }
// GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code. // GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code.
func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) { func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, error) {
allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode) allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode)
if err != nil { if err != nil {
return nil, err return nil, err
@ -158,7 +165,7 @@ func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error)
return cities, nil return cities, nil
} }
func (gl *Geolocation) Stop() error { func (gl *geolocationImpl) Stop() error {
close(gl.stopCh) close(gl.stopCh)
if gl.db != nil { if gl.db != nil {
if err := gl.db.Close(); err != nil { if err := gl.db.Close(); err != nil {
@ -259,3 +266,21 @@ func cleanupMaxMindDatabases(ctx context.Context, dataDir string, mmdbFile strin
} }
return nil return nil
} }
type Mock struct{}
func (g *Mock) Lookup(ip net.IP) (*Record, error) {
return &Record{}, nil
}
func (g *Mock) GetAllCountries() ([]Country, error) {
return []Country{}, nil
}
func (g *Mock) GetCitiesByCountry(countryISOCode string) ([]City, error) {
return []City{}, nil
}
func (g *Mock) Stop() error {
return nil
}

View File

@ -24,7 +24,7 @@ func TestGeoLite_Lookup(t *testing.T) {
db, err := openDB(filename) db, err := openDB(filename)
assert.NoError(t, err) assert.NoError(t, err)
geo := &Geolocation{ geo := &geolocationImpl{
mux: sync.RWMutex{}, mux: sync.RWMutex{},
db: db, db: db,
stopCh: make(chan struct{}), stopCh: make(chan struct{}),

View File

@ -474,6 +474,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
} }
if len(group.Resources) > 0 {
return &GroupLinkError{"network resource", group.Resources[0].ID}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)} return &GroupLinkError{"route", string(linkedRoute.NetID)}
} }
@ -529,7 +533,10 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
} }
for _, r := range routes { for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { isLinked := slices.Contains(r.Groups, groupID) ||
slices.Contains(r.PeerGroups, groupID) ||
slices.Contains(r.AccessControlGroups, groupID)
if isLinked {
return true, r return true, r
} }
} }

View File

@ -38,7 +38,7 @@ type GRPCServer struct {
peersUpdateManager *PeersUpdateManager peersUpdateManager *PeersUpdateManager
config *Config config *Config
secretsManager SecretsManager secretsManager SecretsManager
jwtValidator *jwtclaims.JWTValidator jwtValidator jwtclaims.JWTValidator
jwtClaimsExtractor *jwtclaims.ClaimsExtractor jwtClaimsExtractor *jwtclaims.ClaimsExtractor
appMetrics telemetry.AppMetrics appMetrics telemetry.AppMetrics
ephemeralManager *EphemeralManager ephemeralManager *EphemeralManager
@ -61,7 +61,7 @@ func NewServer(
return nil, err return nil, err
} }
var jwtValidator *jwtclaims.JWTValidator var jwtValidator jwtclaims.JWTValidator
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
jwtValidator, err = jwtclaims.NewJWTValidator( jwtValidator, err = jwtclaims.NewJWTValidator(

View File

@ -725,10 +725,6 @@ components:
PolicyRuleMinimum: PolicyRuleMinimum:
type: object type: object
properties: properties:
id:
description: Policy rule ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
name: name:
description: Policy rule name identifier description: Policy rule name identifier
type: string type: string
@ -790,6 +786,31 @@ components:
- end - end
PolicyRuleUpdate: PolicyRuleUpdate:
allOf:
- $ref: '#/components/schemas/PolicyRuleMinimum'
- type: object
properties:
id:
description: Policy rule ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
sources:
description: Policy rule source group IDs
type: array
items:
type: string
example: "ch8i4ug6lnn4g9hqv797"
destinations:
description: Policy rule destination group IDs
type: array
items:
type: string
example: "ch8i4ug6lnn4g9h7v7m0"
required:
- sources
- destinations
PolicyRuleCreate:
allOf: allOf:
- $ref: '#/components/schemas/PolicyRuleMinimum' - $ref: '#/components/schemas/PolicyRuleMinimum'
- type: object - type: object
@ -817,6 +838,10 @@ components:
- $ref: '#/components/schemas/PolicyRuleMinimum' - $ref: '#/components/schemas/PolicyRuleMinimum'
- type: object - type: object
properties: properties:
id:
description: Policy rule ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
sources: sources:
description: Policy rule source group IDs description: Policy rule source group IDs
type: array type: array
@ -836,10 +861,6 @@ components:
PolicyMinimum: PolicyMinimum:
type: object type: object
properties: properties:
id:
description: Policy ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
name: name:
description: Policy name identifier description: Policy name identifier
type: string type: string
@ -854,7 +875,6 @@ components:
example: true example: true
required: required:
- name - name
- description
- enabled - enabled
PolicyUpdate: PolicyUpdate:
allOf: allOf:
@ -874,11 +894,33 @@ components:
$ref: '#/components/schemas/PolicyRuleUpdate' $ref: '#/components/schemas/PolicyRuleUpdate'
required: required:
- rules - rules
PolicyCreate:
allOf:
- $ref: '#/components/schemas/PolicyMinimum'
- type: object
properties:
source_posture_checks:
description: Posture checks ID's applied to policy source groups
type: array
items:
type: string
example: "chacdk86lnnboviihd70"
rules:
description: Policy rule object for policy UI editor
type: array
items:
$ref: '#/components/schemas/PolicyRuleUpdate'
required:
- rules
Policy: Policy:
allOf: allOf:
- $ref: '#/components/schemas/PolicyMinimum' - $ref: '#/components/schemas/PolicyMinimum'
- type: object - type: object
properties: properties:
id:
description: Policy ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
source_posture_checks: source_posture_checks:
description: Posture checks ID's applied to policy source groups description: Posture checks ID's applied to policy source groups
type: array type: array
@ -2463,7 +2505,7 @@ paths:
content: content:
'application/json': 'application/json':
schema: schema:
$ref: '#/components/schemas/PolicyUpdate' $ref: '#/components/schemas/PolicyCreate'
responses: responses:
'200': '200':
description: A Policy object description: A Policy object

View File

@ -879,7 +879,7 @@ type PersonalAccessTokenRequest struct {
// Policy defines model for Policy. // Policy defines model for Policy.
type Policy struct { type Policy struct {
// Description Policy friendly description // Description Policy friendly description
Description string `json:"description"` Description *string `json:"description,omitempty"`
// Enabled Policy status // Enabled Policy status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
@ -897,16 +897,31 @@ type Policy struct {
SourcePostureChecks []string `json:"source_posture_checks"` SourcePostureChecks []string `json:"source_posture_checks"`
} }
// PolicyMinimum defines model for PolicyMinimum. // PolicyCreate defines model for PolicyCreate.
type PolicyMinimum struct { type PolicyCreate struct {
// Description Policy friendly description // Description Policy friendly description
Description string `json:"description"` Description *string `json:"description,omitempty"`
// Enabled Policy status // Enabled Policy status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Id Policy ID // Name Policy name identifier
Id *string `json:"id,omitempty"` Name string `json:"name"`
// Rules Policy rule object for policy UI editor
Rules []PolicyRuleUpdate `json:"rules"`
// SourcePostureChecks Posture checks ID's applied to policy source groups
SourcePostureChecks *[]string `json:"source_posture_checks,omitempty"`
}
// PolicyMinimum defines model for PolicyMinimum.
type PolicyMinimum struct {
// Description Policy friendly description
Description *string `json:"description,omitempty"`
// Enabled Policy status
Enabled bool `json:"enabled"`
// Name Policy name identifier // Name Policy name identifier
Name string `json:"name"` Name string `json:"name"`
@ -970,9 +985,6 @@ type PolicyRuleMinimum struct {
// Enabled Policy rule status // Enabled Policy rule status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Id Policy rule ID
Id *string `json:"id,omitempty"`
// Name Policy rule name identifier // Name Policy rule name identifier
Name string `json:"name"` Name string `json:"name"`
@ -1039,14 +1051,11 @@ type PolicyRuleUpdateProtocol string
// PolicyUpdate defines model for PolicyUpdate. // PolicyUpdate defines model for PolicyUpdate.
type PolicyUpdate struct { type PolicyUpdate struct {
// Description Policy friendly description // Description Policy friendly description
Description string `json:"description"` Description *string `json:"description,omitempty"`
// Enabled Policy status // Enabled Policy status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Id Policy ID
Id *string `json:"id,omitempty"`
// Name Policy name identifier // Name Policy name identifier
Name string `json:"name"` Name string `json:"name"`
@ -1473,7 +1482,7 @@ type PutApiPeersPeerIdJSONRequestBody = PeerRequest
type PostApiPoliciesJSONRequestBody = PolicyUpdate type PostApiPoliciesJSONRequestBody = PolicyUpdate
// PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType. // PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType.
type PutApiPoliciesPolicyIdJSONRequestBody = PolicyUpdate type PutApiPoliciesPolicyIdJSONRequestBody = PolicyCreate
// PostApiPostureChecksJSONRequestBody defines body for PostApiPostureChecks for application/json ContentType. // PostApiPostureChecksJSONRequestBody defines body for PostApiPostureChecks for application/json ContentType.
type PostApiPostureChecksJSONRequestBody = PostureCheckUpdate type PostApiPostureChecksJSONRequestBody = PostureCheckUpdate

View File

@ -35,15 +35,8 @@ import (
const apiPrefix = "/api" const apiPrefix = "/api"
type apiHandler struct { // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
Router *mux.Router func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
AccountManager s.AccountManager
geolocationManager *geolocation.Geolocation
AuthCfg configs.AuthCfg
}
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
func APIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
claimsExtractor := jwtclaims.NewClaimsExtractor( claimsExtractor := jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
@ -78,27 +71,20 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, networksMa
router := rootRouter.PathPrefix(prefix).Subrouter() router := rootRouter.PathPrefix(prefix).Subrouter()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler) router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
api := apiHandler{ if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
Router: router,
AccountManager: accountManager,
geolocationManager: LocationManager,
AuthCfg: authCfg,
}
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
return nil, fmt.Errorf("register integrations endpoints: %w", err) return nil, fmt.Errorf("register integrations endpoints: %w", err)
} }
accounts.AddEndpoints(api.AccountManager, authCfg, router) accounts.AddEndpoints(accountManager, authCfg, router)
peers.AddEndpoints(api.AccountManager, authCfg, router) peers.AddEndpoints(accountManager, authCfg, router)
users.AddEndpoints(api.AccountManager, authCfg, router) users.AddEndpoints(accountManager, authCfg, router)
setup_keys.AddEndpoints(api.AccountManager, authCfg, router) setup_keys.AddEndpoints(accountManager, authCfg, router)
policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router) policies.AddEndpoints(accountManager, LocationManager, authCfg, router)
groups.AddEndpoints(api.AccountManager, authCfg, router) groups.AddEndpoints(accountManager, authCfg, router)
routes.AddEndpoints(api.AccountManager, authCfg, router) routes.AddEndpoints(accountManager, authCfg, router)
dns.AddEndpoints(api.AccountManager, authCfg, router) dns.AddEndpoints(accountManager, authCfg, router)
events.AddEndpoints(api.AccountManager, authCfg, router) events.AddEndpoints(accountManager, authCfg, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager, api.AccountManager.GetAccountIDFromToken, authCfg, router) networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, accountManager.GetAccountIDFromToken, authCfg, router)
return rootRouter, nil return rootRouter, nil
} }

View File

@ -22,18 +22,18 @@ var (
// geolocationsHandler is a handler that returns locations. // geolocationsHandler is a handler that returns locations.
type geolocationsHandler struct { type geolocationsHandler struct {
accountManager server.AccountManager accountManager server.AccountManager
geolocationManager *geolocation.Geolocation geolocationManager geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
func addLocationsEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg) locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg)
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
} }
// newGeolocationsHandlerHandler creates a new Geolocations handler // newGeolocationsHandlerHandler creates a new Geolocations handler
func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler {
return &geolocationsHandler{ return &geolocationsHandler{
accountManager: accountManager, accountManager: accountManager,
geolocationManager: geolocationManager, geolocationManager: geolocationManager,

View File

@ -23,7 +23,7 @@ type handler struct {
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
func AddEndpoints(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
policiesHandler := newHandler(accountManager, authCfg) policiesHandler := newHandler(accountManager, authCfg)
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
@ -133,16 +133,21 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
return return
} }
description := ""
if req.Description != nil {
description = *req.Description
}
policy := &types.Policy{ policy := &types.Policy{
ID: policyID, ID: policyID,
AccountID: accountID, AccountID: accountID,
Name: req.Name, Name: req.Name,
Enabled: req.Enabled, Enabled: req.Enabled,
Description: req.Description, Description: description,
} }
for _, rule := range req.Rules { for _, rule := range req.Rules {
var ruleID string var ruleID string
if rule.Id != nil { if rule.Id != nil && policyID != "" {
ruleID = *rule.Id ruleID = *rule.Id
} }
@ -370,7 +375,7 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
ap := &api.Policy{ ap := &api.Policy{
Id: &policy.ID, Id: &policy.ID,
Name: policy.Name, Name: policy.Name,
Description: policy.Description, Description: &policy.Description,
Enabled: policy.Enabled, Enabled: policy.Enabled,
SourcePostureChecks: policy.SourcePostureChecks, SourcePostureChecks: policy.SourcePostureChecks,
} }

View File

@ -154,6 +154,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
func TestPoliciesWritePolicy(t *testing.T) { func TestPoliciesWritePolicy(t *testing.T) {
str := func(s string) *string { return &s } str := func(s string) *string { return &s }
emptyString := ""
tt := []struct { tt := []struct {
name string name string
expectedStatus int expectedStatus int
@ -186,6 +187,7 @@ func TestPoliciesWritePolicy(t *testing.T) {
expectedPolicy: &api.Policy{ expectedPolicy: &api.Policy{
Id: str("id-was-set"), Id: str("id-was-set"),
Name: "Default POSTed Policy", Name: "Default POSTed Policy",
Description: &emptyString,
Rules: []api.PolicyRule{ Rules: []api.PolicyRule{
{ {
Id: str("id-was-set"), Id: str("id-was-set"),
@ -234,6 +236,7 @@ func TestPoliciesWritePolicy(t *testing.T) {
expectedPolicy: &api.Policy{ expectedPolicy: &api.Policy{
Id: str("id-existed"), Id: str("id-existed"),
Name: "Default POSTed Policy", Name: "Default POSTed Policy",
Description: &emptyString,
Rules: []api.PolicyRule{ Rules: []api.PolicyRule{
{ {
Id: str("id-existed"), Id: str("id-existed"),

View File

@ -19,11 +19,11 @@ import (
// postureChecksHandler is a handler that returns posture checks of the account. // postureChecksHandler is a handler that returns posture checks of the account.
type postureChecksHandler struct { type postureChecksHandler struct {
accountManager server.AccountManager accountManager server.AccountManager
geolocationManager *geolocation.Geolocation geolocationManager geolocation.Geolocation
claimsExtractor *jwtclaims.ClaimsExtractor claimsExtractor *jwtclaims.ClaimsExtractor
} }
func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg) postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg)
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
@ -34,7 +34,7 @@ func addPostureCheckEndpoint(accountManager server.AccountManager, locationManag
} }
// newPostureChecksHandler creates a new PostureChecks handler // newPostureChecksHandler creates a new PostureChecks handler
func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler {
return &postureChecksHandler{ return &postureChecksHandler{
accountManager: accountManager, accountManager: accountManager,
geolocationManager: geolocationManager, geolocationManager: geolocationManager,

View File

@ -70,7 +70,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH
return claims.AccountId, claims.UserId, nil return claims.AccountId, claims.UserId, nil
}, },
}, },
geolocationManager: &geolocation.Geolocation{}, geolocationManager: &geolocation.Mock{},
claimsExtractor: jwtclaims.NewClaimsExtractor( claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{

View File

@ -93,7 +93,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
return return
} }
apiSetupKeys := toResponseBody(setupKey) apiSetupKeys := ToResponseBody(setupKey)
// for the creation we need to send the plain key // for the creation we need to send the plain key
apiSetupKeys.Key = setupKey.Key apiSetupKeys.Key = setupKey.Key
@ -183,7 +183,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
apiSetupKeys := make([]*api.SetupKey, 0) apiSetupKeys := make([]*api.SetupKey, 0)
for _, key := range setupKeys { for _, key := range setupKeys {
apiSetupKeys = append(apiSetupKeys, toResponseBody(key)) apiSetupKeys = append(apiSetupKeys, ToResponseBody(key))
} }
util.WriteJSONObject(r.Context(), w, apiSetupKeys) util.WriteJSONObject(r.Context(), w, apiSetupKeys)
@ -216,14 +216,14 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) { func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200) w.WriteHeader(200)
err := json.NewEncoder(w).Encode(toResponseBody(key)) err := json.NewEncoder(w).Encode(ToResponseBody(key))
if err != nil { if err != nil {
util.WriteError(ctx, err, w) util.WriteError(ctx, err, w)
return return
} }
} }
func toResponseBody(key *types.SetupKey) *api.SetupKey { func ToResponseBody(key *types.SetupKey) *api.SetupKey {
var state string var state string
switch { switch {
case key.IsExpired(): case key.IsExpired():

View File

@ -26,7 +26,6 @@ const (
newSetupKeyName = "New Setup Key" newSetupKeyName = "New Setup Key"
updatedSetupKeyName = "KKKey" updatedSetupKeyName = "KKKey"
notFoundSetupKeyID = "notFoundSetupKeyID" notFoundSetupKeyID = "notFoundSetupKeyID"
testAccountID = "test_id"
) )
func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey, func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey,
@ -81,7 +80,7 @@ func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKe
return jwtclaims.AuthorizationClaims{ return jwtclaims.AuthorizationClaims{
UserId: user.Id, UserId: user.Id,
Domain: "hotmail.com", Domain: "hotmail.com",
AccountId: testAccountID, AccountId: "testAccountId",
} }
}), }),
), ),
@ -102,7 +101,7 @@ func TestSetupKeysHandlers(t *testing.T) {
updatedDefaultSetupKey.Name = updatedSetupKeyName updatedDefaultSetupKey.Name = updatedSetupKeyName
updatedDefaultSetupKey.Revoked = true updatedDefaultSetupKey.Revoked = true
expectedNewKey := toResponseBody(newSetupKey) expectedNewKey := ToResponseBody(newSetupKey)
expectedNewKey.Key = plainKey expectedNewKey.Key = plainKey
tt := []struct { tt := []struct {
name string name string
@ -120,7 +119,7 @@ func TestSetupKeysHandlers(t *testing.T) {
requestPath: "/api/setup-keys", requestPath: "/api/setup-keys",
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedBody: true, expectedBody: true,
expectedSetupKeys: []*api.SetupKey{toResponseBody(defaultSetupKey)}, expectedSetupKeys: []*api.SetupKey{ToResponseBody(defaultSetupKey)},
}, },
{ {
name: "Get Existing Setup Key", name: "Get Existing Setup Key",
@ -128,7 +127,7 @@ func TestSetupKeysHandlers(t *testing.T) {
requestPath: "/api/setup-keys/" + existingSetupKeyID, requestPath: "/api/setup-keys/" + existingSetupKeyID,
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedBody: true, expectedBody: true,
expectedSetupKey: toResponseBody(defaultSetupKey), expectedSetupKey: ToResponseBody(defaultSetupKey),
}, },
{ {
name: "Get Not Existing Setup Key", name: "Get Not Existing Setup Key",
@ -159,7 +158,7 @@ func TestSetupKeysHandlers(t *testing.T) {
))), ))),
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedBody: true, expectedBody: true,
expectedSetupKey: toResponseBody(updatedDefaultSetupKey), expectedSetupKey: ToResponseBody(updatedDefaultSetupKey),
}, },
{ {
name: "Delete Setup Key", name: "Delete Setup Key",
@ -228,7 +227,7 @@ func TestSetupKeysHandlers(t *testing.T) {
func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) { func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) {
t.Helper() t.Helper()
// this comparison is done manually because when converting to JSON dates formatted differently // this comparison is done manually because when converting to JSON dates formatted differently
// assert.Equal(t, got.UpdatedAt, tc.expectedSetupKey.UpdatedAt) //doesn't work // assert.Equal(t, got.UpdatedAt, tc.expectedResponse.UpdatedAt) //doesn't work
assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "") assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "")
assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "") assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "")
assert.Equal(t, got.Name, expected.Name) assert.Equal(t, got.Name, expected.Name)

View File

@ -175,6 +175,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
claimMaps[jwtclaims.IsToken] = true
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information. // Update the current request with the new context information.

View File

@ -0,0 +1,226 @@
package benchmarks
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
)
// Map to store peers, groups, users, and setupKeys by name
var benchCasesSetupKeys = map[string]testing_tools.BenchmarkCase{
"Setup Keys - XS": {Peers: 10000, Groups: 10000, Users: 10000, SetupKeys: 5},
"Setup Keys - S": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 100},
"Setup Keys - M": {Peers: 100, Groups: 20, Users: 20, SetupKeys: 1000},
"Setup Keys - L": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 5000},
"Peers - L": {Peers: 10000, Groups: 5, Users: 5, SetupKeys: 5000},
"Groups - L": {Peers: 5, Groups: 10000, Users: 5, SetupKeys: 5000},
"Users - L": {Peers: 5, Groups: 5, Users: 10000, SetupKeys: 5000},
"Setup Keys - XL": {Peers: 500, Groups: 50, Users: 100, SetupKeys: 25000},
}
func BenchmarkCreateSetupKey(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
requestBody := api.CreateSetupKeyRequest{
AutoGroups: []string{testing_tools.TestGroupId},
ExpiresIn: testing_tools.ExpiresIn,
Name: testing_tools.NewKeyName + strconv.Itoa(i),
Type: "reusable",
UsageLimit: 0,
}
// the time marshal will be recorded as well but for our use case that is ok
body, err := json.Marshal(requestBody)
assert.NoError(b, err)
req := testing_tools.BuildRequest(b, body, http.MethodPost, "/api/setup-keys", testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
})
}
}
func BenchmarkUpdateSetupKey(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
"Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
groupId := testing_tools.TestGroupId
if i%2 == 0 {
groupId = testing_tools.NewGroupId
}
requestBody := api.SetupKeyRequest{
AutoGroups: []string{groupId},
Revoked: false,
}
// the time marshal will be recorded as well but for our use case that is ok
body, err := json.Marshal(requestBody)
assert.NoError(b, err)
req := testing_tools.BuildRequest(b, body, http.MethodPut, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
})
}
}
func BenchmarkGetOneSetupKey(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
})
}
}
func BenchmarkGetAllSetupKeys(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12},
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15},
"Setup Keys - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 40},
"Setup Keys - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150},
"Peers - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150},
"Groups - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150},
"Users - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150},
"Setup Keys - XL": {MinMsPerOpLocal: 140, MaxMsPerOpLocal: 220, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 500},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
})
}
}
func BenchmarkDeleteSetupKey(b *testing.B) {
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
"Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
}
log.SetOutput(io.Discard)
defer log.SetOutput(os.Stderr)
recorder := httptest.NewRecorder()
for name, bc := range benchCasesSetupKeys {
b.Run(name, func(b *testing.B) {
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000)
b.ResetTimer()
start := time.Now()
for i := 0; i < b.N; i++ {
req := testing_tools.BuildRequest(b, nil, http.MethodDelete, "/api/setup-keys/"+"oldkey-"+strconv.Itoa(i), testing_tools.TestAdminId)
apiHandler.ServeHTTP(recorder, req)
}
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
})
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,24 @@
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,'');
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',1,0);
INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',3,0);
INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,'0001-01-01 00:00:00+00:00','["testGroupId"]',5,1);

View File

@ -0,0 +1,307 @@
package testing_tools
import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
nbhttp "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/configs"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
)
const (
TestAccountId = "testAccountId"
TestPeerId = "testPeerId"
TestGroupId = "testGroupId"
TestKeyId = "testKeyId"
TestUserId = "testUserId"
TestAdminId = "testAdminId"
TestOwnerId = "testOwnerId"
TestServiceUserId = "testServiceUserId"
TestServiceAdminId = "testServiceAdminId"
BlockedUserId = "blockedUserId"
OtherUserId = "otherUserId"
InvalidToken = "invalidToken"
NewKeyName = "newKey"
NewGroupId = "newGroupId"
ExpiresIn = 3600
RevokedKeyId = "revokedKeyId"
ExpiredKeyId = "expiredKeyId"
ExistingKeyName = "existingKey"
)
type TB interface {
Cleanup(func())
Helper()
Errorf(format string, args ...any)
Fatalf(format string, args ...any)
TempDir() string
}
// BenchmarkCase defines a single benchmark test case
type BenchmarkCase struct {
Peers int
Groups int
Users int
SetupKeys int
}
// PerformanceMetrics holds the performance expectations
type PerformanceMetrics struct {
MinMsPerOpLocal float64
MaxMsPerOpLocal float64
MinMsPerOpCICD float64
MaxMsPerOpCICD float64
}
func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage) (http.Handler, server.AccountManager, chan struct{}) {
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
if err != nil {
t.Fatalf("Failed to create test store: %v", err)
}
t.Cleanup(cleanup)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
t.Fatalf("Failed to create metrics: %v", err)
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId)
done := make(chan struct{})
go func() {
if expectedPeerUpdate != nil {
peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate)
} else {
peerShouldNotReceiveUpdate(t, updMsg)
}
close(done)
}()
geoMock := &geolocation.Mock{}
validatorMock := server.MocIntegratedValidator{}
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics)
if err != nil {
t.Fatalf("Failed to create manager: %v", err)
}
networksManagerMock := networks.NewManagerMock()
resourcesManagerMock := resources.NewManagerMock()
routersManagerMock := routers.NewManagerMock()
groupsManagerMock := groups.NewManagerMock()
apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, &jwtclaims.JwtValidatorMock{}, metrics, configs.AuthCfg{}, validatorMock)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
return apiHandler, am, done
}
func peerShouldNotReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
t.Errorf("Unexpected message received: %+v", msg)
case <-time.After(500 * time.Millisecond):
return
}
}
func peerShouldReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
t.Helper()
select {
case msg := <-updateMessage:
if msg == nil {
t.Errorf("Received nil update message, expected valid message")
}
assert.Equal(t, expected, msg)
case <-time.After(500 * time.Millisecond):
t.Errorf("Timed out waiting for update message")
}
}
func BuildRequest(t TB, requestBody []byte, requestType, requestPath, user string) *http.Request {
t.Helper()
req := httptest.NewRequest(requestType, requestPath, bytes.NewBuffer(requestBody))
req.Header.Set("Authorization", "Bearer "+user)
return req
}
func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int, expectResponse bool) ([]byte, bool) {
t.Helper()
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
if !expectResponse {
return nil, false
}
if status := recorder.Code; status != expectedStatus {
t.Fatalf("handler returned wrong status code: got %v want %v, content: %s",
status, expectedStatus, string(content))
}
return content, expectedStatus == http.StatusOK
}
func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) {
b.Helper()
ctx := context.Background()
account, err := am.GetAccount(ctx, TestAccountId)
if err != nil {
b.Fatalf("Failed to get account: %v", err)
}
// Create peers
for i := 0; i < peers; i++ {
peerKey, _ := wgtypes.GeneratePrivateKey()
peer := &nbpeer.Peer{
ID: fmt.Sprintf("oldpeer-%d", i),
DNSLabel: fmt.Sprintf("oldpeer-%d", i),
Key: peerKey.PublicKey().String(),
IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)),
Status: &nbpeer.PeerStatus{},
UserID: TestUserId,
}
account.Peers[peer.ID] = peer
}
// Create users
for i := 0; i < users; i++ {
user := &types.User{
Id: fmt.Sprintf("olduser-%d", i),
AccountID: account.Id,
Role: types.UserRoleUser,
}
account.Users[user.Id] = user
}
for i := 0; i < setupKeys; i++ {
key := &types.SetupKey{
Id: fmt.Sprintf("oldkey-%d", i),
AccountID: account.Id,
AutoGroups: []string{"someGroupID"},
ExpiresAt: time.Now().Add(ExpiresIn * time.Second),
Name: NewKeyName + strconv.Itoa(i),
Type: "reusable",
UsageLimit: 0,
}
account.SetupKeys[key.Id] = key
}
// Create groups and policies
account.Policies = make([]*types.Policy, 0, groups)
for i := 0; i < groups; i++ {
groupID := fmt.Sprintf("group-%d", i)
group := &types.Group{
ID: groupID,
Name: fmt.Sprintf("Group %d", i),
}
for j := 0; j < peers/groups; j++ {
peerIndex := i*(peers/groups) + j
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
}
account.Groups[groupID] = group
// Create a policy for this group
policy := &types.Policy{
ID: fmt.Sprintf("policy-%d", i),
Name: fmt.Sprintf("Policy for Group %d", i),
Enabled: true,
Rules: []*types.PolicyRule{
{
ID: fmt.Sprintf("rule-%d", i),
Name: fmt.Sprintf("Rule for Group %d", i),
Enabled: true,
Sources: []string{groupID},
Destinations: []string{groupID},
Bidirectional: true,
Protocol: types.PolicyRuleProtocolALL,
Action: types.PolicyTrafficActionAccept,
},
},
}
account.Policies = append(account.Policies, policy)
}
account.PostureChecks = []*posture.Checks{
{
ID: "PostureChecksAll",
Name: "All",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
},
}
err = am.Store.SaveAccount(context.Background(), account)
if err != nil {
b.Fatalf("Failed to save account: %v", err)
}
}
func EvaluateBenchmarkResults(b *testing.B, name string, duration time.Duration, perfMetrics PerformanceMetrics, recorder *httptest.ResponseRecorder) {
b.Helper()
if recorder.Code != http.StatusOK {
b.Fatalf("Benchmark %s failed: unexpected status code %d", name, recorder.Code)
}
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
b.ReportMetric(msPerOp, "ms/op")
minExpected := perfMetrics.MinMsPerOpLocal
maxExpected := perfMetrics.MaxMsPerOpLocal
if os.Getenv("CI") == "true" {
minExpected = perfMetrics.MinMsPerOpCICD
maxExpected = perfMetrics.MaxMsPerOpCICD
}
if msPerOp < minExpected {
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", name, msPerOp, minExpected)
}
if msPerOp > maxExpected {
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", name, msPerOp, maxExpected)
}
}

View File

@ -7,6 +7,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
) )
@ -78,3 +79,45 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) { func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) {
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
} }
type MocIntegratedValidator struct {
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
}
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
if a.ValidatePeerFunc != nil {
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
}
return update, false, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
for _, peer := range peers {
validatedPeers[peer.ID] = struct{}{}
}
return validatedPeers, nil
}
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil
}
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
// just a dummy
}
func (MocIntegratedValidator) Stop(_ context.Context) {
// just a dummy
}

View File

@ -22,6 +22,8 @@ const (
LastLoginSuffix = "nb_last_login" LastLoginSuffix = "nb_last_login"
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation // Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
Invited = "nb_invited" Invited = "nb_invited"
// IsToken claim indicates that auth type from the user is a token
IsToken = "is_token"
) )
// ExtractClaims Extract function type // ExtractClaims Extract function type

View File

@ -72,15 +72,19 @@ type JSONWebKey struct {
X5c []string `json:"x5c"` X5c []string `json:"x5c"`
} }
// JWTValidator struct to handle token validation and parsing type JWTValidator interface {
type JWTValidator struct { ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error)
}
// jwtValidatorImpl struct to handle token validation and parsing
type jwtValidatorImpl struct {
options Options options Options
} }
var keyNotFound = errors.New("unable to find appropriate key") var keyNotFound = errors.New("unable to find appropriate key")
// NewJWTValidator constructor // NewJWTValidator constructor
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) { func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) {
keys, err := getPemKeys(ctx, keysLocation) keys, err := getPemKeys(ctx, keysLocation)
if err != nil { if err != nil {
return nil, err return nil, err
@ -146,13 +150,13 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
options.UserProperty = "user" options.UserProperty = "user"
} }
return &JWTValidator{ return &jwtValidatorImpl{
options: options, options: options,
}, nil }, nil
} }
// ValidateAndParse validates the token and returns the parsed token // ValidateAndParse validates the token and returns the parsed token
func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { func (m *jwtValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
// If the token is empty... // If the token is empty...
if token == "" { if token == "" {
// Check if it was required // Check if it was required
@ -318,3 +322,28 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
return 0 return 0
} }
type JwtValidatorMock struct{}
func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
claimMaps := jwt.MapClaims{}
switch token {
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
claimMaps[UserIDClaim] = token
claimMaps[AccountIDSuffix] = "testAccountId"
claimMaps[DomainIDSuffix] = "test.com"
claimMaps[DomainCategorySuffix] = "private"
case "otherUserId":
claimMaps[UserIDClaim] = "otherUserId"
claimMaps[AccountIDSuffix] = "otherAccountId"
claimMaps[DomainIDSuffix] = "other.com"
claimMaps[DomainCategorySuffix] = "private"
case "invalidToken":
return nil, errors.New("invalid token")
}
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
return jwtToken, nil
}

View File

@ -21,13 +21,10 @@ import (
"github.com/netbirdio/netbird/encryption" "github.com/netbirdio/netbird/encryption"
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@ -448,43 +445,6 @@ var _ = Describe("Management service", func() {
}) })
}) })
type MocIntegratedValidator struct {
}
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
return nil
}
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
return update, false, nil
}
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
validatedPeers := make(map[string]struct{})
for p := range peers {
validatedPeers[p] = struct{}{}
}
return validatedPeers, nil
}
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
return peer
}
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
return false, false, nil
}
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
return nil
}
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
}
func (MocIntegratedValidator) Stop(_ context.Context) {}
func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse {
defer GinkgoRecover() defer GinkgoRecover()
@ -547,7 +507,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc.
log.Fatalf("failed creating metrics: %v", err) log.Fatalf("failed creating metrics: %v", err)
} }
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, server.MocIntegratedValidator{}, metrics)
if err != nil { if err != nil {
log.Fatalf("failed creating a manager: %v", err) log.Fatalf("failed creating a manager: %v", err)
} }

View File

@ -195,6 +195,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
groups int groups int
routes int routes int
routesWithRGGroups int routesWithRGGroups int
networks int
networkResources int
networkRouters int
networkRoutersWithPG int
nameservers int nameservers int
uiClient int uiClient int
version string version string
@ -219,6 +223,16 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
} }
groups += len(account.Groups) groups += len(account.Groups)
networks += len(account.Networks)
networkResources += len(account.NetworkResources)
networkRouters += len(account.NetworkRouters)
for _, router := range account.NetworkRouters {
if len(router.PeerGroups) > 0 {
networkRoutersWithPG++
}
}
routes += len(account.Routes) routes += len(account.Routes)
for _, route := range account.Routes { for _, route := range account.Routes {
if len(route.PeerGroups) > 0 { if len(route.PeerGroups) > 0 {
@ -312,6 +326,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks
metricsProperties["posture_checks"] = postureChecks metricsProperties["posture_checks"] = postureChecks
metricsProperties["groups"] = groups metricsProperties["groups"] = groups
metricsProperties["networks"] = networks
metricsProperties["network_resources"] = networkResources
metricsProperties["network_routers"] = networkRouters
metricsProperties["network_routers_with_groups"] = networkRoutersWithPG
metricsProperties["routes"] = routes metricsProperties["routes"] = routes
metricsProperties["routes_with_routing_groups"] = routesWithRGGroups metricsProperties["routes_with_routing_groups"] = routesWithRGGroups
metricsProperties["nameservers"] = nameservers metricsProperties["nameservers"] = nameservers

View File

@ -5,6 +5,9 @@ import (
"testing" "testing"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@ -172,6 +175,31 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
}, },
}, },
}, },
Networks: []*networkTypes.Network{
{
ID: "1",
AccountID: "1",
},
},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: "1",
AccountID: "1",
NetworkID: "1",
},
{
ID: "2",
AccountID: "1",
NetworkID: "1",
},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{
ID: "1",
AccountID: "1",
NetworkID: "1",
},
},
}, },
} }
} }
@ -200,6 +228,15 @@ func TestGenerateProperties(t *testing.T) {
if properties["routes"] != 2 { if properties["routes"] != 2 {
t.Errorf("expected 2 routes, got %d", properties["routes"]) t.Errorf("expected 2 routes, got %d", properties["routes"])
} }
if properties["networks"] != 1 {
t.Errorf("expected 1 networks, got %d", properties["networks"])
}
if properties["network_resources"] != 2 {
t.Errorf("expected 2 network_resources, got %d", properties["network_resources"])
}
if properties["network_routers"] != 1 {
t.Errorf("expected 1 network_routers, got %d", properties["network_routers"])
}
if properties["rules"] != 4 { if properties["rules"] != 4 {
t.Errorf("expected 4 rules, got %d", properties["rules"]) t.Errorf("expected 4 rules, got %d", properties["rules"])
} }

View File

@ -32,6 +32,9 @@ type managerImpl struct {
routersManager routers.Manager routersManager routers.Manager
} }
type mockManager struct {
}
func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager s.AccountManager) Manager { func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager s.AccountManager) Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
@ -185,3 +188,27 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
return nil return nil
} }
func NewManagerMock() Manager {
return &mockManager{}
}
func (m *mockManager) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) {
return []*types.Network{}, nil
}
func (m *mockManager) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
return network, nil
}
func (m *mockManager) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) {
return &types.Network{}, nil
}
func (m *mockManager) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
return network, nil
}
func (m *mockManager) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
return nil
}

View File

@ -34,6 +34,9 @@ type managerImpl struct {
accountManager s.AccountManager accountManager s.AccountManager
} }
type mockManager struct {
}
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager s.AccountManager) Manager { func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager s.AccountManager) Manager {
return &managerImpl{ return &managerImpl{
store: store, store: store,
@ -381,3 +384,39 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti
return eventsToStore, nil return eventsToStore, nil
} }
func NewManagerMock() Manager {
return &mockManager{}
}
func (m *mockManager) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) {
return []*types.NetworkResource{}, nil
}
func (m *mockManager) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
return []*types.NetworkResource{}, nil
}
func (m *mockManager) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
return map[string][]string{}, nil
}
func (m *mockManager) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
return &types.NetworkResource{}, nil
}
func (m *mockManager) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) {
return &types.NetworkResource{}, nil
}
func (m *mockManager) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
return &types.NetworkResource{}, nil
}
func (m *mockManager) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error {
return nil
}
func (m *mockManager) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) {
return []func(){}, nil
}

View File

@ -111,6 +111,7 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network
NetID: route.NetID(n.Name), NetID: route.NetID(n.Name),
Description: n.Description, Description: n.Description,
Peer: peer.Key, Peer: peer.Key,
PeerID: peer.ID,
PeerGroups: nil, PeerGroups: nil,
Masquerade: router.Masquerade, Masquerade: router.Masquerade,
Metric: router.Metric, Metric: router.Metric,

View File

@ -932,11 +932,11 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}{ }{
{"Small", 50, 5, 90, 120, 90, 120}, {"Small", 50, 5, 90, 120, 90, 120},
{"Medium", 500, 100, 110, 150, 120, 260}, {"Medium", 500, 100, 110, 150, 120, 260},
{"Large", 5000, 200, 800, 1390, 2500, 4600}, {"Large", 5000, 200, 800, 1700, 2500, 5000},
{"Small single", 50, 10, 90, 120, 90, 120}, {"Small single", 50, 10, 90, 120, 90, 120},
{"Medium single", 500, 10, 110, 170, 120, 200}, {"Medium single", 500, 10, 110, 170, 120, 200},
{"Large 5", 5000, 15, 1300, 2100, 5000, 7000}, {"Large 5", 5000, 15, 1300, 2100, 4900, 7000},
{"Extra Large", 2000, 2000, 1300, 2100, 4000, 6000}, {"Extra Large", 2000, 2000, 1300, 2400, 4000, 6400},
} }
log.SetOutput(io.Discard) log.SetOutput(io.Discard)

View File

@ -74,6 +74,19 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
"peerH", "peerH",
}, },
}, },
"GroupWorkstations": {
ID: "GroupWorkstations",
Name: "GroupWorkstations",
Peers: []string{
"peerB",
"peerA",
"peerD",
"peerE",
"peerF",
"peerG",
"peerH",
},
},
"GroupSwarm": { "GroupSwarm": {
ID: "GroupSwarm", ID: "GroupSwarm",
Name: "swarm", Name: "swarm",
@ -127,7 +140,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
Action: types.PolicyTrafficActionAccept, Action: types.PolicyTrafficActionAccept,
Sources: []string{ Sources: []string{
"GroupSwarm", "GroupSwarm",
"GroupAll", "GroupWorkstations",
}, },
Destinations: []string{ Destinations: []string{
"GroupSwarm", "GroupSwarm",
@ -159,6 +172,8 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])
assert.Contains(t, peers, account.Peers["peerE"]) assert.Contains(t, peers, account.Peers["peerE"])
assert.Contains(t, peers, account.Peers["peerF"]) assert.Contains(t, peers, account.Peers["peerF"])
assert.Contains(t, peers, account.Peers["peerG"])
assert.Contains(t, peers, account.Peers["peerH"])
epectedFirewallRules := []*types.FirewallRule{ epectedFirewallRules := []*types.FirewallRule{
{ {
@ -189,21 +204,6 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
Protocol: "all", Protocol: "all",
Port: "", Port: "",
}, },
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
},
{ {
PeerIP: "100.65.62.5", PeerIP: "100.65.62.5",
Direction: types.FirewallRuleDirectionOUT, Direction: types.FirewallRuleDirectionOUT,
@ -280,10 +280,16 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(epectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) for _, rule := range firewallRules {
for i := range firewallRules { contains := false
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) for _, expectedRule := range epectedFirewallRules {
if rule.IsEqual(expectedRule) {
contains = true
break
}
}
assert.True(t, contains, "rule not found in expected rules %#v", rule)
} }
}) })
} }

View File

@ -364,7 +364,7 @@ func toProtocolRoute(route *route.Route) *proto.Route {
} }
func toProtocolRoutes(routes []*route.Route) []*proto.Route { func toProtocolRoutes(routes []*route.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0) protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes { for _, r := range routes {
protoRoutes = append(protoRoutes, toProtocolRoute(r)) protoRoutes = append(protoRoutes, toProtocolRoute(r))
} }

View File

@ -75,7 +75,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil {
return err return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
} }
setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
@ -132,7 +132,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil {
return err return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
} }
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id) oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id)

View File

@ -303,55 +303,47 @@ func (a *Account) GetPeerNetworkMap(
return nm return nm
} }
func (a *Account) addNetworksRoutingPeers(networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers []string) []*nbpeer.Peer { func (a *Account) addNetworksRoutingPeers(
missingPeers := map[string]struct{}{} networkResourcesRoutes []*route.Route,
peer *nbpeer.Peer,
peersToConnect []*nbpeer.Peer,
expiredPeers []*nbpeer.Peer,
isRouter bool,
sourcePeers map[string]struct{},
) []*nbpeer.Peer {
networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
for _, r := range networkResourcesRoutes { for _, r := range networkResourcesRoutes {
if r.Peer == peer.Key { networkRoutesPeers[r.PeerID] = struct{}{}
continue
} }
missing := true delete(sourcePeers, peer.ID)
for _, p := range slices.Concat(peersToConnect, expiredPeers) {
if r.Peer == p.Key { for _, existingPeer := range peersToConnect {
missing = false delete(sourcePeers, existingPeer.ID)
break delete(networkRoutesPeers, existingPeer.ID)
}
}
if missing {
missingPeers[r.Peer] = struct{}{}
} }
for _, expPeer := range expiredPeers {
delete(sourcePeers, expPeer.ID)
delete(networkRoutesPeers, expPeer.ID)
} }
missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
if isRouter { if isRouter {
for _, s := range sourcePeers { for p := range sourcePeers {
if s == peer.ID { missingPeers[p] = struct{}{}
continue
}
missing := true
for _, p := range slices.Concat(peersToConnect, expiredPeers) {
if s == p.ID {
missing = false
break
}
}
if missing {
p, ok := a.Peers[s]
if ok {
missingPeers[p.Key] = struct{}{}
}
} }
} }
for p := range networkRoutesPeers {
missingPeers[p] = struct{}{}
} }
for p := range missingPeers { for p := range missingPeers {
for _, p2 := range a.Peers { if missingPeer := a.Peers[p]; missingPeer != nil {
if p2.Key == p { peersToConnect = append(peersToConnect, missingPeer)
peersToConnect = append(peersToConnect, p2)
break
}
} }
} }
return peersToConnect return peersToConnect
} }
@ -1045,14 +1037,9 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs // for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups)
for _, g := range groups { filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
group, ok := a.Groups[g] for _, p := range uniquePeerIDs {
if !ok {
continue
}
for _, p := range group.Peers {
peer, ok := a.Peers[p] peer, ok := a.Peers[p]
if !ok || peer == nil { if !ok || peer == nil {
continue continue
@ -1075,7 +1062,7 @@ func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, pe
filteredPeers = append(filteredPeers, peer) filteredPeers = append(filteredPeers, peer)
} }
}
return filteredPeers, peerInGroups return filteredPeers, peerInGroups
} }
@ -1151,7 +1138,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
continue continue
} }
rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap)
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
fwRules = append(fwRules, rules...) fwRules = append(fwRules, rules...)
} }
@ -1159,7 +1146,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
return fwRules return fwRules
} }
func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
distPeersWithPolicy := make(map[string]struct{}) distPeersWithPolicy := make(map[string]struct{})
for _, id := range rule.Sources { for _, id := range rule.Sources {
group := a.Groups[id] group := a.Groups[id]
@ -1173,7 +1160,7 @@ func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeer
} }
_, distPeer := distributionPeers[pID] _, distPeer := distributionPeers[pID]
_, valid := validatedPeersMap[pID] _, valid := validatedPeersMap[pID]
if distPeer && valid { if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) {
distPeersWithPolicy[pID] = struct{}{} distPeersWithPolicy[pID] = struct{}{}
} }
} }
@ -1271,7 +1258,11 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer
distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups)
rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers)
routesFirewallRules = append(routesFirewallRules, rules...) for _, rule := range rules {
if len(rule.SourceRanges) > 0 {
routesFirewallRules = append(routesFirewallRules, rule)
}
}
} }
return routesFirewallRules return routesFirewallRules
@ -1303,10 +1294,10 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
} }
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers. // GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) { func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool var isRoutingPeer bool
var routes []*route.Route var routes []*route.Route
var allSourcePeers []string allSourcePeers := make(map[string]struct{}, len(a.Peers))
for _, resource := range a.NetworkResources { for _, resource := range a.NetworkResources {
var addSourcePeers bool var addSourcePeers bool
@ -1319,23 +1310,22 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
} }
} }
addedResourceRoute := false
for _, policy := range resourcePolicies[resource.ID] { for _, policy := range resourcePolicies[resource.ID] {
for _, sourceGroup := range policy.SourceGroups() { peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
group := a.GetGroup(sourceGroup)
if group == nil {
log.WithContext(ctx).Warnf("policy %s has source group %s that doesn't exist under account %s, will continue map generation without it", policy.ID, sourceGroup, a.Id)
continue
}
// routing peer should be able to connect with all source peers
if addSourcePeers { if addSourcePeers {
allSourcePeers = append(allSourcePeers, group.Peers...) for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
} else if slices.Contains(group.Peers, peerID) { allSourcePeers[pID] = struct{}{}
}
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
// add routes for the resource if the peer is in the distribution group // add routes for the resource if the peer is in the distribution group
for peerId, router := range networkRoutingPeers { for peerId, router := range networkRoutingPeers {
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
} }
addedResourceRoute = true
} }
if addedResourceRoute {
break
} }
} }
} }
@ -1343,6 +1333,42 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
return isRoutingPeer, routes, allSourcePeers return isRoutingPeer, routes, allSourcePeers
} }
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
var dest []string
for _, peerID := range inputPeers {
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
dest = append(dest, peerID)
}
}
return dest
}
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
for _, groupID := range groups {
group := a.GetGroup(groupID)
if group == nil {
log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id)
continue
}
if group.IsGroupAll() || len(groups) == 1 {
return group.Peers
}
for _, peerID := range group.Peers {
peerIDs[peerID] = struct{}{}
}
}
ids := make([]string, 0, len(peerIDs))
for peerID := range peerIDs {
ids = append(ids, peerID)
}
return ids
}
// getNetworkResources filters and returns a list of network resources associated with the given network ID. // getNetworkResources filters and returns a list of network resources associated with the given network ID.
func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource { func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource {
var resources []*resourceTypes.NetworkResource var resources []*resourceTypes.NetworkResource

View File

@ -1,14 +1,20 @@
package types package types
import ( import (
"context"
"net"
"net/netip"
"slices"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -310,19 +316,19 @@ func Test_GetResourcePoliciesMap(t *testing.T) {
func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) { func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) {
account := setupTestAccount() account := setupTestAccount()
peer := &nbpeer.Peer{Key: "peer1"} peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"}
networkResourcesRoutes := []*route.Route{ networkResourcesRoutes := []*route.Route{
{Peer: "peer2Key"}, {Peer: "peer2Key", PeerID: "peer2"},
{Peer: "peer3Key"}, {Peer: "peer3Key", PeerID: "peer3"},
} }
peersToConnect := []*nbpeer.Peer{ peersToConnect := []*nbpeer.Peer{
{Key: "peer2Key"}, {Key: "peer2Key", ID: "peer2"},
} }
expiredPeers := []*nbpeer.Peer{ expiredPeers := []*nbpeer.Peer{
{Key: "peer4Key"}, {Key: "peer4Key", ID: "peer4"},
} }
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 2) require.Len(t, result, 2)
require.Equal(t, "peer2Key", result[0].Key) require.Equal(t, "peer2Key", result[0].Key)
require.Equal(t, "peer3Key", result[1].Key) require.Equal(t, "peer3Key", result[1].Key)
@ -339,7 +345,7 @@ func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) {
} }
expiredPeers := []*nbpeer.Peer{} expiredPeers := []*nbpeer.Peer{}
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 1) require.Len(t, result, 1)
require.Equal(t, "peer2Key", result[0].Key) require.Equal(t, "peer2Key", result[0].Key)
} }
@ -358,7 +364,7 @@ func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) {
{Key: "peer3Key"}, {Key: "peer3Key"},
} }
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 1) require.Len(t, result, 1)
require.Equal(t, "peer2Key", result[0].Key) require.Equal(t, "peer2Key", result[0].Key)
} }
@ -370,6 +376,382 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) {
peersToConnect := []*nbpeer.Peer{} peersToConnect := []*nbpeer.Peer{}
expiredPeers := []*nbpeer.Peer{} expiredPeers := []*nbpeer.Peer{}
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 0) require.Len(t, result, 0)
} }
const (
accID = "accountID"
network1ID = "network1ID"
group1ID = "group1"
accNetResourcePeer1ID = "peer1"
accNetResourcePeer2ID = "peer2"
accNetResourceRouter1ID = "router1"
accNetResource1ID = "resource1ID"
accNetResourceRestrictPostureCheckID = "restrictPostureCheck"
accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck"
accNetResourceLockedPostureCheckID = "lockedPostureCheck"
accNetResourceLinuxPostureCheckID = "linuxPostureCheck"
)
var (
accNetResourcePeer1IP = net.IP{192, 168, 1, 1}
accNetResourcePeer2IP = net.IP{192, 168, 1, 2}
accNetResourceRouter1IP = net.IP{192, 168, 1, 3}
accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}}
)
func getBasicAccountsWithResource() *Account {
return &Account{
Id: accID,
Peers: map[string]*nbpeer.Peer{
accNetResourcePeer1ID: {
ID: accNetResourcePeer1ID,
AccountID: accID,
Key: "peer1Key",
IP: accNetResourcePeer1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
accNetResourcePeer2ID: {
ID: accNetResourcePeer2ID,
AccountID: accID,
Key: "peer2Key",
IP: accNetResourcePeer2IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "windows",
WtVersion: "0.34.1",
KernelVersion: "4.4.0",
},
},
accNetResourceRouter1ID: {
ID: accNetResourceRouter1ID,
AccountID: accID,
Key: "router1Key",
IP: accNetResourceRouter1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
},
Groups: map[string]*Group{
group1ID: {
ID: group1ID,
Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID},
},
},
Networks: []*networkTypes.Network{
{
ID: network1ID,
AccountID: accID,
Name: "network1",
},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{
ID: accNetResourceRouter1ID,
NetworkID: network1ID,
AccountID: accID,
Peer: accNetResourceRouter1ID,
PeerGroups: []string{},
Masquerade: false,
Metric: 100,
},
},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: accNetResource1ID,
AccountID: accID,
NetworkID: network1ID,
Address: "10.10.10.0/24",
Prefix: netip.MustParsePrefix("10.10.10.0/24"),
Type: resourceTypes.NetworkResourceType("subnet"),
},
},
Policies: []*Policy{
{
ID: "policy1ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "rule1ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"80"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: nil,
},
},
PostureChecks: []*posture.Checks{
{
ID: accNetResourceRestrictPostureCheckID,
Name: accNetResourceRestrictPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.35.0",
},
},
},
{
ID: accNetResourceRelaxedPostureCheckID,
Name: accNetResourceRelaxedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
},
{
ID: accNetResourceLockedPostureCheckID,
Name: accNetResourceLockedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "7.7.7",
},
},
},
{
ID: accNetResourceLinuxPostureCheckID,
Name: accNetResourceLinuxPostureCheckID,
Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "0.0.0"},
},
},
},
},
}
}
func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// all peers should match the policy
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should not match any peer
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 0, "expected rules count don't match")
}
func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// should allow peer1 and peer2 to match the policy
newPolicy := &Policy{
ID: "policy2ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "policy2ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"22"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID},
}
account.Policies = append(account.Policies, newPolicy)
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 2, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
assert.Equal(t, uint16(22), rules[1].Port, "should have port 22")
assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp")
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// two posture checks should match only the peers that match both checks
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}

View File

@ -35,6 +35,15 @@ type FirewallRule struct {
Port string Port string
} }
// IsEqual checks if two firewall rules are equal.
func (r *FirewallRule) IsEqual(other *FirewallRule) bool {
return r.PeerIP == other.PeerIP &&
r.Direction == other.Direction &&
r.Action == other.Action &&
r.Protocol == other.Protocol &&
r.Port == other.Port
}
// generateRouteFirewallRules generates a list of firewall rules for a given route. // generateRouteFirewallRules generates a list of firewall rules for a given route.
func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
rulesExists := make(map[string]struct{}) rulesExists := make(map[string]struct{})

View File

@ -117,9 +117,20 @@ func (p *Policy) RuleGroups() []string {
// SourceGroups returns a slice of all unique source groups referenced in the policy's rules. // SourceGroups returns a slice of all unique source groups referenced in the policy's rules.
func (p *Policy) SourceGroups() []string { func (p *Policy) SourceGroups() []string {
groups := make([]string, 0) if len(p.Rules) == 1 {
for _, rule := range p.Rules { return p.Rules[0].Sources
groups = append(groups, rule.Sources...)
} }
return groups groups := make(map[string]struct{}, len(p.Rules))
for _, rule := range p.Rules {
for _, source := range rule.Sources {
groups[source] = struct{}{}
}
}
groupIDs := make([]string, 0, len(groups))
for groupID := range groups {
groupIDs = append(groupIDs, groupID)
}
return groupIDs
} }

View File

@ -95,6 +95,7 @@ type Route struct {
NetID NetID NetID NetID
Description string Description string
Peer string Peer string
PeerID string `gorm:"-"`
PeerGroups []string `gorm:"serializer:json"` PeerGroups []string `gorm:"serializer:json"`
NetworkType NetworkType NetworkType NetworkType
Masquerade bool Masquerade bool
@ -120,6 +121,7 @@ func (r *Route) Copy() *Route {
KeepRoute: r.KeepRoute, KeepRoute: r.KeepRoute,
NetworkType: r.NetworkType, NetworkType: r.NetworkType,
Peer: r.Peer, Peer: r.Peer,
PeerID: r.PeerID,
PeerGroups: slices.Clone(r.PeerGroups), PeerGroups: slices.Clone(r.PeerGroups),
Metric: r.Metric, Metric: r.Metric,
Masquerade: r.Masquerade, Masquerade: r.Masquerade,
@ -146,6 +148,7 @@ func (r *Route) IsEqual(other *Route) bool {
other.KeepRoute == r.KeepRoute && other.KeepRoute == r.KeepRoute &&
other.NetworkType == r.NetworkType && other.NetworkType == r.NetworkType &&
other.Peer == r.Peer && other.Peer == r.Peer &&
other.PeerID == r.PeerID &&
other.Metric == r.Metric && other.Metric == r.Metric &&
other.Masquerade == r.Masquerade && other.Masquerade == r.Masquerade &&
other.Enabled == r.Enabled && other.Enabled == r.Enabled &&