Merge branch 'main' into feature/mysql-support

This commit is contained in:
bcmmbaga 2025-01-02 14:54:14 +03:00
commit a3fe7bea38
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
30 changed files with 1001 additions and 235 deletions

View File

@ -19,7 +19,7 @@ jobs:
- name: codespell - name: codespell
uses: codespell-project/actions-codespell@v2 uses: codespell-project/actions-codespell@v2
with: with:
ignore_words_list: erro,clienta,hastable,iif,groupd ignore_words_list: erro,clienta,hastable,iif,groupd,testin
skip: go.mod,go.sum skip: go.mod,go.sum
only_warn: 1 only_warn: 1
golangci: golangci:

View File

@ -197,7 +197,7 @@ func (m *Manager) AllowNetbird() error {
} }
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(
net.ParseIP("0.0.0.0"), net.IP{0, 0, 0, 0},
"all", "all",
nil, nil,
nil, nil,

View File

@ -10,7 +10,6 @@ import (
// BaseConnTrack provides common fields and locking for all connection types // BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct { type BaseConnTrack struct {
sync.RWMutex
SourceIP net.IP SourceIP net.IP
DestIP net.IP DestIP net.IP
SourcePort uint16 SourcePort uint16

View File

@ -62,6 +62,7 @@ type TCPConnKey struct {
type TCPConnTrack struct { type TCPConnTrack struct {
BaseConnTrack BaseConnTrack
State TCPState State TCPState
sync.RWMutex
} }
// TCPTracker manages TCP connection states // TCPTracker manages TCP connection states
@ -131,36 +132,8 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false return false
} }
// Handle new SYN packets
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.Lock()
if _, exists := t.connections[key]; !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, dstIP)
copyIP(dstIPCopy, srcIP)
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: dstPort,
DestPort: srcPort,
},
State: TCPStateSynReceived,
}
conn.lastSeen.Store(time.Now().UnixNano())
conn.established.Store(false)
t.connections[key] = conn
}
t.mutex.Unlock()
return true
}
// Look up existing connection
key := makeConnKey(dstIP, srcIP, dstPort, srcPort) key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
@ -172,8 +145,7 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
// Handle RST packets // Handle RST packets
if flags&TCPRst != 0 { if flags&TCPRst != 0 {
conn.Lock() conn.Lock()
isEstablished := conn.IsEstablished() if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
if isEstablished || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed conn.State = TCPStateClosed
conn.SetEstablished(false) conn.SetEstablished(false)
conn.Unlock() conn.Unlock()
@ -183,7 +155,6 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false return false
} }
// Update state
conn.Lock() conn.Lock()
t.updateState(conn, flags, false) t.updateState(conn, flags, false)
conn.UpdateLastSeen() conn.UpdateLastSeen()
@ -306,6 +277,11 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return flags&TCPFin != 0 || flags&TCPAck != 0 return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateLastAck: case TCPStateLastAck:
return flags&TCPAck != 0 return flags&TCPAck != 0
case TCPStateClosed:
// Accept retransmitted ACKs in closed state
// This is important because the final ACK might be lost
// and the peer will retransmit their FIN-ACK
return flags&TCPAck != 0
} }
return false return false
} }

View File

@ -125,11 +125,8 @@ func TestTCPStateMachine(t *testing.T) {
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
require.True(t, valid, "RST should be allowed for established connection") require.True(t, valid, "RST should be allowed for established connection")
// Verify connection is closed // Connection is logically dead but we don't enforce blocking subsequent packets
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) // The connection will be cleaned up by timeout
t.Helper()
require.False(t, valid, "Data should be blocked after RST")
}, },
}, },
{ {

View File

@ -68,17 +68,16 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
pattern = strings.ToLower(dns.Fqdn(pattern))
origPattern := pattern origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.") isWildcard := strings.HasPrefix(pattern, "*.")
if isWildcard { if isWildcard {
pattern = pattern[2:] pattern = pattern[2:]
} }
pattern = dns.Fqdn(pattern)
origPattern = dns.Fqdn(origPattern)
// First remove any existing handler with same original pattern and priority // First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority { if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil { if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop() c.handlers[i].StopHandler.stop()
} }
@ -126,10 +125,10 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
pattern = dns.Fqdn(pattern) pattern = dns.Fqdn(pattern)
// Find and remove handlers matching both original pattern and priority // Find and remove handlers matching both original pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- { for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i] entry := c.handlers[i]
if entry.OrigPattern == pattern && entry.Priority == priority { if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
if entry.StopHandler != nil { if entry.StopHandler != nil {
entry.StopHandler.stop() entry.StopHandler.stop()
} }
@ -144,9 +143,9 @@ func (c *HandlerChain) HasHandlers(pattern string) bool {
c.mu.RLock() c.mu.RLock()
defer c.mu.RUnlock() defer c.mu.RUnlock()
pattern = dns.Fqdn(pattern) pattern = strings.ToLower(dns.Fqdn(pattern))
for _, entry := range c.handlers { for _, entry := range c.handlers {
if entry.Pattern == pattern { if strings.EqualFold(entry.Pattern, pattern) {
return true return true
} }
} }
@ -158,7 +157,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
qname := r.Question[0].Name qname := strings.ToLower(r.Question[0].Name)
log.Tracef("handling DNS request for domain=%s", qname) log.Tracef("handling DNS request for domain=%s", qname)
c.mu.RLock() c.mu.RLock()
@ -187,9 +186,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants subdomain matching, allow suffix match // If handler wants subdomain matching, allow suffix match
// Otherwise require exact match // Otherwise require exact match
if entry.MatchSubdomains { if entry.MatchSubdomains {
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern) matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else { } else {
matched = qname == entry.Pattern matched = strings.EqualFold(qname, entry.Pattern)
} }
} }

View File

@ -507,5 +507,173 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 4: Remove last handler // Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault) chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
assert.False(t, chain.HasHandlers(testDomain)) assert.False(t, chain.HasHandlers(testDomain))
} }
func TestHandlerChain_CaseSensitivity(t *testing.T) {
tests := []struct {
name string
scenario string
addHandlers []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}
query string
expectedCalls int
}{
{
name: "case insensitive exact match",
scenario: "handler registered lowercase, query uppercase",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"example.com.", nbdns.PriorityDefault, false, true},
},
query: "EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "case insensitive wildcard match",
scenario: "handler registered mixed case wildcard, query different case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"*.Example.Com.", nbdns.PriorityDefault, false, true},
},
query: "sub.EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "multiple handlers different case same domain",
scenario: "second handler should replace first despite case difference",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityDefault, false, true},
},
query: "ExAmPlE.cOm.",
expectedCalls: 1,
},
{
name: "subdomain matching case insensitive",
scenario: "handler with MatchSubdomains true should match regardless of case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"example.com.", nbdns.PriorityDefault, true, true},
},
query: "SUB.EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "root zone case insensitive",
scenario: "root zone handler should match regardless of case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{".", nbdns.PriorityDefault, false, true},
},
query: "EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "multiple handlers different priority",
scenario: "should call higher priority handler despite case differences",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityMatchDomain, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
},
query: "example.com.",
expectedCalls: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlerCalls := make(map[string]bool) // track which patterns were called
// Add handlers according to test case
for _, h := range tt.addHandlers {
var handler dns.Handler
pattern := h.pattern // capture pattern for closure
if h.subdomains {
subHandler := &nbdns.MockSubdomainHandler{
Subdomains: true,
}
if h.shouldMatch {
subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
handlerCalls[pattern] = true
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
handler = subHandler
} else {
mockHandler := &nbdns.MockHandler{}
if h.shouldMatch {
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
handlerCalls[pattern] = true
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
handler = mockHandler
}
chain.AddHandler(pattern, handler, h.priority, nil)
}
// Execute request
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
chain.ServeDNS(&mockResponseWriter{}, r)
// Verify each handler was called exactly as expected
for _, h := range tt.addHandlers {
wasCalled := handlerCalls[h.pattern]
assert.Equal(t, h.shouldMatch, wasCalled,
"Handler for pattern %q was %s when it should%s have been",
h.pattern,
map[bool]string{true: "called", false: "not called"}[wasCalled],
map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch])
}
// Verify total number of calls
assert.Equal(t, tt.expectedCalls, len(handlerCalls),
"Wrong number of total handler calls")
})
}
}

View File

@ -83,7 +83,7 @@ func (h *Manager) allowDNSFirewall() error {
IsRange: false, IsRange: false,
Values: []int{ListenPort}, Values: []int{ListenPort},
} }
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
if err != nil { if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err) log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err return err

View File

@ -406,13 +406,9 @@ func (e *Engine) Start() error {
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
if err != nil { if err != nil {
log.Errorf("failed creating firewall manager: %s", err) log.Errorf("failed creating firewall manager: %s", err)
} } else if e.firewall != nil {
if err := e.initFirewall(err); err != nil {
if e.firewall != nil && e.firewall.IsServerRouteSupported() { return err
err = e.routeManager.EnableServerRouter(e.firewall)
if err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
} }
} }
@ -455,6 +451,41 @@ func (e *Engine) Start() error {
return nil return nil
} }
func (e *Engine) initFirewall(error) error {
if e.firewall.IsServerRouteSupported() {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
}
}
if e.rpManager == nil || !e.config.RosenpassEnabled {
return nil
}
rosenpassPort := e.rpManager.GetAddress().Port
port := manager.Port{Values: []int{rosenpassPort}}
// this rule is static and will be torn down on engine down by the firewall manager
if _, err := e.firewall.AddPeerFiltering(
net.IP{0, 0, 0, 0},
manager.ProtocolUDP,
nil,
&port,
manager.RuleDirectionIN,
manager.ActionAccept,
"",
"",
); err != nil {
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
return nil
}
log.Infof("rosenpass interface traffic allowed on port %d", rosenpassPort)
return nil
}
// modifyPeers updates peers that have been modified (e.g. IP address has been changed). // modifyPeers updates peers that have been modified (e.g. IP address has been changed).
// It closes the existing connection, removes it from the peerConns map, and creates a new one. // It closes the existing connection, removes it from the peerConns map, and creates a new one.
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {

View File

@ -139,10 +139,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
if s.logFile == "console" {
return nil, fmt.Errorf("log file is set to console, cannot create debug bundle")
}
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip") bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
if err != nil { if err != nil {
return nil, fmt.Errorf("create zip file: %w", err) return nil, fmt.Errorf("create zip file: %w", err)
@ -185,17 +181,7 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
} }
if req.GetSystemInfo() { if req.GetSystemInfo() {
if err := s.addRoutes(req, anonymizer, archive); err != nil { s.addSystemInfo(req, anonymizer, archive)
log.Errorf("Failed to add routes to debug bundle: %v", err)
}
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add interfaces to debug bundle: %v", err)
}
if err := s.addFirewallRules(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add firewall rules to debug bundle: %v", err)
}
} }
if err := s.addNetworkMap(req, anonymizer, archive); err != nil { if err := s.addNetworkMap(req, anonymizer, archive); err != nil {
@ -206,8 +192,10 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
log.Errorf("Failed to add state file to debug bundle: %v", err) log.Errorf("Failed to add state file to debug bundle: %v", err)
} }
if err := s.addLogfile(req, anonymizer, archive); err != nil { if s.logFile != "console" {
return fmt.Errorf("add log file: %w", err) if err := s.addLogfile(req, anonymizer, archive); err != nil {
return fmt.Errorf("add log file: %w", err)
}
} }
if err := archive.Close(); err != nil { if err := archive.Close(); err != nil {
@ -216,6 +204,20 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
return nil return nil
} }
func (s *Server) addSystemInfo(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) {
if err := s.addRoutes(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add routes to debug bundle: %v", err)
}
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add interfaces to debug bundle: %v", err)
}
if err := s.addFirewallRules(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add firewall rules to debug bundle: %v", err)
}
}
func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error { func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error {
if req.GetAnonymize() { if req.GetAnonymize() {
readmeReader := strings.NewReader(readmeContent) readmeReader := strings.NewReader(readmeContent)

View File

@ -1251,6 +1251,12 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled. // and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
if claim, exists := claims.Raw[jwtclaims.IsToken]; exists {
if isToken, ok := claim.(bool); ok && isToken {
return nil
}
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err

View File

@ -2730,6 +2730,19 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
t.Run("skip sync for token auth type", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
})
t.Run("empty jwt groups", func(t *testing.T) { t.Run("empty jwt groups", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{ claims := jwtclaims.AuthorizationClaims{
UserId: "user1", UserId: "user1",
@ -2823,7 +2836,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.Len(t, user.AutoGroups, 1, "new group should be added") assert.Len(t, user.AutoGroups, 1, "new group should be added")
}) })
t.Run("remove all JWT groups", func(t *testing.T) { t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{ claims := jwtclaims.AuthorizationClaims{
UserId: "user1", UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{}}, Raw: jwt.MapClaims{"groups": []interface{}{}},
@ -2834,7 +2847,20 @@ func TestAccount_SetJWTGroups(t *testing.T) {
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
})
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user2",
Raw: jwt.MapClaims{},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
}) })
} }
@ -3038,9 +3064,9 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
minMsPerOpCICD float64 minMsPerOpCICD float64
maxMsPerOpCICD float64 maxMsPerOpCICD float64
}{ }{
{"Small", 50, 5, 1, 3, 3, 10}, {"Small", 50, 5, 1, 3, 3, 11},
{"Medium", 500, 100, 7, 13, 10, 70}, {"Medium", 500, 100, 7, 13, 10, 70},
{"Large", 5000, 200, 65, 80, 60, 200}, {"Large", 5000, 200, 65, 80, 60, 220},
{"Small single", 50, 10, 1, 3, 3, 70}, {"Small single", 50, 10, 1, 3, 3, 70},
{"Medium single", 500, 10, 7, 13, 10, 26}, {"Medium single", 500, 10, 7, 13, 10, 26},
{"Large 5", 5000, 15, 65, 80, 60, 200}, {"Large 5", 5000, 15, 65, 80, 60, 200},
@ -3180,7 +3206,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
maxMsPerOpCICD float64 maxMsPerOpCICD float64
}{ }{
{"Small", 50, 5, 107, 120, 107, 160}, {"Small", 50, 5, 107, 120, 107, 160},
{"Medium", 500, 100, 105, 140, 105, 190}, {"Medium", 500, 100, 105, 140, 105, 220},
{"Large", 5000, 200, 180, 220, 180, 350}, {"Large", 5000, 200, 180, 220, 180, 350},
{"Small single", 50, 10, 107, 120, 105, 160}, {"Small single", 50, 10, 107, 120, 105, 160},
{"Medium single", 500, 10, 105, 140, 105, 170}, {"Medium single", 500, 10, 105, 140, 105, 170},

View File

@ -474,6 +474,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
} }
if len(group.Resources) > 0 {
return &GroupLinkError{"network resource", group.Resources[0].ID}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)} return &GroupLinkError{"route", string(linkedRoute.NetID)}
} }
@ -529,7 +533,10 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
} }
for _, r := range routes { for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { isLinked := slices.Contains(r.Groups, groupID) ||
slices.Contains(r.PeerGroups, groupID) ||
slices.Contains(r.AccessControlGroups, groupID)
if isLinked {
return true, r return true, r
} }
} }

View File

@ -725,10 +725,6 @@ components:
PolicyRuleMinimum: PolicyRuleMinimum:
type: object type: object
properties: properties:
id:
description: Policy rule ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
name: name:
description: Policy rule name identifier description: Policy rule name identifier
type: string type: string
@ -790,6 +786,31 @@ components:
- end - end
PolicyRuleUpdate: PolicyRuleUpdate:
allOf:
- $ref: '#/components/schemas/PolicyRuleMinimum'
- type: object
properties:
id:
description: Policy rule ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
sources:
description: Policy rule source group IDs
type: array
items:
type: string
example: "ch8i4ug6lnn4g9hqv797"
destinations:
description: Policy rule destination group IDs
type: array
items:
type: string
example: "ch8i4ug6lnn4g9h7v7m0"
required:
- sources
- destinations
PolicyRuleCreate:
allOf: allOf:
- $ref: '#/components/schemas/PolicyRuleMinimum' - $ref: '#/components/schemas/PolicyRuleMinimum'
- type: object - type: object
@ -817,6 +838,10 @@ components:
- $ref: '#/components/schemas/PolicyRuleMinimum' - $ref: '#/components/schemas/PolicyRuleMinimum'
- type: object - type: object
properties: properties:
id:
description: Policy rule ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
sources: sources:
description: Policy rule source group IDs description: Policy rule source group IDs
type: array type: array
@ -836,10 +861,6 @@ components:
PolicyMinimum: PolicyMinimum:
type: object type: object
properties: properties:
id:
description: Policy ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
name: name:
description: Policy name identifier description: Policy name identifier
type: string type: string
@ -854,7 +875,6 @@ components:
example: true example: true
required: required:
- name - name
- description
- enabled - enabled
PolicyUpdate: PolicyUpdate:
allOf: allOf:
@ -874,11 +894,33 @@ components:
$ref: '#/components/schemas/PolicyRuleUpdate' $ref: '#/components/schemas/PolicyRuleUpdate'
required: required:
- rules - rules
PolicyCreate:
allOf:
- $ref: '#/components/schemas/PolicyMinimum'
- type: object
properties:
source_posture_checks:
description: Posture checks ID's applied to policy source groups
type: array
items:
type: string
example: "chacdk86lnnboviihd70"
rules:
description: Policy rule object for policy UI editor
type: array
items:
$ref: '#/components/schemas/PolicyRuleUpdate'
required:
- rules
Policy: Policy:
allOf: allOf:
- $ref: '#/components/schemas/PolicyMinimum' - $ref: '#/components/schemas/PolicyMinimum'
- type: object - type: object
properties: properties:
id:
description: Policy ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
source_posture_checks: source_posture_checks:
description: Posture checks ID's applied to policy source groups description: Posture checks ID's applied to policy source groups
type: array type: array
@ -2463,7 +2505,7 @@ paths:
content: content:
'application/json': 'application/json':
schema: schema:
$ref: '#/components/schemas/PolicyUpdate' $ref: '#/components/schemas/PolicyCreate'
responses: responses:
'200': '200':
description: A Policy object description: A Policy object

View File

@ -879,7 +879,7 @@ type PersonalAccessTokenRequest struct {
// Policy defines model for Policy. // Policy defines model for Policy.
type Policy struct { type Policy struct {
// Description Policy friendly description // Description Policy friendly description
Description string `json:"description"` Description *string `json:"description,omitempty"`
// Enabled Policy status // Enabled Policy status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
@ -897,16 +897,31 @@ type Policy struct {
SourcePostureChecks []string `json:"source_posture_checks"` SourcePostureChecks []string `json:"source_posture_checks"`
} }
// PolicyMinimum defines model for PolicyMinimum. // PolicyCreate defines model for PolicyCreate.
type PolicyMinimum struct { type PolicyCreate struct {
// Description Policy friendly description // Description Policy friendly description
Description string `json:"description"` Description *string `json:"description,omitempty"`
// Enabled Policy status // Enabled Policy status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Id Policy ID // Name Policy name identifier
Id *string `json:"id,omitempty"` Name string `json:"name"`
// Rules Policy rule object for policy UI editor
Rules []PolicyRuleUpdate `json:"rules"`
// SourcePostureChecks Posture checks ID's applied to policy source groups
SourcePostureChecks *[]string `json:"source_posture_checks,omitempty"`
}
// PolicyMinimum defines model for PolicyMinimum.
type PolicyMinimum struct {
// Description Policy friendly description
Description *string `json:"description,omitempty"`
// Enabled Policy status
Enabled bool `json:"enabled"`
// Name Policy name identifier // Name Policy name identifier
Name string `json:"name"` Name string `json:"name"`
@ -970,9 +985,6 @@ type PolicyRuleMinimum struct {
// Enabled Policy rule status // Enabled Policy rule status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Id Policy rule ID
Id *string `json:"id,omitempty"`
// Name Policy rule name identifier // Name Policy rule name identifier
Name string `json:"name"` Name string `json:"name"`
@ -1039,14 +1051,11 @@ type PolicyRuleUpdateProtocol string
// PolicyUpdate defines model for PolicyUpdate. // PolicyUpdate defines model for PolicyUpdate.
type PolicyUpdate struct { type PolicyUpdate struct {
// Description Policy friendly description // Description Policy friendly description
Description string `json:"description"` Description *string `json:"description,omitempty"`
// Enabled Policy status // Enabled Policy status
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// Id Policy ID
Id *string `json:"id,omitempty"`
// Name Policy name identifier // Name Policy name identifier
Name string `json:"name"` Name string `json:"name"`
@ -1473,7 +1482,7 @@ type PutApiPeersPeerIdJSONRequestBody = PeerRequest
type PostApiPoliciesJSONRequestBody = PolicyUpdate type PostApiPoliciesJSONRequestBody = PolicyUpdate
// PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType. // PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType.
type PutApiPoliciesPolicyIdJSONRequestBody = PolicyUpdate type PutApiPoliciesPolicyIdJSONRequestBody = PolicyCreate
// PostApiPostureChecksJSONRequestBody defines body for PostApiPostureChecks for application/json ContentType. // PostApiPostureChecksJSONRequestBody defines body for PostApiPostureChecks for application/json ContentType.
type PostApiPostureChecksJSONRequestBody = PostureCheckUpdate type PostApiPostureChecksJSONRequestBody = PostureCheckUpdate

View File

@ -133,16 +133,21 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
return return
} }
description := ""
if req.Description != nil {
description = *req.Description
}
policy := &types.Policy{ policy := &types.Policy{
ID: policyID, ID: policyID,
AccountID: accountID, AccountID: accountID,
Name: req.Name, Name: req.Name,
Enabled: req.Enabled, Enabled: req.Enabled,
Description: req.Description, Description: description,
} }
for _, rule := range req.Rules { for _, rule := range req.Rules {
var ruleID string var ruleID string
if rule.Id != nil { if rule.Id != nil && policyID != "" {
ruleID = *rule.Id ruleID = *rule.Id
} }
@ -370,7 +375,7 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
ap := &api.Policy{ ap := &api.Policy{
Id: &policy.ID, Id: &policy.ID,
Name: policy.Name, Name: policy.Name,
Description: policy.Description, Description: &policy.Description,
Enabled: policy.Enabled, Enabled: policy.Enabled,
SourcePostureChecks: policy.SourcePostureChecks, SourcePostureChecks: policy.SourcePostureChecks,
} }

View File

@ -154,6 +154,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
func TestPoliciesWritePolicy(t *testing.T) { func TestPoliciesWritePolicy(t *testing.T) {
str := func(s string) *string { return &s } str := func(s string) *string { return &s }
emptyString := ""
tt := []struct { tt := []struct {
name string name string
expectedStatus int expectedStatus int
@ -184,8 +185,9 @@ func TestPoliciesWritePolicy(t *testing.T) {
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedBody: true, expectedBody: true,
expectedPolicy: &api.Policy{ expectedPolicy: &api.Policy{
Id: str("id-was-set"), Id: str("id-was-set"),
Name: "Default POSTed Policy", Name: "Default POSTed Policy",
Description: &emptyString,
Rules: []api.PolicyRule{ Rules: []api.PolicyRule{
{ {
Id: str("id-was-set"), Id: str("id-was-set"),
@ -232,8 +234,9 @@ func TestPoliciesWritePolicy(t *testing.T) {
expectedStatus: http.StatusOK, expectedStatus: http.StatusOK,
expectedBody: true, expectedBody: true,
expectedPolicy: &api.Policy{ expectedPolicy: &api.Policy{
Id: str("id-existed"), Id: str("id-existed"),
Name: "Default POSTed Policy", Name: "Default POSTed Policy",
Description: &emptyString,
Rules: []api.PolicyRule{ Rules: []api.PolicyRule{
{ {
Id: str("id-existed"), Id: str("id-existed"),

View File

@ -175,6 +175,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
claimMaps[jwtclaims.IsToken] = true
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information. // Update the current request with the new context information.

View File

@ -22,6 +22,8 @@ const (
LastLoginSuffix = "nb_last_login" LastLoginSuffix = "nb_last_login"
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation // Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
Invited = "nb_invited" Invited = "nb_invited"
// IsToken claim indicates that auth type from the user is a token
IsToken = "is_token"
) )
// ExtractClaims Extract function type // ExtractClaims Extract function type

View File

@ -195,6 +195,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
groups int groups int
routes int routes int
routesWithRGGroups int routesWithRGGroups int
networks int
networkResources int
networkRouters int
networkRoutersWithPG int
nameservers int nameservers int
uiClient int uiClient int
version string version string
@ -219,6 +223,16 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
} }
groups += len(account.Groups) groups += len(account.Groups)
networks += len(account.Networks)
networkResources += len(account.NetworkResources)
networkRouters += len(account.NetworkRouters)
for _, router := range account.NetworkRouters {
if len(router.PeerGroups) > 0 {
networkRoutersWithPG++
}
}
routes += len(account.Routes) routes += len(account.Routes)
for _, route := range account.Routes { for _, route := range account.Routes {
if len(route.PeerGroups) > 0 { if len(route.PeerGroups) > 0 {
@ -312,6 +326,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks
metricsProperties["posture_checks"] = postureChecks metricsProperties["posture_checks"] = postureChecks
metricsProperties["groups"] = groups metricsProperties["groups"] = groups
metricsProperties["networks"] = networks
metricsProperties["network_resources"] = networkResources
metricsProperties["network_routers"] = networkRouters
metricsProperties["network_routers_with_groups"] = networkRoutersWithPG
metricsProperties["routes"] = routes metricsProperties["routes"] = routes
metricsProperties["routes_with_routing_groups"] = routesWithRGGroups metricsProperties["routes_with_routing_groups"] = routesWithRGGroups
metricsProperties["nameservers"] = nameservers metricsProperties["nameservers"] = nameservers

View File

@ -5,6 +5,9 @@ import (
"testing" "testing"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
@ -172,6 +175,31 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
}, },
}, },
}, },
Networks: []*networkTypes.Network{
{
ID: "1",
AccountID: "1",
},
},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: "1",
AccountID: "1",
NetworkID: "1",
},
{
ID: "2",
AccountID: "1",
NetworkID: "1",
},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{
ID: "1",
AccountID: "1",
NetworkID: "1",
},
},
}, },
} }
} }
@ -200,6 +228,15 @@ func TestGenerateProperties(t *testing.T) {
if properties["routes"] != 2 { if properties["routes"] != 2 {
t.Errorf("expected 2 routes, got %d", properties["routes"]) t.Errorf("expected 2 routes, got %d", properties["routes"])
} }
if properties["networks"] != 1 {
t.Errorf("expected 1 networks, got %d", properties["networks"])
}
if properties["network_resources"] != 2 {
t.Errorf("expected 2 network_resources, got %d", properties["network_resources"])
}
if properties["network_routers"] != 1 {
t.Errorf("expected 1 network_routers, got %d", properties["network_routers"])
}
if properties["rules"] != 4 { if properties["rules"] != 4 {
t.Errorf("expected 4 rules, got %d", properties["rules"]) t.Errorf("expected 4 rules, got %d", properties["rules"])
} }

View File

@ -111,6 +111,7 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network
NetID: route.NetID(n.Name), NetID: route.NetID(n.Name),
Description: n.Description, Description: n.Description,
Peer: peer.Key, Peer: peer.Key,
PeerID: peer.ID,
PeerGroups: nil, PeerGroups: nil,
Masquerade: router.Masquerade, Masquerade: router.Masquerade,
Metric: router.Metric, Metric: router.Metric,

View File

@ -932,11 +932,11 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}{ }{
{"Small", 50, 5, 90, 120, 90, 120}, {"Small", 50, 5, 90, 120, 90, 120},
{"Medium", 500, 100, 110, 150, 120, 260}, {"Medium", 500, 100, 110, 150, 120, 260},
{"Large", 5000, 200, 800, 1390, 2500, 4600}, {"Large", 5000, 200, 800, 1700, 2500, 5000},
{"Small single", 50, 10, 90, 120, 90, 120}, {"Small single", 50, 10, 90, 120, 90, 120},
{"Medium single", 500, 10, 110, 170, 120, 200}, {"Medium single", 500, 10, 110, 170, 120, 200},
{"Large 5", 5000, 15, 1300, 2100, 5000, 7000}, {"Large 5", 5000, 15, 1300, 2100, 4900, 7000},
{"Extra Large", 2000, 2000, 1300, 2100, 4000, 6000}, {"Extra Large", 2000, 2000, 1300, 2400, 4000, 6400},
} }
log.SetOutput(io.Discard) log.SetOutput(io.Discard)

View File

@ -74,6 +74,19 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
"peerH", "peerH",
}, },
}, },
"GroupWorkstations": {
ID: "GroupWorkstations",
Name: "GroupWorkstations",
Peers: []string{
"peerB",
"peerA",
"peerD",
"peerE",
"peerF",
"peerG",
"peerH",
},
},
"GroupSwarm": { "GroupSwarm": {
ID: "GroupSwarm", ID: "GroupSwarm",
Name: "swarm", Name: "swarm",
@ -127,7 +140,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
Action: types.PolicyTrafficActionAccept, Action: types.PolicyTrafficActionAccept,
Sources: []string{ Sources: []string{
"GroupSwarm", "GroupSwarm",
"GroupAll", "GroupWorkstations",
}, },
Destinations: []string{ Destinations: []string{
"GroupSwarm", "GroupSwarm",
@ -159,6 +172,8 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])
assert.Contains(t, peers, account.Peers["peerE"]) assert.Contains(t, peers, account.Peers["peerE"])
assert.Contains(t, peers, account.Peers["peerF"]) assert.Contains(t, peers, account.Peers["peerF"])
assert.Contains(t, peers, account.Peers["peerG"])
assert.Contains(t, peers, account.Peers["peerH"])
epectedFirewallRules := []*types.FirewallRule{ epectedFirewallRules := []*types.FirewallRule{
{ {
@ -189,21 +204,6 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
Protocol: "all", Protocol: "all",
Port: "", Port: "",
}, },
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
},
{ {
PeerIP: "100.65.62.5", PeerIP: "100.65.62.5",
Direction: types.FirewallRuleDirectionOUT, Direction: types.FirewallRuleDirectionOUT,
@ -280,10 +280,16 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(epectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) for _, rule := range firewallRules {
for i := range firewallRules { contains := false
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) for _, expectedRule := range epectedFirewallRules {
if rule.IsEqual(expectedRule) {
contains = true
break
}
}
assert.True(t, contains, "rule not found in expected rules %#v", rule)
} }
}) })
} }

View File

@ -364,7 +364,7 @@ func toProtocolRoute(route *route.Route) *proto.Route {
} }
func toProtocolRoutes(routes []*route.Route) []*proto.Route { func toProtocolRoutes(routes []*route.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0) protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes { for _, r := range routes {
protoRoutes = append(protoRoutes, toProtocolRoute(r)) protoRoutes = append(protoRoutes, toProtocolRoute(r))
} }

View File

@ -303,55 +303,47 @@ func (a *Account) GetPeerNetworkMap(
return nm return nm
} }
func (a *Account) addNetworksRoutingPeers(networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers []string) []*nbpeer.Peer { func (a *Account) addNetworksRoutingPeers(
missingPeers := map[string]struct{}{} networkResourcesRoutes []*route.Route,
for _, r := range networkResourcesRoutes { peer *nbpeer.Peer,
if r.Peer == peer.Key { peersToConnect []*nbpeer.Peer,
continue expiredPeers []*nbpeer.Peer,
} isRouter bool,
sourcePeers map[string]struct{},
) []*nbpeer.Peer {
missing := true networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
for _, p := range slices.Concat(peersToConnect, expiredPeers) { for _, r := range networkResourcesRoutes {
if r.Peer == p.Key { networkRoutesPeers[r.PeerID] = struct{}{}
missing = false
break
}
}
if missing {
missingPeers[r.Peer] = struct{}{}
}
} }
if isRouter { delete(sourcePeers, peer.ID)
for _, s := range sourcePeers {
if s == peer.ID {
continue
}
missing := true for _, existingPeer := range peersToConnect {
for _, p := range slices.Concat(peersToConnect, expiredPeers) { delete(sourcePeers, existingPeer.ID)
if s == p.ID { delete(networkRoutesPeers, existingPeer.ID)
missing = false }
break for _, expPeer := range expiredPeers {
} delete(sourcePeers, expPeer.ID)
} delete(networkRoutesPeers, expPeer.ID)
if missing { }
p, ok := a.Peers[s]
if ok { missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
missingPeers[p.Key] = struct{}{} if isRouter {
} for p := range sourcePeers {
} missingPeers[p] = struct{}{}
} }
} }
for p := range networkRoutesPeers {
missingPeers[p] = struct{}{}
}
for p := range missingPeers { for p := range missingPeers {
for _, p2 := range a.Peers { if missingPeer := a.Peers[p]; missingPeer != nil {
if p2.Key == p { peersToConnect = append(peersToConnect, missingPeer)
peersToConnect = append(peersToConnect, p2)
break
}
} }
} }
return peersToConnect return peersToConnect
} }
@ -1045,37 +1037,32 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs // for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups)
for _, g := range groups { filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
group, ok := a.Groups[g] for _, p := range uniquePeerIDs {
if !ok { peer, ok := a.Peers[p]
if !ok || peer == nil {
continue continue
} }
for _, p := range group.Peers { // validate the peer based on policy posture checks applied
peer, ok := a.Peers[p] isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !ok || peer == nil { if !isValid {
continue continue
}
// validate the peer based on policy posture checks applied
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peer)
} }
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peer)
} }
return filteredPeers, peerInGroups return filteredPeers, peerInGroups
} }
@ -1151,7 +1138,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
continue continue
} }
rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap)
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
fwRules = append(fwRules, rules...) fwRules = append(fwRules, rules...)
} }
@ -1159,7 +1146,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
return fwRules return fwRules
} }
func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
distPeersWithPolicy := make(map[string]struct{}) distPeersWithPolicy := make(map[string]struct{})
for _, id := range rule.Sources { for _, id := range rule.Sources {
group := a.Groups[id] group := a.Groups[id]
@ -1173,7 +1160,7 @@ func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeer
} }
_, distPeer := distributionPeers[pID] _, distPeer := distributionPeers[pID]
_, valid := validatedPeersMap[pID] _, valid := validatedPeersMap[pID]
if distPeer && valid { if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) {
distPeersWithPolicy[pID] = struct{}{} distPeersWithPolicy[pID] = struct{}{}
} }
} }
@ -1271,7 +1258,11 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer
distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups)
rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers)
routesFirewallRules = append(routesFirewallRules, rules...) for _, rule := range rules {
if len(rule.SourceRanges) > 0 {
routesFirewallRules = append(routesFirewallRules, rule)
}
}
} }
return routesFirewallRules return routesFirewallRules
@ -1303,10 +1294,10 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
} }
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers. // GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) { func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool var isRoutingPeer bool
var routes []*route.Route var routes []*route.Route
var allSourcePeers []string allSourcePeers := make(map[string]struct{}, len(a.Peers))
for _, resource := range a.NetworkResources { for _, resource := range a.NetworkResources {
var addSourcePeers bool var addSourcePeers bool
@ -1319,23 +1310,22 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
} }
} }
addedResourceRoute := false
for _, policy := range resourcePolicies[resource.ID] { for _, policy := range resourcePolicies[resource.ID] {
for _, sourceGroup := range policy.SourceGroups() { peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
group := a.GetGroup(sourceGroup) if addSourcePeers {
if group == nil { for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
log.WithContext(ctx).Warnf("policy %s has source group %s that doesn't exist under account %s, will continue map generation without it", policy.ID, sourceGroup, a.Id) allSourcePeers[pID] = struct{}{}
continue
} }
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
// routing peer should be able to connect with all source peers // add routes for the resource if the peer is in the distribution group
if addSourcePeers { for peerId, router := range networkRoutingPeers {
allSourcePeers = append(allSourcePeers, group.Peers...) routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
} else if slices.Contains(group.Peers, peerID) {
// add routes for the resource if the peer is in the distribution group
for peerId, router := range networkRoutingPeers {
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
}
} }
addedResourceRoute = true
}
if addedResourceRoute {
break
} }
} }
} }
@ -1343,6 +1333,42 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
return isRoutingPeer, routes, allSourcePeers return isRoutingPeer, routes, allSourcePeers
} }
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
var dest []string
for _, peerID := range inputPeers {
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
dest = append(dest, peerID)
}
}
return dest
}
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
for _, groupID := range groups {
group := a.GetGroup(groupID)
if group == nil {
log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id)
continue
}
if group.IsGroupAll() || len(groups) == 1 {
return group.Peers
}
for _, peerID := range group.Peers {
peerIDs[peerID] = struct{}{}
}
}
ids := make([]string, 0, len(peerIDs))
for peerID := range peerIDs {
ids = append(ids, peerID)
}
return ids
}
// getNetworkResources filters and returns a list of network resources associated with the given network ID. // getNetworkResources filters and returns a list of network resources associated with the given network ID.
func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource { func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource {
var resources []*resourceTypes.NetworkResource var resources []*resourceTypes.NetworkResource

View File

@ -1,14 +1,20 @@
package types package types
import ( import (
"context"
"net"
"net/netip"
"slices"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@ -310,19 +316,19 @@ func Test_GetResourcePoliciesMap(t *testing.T) {
func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) { func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) {
account := setupTestAccount() account := setupTestAccount()
peer := &nbpeer.Peer{Key: "peer1"} peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"}
networkResourcesRoutes := []*route.Route{ networkResourcesRoutes := []*route.Route{
{Peer: "peer2Key"}, {Peer: "peer2Key", PeerID: "peer2"},
{Peer: "peer3Key"}, {Peer: "peer3Key", PeerID: "peer3"},
} }
peersToConnect := []*nbpeer.Peer{ peersToConnect := []*nbpeer.Peer{
{Key: "peer2Key"}, {Key: "peer2Key", ID: "peer2"},
} }
expiredPeers := []*nbpeer.Peer{ expiredPeers := []*nbpeer.Peer{
{Key: "peer4Key"}, {Key: "peer4Key", ID: "peer4"},
} }
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 2) require.Len(t, result, 2)
require.Equal(t, "peer2Key", result[0].Key) require.Equal(t, "peer2Key", result[0].Key)
require.Equal(t, "peer3Key", result[1].Key) require.Equal(t, "peer3Key", result[1].Key)
@ -339,7 +345,7 @@ func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) {
} }
expiredPeers := []*nbpeer.Peer{} expiredPeers := []*nbpeer.Peer{}
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 1) require.Len(t, result, 1)
require.Equal(t, "peer2Key", result[0].Key) require.Equal(t, "peer2Key", result[0].Key)
} }
@ -358,7 +364,7 @@ func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) {
{Key: "peer3Key"}, {Key: "peer3Key"},
} }
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 1) require.Len(t, result, 1)
require.Equal(t, "peer2Key", result[0].Key) require.Equal(t, "peer2Key", result[0].Key)
} }
@ -370,6 +376,382 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) {
peersToConnect := []*nbpeer.Peer{} peersToConnect := []*nbpeer.Peer{}
expiredPeers := []*nbpeer.Peer{} expiredPeers := []*nbpeer.Peer{}
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 0) require.Len(t, result, 0)
} }
const (
accID = "accountID"
network1ID = "network1ID"
group1ID = "group1"
accNetResourcePeer1ID = "peer1"
accNetResourcePeer2ID = "peer2"
accNetResourceRouter1ID = "router1"
accNetResource1ID = "resource1ID"
accNetResourceRestrictPostureCheckID = "restrictPostureCheck"
accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck"
accNetResourceLockedPostureCheckID = "lockedPostureCheck"
accNetResourceLinuxPostureCheckID = "linuxPostureCheck"
)
var (
accNetResourcePeer1IP = net.IP{192, 168, 1, 1}
accNetResourcePeer2IP = net.IP{192, 168, 1, 2}
accNetResourceRouter1IP = net.IP{192, 168, 1, 3}
accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}}
)
func getBasicAccountsWithResource() *Account {
return &Account{
Id: accID,
Peers: map[string]*nbpeer.Peer{
accNetResourcePeer1ID: {
ID: accNetResourcePeer1ID,
AccountID: accID,
Key: "peer1Key",
IP: accNetResourcePeer1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
accNetResourcePeer2ID: {
ID: accNetResourcePeer2ID,
AccountID: accID,
Key: "peer2Key",
IP: accNetResourcePeer2IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "windows",
WtVersion: "0.34.1",
KernelVersion: "4.4.0",
},
},
accNetResourceRouter1ID: {
ID: accNetResourceRouter1ID,
AccountID: accID,
Key: "router1Key",
IP: accNetResourceRouter1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
},
Groups: map[string]*Group{
group1ID: {
ID: group1ID,
Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID},
},
},
Networks: []*networkTypes.Network{
{
ID: network1ID,
AccountID: accID,
Name: "network1",
},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{
ID: accNetResourceRouter1ID,
NetworkID: network1ID,
AccountID: accID,
Peer: accNetResourceRouter1ID,
PeerGroups: []string{},
Masquerade: false,
Metric: 100,
},
},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: accNetResource1ID,
AccountID: accID,
NetworkID: network1ID,
Address: "10.10.10.0/24",
Prefix: netip.MustParsePrefix("10.10.10.0/24"),
Type: resourceTypes.NetworkResourceType("subnet"),
},
},
Policies: []*Policy{
{
ID: "policy1ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "rule1ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"80"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: nil,
},
},
PostureChecks: []*posture.Checks{
{
ID: accNetResourceRestrictPostureCheckID,
Name: accNetResourceRestrictPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.35.0",
},
},
},
{
ID: accNetResourceRelaxedPostureCheckID,
Name: accNetResourceRelaxedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
},
{
ID: accNetResourceLockedPostureCheckID,
Name: accNetResourceLockedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "7.7.7",
},
},
},
{
ID: accNetResourceLinuxPostureCheckID,
Name: accNetResourceLinuxPostureCheckID,
Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "0.0.0"},
},
},
},
},
}
}
func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// all peers should match the policy
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should not match any peer
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 0, "expected rules count don't match")
}
func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// should allow peer1 and peer2 to match the policy
newPolicy := &Policy{
ID: "policy2ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "policy2ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"22"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID},
}
account.Policies = append(account.Policies, newPolicy)
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 2, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
assert.Equal(t, uint16(22), rules[1].Port, "should have port 22")
assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp")
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// two posture checks should match only the peers that match both checks
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}

View File

@ -35,6 +35,15 @@ type FirewallRule struct {
Port string Port string
} }
// IsEqual checks if two firewall rules are equal.
func (r *FirewallRule) IsEqual(other *FirewallRule) bool {
return r.PeerIP == other.PeerIP &&
r.Direction == other.Direction &&
r.Action == other.Action &&
r.Protocol == other.Protocol &&
r.Port == other.Port
}
// generateRouteFirewallRules generates a list of firewall rules for a given route. // generateRouteFirewallRules generates a list of firewall rules for a given route.
func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
rulesExists := make(map[string]struct{}) rulesExists := make(map[string]struct{})

View File

@ -117,9 +117,20 @@ func (p *Policy) RuleGroups() []string {
// SourceGroups returns a slice of all unique source groups referenced in the policy's rules. // SourceGroups returns a slice of all unique source groups referenced in the policy's rules.
func (p *Policy) SourceGroups() []string { func (p *Policy) SourceGroups() []string {
groups := make([]string, 0) if len(p.Rules) == 1 {
for _, rule := range p.Rules { return p.Rules[0].Sources
groups = append(groups, rule.Sources...)
} }
return groups groups := make(map[string]struct{}, len(p.Rules))
for _, rule := range p.Rules {
for _, source := range rule.Sources {
groups[source] = struct{}{}
}
}
groupIDs := make([]string, 0, len(groups))
for groupID := range groups {
groupIDs = append(groupIDs, groupID)
}
return groupIDs
} }

View File

@ -95,6 +95,7 @@ type Route struct {
NetID NetID NetID NetID
Description string Description string
Peer string Peer string
PeerID string `gorm:"-"`
PeerGroups []string `gorm:"serializer:json"` PeerGroups []string `gorm:"serializer:json"`
NetworkType NetworkType NetworkType NetworkType
Masquerade bool Masquerade bool
@ -120,6 +121,7 @@ func (r *Route) Copy() *Route {
KeepRoute: r.KeepRoute, KeepRoute: r.KeepRoute,
NetworkType: r.NetworkType, NetworkType: r.NetworkType,
Peer: r.Peer, Peer: r.Peer,
PeerID: r.PeerID,
PeerGroups: slices.Clone(r.PeerGroups), PeerGroups: slices.Clone(r.PeerGroups),
Metric: r.Metric, Metric: r.Metric,
Masquerade: r.Masquerade, Masquerade: r.Masquerade,
@ -146,6 +148,7 @@ func (r *Route) IsEqual(other *Route) bool {
other.KeepRoute == r.KeepRoute && other.KeepRoute == r.KeepRoute &&
other.NetworkType == r.NetworkType && other.NetworkType == r.NetworkType &&
other.Peer == r.Peer && other.Peer == r.Peer &&
other.PeerID == r.PeerID &&
other.Metric == r.Metric && other.Metric == r.Metric &&
other.Masquerade == r.Masquerade && other.Masquerade == r.Masquerade &&
other.Enabled == r.Enabled && other.Enabled == r.Enabled &&