mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-20 09:47:49 +02:00
Merge branch 'main' into userspace-router
This commit is contained in:
commit
c3c6afa37b
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
only_warn: 1
|
only_warn: 1
|
||||||
golangci:
|
golangci:
|
||||||
|
@ -197,7 +197,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
net.ParseIP("0.0.0.0"),
|
net.IP{0, 0, 0, 0},
|
||||||
"all",
|
"all",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
|
@ -68,17 +68,16 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||||
origPattern := pattern
|
origPattern := pattern
|
||||||
isWildcard := strings.HasPrefix(pattern, "*.")
|
isWildcard := strings.HasPrefix(pattern, "*.")
|
||||||
if isWildcard {
|
if isWildcard {
|
||||||
pattern = pattern[2:]
|
pattern = pattern[2:]
|
||||||
}
|
}
|
||||||
pattern = dns.Fqdn(pattern)
|
|
||||||
origPattern = dns.Fqdn(origPattern)
|
|
||||||
|
|
||||||
// First remove any existing handler with same original pattern and priority
|
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority {
|
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||||
if c.handlers[i].StopHandler != nil {
|
if c.handlers[i].StopHandler != nil {
|
||||||
c.handlers[i].StopHandler.stop()
|
c.handlers[i].StopHandler.stop()
|
||||||
}
|
}
|
||||||
@ -126,10 +125,10 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
|||||||
|
|
||||||
pattern = dns.Fqdn(pattern)
|
pattern = dns.Fqdn(pattern)
|
||||||
|
|
||||||
// Find and remove handlers matching both original pattern and priority
|
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
entry := c.handlers[i]
|
entry := c.handlers[i]
|
||||||
if entry.OrigPattern == pattern && entry.Priority == priority {
|
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||||
if entry.StopHandler != nil {
|
if entry.StopHandler != nil {
|
||||||
entry.StopHandler.stop()
|
entry.StopHandler.stop()
|
||||||
}
|
}
|
||||||
@ -144,9 +143,9 @@ func (c *HandlerChain) HasHandlers(pattern string) bool {
|
|||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
defer c.mu.RUnlock()
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
pattern = dns.Fqdn(pattern)
|
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||||
for _, entry := range c.handlers {
|
for _, entry := range c.handlers {
|
||||||
if entry.Pattern == pattern {
|
if strings.EqualFold(entry.Pattern, pattern) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -158,7 +157,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
qname := r.Question[0].Name
|
qname := strings.ToLower(r.Question[0].Name)
|
||||||
log.Tracef("handling DNS request for domain=%s", qname)
|
log.Tracef("handling DNS request for domain=%s", qname)
|
||||||
|
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
@ -187,9 +186,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
// If handler wants subdomain matching, allow suffix match
|
// If handler wants subdomain matching, allow suffix match
|
||||||
// Otherwise require exact match
|
// Otherwise require exact match
|
||||||
if entry.MatchSubdomains {
|
if entry.MatchSubdomains {
|
||||||
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||||
} else {
|
} else {
|
||||||
matched = qname == entry.Pattern
|
matched = strings.EqualFold(qname, entry.Pattern)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -507,5 +507,173 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
|
|
||||||
// Test 4: Remove last handler
|
// Test 4: Remove last handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||||
|
|
||||||
assert.False(t, chain.HasHandlers(testDomain))
|
assert.False(t, chain.HasHandlers(testDomain))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
scenario string
|
||||||
|
addHandlers []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}
|
||||||
|
query string
|
||||||
|
expectedCalls int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "case insensitive exact match",
|
||||||
|
scenario: "handler registered lowercase, query uppercase",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case insensitive wildcard match",
|
||||||
|
scenario: "handler registered mixed case wildcard, query different case",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"*.Example.Com.", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "sub.EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple handlers different case same domain",
|
||||||
|
scenario: "second handler should replace first despite case difference",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||||
|
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "ExAmPlE.cOm.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain matching case insensitive",
|
||||||
|
scenario: "handler with MatchSubdomains true should match regardless of case",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"example.com.", nbdns.PriorityDefault, true, true},
|
||||||
|
},
|
||||||
|
query: "SUB.EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone case insensitive",
|
||||||
|
scenario: "root zone handler should match regardless of case",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{".", nbdns.PriorityDefault, false, true},
|
||||||
|
},
|
||||||
|
query: "EXAMPLE.COM.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple handlers different priority",
|
||||||
|
scenario: "should call higher priority handler despite case differences",
|
||||||
|
addHandlers: []struct {
|
||||||
|
pattern string
|
||||||
|
priority int
|
||||||
|
subdomains bool
|
||||||
|
shouldMatch bool
|
||||||
|
}{
|
||||||
|
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||||
|
{"example.com.", nbdns.PriorityMatchDomain, false, false},
|
||||||
|
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
|
||||||
|
},
|
||||||
|
query: "example.com.",
|
||||||
|
expectedCalls: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
handlerCalls := make(map[string]bool) // track which patterns were called
|
||||||
|
|
||||||
|
// Add handlers according to test case
|
||||||
|
for _, h := range tt.addHandlers {
|
||||||
|
var handler dns.Handler
|
||||||
|
pattern := h.pattern // capture pattern for closure
|
||||||
|
|
||||||
|
if h.subdomains {
|
||||||
|
subHandler := &nbdns.MockSubdomainHandler{
|
||||||
|
Subdomains: true,
|
||||||
|
}
|
||||||
|
if h.shouldMatch {
|
||||||
|
subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
handlerCalls[pattern] = true
|
||||||
|
w := args.Get(0).(dns.ResponseWriter)
|
||||||
|
r := args.Get(1).(*dns.Msg)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeSuccess)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
}
|
||||||
|
handler = subHandler
|
||||||
|
} else {
|
||||||
|
mockHandler := &nbdns.MockHandler{}
|
||||||
|
if h.shouldMatch {
|
||||||
|
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||||
|
handlerCalls[pattern] = true
|
||||||
|
w := args.Get(0).(dns.ResponseWriter)
|
||||||
|
r := args.Get(1).(*dns.Msg)
|
||||||
|
resp := new(dns.Msg)
|
||||||
|
resp.SetRcode(r, dns.RcodeSuccess)
|
||||||
|
assert.NoError(t, w.WriteMsg(resp))
|
||||||
|
}).Once()
|
||||||
|
}
|
||||||
|
handler = mockHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
chain.AddHandler(pattern, handler, h.priority, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute request
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
|
chain.ServeDNS(&mockResponseWriter{}, r)
|
||||||
|
|
||||||
|
// Verify each handler was called exactly as expected
|
||||||
|
for _, h := range tt.addHandlers {
|
||||||
|
wasCalled := handlerCalls[h.pattern]
|
||||||
|
assert.Equal(t, h.shouldMatch, wasCalled,
|
||||||
|
"Handler for pattern %q was %s when it should%s have been",
|
||||||
|
h.pattern,
|
||||||
|
map[bool]string{true: "called", false: "not called"}[wasCalled],
|
||||||
|
map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify total number of calls
|
||||||
|
assert.Equal(t, tt.expectedCalls, len(handlerCalls),
|
||||||
|
"Wrong number of total handler calls")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -83,7 +83,7 @@ func (h *Manager) allowDNSFirewall() error {
|
|||||||
IsRange: false,
|
IsRange: false,
|
||||||
Values: []int{ListenPort},
|
Values: []int{ListenPort},
|
||||||
}
|
}
|
||||||
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||||
return err
|
return err
|
||||||
|
@ -410,13 +410,9 @@ func (e *Engine) Start() error {
|
|||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
log.Errorf("failed creating firewall manager: %s", err)
|
||||||
}
|
} else if e.firewall != nil {
|
||||||
|
if err := e.initFirewall(err); err != nil {
|
||||||
if e.firewall != nil && e.firewall.IsServerRouteSupported() {
|
return err
|
||||||
err = e.routeManager.EnableServerRouter(e.firewall)
|
|
||||||
if err != nil {
|
|
||||||
e.close()
|
|
||||||
return fmt.Errorf("enable server router: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -459,6 +455,41 @@ func (e *Engine) Start() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) initFirewall(error) error {
|
||||||
|
if e.firewall.IsServerRouteSupported() {
|
||||||
|
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
||||||
|
e.close()
|
||||||
|
return fmt.Errorf("enable server router: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.rpManager == nil || !e.config.RosenpassEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rosenpassPort := e.rpManager.GetAddress().Port
|
||||||
|
port := manager.Port{Values: []int{rosenpassPort}}
|
||||||
|
|
||||||
|
// this rule is static and will be torn down on engine down by the firewall manager
|
||||||
|
if _, err := e.firewall.AddPeerFiltering(
|
||||||
|
net.IP{0, 0, 0, 0},
|
||||||
|
manager.ProtocolUDP,
|
||||||
|
nil,
|
||||||
|
&port,
|
||||||
|
manager.RuleDirectionIN,
|
||||||
|
manager.ActionAccept,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
); err != nil {
|
||||||
|
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("rosenpass interface traffic allowed on port %d", rosenpassPort)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
|
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
|
||||||
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
|
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
|
||||||
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||||
|
@ -42,7 +42,7 @@ import (
|
|||||||
nbContext "github.com/netbirdio/netbird/management/server/context"
|
nbContext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
httpapi "github.com/netbirdio/netbird/management/server/http"
|
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||||
"github.com/netbirdio/netbird/management/server/http/configs"
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
@ -281,7 +281,7 @@ var (
|
|||||||
routersManager := routers.NewManager(store, permissionsManager, accountManager)
|
routersManager := routers.NewManager(store, permissionsManager, accountManager)
|
||||||
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
|
networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager)
|
||||||
|
|
||||||
httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
|
httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
return fmt.Errorf("failed creating HTTP API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -161,7 +161,7 @@ type DefaultAccountManager struct {
|
|||||||
externalCacheManager ExternalCacheManager
|
externalCacheManager ExternalCacheManager
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
eventStore activity.Store
|
eventStore activity.Store
|
||||||
geo *geolocation.Geolocation
|
geo geolocation.Geolocation
|
||||||
|
|
||||||
requestBuffer *AccountRequestBuffer
|
requestBuffer *AccountRequestBuffer
|
||||||
|
|
||||||
@ -244,7 +244,7 @@ func BuildManager(
|
|||||||
singleAccountModeDomain string,
|
singleAccountModeDomain string,
|
||||||
dnsDomain string,
|
dnsDomain string,
|
||||||
eventStore activity.Store,
|
eventStore activity.Store,
|
||||||
geo *geolocation.Geolocation,
|
geo geolocation.Geolocation,
|
||||||
userDeleteFromIDPEnabled bool,
|
userDeleteFromIDPEnabled bool,
|
||||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||||
metrics telemetry.AppMetrics,
|
metrics telemetry.AppMetrics,
|
||||||
@ -1252,6 +1252,12 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
|
|||||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||||
// and propagates changes to peers if group propagation is enabled.
|
// and propagates changes to peers if group propagation is enabled.
|
||||||
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
|
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
|
||||||
|
if claim, exists := claims.Raw[jwtclaims.IsToken]; exists {
|
||||||
|
if isToken, ok := claim.(bool); ok && isToken {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -27,7 +27,6 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
@ -38,47 +37,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MocIntegratedValidator struct {
|
|
||||||
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
|
||||||
if a.ValidatePeerFunc != nil {
|
|
||||||
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
|
|
||||||
}
|
|
||||||
return update, false, nil
|
|
||||||
}
|
|
||||||
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
|
|
||||||
validatedPeers := make(map[string]struct{})
|
|
||||||
for _, peer := range peers {
|
|
||||||
validatedPeers[peer.ID] = struct{}{}
|
|
||||||
}
|
|
||||||
return validatedPeers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
|
|
||||||
return peer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
|
|
||||||
return false, false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (MocIntegratedValidator) Stop(_ context.Context) {
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) {
|
func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
peer := &nbpeer.Peer{
|
peer := &nbpeer.Peer{
|
||||||
@ -2729,6 +2687,19 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
|
|
||||||
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
|
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
|
||||||
|
|
||||||
|
t.Run("skip sync for token auth type", func(t *testing.T) {
|
||||||
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
|
UserId: "user1",
|
||||||
|
Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true},
|
||||||
|
}
|
||||||
|
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||||
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
|
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||||
|
assert.NoError(t, err, "unable to get user")
|
||||||
|
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("empty jwt groups", func(t *testing.T) {
|
t.Run("empty jwt groups", func(t *testing.T) {
|
||||||
claims := jwtclaims.AuthorizationClaims{
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
UserId: "user1",
|
UserId: "user1",
|
||||||
@ -2822,7 +2793,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("remove all JWT groups", func(t *testing.T) {
|
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
|
||||||
claims := jwtclaims.AuthorizationClaims{
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
UserId: "user1",
|
UserId: "user1",
|
||||||
Raw: jwt.MapClaims{"groups": []interface{}{}},
|
Raw: jwt.MapClaims{"groups": []interface{}{}},
|
||||||
@ -2835,6 +2806,19 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
|
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
|
||||||
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
|
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
|
||||||
|
claims := jwtclaims.AuthorizationClaims{
|
||||||
|
UserId: "user2",
|
||||||
|
Raw: jwt.MapClaims{},
|
||||||
|
}
|
||||||
|
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||||
|
assert.NoError(t, err, "unable to sync jwt groups")
|
||||||
|
|
||||||
|
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
|
||||||
|
assert.NoError(t, err, "unable to get user")
|
||||||
|
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAccount_UserGroupsAddToPeers(t *testing.T) {
|
func TestAccount_UserGroupsAddToPeers(t *testing.T) {
|
||||||
@ -3037,9 +3021,9 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
|||||||
minMsPerOpCICD float64
|
minMsPerOpCICD float64
|
||||||
maxMsPerOpCICD float64
|
maxMsPerOpCICD float64
|
||||||
}{
|
}{
|
||||||
{"Small", 50, 5, 1, 3, 3, 10},
|
{"Small", 50, 5, 1, 3, 3, 11},
|
||||||
{"Medium", 500, 100, 7, 13, 10, 70},
|
{"Medium", 500, 100, 7, 13, 10, 70},
|
||||||
{"Large", 5000, 200, 65, 80, 60, 200},
|
{"Large", 5000, 200, 65, 80, 60, 220},
|
||||||
{"Small single", 50, 10, 1, 3, 3, 70},
|
{"Small single", 50, 10, 1, 3, 3, 70},
|
||||||
{"Medium single", 500, 10, 7, 13, 10, 26},
|
{"Medium single", 500, 10, 7, 13, 10, 26},
|
||||||
{"Large 5", 5000, 15, 65, 80, 60, 200},
|
{"Large 5", 5000, 15, 65, 80, 60, 200},
|
||||||
@ -3179,7 +3163,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
|||||||
maxMsPerOpCICD float64
|
maxMsPerOpCICD float64
|
||||||
}{
|
}{
|
||||||
{"Small", 50, 5, 107, 120, 107, 160},
|
{"Small", 50, 5, 107, 120, 107, 160},
|
||||||
{"Medium", 500, 100, 105, 140, 105, 190},
|
{"Medium", 500, 100, 105, 140, 105, 220},
|
||||||
{"Large", 5000, 200, 180, 220, 180, 350},
|
{"Large", 5000, 200, 180, 220, 180, 350},
|
||||||
{"Small single", 50, 10, 107, 120, 105, 160},
|
{"Small single", 50, 10, 107, 120, 105, 160},
|
||||||
{"Medium single", 500, 10, 105, 140, 105, 170},
|
{"Medium single", 500, 10, 105, 140, 105, 170},
|
||||||
|
@ -14,7 +14,14 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Geolocation struct {
|
type Geolocation interface {
|
||||||
|
Lookup(ip net.IP) (*Record, error)
|
||||||
|
GetAllCountries() ([]Country, error)
|
||||||
|
GetCitiesByCountry(countryISOCode string) ([]City, error)
|
||||||
|
Stop() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type geolocationImpl struct {
|
||||||
mmdbPath string
|
mmdbPath string
|
||||||
mux sync.RWMutex
|
mux sync.RWMutex
|
||||||
db *maxminddb.Reader
|
db *maxminddb.Reader
|
||||||
@ -54,7 +61,7 @@ const (
|
|||||||
geonamesdbPattern = "geonames_*.db"
|
geonamesdbPattern = "geonames_*.db"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geolocation, error) {
|
func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (Geolocation, error) {
|
||||||
mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern)
|
mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern)
|
||||||
mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate)
|
mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -86,7 +93,7 @@ func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geol
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
geo := &Geolocation{
|
geo := &geolocationImpl{
|
||||||
mmdbPath: mmdbPath,
|
mmdbPath: mmdbPath,
|
||||||
mux: sync.RWMutex{},
|
mux: sync.RWMutex{},
|
||||||
db: db,
|
db: db,
|
||||||
@ -113,7 +120,7 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) {
|
|||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
|
func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) {
|
||||||
gl.mux.RLock()
|
gl.mux.RLock()
|
||||||
defer gl.mux.RUnlock()
|
defer gl.mux.RUnlock()
|
||||||
|
|
||||||
@ -127,7 +134,7 @@ func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetAllCountries retrieves a list of all countries.
|
// GetAllCountries retrieves a list of all countries.
|
||||||
func (gl *Geolocation) GetAllCountries() ([]Country, error) {
|
func (gl *geolocationImpl) GetAllCountries() ([]Country, error) {
|
||||||
allCountries, err := gl.locationDB.GetAllCountries()
|
allCountries, err := gl.locationDB.GetAllCountries()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -143,7 +150,7 @@ func (gl *Geolocation) GetAllCountries() ([]Country, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code.
|
// GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code.
|
||||||
func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) {
|
func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, error) {
|
||||||
allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode)
|
allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -158,7 +165,7 @@ func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error)
|
|||||||
return cities, nil
|
return cities, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (gl *Geolocation) Stop() error {
|
func (gl *geolocationImpl) Stop() error {
|
||||||
close(gl.stopCh)
|
close(gl.stopCh)
|
||||||
if gl.db != nil {
|
if gl.db != nil {
|
||||||
if err := gl.db.Close(); err != nil {
|
if err := gl.db.Close(); err != nil {
|
||||||
@ -259,3 +266,21 @@ func cleanupMaxMindDatabases(ctx context.Context, dataDir string, mmdbFile strin
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Mock struct{}
|
||||||
|
|
||||||
|
func (g *Mock) Lookup(ip net.IP) (*Record, error) {
|
||||||
|
return &Record{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Mock) GetAllCountries() ([]Country, error) {
|
||||||
|
return []Country{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Mock) GetCitiesByCountry(countryISOCode string) ([]City, error) {
|
||||||
|
return []City{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *Mock) Stop() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -24,7 +24,7 @@ func TestGeoLite_Lookup(t *testing.T) {
|
|||||||
db, err := openDB(filename)
|
db, err := openDB(filename)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
geo := &Geolocation{
|
geo := &geolocationImpl{
|
||||||
mux: sync.RWMutex{},
|
mux: sync.RWMutex{},
|
||||||
db: db,
|
db: db,
|
||||||
stopCh: make(chan struct{}),
|
stopCh: make(chan struct{}),
|
||||||
|
@ -474,6 +474,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
|
|||||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(group.Resources) > 0 {
|
||||||
|
return &GroupLinkError{"network resource", group.Resources[0].ID}
|
||||||
|
}
|
||||||
|
|
||||||
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||||
}
|
}
|
||||||
@ -529,7 +533,10 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
|
isLinked := slices.Contains(r.Groups, groupID) ||
|
||||||
|
slices.Contains(r.PeerGroups, groupID) ||
|
||||||
|
slices.Contains(r.AccessControlGroups, groupID)
|
||||||
|
if isLinked {
|
||||||
return true, r
|
return true, r
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -38,7 +38,7 @@ type GRPCServer struct {
|
|||||||
peersUpdateManager *PeersUpdateManager
|
peersUpdateManager *PeersUpdateManager
|
||||||
config *Config
|
config *Config
|
||||||
secretsManager SecretsManager
|
secretsManager SecretsManager
|
||||||
jwtValidator *jwtclaims.JWTValidator
|
jwtValidator jwtclaims.JWTValidator
|
||||||
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
jwtClaimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
appMetrics telemetry.AppMetrics
|
appMetrics telemetry.AppMetrics
|
||||||
ephemeralManager *EphemeralManager
|
ephemeralManager *EphemeralManager
|
||||||
@ -61,7 +61,7 @@ func NewServer(
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var jwtValidator *jwtclaims.JWTValidator
|
var jwtValidator jwtclaims.JWTValidator
|
||||||
|
|
||||||
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
|
if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) {
|
||||||
jwtValidator, err = jwtclaims.NewJWTValidator(
|
jwtValidator, err = jwtclaims.NewJWTValidator(
|
||||||
|
@ -725,10 +725,6 @@ components:
|
|||||||
PolicyRuleMinimum:
|
PolicyRuleMinimum:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
id:
|
|
||||||
description: Policy rule ID
|
|
||||||
type: string
|
|
||||||
example: ch8i4ug6lnn4g9hqv7mg
|
|
||||||
name:
|
name:
|
||||||
description: Policy rule name identifier
|
description: Policy rule name identifier
|
||||||
type: string
|
type: string
|
||||||
@ -790,6 +786,31 @@ components:
|
|||||||
- end
|
- end
|
||||||
|
|
||||||
PolicyRuleUpdate:
|
PolicyRuleUpdate:
|
||||||
|
allOf:
|
||||||
|
- $ref: '#/components/schemas/PolicyRuleMinimum'
|
||||||
|
- type: object
|
||||||
|
properties:
|
||||||
|
id:
|
||||||
|
description: Policy rule ID
|
||||||
|
type: string
|
||||||
|
example: ch8i4ug6lnn4g9hqv7mg
|
||||||
|
sources:
|
||||||
|
description: Policy rule source group IDs
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
example: "ch8i4ug6lnn4g9hqv797"
|
||||||
|
destinations:
|
||||||
|
description: Policy rule destination group IDs
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
example: "ch8i4ug6lnn4g9h7v7m0"
|
||||||
|
required:
|
||||||
|
- sources
|
||||||
|
- destinations
|
||||||
|
|
||||||
|
PolicyRuleCreate:
|
||||||
allOf:
|
allOf:
|
||||||
- $ref: '#/components/schemas/PolicyRuleMinimum'
|
- $ref: '#/components/schemas/PolicyRuleMinimum'
|
||||||
- type: object
|
- type: object
|
||||||
@ -817,6 +838,10 @@ components:
|
|||||||
- $ref: '#/components/schemas/PolicyRuleMinimum'
|
- $ref: '#/components/schemas/PolicyRuleMinimum'
|
||||||
- type: object
|
- type: object
|
||||||
properties:
|
properties:
|
||||||
|
id:
|
||||||
|
description: Policy rule ID
|
||||||
|
type: string
|
||||||
|
example: ch8i4ug6lnn4g9hqv7mg
|
||||||
sources:
|
sources:
|
||||||
description: Policy rule source group IDs
|
description: Policy rule source group IDs
|
||||||
type: array
|
type: array
|
||||||
@ -836,10 +861,6 @@ components:
|
|||||||
PolicyMinimum:
|
PolicyMinimum:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
id:
|
|
||||||
description: Policy ID
|
|
||||||
type: string
|
|
||||||
example: ch8i4ug6lnn4g9hqv7mg
|
|
||||||
name:
|
name:
|
||||||
description: Policy name identifier
|
description: Policy name identifier
|
||||||
type: string
|
type: string
|
||||||
@ -854,7 +875,6 @@ components:
|
|||||||
example: true
|
example: true
|
||||||
required:
|
required:
|
||||||
- name
|
- name
|
||||||
- description
|
|
||||||
- enabled
|
- enabled
|
||||||
PolicyUpdate:
|
PolicyUpdate:
|
||||||
allOf:
|
allOf:
|
||||||
@ -874,11 +894,33 @@ components:
|
|||||||
$ref: '#/components/schemas/PolicyRuleUpdate'
|
$ref: '#/components/schemas/PolicyRuleUpdate'
|
||||||
required:
|
required:
|
||||||
- rules
|
- rules
|
||||||
|
PolicyCreate:
|
||||||
|
allOf:
|
||||||
|
- $ref: '#/components/schemas/PolicyMinimum'
|
||||||
|
- type: object
|
||||||
|
properties:
|
||||||
|
source_posture_checks:
|
||||||
|
description: Posture checks ID's applied to policy source groups
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
|
example: "chacdk86lnnboviihd70"
|
||||||
|
rules:
|
||||||
|
description: Policy rule object for policy UI editor
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
$ref: '#/components/schemas/PolicyRuleUpdate'
|
||||||
|
required:
|
||||||
|
- rules
|
||||||
Policy:
|
Policy:
|
||||||
allOf:
|
allOf:
|
||||||
- $ref: '#/components/schemas/PolicyMinimum'
|
- $ref: '#/components/schemas/PolicyMinimum'
|
||||||
- type: object
|
- type: object
|
||||||
properties:
|
properties:
|
||||||
|
id:
|
||||||
|
description: Policy ID
|
||||||
|
type: string
|
||||||
|
example: ch8i4ug6lnn4g9hqv7mg
|
||||||
source_posture_checks:
|
source_posture_checks:
|
||||||
description: Posture checks ID's applied to policy source groups
|
description: Posture checks ID's applied to policy source groups
|
||||||
type: array
|
type: array
|
||||||
@ -2463,7 +2505,7 @@ paths:
|
|||||||
content:
|
content:
|
||||||
'application/json':
|
'application/json':
|
||||||
schema:
|
schema:
|
||||||
$ref: '#/components/schemas/PolicyUpdate'
|
$ref: '#/components/schemas/PolicyCreate'
|
||||||
responses:
|
responses:
|
||||||
'200':
|
'200':
|
||||||
description: A Policy object
|
description: A Policy object
|
||||||
|
@ -879,7 +879,7 @@ type PersonalAccessTokenRequest struct {
|
|||||||
// Policy defines model for Policy.
|
// Policy defines model for Policy.
|
||||||
type Policy struct {
|
type Policy struct {
|
||||||
// Description Policy friendly description
|
// Description Policy friendly description
|
||||||
Description string `json:"description"`
|
Description *string `json:"description,omitempty"`
|
||||||
|
|
||||||
// Enabled Policy status
|
// Enabled Policy status
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
@ -897,16 +897,31 @@ type Policy struct {
|
|||||||
SourcePostureChecks []string `json:"source_posture_checks"`
|
SourcePostureChecks []string `json:"source_posture_checks"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PolicyMinimum defines model for PolicyMinimum.
|
// PolicyCreate defines model for PolicyCreate.
|
||||||
type PolicyMinimum struct {
|
type PolicyCreate struct {
|
||||||
// Description Policy friendly description
|
// Description Policy friendly description
|
||||||
Description string `json:"description"`
|
Description *string `json:"description,omitempty"`
|
||||||
|
|
||||||
// Enabled Policy status
|
// Enabled Policy status
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
|
|
||||||
// Id Policy ID
|
// Name Policy name identifier
|
||||||
Id *string `json:"id,omitempty"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
|
// Rules Policy rule object for policy UI editor
|
||||||
|
Rules []PolicyRuleUpdate `json:"rules"`
|
||||||
|
|
||||||
|
// SourcePostureChecks Posture checks ID's applied to policy source groups
|
||||||
|
SourcePostureChecks *[]string `json:"source_posture_checks,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PolicyMinimum defines model for PolicyMinimum.
|
||||||
|
type PolicyMinimum struct {
|
||||||
|
// Description Policy friendly description
|
||||||
|
Description *string `json:"description,omitempty"`
|
||||||
|
|
||||||
|
// Enabled Policy status
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
|
||||||
// Name Policy name identifier
|
// Name Policy name identifier
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@ -970,9 +985,6 @@ type PolicyRuleMinimum struct {
|
|||||||
// Enabled Policy rule status
|
// Enabled Policy rule status
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
|
|
||||||
// Id Policy rule ID
|
|
||||||
Id *string `json:"id,omitempty"`
|
|
||||||
|
|
||||||
// Name Policy rule name identifier
|
// Name Policy rule name identifier
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
@ -1039,14 +1051,11 @@ type PolicyRuleUpdateProtocol string
|
|||||||
// PolicyUpdate defines model for PolicyUpdate.
|
// PolicyUpdate defines model for PolicyUpdate.
|
||||||
type PolicyUpdate struct {
|
type PolicyUpdate struct {
|
||||||
// Description Policy friendly description
|
// Description Policy friendly description
|
||||||
Description string `json:"description"`
|
Description *string `json:"description,omitempty"`
|
||||||
|
|
||||||
// Enabled Policy status
|
// Enabled Policy status
|
||||||
Enabled bool `json:"enabled"`
|
Enabled bool `json:"enabled"`
|
||||||
|
|
||||||
// Id Policy ID
|
|
||||||
Id *string `json:"id,omitempty"`
|
|
||||||
|
|
||||||
// Name Policy name identifier
|
// Name Policy name identifier
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
|
|
||||||
@ -1473,7 +1482,7 @@ type PutApiPeersPeerIdJSONRequestBody = PeerRequest
|
|||||||
type PostApiPoliciesJSONRequestBody = PolicyUpdate
|
type PostApiPoliciesJSONRequestBody = PolicyUpdate
|
||||||
|
|
||||||
// PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType.
|
// PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType.
|
||||||
type PutApiPoliciesPolicyIdJSONRequestBody = PolicyUpdate
|
type PutApiPoliciesPolicyIdJSONRequestBody = PolicyCreate
|
||||||
|
|
||||||
// PostApiPostureChecksJSONRequestBody defines body for PostApiPostureChecks for application/json ContentType.
|
// PostApiPostureChecksJSONRequestBody defines body for PostApiPostureChecks for application/json ContentType.
|
||||||
type PostApiPostureChecksJSONRequestBody = PostureCheckUpdate
|
type PostApiPostureChecksJSONRequestBody = PostureCheckUpdate
|
||||||
|
@ -35,15 +35,8 @@ import (
|
|||||||
|
|
||||||
const apiPrefix = "/api"
|
const apiPrefix = "/api"
|
||||||
|
|
||||||
type apiHandler struct {
|
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
Router *mux.Router
|
func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
|
||||||
AccountManager s.AccountManager
|
|
||||||
geolocationManager *geolocation.Geolocation
|
|
||||||
AuthCfg configs.AuthCfg
|
|
||||||
}
|
|
||||||
|
|
||||||
// APIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
|
||||||
func APIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) {
|
|
||||||
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
claimsExtractor := jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithAudience(authCfg.Audience),
|
jwtclaims.WithAudience(authCfg.Audience),
|
||||||
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
|
||||||
@ -78,27 +71,20 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, networksMa
|
|||||||
router := rootRouter.PathPrefix(prefix).Subrouter()
|
router := rootRouter.PathPrefix(prefix).Subrouter()
|
||||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
|
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler)
|
||||||
|
|
||||||
api := apiHandler{
|
if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
|
||||||
Router: router,
|
|
||||||
AccountManager: accountManager,
|
|
||||||
geolocationManager: LocationManager,
|
|
||||||
AuthCfg: authCfg,
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil {
|
|
||||||
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
return nil, fmt.Errorf("register integrations endpoints: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
accounts.AddEndpoints(api.AccountManager, authCfg, router)
|
accounts.AddEndpoints(accountManager, authCfg, router)
|
||||||
peers.AddEndpoints(api.AccountManager, authCfg, router)
|
peers.AddEndpoints(accountManager, authCfg, router)
|
||||||
users.AddEndpoints(api.AccountManager, authCfg, router)
|
users.AddEndpoints(accountManager, authCfg, router)
|
||||||
setup_keys.AddEndpoints(api.AccountManager, authCfg, router)
|
setup_keys.AddEndpoints(accountManager, authCfg, router)
|
||||||
policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router)
|
policies.AddEndpoints(accountManager, LocationManager, authCfg, router)
|
||||||
groups.AddEndpoints(api.AccountManager, authCfg, router)
|
groups.AddEndpoints(accountManager, authCfg, router)
|
||||||
routes.AddEndpoints(api.AccountManager, authCfg, router)
|
routes.AddEndpoints(accountManager, authCfg, router)
|
||||||
dns.AddEndpoints(api.AccountManager, authCfg, router)
|
dns.AddEndpoints(accountManager, authCfg, router)
|
||||||
events.AddEndpoints(api.AccountManager, authCfg, router)
|
events.AddEndpoints(accountManager, authCfg, router)
|
||||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager, api.AccountManager.GetAccountIDFromToken, authCfg, router)
|
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, accountManager.GetAccountIDFromToken, authCfg, router)
|
||||||
|
|
||||||
return rootRouter, nil
|
return rootRouter, nil
|
||||||
}
|
}
|
||||||
|
@ -22,18 +22,18 @@ var (
|
|||||||
// geolocationsHandler is a handler that returns locations.
|
// geolocationsHandler is a handler that returns locations.
|
||||||
type geolocationsHandler struct {
|
type geolocationsHandler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
geolocationManager *geolocation.Geolocation
|
geolocationManager geolocation.Geolocation
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
func addLocationsEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
|
func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg)
|
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg)
|
||||||
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
|
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
|
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
|
||||||
}
|
}
|
||||||
|
|
||||||
// newGeolocationsHandlerHandler creates a new Geolocations handler
|
// newGeolocationsHandlerHandler creates a new Geolocations handler
|
||||||
func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler {
|
func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler {
|
||||||
return &geolocationsHandler{
|
return &geolocationsHandler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
geolocationManager: geolocationManager,
|
geolocationManager: geolocationManager,
|
||||||
|
@ -23,7 +23,7 @@ type handler struct {
|
|||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
func AddEndpoints(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
|
func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
policiesHandler := newHandler(accountManager, authCfg)
|
policiesHandler := newHandler(accountManager, authCfg)
|
||||||
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
|
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
|
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
|
||||||
@ -133,16 +133,21 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
description := ""
|
||||||
|
if req.Description != nil {
|
||||||
|
description = *req.Description
|
||||||
|
}
|
||||||
|
|
||||||
policy := &types.Policy{
|
policy := &types.Policy{
|
||||||
ID: policyID,
|
ID: policyID,
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
Enabled: req.Enabled,
|
Enabled: req.Enabled,
|
||||||
Description: req.Description,
|
Description: description,
|
||||||
}
|
}
|
||||||
for _, rule := range req.Rules {
|
for _, rule := range req.Rules {
|
||||||
var ruleID string
|
var ruleID string
|
||||||
if rule.Id != nil {
|
if rule.Id != nil && policyID != "" {
|
||||||
ruleID = *rule.Id
|
ruleID = *rule.Id
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -370,7 +375,7 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
|
|||||||
ap := &api.Policy{
|
ap := &api.Policy{
|
||||||
Id: &policy.ID,
|
Id: &policy.ID,
|
||||||
Name: policy.Name,
|
Name: policy.Name,
|
||||||
Description: policy.Description,
|
Description: &policy.Description,
|
||||||
Enabled: policy.Enabled,
|
Enabled: policy.Enabled,
|
||||||
SourcePostureChecks: policy.SourcePostureChecks,
|
SourcePostureChecks: policy.SourcePostureChecks,
|
||||||
}
|
}
|
||||||
|
@ -154,6 +154,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
|
|||||||
|
|
||||||
func TestPoliciesWritePolicy(t *testing.T) {
|
func TestPoliciesWritePolicy(t *testing.T) {
|
||||||
str := func(s string) *string { return &s }
|
str := func(s string) *string { return &s }
|
||||||
|
emptyString := ""
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
expectedStatus int
|
expectedStatus int
|
||||||
@ -186,6 +187,7 @@ func TestPoliciesWritePolicy(t *testing.T) {
|
|||||||
expectedPolicy: &api.Policy{
|
expectedPolicy: &api.Policy{
|
||||||
Id: str("id-was-set"),
|
Id: str("id-was-set"),
|
||||||
Name: "Default POSTed Policy",
|
Name: "Default POSTed Policy",
|
||||||
|
Description: &emptyString,
|
||||||
Rules: []api.PolicyRule{
|
Rules: []api.PolicyRule{
|
||||||
{
|
{
|
||||||
Id: str("id-was-set"),
|
Id: str("id-was-set"),
|
||||||
@ -234,6 +236,7 @@ func TestPoliciesWritePolicy(t *testing.T) {
|
|||||||
expectedPolicy: &api.Policy{
|
expectedPolicy: &api.Policy{
|
||||||
Id: str("id-existed"),
|
Id: str("id-existed"),
|
||||||
Name: "Default POSTed Policy",
|
Name: "Default POSTed Policy",
|
||||||
|
Description: &emptyString,
|
||||||
Rules: []api.PolicyRule{
|
Rules: []api.PolicyRule{
|
||||||
{
|
{
|
||||||
Id: str("id-existed"),
|
Id: str("id-existed"),
|
||||||
|
@ -19,11 +19,11 @@ import (
|
|||||||
// postureChecksHandler is a handler that returns posture checks of the account.
|
// postureChecksHandler is a handler that returns posture checks of the account.
|
||||||
type postureChecksHandler struct {
|
type postureChecksHandler struct {
|
||||||
accountManager server.AccountManager
|
accountManager server.AccountManager
|
||||||
geolocationManager *geolocation.Geolocation
|
geolocationManager geolocation.Geolocation
|
||||||
claimsExtractor *jwtclaims.ClaimsExtractor
|
claimsExtractor *jwtclaims.ClaimsExtractor
|
||||||
}
|
}
|
||||||
|
|
||||||
func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
|
func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) {
|
||||||
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg)
|
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg)
|
||||||
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
|
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
|
||||||
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
|
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
|
||||||
@ -34,7 +34,7 @@ func addPostureCheckEndpoint(accountManager server.AccountManager, locationManag
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newPostureChecksHandler creates a new PostureChecks handler
|
// newPostureChecksHandler creates a new PostureChecks handler
|
||||||
func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler {
|
func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler {
|
||||||
return &postureChecksHandler{
|
return &postureChecksHandler{
|
||||||
accountManager: accountManager,
|
accountManager: accountManager,
|
||||||
geolocationManager: geolocationManager,
|
geolocationManager: geolocationManager,
|
||||||
|
@ -70,7 +70,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH
|
|||||||
return claims.AccountId, claims.UserId, nil
|
return claims.AccountId, claims.UserId, nil
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
geolocationManager: &geolocation.Geolocation{},
|
geolocationManager: &geolocation.Mock{},
|
||||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||||
return jwtclaims.AuthorizationClaims{
|
return jwtclaims.AuthorizationClaims{
|
||||||
|
@ -93,7 +93,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
apiSetupKeys := toResponseBody(setupKey)
|
apiSetupKeys := ToResponseBody(setupKey)
|
||||||
// for the creation we need to send the plain key
|
// for the creation we need to send the plain key
|
||||||
apiSetupKeys.Key = setupKey.Key
|
apiSetupKeys.Key = setupKey.Key
|
||||||
|
|
||||||
@ -183,7 +183,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
apiSetupKeys := make([]*api.SetupKey, 0)
|
apiSetupKeys := make([]*api.SetupKey, 0)
|
||||||
for _, key := range setupKeys {
|
for _, key := range setupKeys {
|
||||||
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
|
apiSetupKeys = append(apiSetupKeys, ToResponseBody(key))
|
||||||
}
|
}
|
||||||
|
|
||||||
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
||||||
@ -216,14 +216,14 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
|
|||||||
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) {
|
func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) {
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(200)
|
w.WriteHeader(200)
|
||||||
err := json.NewEncoder(w).Encode(toResponseBody(key))
|
err := json.NewEncoder(w).Encode(ToResponseBody(key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.WriteError(ctx, err, w)
|
util.WriteError(ctx, err, w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func toResponseBody(key *types.SetupKey) *api.SetupKey {
|
func ToResponseBody(key *types.SetupKey) *api.SetupKey {
|
||||||
var state string
|
var state string
|
||||||
switch {
|
switch {
|
||||||
case key.IsExpired():
|
case key.IsExpired():
|
||||||
|
@ -26,7 +26,6 @@ const (
|
|||||||
newSetupKeyName = "New Setup Key"
|
newSetupKeyName = "New Setup Key"
|
||||||
updatedSetupKeyName = "KKKey"
|
updatedSetupKeyName = "KKKey"
|
||||||
notFoundSetupKeyID = "notFoundSetupKeyID"
|
notFoundSetupKeyID = "notFoundSetupKeyID"
|
||||||
testAccountID = "test_id"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey,
|
func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey,
|
||||||
@ -81,7 +80,7 @@ func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKe
|
|||||||
return jwtclaims.AuthorizationClaims{
|
return jwtclaims.AuthorizationClaims{
|
||||||
UserId: user.Id,
|
UserId: user.Id,
|
||||||
Domain: "hotmail.com",
|
Domain: "hotmail.com",
|
||||||
AccountId: testAccountID,
|
AccountId: "testAccountId",
|
||||||
}
|
}
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
@ -102,7 +101,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
|||||||
updatedDefaultSetupKey.Name = updatedSetupKeyName
|
updatedDefaultSetupKey.Name = updatedSetupKeyName
|
||||||
updatedDefaultSetupKey.Revoked = true
|
updatedDefaultSetupKey.Revoked = true
|
||||||
|
|
||||||
expectedNewKey := toResponseBody(newSetupKey)
|
expectedNewKey := ToResponseBody(newSetupKey)
|
||||||
expectedNewKey.Key = plainKey
|
expectedNewKey.Key = plainKey
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
@ -120,7 +119,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
|||||||
requestPath: "/api/setup-keys",
|
requestPath: "/api/setup-keys",
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
expectedSetupKeys: []*api.SetupKey{toResponseBody(defaultSetupKey)},
|
expectedSetupKeys: []*api.SetupKey{ToResponseBody(defaultSetupKey)},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Get Existing Setup Key",
|
name: "Get Existing Setup Key",
|
||||||
@ -128,7 +127,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
|||||||
requestPath: "/api/setup-keys/" + existingSetupKeyID,
|
requestPath: "/api/setup-keys/" + existingSetupKeyID,
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
expectedSetupKey: toResponseBody(defaultSetupKey),
|
expectedSetupKey: ToResponseBody(defaultSetupKey),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Get Not Existing Setup Key",
|
name: "Get Not Existing Setup Key",
|
||||||
@ -159,7 +158,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
|||||||
))),
|
))),
|
||||||
expectedStatus: http.StatusOK,
|
expectedStatus: http.StatusOK,
|
||||||
expectedBody: true,
|
expectedBody: true,
|
||||||
expectedSetupKey: toResponseBody(updatedDefaultSetupKey),
|
expectedSetupKey: ToResponseBody(updatedDefaultSetupKey),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Delete Setup Key",
|
name: "Delete Setup Key",
|
||||||
@ -228,7 +227,7 @@ func TestSetupKeysHandlers(t *testing.T) {
|
|||||||
func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) {
|
func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
// this comparison is done manually because when converting to JSON dates formatted differently
|
// this comparison is done manually because when converting to JSON dates formatted differently
|
||||||
// assert.Equal(t, got.UpdatedAt, tc.expectedSetupKey.UpdatedAt) //doesn't work
|
// assert.Equal(t, got.UpdatedAt, tc.expectedResponse.UpdatedAt) //doesn't work
|
||||||
assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "")
|
assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "")
|
||||||
assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "")
|
assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "")
|
||||||
assert.Equal(t, got.Name, expected.Name)
|
assert.Equal(t, got.Name, expected.Name)
|
||||||
|
@ -175,6 +175,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
|||||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
||||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
||||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
||||||
|
claimMaps[jwtclaims.IsToken] = true
|
||||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||||
// Update the current request with the new context information.
|
// Update the current request with the new context information.
|
||||||
|
@ -0,0 +1,226 @@
|
|||||||
|
package benchmarks
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/api"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Map to store peers, groups, users, and setupKeys by name
|
||||||
|
var benchCasesSetupKeys = map[string]testing_tools.BenchmarkCase{
|
||||||
|
"Setup Keys - XS": {Peers: 10000, Groups: 10000, Users: 10000, SetupKeys: 5},
|
||||||
|
"Setup Keys - S": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 100},
|
||||||
|
"Setup Keys - M": {Peers: 100, Groups: 20, Users: 20, SetupKeys: 1000},
|
||||||
|
"Setup Keys - L": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 5000},
|
||||||
|
"Peers - L": {Peers: 10000, Groups: 5, Users: 5, SetupKeys: 5000},
|
||||||
|
"Groups - L": {Peers: 5, Groups: 10000, Users: 5, SetupKeys: 5000},
|
||||||
|
"Users - L": {Peers: 5, Groups: 5, Users: 10000, SetupKeys: 5000},
|
||||||
|
"Setup Keys - XL": {Peers: 500, Groups: 50, Users: 100, SetupKeys: 25000},
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCreateSetupKey(b *testing.B) {
|
||||||
|
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||||
|
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
for name, bc := range benchCasesSetupKeys {
|
||||||
|
b.Run(name, func(b *testing.B) {
|
||||||
|
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
|
||||||
|
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
start := time.Now()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
requestBody := api.CreateSetupKeyRequest{
|
||||||
|
AutoGroups: []string{testing_tools.TestGroupId},
|
||||||
|
ExpiresIn: testing_tools.ExpiresIn,
|
||||||
|
Name: testing_tools.NewKeyName + strconv.Itoa(i),
|
||||||
|
Type: "reusable",
|
||||||
|
UsageLimit: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
// the time marshal will be recorded as well but for our use case that is ok
|
||||||
|
body, err := json.Marshal(requestBody)
|
||||||
|
assert.NoError(b, err)
|
||||||
|
|
||||||
|
req := testing_tools.BuildRequest(b, body, http.MethodPost, "/api/setup-keys", testing_tools.TestAdminId)
|
||||||
|
apiHandler.ServeHTTP(recorder, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUpdateSetupKey(b *testing.B) {
|
||||||
|
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||||
|
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
|
||||||
|
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
|
||||||
|
"Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
|
||||||
|
"Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
|
||||||
|
"Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
|
||||||
|
"Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
|
||||||
|
"Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
|
||||||
|
"Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
for name, bc := range benchCasesSetupKeys {
|
||||||
|
b.Run(name, func(b *testing.B) {
|
||||||
|
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
|
||||||
|
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
start := time.Now()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
groupId := testing_tools.TestGroupId
|
||||||
|
if i%2 == 0 {
|
||||||
|
groupId = testing_tools.NewGroupId
|
||||||
|
}
|
||||||
|
requestBody := api.SetupKeyRequest{
|
||||||
|
AutoGroups: []string{groupId},
|
||||||
|
Revoked: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
// the time marshal will be recorded as well but for our use case that is ok
|
||||||
|
body, err := json.Marshal(requestBody)
|
||||||
|
assert.NoError(b, err)
|
||||||
|
|
||||||
|
req := testing_tools.BuildRequest(b, body, http.MethodPut, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId)
|
||||||
|
apiHandler.ServeHTTP(recorder, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkGetOneSetupKey(b *testing.B) {
|
||||||
|
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||||
|
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
for name, bc := range benchCasesSetupKeys {
|
||||||
|
b.Run(name, func(b *testing.B) {
|
||||||
|
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
|
||||||
|
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
start := time.Now()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId)
|
||||||
|
apiHandler.ServeHTTP(recorder, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkGetAllSetupKeys(b *testing.B) {
|
||||||
|
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||||
|
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12},
|
||||||
|
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15},
|
||||||
|
"Setup Keys - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 40},
|
||||||
|
"Setup Keys - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150},
|
||||||
|
"Peers - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150},
|
||||||
|
"Groups - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150},
|
||||||
|
"Users - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150},
|
||||||
|
"Setup Keys - XL": {MinMsPerOpLocal: 140, MaxMsPerOpLocal: 220, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 500},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
for name, bc := range benchCasesSetupKeys {
|
||||||
|
b.Run(name, func(b *testing.B) {
|
||||||
|
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
|
||||||
|
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
start := time.Now()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId)
|
||||||
|
apiHandler.ServeHTTP(recorder, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDeleteSetupKey(b *testing.B) {
|
||||||
|
var expectedMetrics = map[string]testing_tools.PerformanceMetrics{
|
||||||
|
"Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
"Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.SetOutput(io.Discard)
|
||||||
|
defer log.SetOutput(os.Stderr)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
for name, bc := range benchCasesSetupKeys {
|
||||||
|
b.Run(name, func(b *testing.B) {
|
||||||
|
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil)
|
||||||
|
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
start := time.Now()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
req := testing_tools.BuildRequest(b, nil, http.MethodDelete, "/api/setup-keys/"+"oldkey-"+strconv.Itoa(i), testing_tools.TestAdminId)
|
||||||
|
apiHandler.ServeHTTP(recorder, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
24
management/server/http/testing/testdata/setup_keys.sql
vendored
Normal file
24
management/server/http/testing/testdata/setup_keys.sql
vendored
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`));
|
||||||
|
CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||||
|
CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||||
|
CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||||
|
|
||||||
|
INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL);
|
||||||
|
INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
|
||||||
|
INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
|
||||||
|
INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
|
||||||
|
INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
|
||||||
|
INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
|
||||||
|
INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
|
||||||
|
INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,'');
|
||||||
|
INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0);
|
||||||
|
INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,'');
|
||||||
|
INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,'');
|
||||||
|
|
||||||
|
|
||||||
|
CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`));
|
||||||
|
|
||||||
|
INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',1,0);
|
||||||
|
INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',3,0);
|
||||||
|
INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,'0001-01-01 00:00:00+00:00','["testGroupId"]',5,1);
|
||||||
|
|
307
management/server/http/testing/testing_tools/tools.go
Normal file
307
management/server/http/testing/testing_tools/tools.go
Normal file
@ -0,0 +1,307 @@
|
|||||||
|
package testing_tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server"
|
||||||
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
|
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/configs"
|
||||||
|
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||||
|
"github.com/netbirdio/netbird/management/server/networks"
|
||||||
|
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||||
|
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TestAccountId = "testAccountId"
|
||||||
|
TestPeerId = "testPeerId"
|
||||||
|
TestGroupId = "testGroupId"
|
||||||
|
TestKeyId = "testKeyId"
|
||||||
|
|
||||||
|
TestUserId = "testUserId"
|
||||||
|
TestAdminId = "testAdminId"
|
||||||
|
TestOwnerId = "testOwnerId"
|
||||||
|
TestServiceUserId = "testServiceUserId"
|
||||||
|
TestServiceAdminId = "testServiceAdminId"
|
||||||
|
BlockedUserId = "blockedUserId"
|
||||||
|
OtherUserId = "otherUserId"
|
||||||
|
InvalidToken = "invalidToken"
|
||||||
|
|
||||||
|
NewKeyName = "newKey"
|
||||||
|
NewGroupId = "newGroupId"
|
||||||
|
ExpiresIn = 3600
|
||||||
|
RevokedKeyId = "revokedKeyId"
|
||||||
|
ExpiredKeyId = "expiredKeyId"
|
||||||
|
|
||||||
|
ExistingKeyName = "existingKey"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TB interface {
|
||||||
|
Cleanup(func())
|
||||||
|
Helper()
|
||||||
|
Errorf(format string, args ...any)
|
||||||
|
Fatalf(format string, args ...any)
|
||||||
|
TempDir() string
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkCase defines a single benchmark test case
|
||||||
|
type BenchmarkCase struct {
|
||||||
|
Peers int
|
||||||
|
Groups int
|
||||||
|
Users int
|
||||||
|
SetupKeys int
|
||||||
|
}
|
||||||
|
|
||||||
|
// PerformanceMetrics holds the performance expectations
|
||||||
|
type PerformanceMetrics struct {
|
||||||
|
MinMsPerOpLocal float64
|
||||||
|
MaxMsPerOpLocal float64
|
||||||
|
MinMsPerOpCICD float64
|
||||||
|
MaxMsPerOpCICD float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage) (http.Handler, server.AccountManager, chan struct{}) {
|
||||||
|
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create test store: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create metrics: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||||
|
updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId)
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
if expectedPeerUpdate != nil {
|
||||||
|
peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate)
|
||||||
|
} else {
|
||||||
|
peerShouldNotReceiveUpdate(t, updMsg)
|
||||||
|
}
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
geoMock := &geolocation.Mock{}
|
||||||
|
validatorMock := server.MocIntegratedValidator{}
|
||||||
|
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create manager: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
networksManagerMock := networks.NewManagerMock()
|
||||||
|
resourcesManagerMock := resources.NewManagerMock()
|
||||||
|
routersManagerMock := routers.NewManagerMock()
|
||||||
|
groupsManagerMock := groups.NewManagerMock()
|
||||||
|
apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, &jwtclaims.JwtValidatorMock{}, metrics, configs.AuthCfg{}, validatorMock)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return apiHandler, am, done
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerShouldNotReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage) {
|
||||||
|
t.Helper()
|
||||||
|
select {
|
||||||
|
case msg := <-updateMessage:
|
||||||
|
t.Errorf("Unexpected message received: %+v", msg)
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func peerShouldReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case msg := <-updateMessage:
|
||||||
|
if msg == nil {
|
||||||
|
t.Errorf("Received nil update message, expected valid message")
|
||||||
|
}
|
||||||
|
assert.Equal(t, expected, msg)
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
t.Errorf("Timed out waiting for update message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildRequest(t TB, requestBody []byte, requestType, requestPath, user string) *http.Request {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
req := httptest.NewRequest(requestType, requestPath, bytes.NewBuffer(requestBody))
|
||||||
|
req.Header.Set("Authorization", "Bearer "+user)
|
||||||
|
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int, expectResponse bool) ([]byte, bool) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
res := recorder.Result()
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
content, err := io.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read response body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !expectResponse {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if status := recorder.Code; status != expectedStatus {
|
||||||
|
t.Fatalf("handler returned wrong status code: got %v want %v, content: %s",
|
||||||
|
status, expectedStatus, string(content))
|
||||||
|
}
|
||||||
|
|
||||||
|
return content, expectedStatus == http.StatusOK
|
||||||
|
}
|
||||||
|
|
||||||
|
func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
account, err := am.GetAccount(ctx, TestAccountId)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to get account: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create peers
|
||||||
|
for i := 0; i < peers; i++ {
|
||||||
|
peerKey, _ := wgtypes.GeneratePrivateKey()
|
||||||
|
peer := &nbpeer.Peer{
|
||||||
|
ID: fmt.Sprintf("oldpeer-%d", i),
|
||||||
|
DNSLabel: fmt.Sprintf("oldpeer-%d", i),
|
||||||
|
Key: peerKey.PublicKey().String(),
|
||||||
|
IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)),
|
||||||
|
Status: &nbpeer.PeerStatus{},
|
||||||
|
UserID: TestUserId,
|
||||||
|
}
|
||||||
|
account.Peers[peer.ID] = peer
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create users
|
||||||
|
for i := 0; i < users; i++ {
|
||||||
|
user := &types.User{
|
||||||
|
Id: fmt.Sprintf("olduser-%d", i),
|
||||||
|
AccountID: account.Id,
|
||||||
|
Role: types.UserRoleUser,
|
||||||
|
}
|
||||||
|
account.Users[user.Id] = user
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < setupKeys; i++ {
|
||||||
|
key := &types.SetupKey{
|
||||||
|
Id: fmt.Sprintf("oldkey-%d", i),
|
||||||
|
AccountID: account.Id,
|
||||||
|
AutoGroups: []string{"someGroupID"},
|
||||||
|
ExpiresAt: time.Now().Add(ExpiresIn * time.Second),
|
||||||
|
Name: NewKeyName + strconv.Itoa(i),
|
||||||
|
Type: "reusable",
|
||||||
|
UsageLimit: 0,
|
||||||
|
}
|
||||||
|
account.SetupKeys[key.Id] = key
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create groups and policies
|
||||||
|
account.Policies = make([]*types.Policy, 0, groups)
|
||||||
|
for i := 0; i < groups; i++ {
|
||||||
|
groupID := fmt.Sprintf("group-%d", i)
|
||||||
|
group := &types.Group{
|
||||||
|
ID: groupID,
|
||||||
|
Name: fmt.Sprintf("Group %d", i),
|
||||||
|
}
|
||||||
|
for j := 0; j < peers/groups; j++ {
|
||||||
|
peerIndex := i*(peers/groups) + j
|
||||||
|
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
|
||||||
|
}
|
||||||
|
account.Groups[groupID] = group
|
||||||
|
|
||||||
|
// Create a policy for this group
|
||||||
|
policy := &types.Policy{
|
||||||
|
ID: fmt.Sprintf("policy-%d", i),
|
||||||
|
Name: fmt.Sprintf("Policy for Group %d", i),
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*types.PolicyRule{
|
||||||
|
{
|
||||||
|
ID: fmt.Sprintf("rule-%d", i),
|
||||||
|
Name: fmt.Sprintf("Rule for Group %d", i),
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{groupID},
|
||||||
|
Destinations: []string{groupID},
|
||||||
|
Bidirectional: true,
|
||||||
|
Protocol: types.PolicyRuleProtocolALL,
|
||||||
|
Action: types.PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
account.Policies = append(account.Policies, policy)
|
||||||
|
}
|
||||||
|
|
||||||
|
account.PostureChecks = []*posture.Checks{
|
||||||
|
{
|
||||||
|
ID: "PostureChecksAll",
|
||||||
|
Name: "All",
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
|
MinVersion: "0.0.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = am.Store.SaveAccount(context.Background(), account)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to save account: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func EvaluateBenchmarkResults(b *testing.B, name string, duration time.Duration, perfMetrics PerformanceMetrics, recorder *httptest.ResponseRecorder) {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusOK {
|
||||||
|
b.Fatalf("Benchmark %s failed: unexpected status code %d", name, recorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||||
|
b.ReportMetric(msPerOp, "ms/op")
|
||||||
|
|
||||||
|
minExpected := perfMetrics.MinMsPerOpLocal
|
||||||
|
maxExpected := perfMetrics.MaxMsPerOpLocal
|
||||||
|
if os.Getenv("CI") == "true" {
|
||||||
|
minExpected = perfMetrics.MinMsPerOpCICD
|
||||||
|
maxExpected = perfMetrics.MaxMsPerOpCICD
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp < minExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", name, msPerOp, minExpected)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msPerOp > maxExpected {
|
||||||
|
b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", name, msPerOp, maxExpected)
|
||||||
|
}
|
||||||
|
}
|
@ -7,6 +7,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
)
|
)
|
||||||
@ -78,3 +79,45 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
|
|||||||
func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) {
|
func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) {
|
||||||
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
|
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type MocIntegratedValidator struct {
|
||||||
|
ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
||||||
|
if a.ValidatePeerFunc != nil {
|
||||||
|
return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings)
|
||||||
|
}
|
||||||
|
return update, false, nil
|
||||||
|
}
|
||||||
|
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
|
||||||
|
validatedPeers := make(map[string]struct{})
|
||||||
|
for _, peer := range peers {
|
||||||
|
validatedPeers[peer.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
return validatedPeers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
|
||||||
|
return peer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
|
||||||
|
return false, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
|
||||||
|
// just a dummy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (MocIntegratedValidator) Stop(_ context.Context) {
|
||||||
|
// just a dummy
|
||||||
|
}
|
||||||
|
@ -22,6 +22,8 @@ const (
|
|||||||
LastLoginSuffix = "nb_last_login"
|
LastLoginSuffix = "nb_last_login"
|
||||||
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
|
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
|
||||||
Invited = "nb_invited"
|
Invited = "nb_invited"
|
||||||
|
// IsToken claim indicates that auth type from the user is a token
|
||||||
|
IsToken = "is_token"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ExtractClaims Extract function type
|
// ExtractClaims Extract function type
|
||||||
|
@ -72,15 +72,19 @@ type JSONWebKey struct {
|
|||||||
X5c []string `json:"x5c"`
|
X5c []string `json:"x5c"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTValidator struct to handle token validation and parsing
|
type JWTValidator interface {
|
||||||
type JWTValidator struct {
|
ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwtValidatorImpl struct to handle token validation and parsing
|
||||||
|
type jwtValidatorImpl struct {
|
||||||
options Options
|
options Options
|
||||||
}
|
}
|
||||||
|
|
||||||
var keyNotFound = errors.New("unable to find appropriate key")
|
var keyNotFound = errors.New("unable to find appropriate key")
|
||||||
|
|
||||||
// NewJWTValidator constructor
|
// NewJWTValidator constructor
|
||||||
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) {
|
func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) {
|
||||||
keys, err := getPemKeys(ctx, keysLocation)
|
keys, err := getPemKeys(ctx, keysLocation)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -146,13 +150,13 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string,
|
|||||||
options.UserProperty = "user"
|
options.UserProperty = "user"
|
||||||
}
|
}
|
||||||
|
|
||||||
return &JWTValidator{
|
return &jwtValidatorImpl{
|
||||||
options: options,
|
options: options,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateAndParse validates the token and returns the parsed token
|
// ValidateAndParse validates the token and returns the parsed token
|
||||||
func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
|
func (m *jwtValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
|
||||||
// If the token is empty...
|
// If the token is empty...
|
||||||
if token == "" {
|
if token == "" {
|
||||||
// Check if it was required
|
// Check if it was required
|
||||||
@ -318,3 +322,28 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int {
|
|||||||
|
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type JwtValidatorMock struct{}
|
||||||
|
|
||||||
|
func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) {
|
||||||
|
claimMaps := jwt.MapClaims{}
|
||||||
|
|
||||||
|
switch token {
|
||||||
|
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
|
||||||
|
claimMaps[UserIDClaim] = token
|
||||||
|
claimMaps[AccountIDSuffix] = "testAccountId"
|
||||||
|
claimMaps[DomainIDSuffix] = "test.com"
|
||||||
|
claimMaps[DomainCategorySuffix] = "private"
|
||||||
|
case "otherUserId":
|
||||||
|
claimMaps[UserIDClaim] = "otherUserId"
|
||||||
|
claimMaps[AccountIDSuffix] = "otherAccountId"
|
||||||
|
claimMaps[DomainIDSuffix] = "other.com"
|
||||||
|
claimMaps[DomainCategorySuffix] = "private"
|
||||||
|
case "invalidToken":
|
||||||
|
return nil, errors.New("invalid token")
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||||
|
return jwtToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -21,13 +21,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/encryption"
|
"github.com/netbirdio/netbird/encryption"
|
||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -448,43 +445,6 @@ var _ = Describe("Management service", func() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
type MocIntegratedValidator struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) {
|
|
||||||
return update, false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) {
|
|
||||||
validatedPeers := make(map[string]struct{})
|
|
||||||
for p := range peers {
|
|
||||||
validatedPeers[p] = struct{}{}
|
|
||||||
}
|
|
||||||
return validatedPeers, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer {
|
|
||||||
return peer
|
|
||||||
}
|
|
||||||
|
|
||||||
func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) {
|
|
||||||
return false, false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (MocIntegratedValidator) Stop(_ context.Context) {}
|
|
||||||
|
|
||||||
func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse {
|
func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse {
|
||||||
defer GinkgoRecover()
|
defer GinkgoRecover()
|
||||||
|
|
||||||
@ -547,7 +507,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc.
|
|||||||
log.Fatalf("failed creating metrics: %v", err)
|
log.Fatalf("failed creating metrics: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, server.MocIntegratedValidator{}, metrics)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed creating a manager: %v", err)
|
log.Fatalf("failed creating a manager: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -195,6 +195,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
groups int
|
groups int
|
||||||
routes int
|
routes int
|
||||||
routesWithRGGroups int
|
routesWithRGGroups int
|
||||||
|
networks int
|
||||||
|
networkResources int
|
||||||
|
networkRouters int
|
||||||
|
networkRoutersWithPG int
|
||||||
nameservers int
|
nameservers int
|
||||||
uiClient int
|
uiClient int
|
||||||
version string
|
version string
|
||||||
@ -219,6 +223,16 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
}
|
}
|
||||||
|
|
||||||
groups += len(account.Groups)
|
groups += len(account.Groups)
|
||||||
|
networks += len(account.Networks)
|
||||||
|
networkResources += len(account.NetworkResources)
|
||||||
|
|
||||||
|
networkRouters += len(account.NetworkRouters)
|
||||||
|
for _, router := range account.NetworkRouters {
|
||||||
|
if len(router.PeerGroups) > 0 {
|
||||||
|
networkRoutersWithPG++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
routes += len(account.Routes)
|
routes += len(account.Routes)
|
||||||
for _, route := range account.Routes {
|
for _, route := range account.Routes {
|
||||||
if len(route.PeerGroups) > 0 {
|
if len(route.PeerGroups) > 0 {
|
||||||
@ -312,6 +326,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks
|
metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks
|
||||||
metricsProperties["posture_checks"] = postureChecks
|
metricsProperties["posture_checks"] = postureChecks
|
||||||
metricsProperties["groups"] = groups
|
metricsProperties["groups"] = groups
|
||||||
|
metricsProperties["networks"] = networks
|
||||||
|
metricsProperties["network_resources"] = networkResources
|
||||||
|
metricsProperties["network_routers"] = networkRouters
|
||||||
|
metricsProperties["network_routers_with_groups"] = networkRoutersWithPG
|
||||||
metricsProperties["routes"] = routes
|
metricsProperties["routes"] = routes
|
||||||
metricsProperties["routes_with_routing_groups"] = routesWithRGGroups
|
metricsProperties["routes_with_routing_groups"] = routesWithRGGroups
|
||||||
metricsProperties["nameservers"] = nameservers
|
metricsProperties["nameservers"] = nameservers
|
||||||
|
@ -5,6 +5,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
"github.com/netbirdio/netbird/management/server/posture"
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
@ -172,6 +175,31 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Networks: []*networkTypes.Network{
|
||||||
|
{
|
||||||
|
ID: "1",
|
||||||
|
AccountID: "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NetworkResources: []*resourceTypes.NetworkResource{
|
||||||
|
{
|
||||||
|
ID: "1",
|
||||||
|
AccountID: "1",
|
||||||
|
NetworkID: "1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "2",
|
||||||
|
AccountID: "1",
|
||||||
|
NetworkID: "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||||
|
{
|
||||||
|
ID: "1",
|
||||||
|
AccountID: "1",
|
||||||
|
NetworkID: "1",
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -200,6 +228,15 @@ func TestGenerateProperties(t *testing.T) {
|
|||||||
if properties["routes"] != 2 {
|
if properties["routes"] != 2 {
|
||||||
t.Errorf("expected 2 routes, got %d", properties["routes"])
|
t.Errorf("expected 2 routes, got %d", properties["routes"])
|
||||||
}
|
}
|
||||||
|
if properties["networks"] != 1 {
|
||||||
|
t.Errorf("expected 1 networks, got %d", properties["networks"])
|
||||||
|
}
|
||||||
|
if properties["network_resources"] != 2 {
|
||||||
|
t.Errorf("expected 2 network_resources, got %d", properties["network_resources"])
|
||||||
|
}
|
||||||
|
if properties["network_routers"] != 1 {
|
||||||
|
t.Errorf("expected 1 network_routers, got %d", properties["network_routers"])
|
||||||
|
}
|
||||||
if properties["rules"] != 4 {
|
if properties["rules"] != 4 {
|
||||||
t.Errorf("expected 4 rules, got %d", properties["rules"])
|
t.Errorf("expected 4 rules, got %d", properties["rules"])
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,9 @@ type managerImpl struct {
|
|||||||
routersManager routers.Manager
|
routersManager routers.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockManager struct {
|
||||||
|
}
|
||||||
|
|
||||||
func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager s.AccountManager) Manager {
|
func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager s.AccountManager) Manager {
|
||||||
return &managerImpl{
|
return &managerImpl{
|
||||||
store: store,
|
store: store,
|
||||||
@ -185,3 +188,27 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewManagerMock() Manager {
|
||||||
|
return &mockManager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) {
|
||||||
|
return []*types.Network{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
||||||
|
return network, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) {
|
||||||
|
return &types.Network{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
|
||||||
|
return network, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
@ -34,6 +34,9 @@ type managerImpl struct {
|
|||||||
accountManager s.AccountManager
|
accountManager s.AccountManager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockManager struct {
|
||||||
|
}
|
||||||
|
|
||||||
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager s.AccountManager) Manager {
|
func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager s.AccountManager) Manager {
|
||||||
return &managerImpl{
|
return &managerImpl{
|
||||||
store: store,
|
store: store,
|
||||||
@ -381,3 +384,39 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti
|
|||||||
|
|
||||||
return eventsToStore, nil
|
return eventsToStore, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewManagerMock() Manager {
|
||||||
|
return &mockManager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) {
|
||||||
|
return []*types.NetworkResource{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
|
||||||
|
return []*types.NetworkResource{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
|
||||||
|
return map[string][]string{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
|
||||||
|
return &types.NetworkResource{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) {
|
||||||
|
return &types.NetworkResource{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) {
|
||||||
|
return &types.NetworkResource{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockManager) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) {
|
||||||
|
return []func(){}, nil
|
||||||
|
}
|
||||||
|
@ -111,6 +111,7 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network
|
|||||||
NetID: route.NetID(n.Name),
|
NetID: route.NetID(n.Name),
|
||||||
Description: n.Description,
|
Description: n.Description,
|
||||||
Peer: peer.Key,
|
Peer: peer.Key,
|
||||||
|
PeerID: peer.ID,
|
||||||
PeerGroups: nil,
|
PeerGroups: nil,
|
||||||
Masquerade: router.Masquerade,
|
Masquerade: router.Masquerade,
|
||||||
Metric: router.Metric,
|
Metric: router.Metric,
|
||||||
|
@ -932,11 +932,11 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
|||||||
}{
|
}{
|
||||||
{"Small", 50, 5, 90, 120, 90, 120},
|
{"Small", 50, 5, 90, 120, 90, 120},
|
||||||
{"Medium", 500, 100, 110, 150, 120, 260},
|
{"Medium", 500, 100, 110, 150, 120, 260},
|
||||||
{"Large", 5000, 200, 800, 1390, 2500, 4600},
|
{"Large", 5000, 200, 800, 1700, 2500, 5000},
|
||||||
{"Small single", 50, 10, 90, 120, 90, 120},
|
{"Small single", 50, 10, 90, 120, 90, 120},
|
||||||
{"Medium single", 500, 10, 110, 170, 120, 200},
|
{"Medium single", 500, 10, 110, 170, 120, 200},
|
||||||
{"Large 5", 5000, 15, 1300, 2100, 5000, 7000},
|
{"Large 5", 5000, 15, 1300, 2100, 4900, 7000},
|
||||||
{"Extra Large", 2000, 2000, 1300, 2100, 4000, 6000},
|
{"Extra Large", 2000, 2000, 1300, 2400, 4000, 6400},
|
||||||
}
|
}
|
||||||
|
|
||||||
log.SetOutput(io.Discard)
|
log.SetOutput(io.Discard)
|
||||||
|
@ -74,6 +74,19 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
"peerH",
|
"peerH",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"GroupWorkstations": {
|
||||||
|
ID: "GroupWorkstations",
|
||||||
|
Name: "GroupWorkstations",
|
||||||
|
Peers: []string{
|
||||||
|
"peerB",
|
||||||
|
"peerA",
|
||||||
|
"peerD",
|
||||||
|
"peerE",
|
||||||
|
"peerF",
|
||||||
|
"peerG",
|
||||||
|
"peerH",
|
||||||
|
},
|
||||||
|
},
|
||||||
"GroupSwarm": {
|
"GroupSwarm": {
|
||||||
ID: "GroupSwarm",
|
ID: "GroupSwarm",
|
||||||
Name: "swarm",
|
Name: "swarm",
|
||||||
@ -127,7 +140,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
Action: types.PolicyTrafficActionAccept,
|
Action: types.PolicyTrafficActionAccept,
|
||||||
Sources: []string{
|
Sources: []string{
|
||||||
"GroupSwarm",
|
"GroupSwarm",
|
||||||
"GroupAll",
|
"GroupWorkstations",
|
||||||
},
|
},
|
||||||
Destinations: []string{
|
Destinations: []string{
|
||||||
"GroupSwarm",
|
"GroupSwarm",
|
||||||
@ -159,6 +172,8 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
assert.Contains(t, peers, account.Peers["peerD"])
|
assert.Contains(t, peers, account.Peers["peerD"])
|
||||||
assert.Contains(t, peers, account.Peers["peerE"])
|
assert.Contains(t, peers, account.Peers["peerE"])
|
||||||
assert.Contains(t, peers, account.Peers["peerF"])
|
assert.Contains(t, peers, account.Peers["peerF"])
|
||||||
|
assert.Contains(t, peers, account.Peers["peerG"])
|
||||||
|
assert.Contains(t, peers, account.Peers["peerH"])
|
||||||
|
|
||||||
epectedFirewallRules := []*types.FirewallRule{
|
epectedFirewallRules := []*types.FirewallRule{
|
||||||
{
|
{
|
||||||
@ -189,21 +204,6 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
Protocol: "all",
|
Protocol: "all",
|
||||||
Port: "",
|
Port: "",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
PeerIP: "100.65.254.139",
|
|
||||||
Direction: types.FirewallRuleDirectionOUT,
|
|
||||||
Action: "accept",
|
|
||||||
Protocol: "all",
|
|
||||||
Port: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
PeerIP: "100.65.254.139",
|
|
||||||
Direction: types.FirewallRuleDirectionIN,
|
|
||||||
Action: "accept",
|
|
||||||
Protocol: "all",
|
|
||||||
Port: "",
|
|
||||||
},
|
|
||||||
|
|
||||||
{
|
{
|
||||||
PeerIP: "100.65.62.5",
|
PeerIP: "100.65.62.5",
|
||||||
Direction: types.FirewallRuleDirectionOUT,
|
Direction: types.FirewallRuleDirectionOUT,
|
||||||
@ -280,10 +280,16 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
||||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
|
||||||
slices.SortFunc(firewallRules, sortFunc())
|
for _, rule := range firewallRules {
|
||||||
for i := range firewallRules {
|
contains := false
|
||||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
for _, expectedRule := range epectedFirewallRules {
|
||||||
|
if rule.IsEqual(expectedRule) {
|
||||||
|
contains = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.True(t, contains, "rule not found in expected rules %#v", rule)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -364,7 +364,7 @@ func toProtocolRoute(route *route.Route) *proto.Route {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
|
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
|
||||||
protoRoutes := make([]*proto.Route, 0)
|
protoRoutes := make([]*proto.Route, 0, len(routes))
|
||||||
for _, r := range routes {
|
for _, r := range routes {
|
||||||
protoRoutes = append(protoRoutes, toProtocolRoute(r))
|
protoRoutes = append(protoRoutes, toProtocolRoute(r))
|
||||||
}
|
}
|
||||||
|
@ -75,7 +75,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
|||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil {
|
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil {
|
||||||
return err
|
return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
|
setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral)
|
||||||
@ -132,7 +132,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
|||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||||
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil {
|
if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil {
|
||||||
return err
|
return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id)
|
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id)
|
||||||
|
@ -303,55 +303,47 @@ func (a *Account) GetPeerNetworkMap(
|
|||||||
return nm
|
return nm
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) addNetworksRoutingPeers(networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers []string) []*nbpeer.Peer {
|
func (a *Account) addNetworksRoutingPeers(
|
||||||
missingPeers := map[string]struct{}{}
|
networkResourcesRoutes []*route.Route,
|
||||||
|
peer *nbpeer.Peer,
|
||||||
|
peersToConnect []*nbpeer.Peer,
|
||||||
|
expiredPeers []*nbpeer.Peer,
|
||||||
|
isRouter bool,
|
||||||
|
sourcePeers map[string]struct{},
|
||||||
|
) []*nbpeer.Peer {
|
||||||
|
|
||||||
|
networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
|
||||||
for _, r := range networkResourcesRoutes {
|
for _, r := range networkResourcesRoutes {
|
||||||
if r.Peer == peer.Key {
|
networkRoutesPeers[r.PeerID] = struct{}{}
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
missing := true
|
delete(sourcePeers, peer.ID)
|
||||||
for _, p := range slices.Concat(peersToConnect, expiredPeers) {
|
|
||||||
if r.Peer == p.Key {
|
for _, existingPeer := range peersToConnect {
|
||||||
missing = false
|
delete(sourcePeers, existingPeer.ID)
|
||||||
break
|
delete(networkRoutesPeers, existingPeer.ID)
|
||||||
}
|
|
||||||
}
|
|
||||||
if missing {
|
|
||||||
missingPeers[r.Peer] = struct{}{}
|
|
||||||
}
|
}
|
||||||
|
for _, expPeer := range expiredPeers {
|
||||||
|
delete(sourcePeers, expPeer.ID)
|
||||||
|
delete(networkRoutesPeers, expPeer.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
|
||||||
if isRouter {
|
if isRouter {
|
||||||
for _, s := range sourcePeers {
|
for p := range sourcePeers {
|
||||||
if s == peer.ID {
|
missingPeers[p] = struct{}{}
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
missing := true
|
|
||||||
for _, p := range slices.Concat(peersToConnect, expiredPeers) {
|
|
||||||
if s == p.ID {
|
|
||||||
missing = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if missing {
|
|
||||||
p, ok := a.Peers[s]
|
|
||||||
if ok {
|
|
||||||
missingPeers[p.Key] = struct{}{}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for p := range networkRoutesPeers {
|
||||||
|
missingPeers[p] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
for p := range missingPeers {
|
for p := range missingPeers {
|
||||||
for _, p2 := range a.Peers {
|
if missingPeer := a.Peers[p]; missingPeer != nil {
|
||||||
if p2.Key == p {
|
peersToConnect = append(peersToConnect, missingPeer)
|
||||||
peersToConnect = append(peersToConnect, p2)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return peersToConnect
|
return peersToConnect
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1045,14 +1037,9 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
|||||||
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
||||||
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||||
peerInGroups := false
|
peerInGroups := false
|
||||||
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
|
uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups)
|
||||||
for _, g := range groups {
|
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
|
||||||
group, ok := a.Groups[g]
|
for _, p := range uniquePeerIDs {
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, p := range group.Peers {
|
|
||||||
peer, ok := a.Peers[p]
|
peer, ok := a.Peers[p]
|
||||||
if !ok || peer == nil {
|
if !ok || peer == nil {
|
||||||
continue
|
continue
|
||||||
@ -1075,7 +1062,7 @@ func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, pe
|
|||||||
|
|
||||||
filteredPeers = append(filteredPeers, peer)
|
filteredPeers = append(filteredPeers, peer)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return filteredPeers, peerInGroups
|
return filteredPeers, peerInGroups
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1151,7 +1138,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap)
|
rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap)
|
||||||
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
|
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
|
||||||
fwRules = append(fwRules, rules...)
|
fwRules = append(fwRules, rules...)
|
||||||
}
|
}
|
||||||
@ -1159,7 +1146,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
|
|||||||
return fwRules
|
return fwRules
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
|
func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
|
||||||
distPeersWithPolicy := make(map[string]struct{})
|
distPeersWithPolicy := make(map[string]struct{})
|
||||||
for _, id := range rule.Sources {
|
for _, id := range rule.Sources {
|
||||||
group := a.Groups[id]
|
group := a.Groups[id]
|
||||||
@ -1173,7 +1160,7 @@ func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeer
|
|||||||
}
|
}
|
||||||
_, distPeer := distributionPeers[pID]
|
_, distPeer := distributionPeers[pID]
|
||||||
_, valid := validatedPeersMap[pID]
|
_, valid := validatedPeersMap[pID]
|
||||||
if distPeer && valid {
|
if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) {
|
||||||
distPeersWithPolicy[pID] = struct{}{}
|
distPeersWithPolicy[pID] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1271,7 +1258,11 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer
|
|||||||
distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups)
|
distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups)
|
||||||
|
|
||||||
rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers)
|
rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers)
|
||||||
routesFirewallRules = append(routesFirewallRules, rules...)
|
for _, rule := range rules {
|
||||||
|
if len(rule.SourceRanges) > 0 {
|
||||||
|
routesFirewallRules = append(routesFirewallRules, rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return routesFirewallRules
|
return routesFirewallRules
|
||||||
@ -1303,10 +1294,10 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
|
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
|
||||||
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) {
|
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
|
||||||
var isRoutingPeer bool
|
var isRoutingPeer bool
|
||||||
var routes []*route.Route
|
var routes []*route.Route
|
||||||
var allSourcePeers []string
|
allSourcePeers := make(map[string]struct{}, len(a.Peers))
|
||||||
|
|
||||||
for _, resource := range a.NetworkResources {
|
for _, resource := range a.NetworkResources {
|
||||||
var addSourcePeers bool
|
var addSourcePeers bool
|
||||||
@ -1319,23 +1310,22 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
addedResourceRoute := false
|
||||||
for _, policy := range resourcePolicies[resource.ID] {
|
for _, policy := range resourcePolicies[resource.ID] {
|
||||||
for _, sourceGroup := range policy.SourceGroups() {
|
peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
|
||||||
group := a.GetGroup(sourceGroup)
|
|
||||||
if group == nil {
|
|
||||||
log.WithContext(ctx).Warnf("policy %s has source group %s that doesn't exist under account %s, will continue map generation without it", policy.ID, sourceGroup, a.Id)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// routing peer should be able to connect with all source peers
|
|
||||||
if addSourcePeers {
|
if addSourcePeers {
|
||||||
allSourcePeers = append(allSourcePeers, group.Peers...)
|
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
|
||||||
} else if slices.Contains(group.Peers, peerID) {
|
allSourcePeers[pID] = struct{}{}
|
||||||
|
}
|
||||||
|
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
|
||||||
// add routes for the resource if the peer is in the distribution group
|
// add routes for the resource if the peer is in the distribution group
|
||||||
for peerId, router := range networkRoutingPeers {
|
for peerId, router := range networkRoutingPeers {
|
||||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
|
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
|
||||||
}
|
}
|
||||||
|
addedResourceRoute = true
|
||||||
}
|
}
|
||||||
|
if addedResourceRoute {
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1343,6 +1333,42 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
|
|||||||
return isRoutingPeer, routes, allSourcePeers
|
return isRoutingPeer, routes, allSourcePeers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
|
||||||
|
var dest []string
|
||||||
|
for _, peerID := range inputPeers {
|
||||||
|
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
|
||||||
|
dest = append(dest, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
|
||||||
|
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
|
||||||
|
for _, groupID := range groups {
|
||||||
|
group := a.GetGroup(groupID)
|
||||||
|
if group == nil {
|
||||||
|
log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if group.IsGroupAll() || len(groups) == 1 {
|
||||||
|
return group.Peers
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peerID := range group.Peers {
|
||||||
|
peerIDs[peerID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := make([]string, 0, len(peerIDs))
|
||||||
|
for peerID := range peerIDs {
|
||||||
|
ids = append(ids, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
// getNetworkResources filters and returns a list of network resources associated with the given network ID.
|
// getNetworkResources filters and returns a list of network resources associated with the given network ID.
|
||||||
func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource {
|
func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource {
|
||||||
var resources []*resourceTypes.NetworkResource
|
var resources []*resourceTypes.NetworkResource
|
||||||
|
@ -1,14 +1,20 @@
|
|||||||
package types
|
package types
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -310,19 +316,19 @@ func Test_GetResourcePoliciesMap(t *testing.T) {
|
|||||||
|
|
||||||
func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) {
|
func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) {
|
||||||
account := setupTestAccount()
|
account := setupTestAccount()
|
||||||
peer := &nbpeer.Peer{Key: "peer1"}
|
peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"}
|
||||||
networkResourcesRoutes := []*route.Route{
|
networkResourcesRoutes := []*route.Route{
|
||||||
{Peer: "peer2Key"},
|
{Peer: "peer2Key", PeerID: "peer2"},
|
||||||
{Peer: "peer3Key"},
|
{Peer: "peer3Key", PeerID: "peer3"},
|
||||||
}
|
}
|
||||||
peersToConnect := []*nbpeer.Peer{
|
peersToConnect := []*nbpeer.Peer{
|
||||||
{Key: "peer2Key"},
|
{Key: "peer2Key", ID: "peer2"},
|
||||||
}
|
}
|
||||||
expiredPeers := []*nbpeer.Peer{
|
expiredPeers := []*nbpeer.Peer{
|
||||||
{Key: "peer4Key"},
|
{Key: "peer4Key", ID: "peer4"},
|
||||||
}
|
}
|
||||||
|
|
||||||
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{})
|
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
|
||||||
require.Len(t, result, 2)
|
require.Len(t, result, 2)
|
||||||
require.Equal(t, "peer2Key", result[0].Key)
|
require.Equal(t, "peer2Key", result[0].Key)
|
||||||
require.Equal(t, "peer3Key", result[1].Key)
|
require.Equal(t, "peer3Key", result[1].Key)
|
||||||
@ -339,7 +345,7 @@ func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
expiredPeers := []*nbpeer.Peer{}
|
expiredPeers := []*nbpeer.Peer{}
|
||||||
|
|
||||||
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{})
|
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
|
||||||
require.Len(t, result, 1)
|
require.Len(t, result, 1)
|
||||||
require.Equal(t, "peer2Key", result[0].Key)
|
require.Equal(t, "peer2Key", result[0].Key)
|
||||||
}
|
}
|
||||||
@ -358,7 +364,7 @@ func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) {
|
|||||||
{Key: "peer3Key"},
|
{Key: "peer3Key"},
|
||||||
}
|
}
|
||||||
|
|
||||||
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{})
|
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
|
||||||
require.Len(t, result, 1)
|
require.Len(t, result, 1)
|
||||||
require.Equal(t, "peer2Key", result[0].Key)
|
require.Equal(t, "peer2Key", result[0].Key)
|
||||||
}
|
}
|
||||||
@ -370,6 +376,382 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) {
|
|||||||
peersToConnect := []*nbpeer.Peer{}
|
peersToConnect := []*nbpeer.Peer{}
|
||||||
expiredPeers := []*nbpeer.Peer{}
|
expiredPeers := []*nbpeer.Peer{}
|
||||||
|
|
||||||
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{})
|
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
|
||||||
require.Len(t, result, 0)
|
require.Len(t, result, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
accID = "accountID"
|
||||||
|
network1ID = "network1ID"
|
||||||
|
group1ID = "group1"
|
||||||
|
accNetResourcePeer1ID = "peer1"
|
||||||
|
accNetResourcePeer2ID = "peer2"
|
||||||
|
accNetResourceRouter1ID = "router1"
|
||||||
|
accNetResource1ID = "resource1ID"
|
||||||
|
accNetResourceRestrictPostureCheckID = "restrictPostureCheck"
|
||||||
|
accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck"
|
||||||
|
accNetResourceLockedPostureCheckID = "lockedPostureCheck"
|
||||||
|
accNetResourceLinuxPostureCheckID = "linuxPostureCheck"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
accNetResourcePeer1IP = net.IP{192, 168, 1, 1}
|
||||||
|
accNetResourcePeer2IP = net.IP{192, 168, 1, 2}
|
||||||
|
accNetResourceRouter1IP = net.IP{192, 168, 1, 3}
|
||||||
|
accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}}
|
||||||
|
)
|
||||||
|
|
||||||
|
func getBasicAccountsWithResource() *Account {
|
||||||
|
return &Account{
|
||||||
|
Id: accID,
|
||||||
|
Peers: map[string]*nbpeer.Peer{
|
||||||
|
accNetResourcePeer1ID: {
|
||||||
|
ID: accNetResourcePeer1ID,
|
||||||
|
AccountID: accID,
|
||||||
|
Key: "peer1Key",
|
||||||
|
IP: accNetResourcePeer1IP,
|
||||||
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
|
GoOS: "linux",
|
||||||
|
WtVersion: "0.35.1",
|
||||||
|
KernelVersion: "4.4.0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
accNetResourcePeer2ID: {
|
||||||
|
ID: accNetResourcePeer2ID,
|
||||||
|
AccountID: accID,
|
||||||
|
Key: "peer2Key",
|
||||||
|
IP: accNetResourcePeer2IP,
|
||||||
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
|
GoOS: "windows",
|
||||||
|
WtVersion: "0.34.1",
|
||||||
|
KernelVersion: "4.4.0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
accNetResourceRouter1ID: {
|
||||||
|
ID: accNetResourceRouter1ID,
|
||||||
|
AccountID: accID,
|
||||||
|
Key: "router1Key",
|
||||||
|
IP: accNetResourceRouter1IP,
|
||||||
|
Meta: nbpeer.PeerSystemMeta{
|
||||||
|
GoOS: "linux",
|
||||||
|
WtVersion: "0.35.1",
|
||||||
|
KernelVersion: "4.4.0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Groups: map[string]*Group{
|
||||||
|
group1ID: {
|
||||||
|
ID: group1ID,
|
||||||
|
Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Networks: []*networkTypes.Network{
|
||||||
|
{
|
||||||
|
ID: network1ID,
|
||||||
|
AccountID: accID,
|
||||||
|
Name: "network1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||||
|
{
|
||||||
|
ID: accNetResourceRouter1ID,
|
||||||
|
NetworkID: network1ID,
|
||||||
|
AccountID: accID,
|
||||||
|
Peer: accNetResourceRouter1ID,
|
||||||
|
PeerGroups: []string{},
|
||||||
|
Masquerade: false,
|
||||||
|
Metric: 100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
NetworkResources: []*resourceTypes.NetworkResource{
|
||||||
|
{
|
||||||
|
ID: accNetResource1ID,
|
||||||
|
AccountID: accID,
|
||||||
|
NetworkID: network1ID,
|
||||||
|
Address: "10.10.10.0/24",
|
||||||
|
Prefix: netip.MustParsePrefix("10.10.10.0/24"),
|
||||||
|
Type: resourceTypes.NetworkResourceType("subnet"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Policies: []*Policy{
|
||||||
|
{
|
||||||
|
ID: "policy1ID",
|
||||||
|
AccountID: accID,
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "rule1ID",
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{group1ID},
|
||||||
|
DestinationResource: Resource{
|
||||||
|
ID: accNetResource1ID,
|
||||||
|
Type: "Host",
|
||||||
|
},
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
Ports: []string{"80"},
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SourcePostureChecks: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
PostureChecks: []*posture.Checks{
|
||||||
|
{
|
||||||
|
ID: accNetResourceRestrictPostureCheckID,
|
||||||
|
Name: accNetResourceRestrictPostureCheckID,
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
|
MinVersion: "0.35.0",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: accNetResourceRelaxedPostureCheckID,
|
||||||
|
Name: accNetResourceRelaxedPostureCheckID,
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
|
MinVersion: "0.0.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: accNetResourceLockedPostureCheckID,
|
||||||
|
Name: accNetResourceLockedPostureCheckID,
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{
|
||||||
|
MinVersion: "7.7.7",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: accNetResourceLinuxPostureCheckID,
|
||||||
|
Name: accNetResourceLinuxPostureCheckID,
|
||||||
|
Checks: posture.ChecksDefinition{
|
||||||
|
OSVersionCheck: &posture.OSVersionCheck{
|
||||||
|
Linux: &posture.MinKernelVersionCheck{
|
||||||
|
MinKernelVersion: "0.0.0"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) {
|
||||||
|
account := getBasicAccountsWithResource()
|
||||||
|
|
||||||
|
// all peers should match the policy
|
||||||
|
|
||||||
|
// validate for peer1
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.False(t, isRouter, "expected router status")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate for peer2
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.False(t, isRouter, "expected router status")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate routes for router1
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.True(t, isRouter, "should be router")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
||||||
|
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||||
|
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate rules for router1
|
||||||
|
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||||
|
assert.Len(t, rules, 1, "expected rules count don't match")
|
||||||
|
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||||
|
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||||
|
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||||
|
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||||
|
}
|
||||||
|
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||||
|
t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) {
|
||||||
|
account := getBasicAccountsWithResource()
|
||||||
|
|
||||||
|
// should allow peer1 to match the policy
|
||||||
|
policy := account.Policies[0]
|
||||||
|
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
|
||||||
|
|
||||||
|
// validate for peer1
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.False(t, isRouter, "expected router status")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate for peer2
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.False(t, isRouter, "expected router status")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate routes for router1
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.True(t, isRouter, "should be router")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
|
||||||
|
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate rules for router1
|
||||||
|
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||||
|
assert.Len(t, rules, 1, "expected rules count don't match")
|
||||||
|
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||||
|
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||||
|
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||||
|
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||||
|
}
|
||||||
|
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||||
|
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) {
|
||||||
|
account := getBasicAccountsWithResource()
|
||||||
|
|
||||||
|
// should not match any peer
|
||||||
|
policy := account.Policies[0]
|
||||||
|
policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID}
|
||||||
|
|
||||||
|
// validate for peer1
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.False(t, isRouter, "expected router status")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate for peer2
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.False(t, isRouter, "expected router status")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate routes for router1
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.True(t, isRouter, "should be router")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate rules for router1
|
||||||
|
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||||
|
assert.Len(t, rules, 0, "expected rules count don't match")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) {
|
||||||
|
account := getBasicAccountsWithResource()
|
||||||
|
|
||||||
|
// should allow peer1 to match the policy
|
||||||
|
policy := account.Policies[0]
|
||||||
|
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
|
||||||
|
|
||||||
|
// should allow peer1 and peer2 to match the policy
|
||||||
|
newPolicy := &Policy{
|
||||||
|
ID: "policy2ID",
|
||||||
|
AccountID: accID,
|
||||||
|
Enabled: true,
|
||||||
|
Rules: []*PolicyRule{
|
||||||
|
{
|
||||||
|
ID: "policy2ID",
|
||||||
|
Enabled: true,
|
||||||
|
Sources: []string{group1ID},
|
||||||
|
DestinationResource: Resource{
|
||||||
|
ID: accNetResource1ID,
|
||||||
|
Type: "Host",
|
||||||
|
},
|
||||||
|
Protocol: PolicyRuleProtocolTCP,
|
||||||
|
Ports: []string{"22"},
|
||||||
|
Action: PolicyTrafficActionAccept,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID},
|
||||||
|
}
|
||||||
|
|
||||||
|
account.Policies = append(account.Policies, newPolicy)
|
||||||
|
|
||||||
|
// validate for peer1
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.False(t, isRouter, "expected router status")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate for peer2
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.False(t, isRouter, "expected router status")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate routes for router1
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.True(t, isRouter, "should be router")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
||||||
|
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||||
|
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate rules for router1
|
||||||
|
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||||
|
assert.Len(t, rules, 2, "expected rules count don't match")
|
||||||
|
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||||
|
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||||
|
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||||
|
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||||
|
}
|
||||||
|
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||||
|
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, uint16(22), rules[1].Port, "should have port 22")
|
||||||
|
assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp")
|
||||||
|
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||||
|
t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String())
|
||||||
|
}
|
||||||
|
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||||
|
t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) {
|
||||||
|
account := getBasicAccountsWithResource()
|
||||||
|
|
||||||
|
// two posture checks should match only the peers that match both checks
|
||||||
|
policy := account.Policies[0]
|
||||||
|
policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID}
|
||||||
|
|
||||||
|
// validate for peer1
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.False(t, isRouter, "expected router status")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate for peer2
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.False(t, isRouter, "expected router status")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate routes for router1
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||||
|
assert.True(t, isRouter, "should be router")
|
||||||
|
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||||
|
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
|
||||||
|
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||||
|
|
||||||
|
// validate rules for router1
|
||||||
|
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||||
|
assert.Len(t, rules, 1, "expected rules count don't match")
|
||||||
|
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||||
|
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||||
|
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||||
|
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||||
|
}
|
||||||
|
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||||
|
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -35,6 +35,15 @@ type FirewallRule struct {
|
|||||||
Port string
|
Port string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsEqual checks if two firewall rules are equal.
|
||||||
|
func (r *FirewallRule) IsEqual(other *FirewallRule) bool {
|
||||||
|
return r.PeerIP == other.PeerIP &&
|
||||||
|
r.Direction == other.Direction &&
|
||||||
|
r.Action == other.Action &&
|
||||||
|
r.Protocol == other.Protocol &&
|
||||||
|
r.Port == other.Port
|
||||||
|
}
|
||||||
|
|
||||||
// generateRouteFirewallRules generates a list of firewall rules for a given route.
|
// generateRouteFirewallRules generates a list of firewall rules for a given route.
|
||||||
func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
|
func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
|
||||||
rulesExists := make(map[string]struct{})
|
rulesExists := make(map[string]struct{})
|
||||||
|
@ -117,9 +117,20 @@ func (p *Policy) RuleGroups() []string {
|
|||||||
|
|
||||||
// SourceGroups returns a slice of all unique source groups referenced in the policy's rules.
|
// SourceGroups returns a slice of all unique source groups referenced in the policy's rules.
|
||||||
func (p *Policy) SourceGroups() []string {
|
func (p *Policy) SourceGroups() []string {
|
||||||
groups := make([]string, 0)
|
if len(p.Rules) == 1 {
|
||||||
|
return p.Rules[0].Sources
|
||||||
|
}
|
||||||
|
groups := make(map[string]struct{}, len(p.Rules))
|
||||||
for _, rule := range p.Rules {
|
for _, rule := range p.Rules {
|
||||||
groups = append(groups, rule.Sources...)
|
for _, source := range rule.Sources {
|
||||||
|
groups[source] = struct{}{}
|
||||||
}
|
}
|
||||||
return groups
|
}
|
||||||
|
|
||||||
|
groupIDs := make([]string, 0, len(groups))
|
||||||
|
for groupID := range groups {
|
||||||
|
groupIDs = append(groupIDs, groupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return groupIDs
|
||||||
}
|
}
|
||||||
|
@ -95,6 +95,7 @@ type Route struct {
|
|||||||
NetID NetID
|
NetID NetID
|
||||||
Description string
|
Description string
|
||||||
Peer string
|
Peer string
|
||||||
|
PeerID string `gorm:"-"`
|
||||||
PeerGroups []string `gorm:"serializer:json"`
|
PeerGroups []string `gorm:"serializer:json"`
|
||||||
NetworkType NetworkType
|
NetworkType NetworkType
|
||||||
Masquerade bool
|
Masquerade bool
|
||||||
@ -120,6 +121,7 @@ func (r *Route) Copy() *Route {
|
|||||||
KeepRoute: r.KeepRoute,
|
KeepRoute: r.KeepRoute,
|
||||||
NetworkType: r.NetworkType,
|
NetworkType: r.NetworkType,
|
||||||
Peer: r.Peer,
|
Peer: r.Peer,
|
||||||
|
PeerID: r.PeerID,
|
||||||
PeerGroups: slices.Clone(r.PeerGroups),
|
PeerGroups: slices.Clone(r.PeerGroups),
|
||||||
Metric: r.Metric,
|
Metric: r.Metric,
|
||||||
Masquerade: r.Masquerade,
|
Masquerade: r.Masquerade,
|
||||||
@ -146,6 +148,7 @@ func (r *Route) IsEqual(other *Route) bool {
|
|||||||
other.KeepRoute == r.KeepRoute &&
|
other.KeepRoute == r.KeepRoute &&
|
||||||
other.NetworkType == r.NetworkType &&
|
other.NetworkType == r.NetworkType &&
|
||||||
other.Peer == r.Peer &&
|
other.Peer == r.Peer &&
|
||||||
|
other.PeerID == r.PeerID &&
|
||||||
other.Metric == r.Metric &&
|
other.Metric == r.Metric &&
|
||||||
other.Masquerade == r.Masquerade &&
|
other.Masquerade == r.Masquerade &&
|
||||||
other.Enabled == r.Enabled &&
|
other.Enabled == r.Enabled &&
|
||||||
|
Loading…
x
Reference in New Issue
Block a user