Merge branch 'main' into feature/mysql-support

This commit is contained in:
bcmmbaga 2025-01-02 14:54:14 +03:00
commit a3fe7bea38
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
30 changed files with 1001 additions and 235 deletions

View File

@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
skip: go.mod,go.sum
only_warn: 1
golangci:

View File

@ -197,7 +197,7 @@ func (m *Manager) AllowNetbird() error {
}
_, err := m.AddPeerFiltering(
net.ParseIP("0.0.0.0"),
net.IP{0, 0, 0, 0},
"all",
nil,
nil,

View File

@ -10,7 +10,6 @@ import (
// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
sync.RWMutex
SourceIP net.IP
DestIP net.IP
SourcePort uint16

View File

@ -62,6 +62,7 @@ type TCPConnKey struct {
type TCPConnTrack struct {
BaseConnTrack
State TCPState
sync.RWMutex
}
// TCPTracker manages TCP connection states
@ -131,36 +132,8 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false
}
// Handle new SYN packets
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.Lock()
if _, exists := t.connections[key]; !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, dstIP)
copyIP(dstIPCopy, srcIP)
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: dstPort,
DestPort: srcPort,
},
State: TCPStateSynReceived,
}
conn.lastSeen.Store(time.Now().UnixNano())
conn.established.Store(false)
t.connections[key] = conn
}
t.mutex.Unlock()
return true
}
// Look up existing connection
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
@ -172,8 +145,7 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
// Handle RST packets
if flags&TCPRst != 0 {
conn.Lock()
isEstablished := conn.IsEstablished()
if isEstablished || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
@ -183,7 +155,6 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false
}
// Update state
conn.Lock()
t.updateState(conn, flags, false)
conn.UpdateLastSeen()
@ -306,6 +277,11 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return flags&TCPFin != 0 || flags&TCPAck != 0
case TCPStateLastAck:
return flags&TCPAck != 0
case TCPStateClosed:
// Accept retransmitted ACKs in closed state
// This is important because the final ACK might be lost
// and the peer will retransmit their FIN-ACK
return flags&TCPAck != 0
}
return false
}

View File

@ -125,11 +125,8 @@ func TestTCPStateMachine(t *testing.T) {
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
require.True(t, valid, "RST should be allowed for established connection")
// Verify connection is closed
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
t.Helper()
require.False(t, valid, "Data should be blocked after RST")
// Connection is logically dead but we don't enforce blocking subsequent packets
// The connection will be cleaned up by timeout
},
},
{

View File

@ -68,17 +68,16 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
c.mu.Lock()
defer c.mu.Unlock()
pattern = strings.ToLower(dns.Fqdn(pattern))
origPattern := pattern
isWildcard := strings.HasPrefix(pattern, "*.")
if isWildcard {
pattern = pattern[2:]
}
pattern = dns.Fqdn(pattern)
origPattern = dns.Fqdn(origPattern)
// First remove any existing handler with same original pattern and priority
// First remove any existing handler with same pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority {
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
if c.handlers[i].StopHandler != nil {
c.handlers[i].StopHandler.stop()
}
@ -126,10 +125,10 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
pattern = dns.Fqdn(pattern)
// Find and remove handlers matching both original pattern and priority
// Find and remove handlers matching both original pattern (case-insensitive) and priority
for i := len(c.handlers) - 1; i >= 0; i-- {
entry := c.handlers[i]
if entry.OrigPattern == pattern && entry.Priority == priority {
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
if entry.StopHandler != nil {
entry.StopHandler.stop()
}
@ -144,9 +143,9 @@ func (c *HandlerChain) HasHandlers(pattern string) bool {
c.mu.RLock()
defer c.mu.RUnlock()
pattern = dns.Fqdn(pattern)
pattern = strings.ToLower(dns.Fqdn(pattern))
for _, entry := range c.handlers {
if entry.Pattern == pattern {
if strings.EqualFold(entry.Pattern, pattern) {
return true
}
}
@ -158,7 +157,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
qname := r.Question[0].Name
qname := strings.ToLower(r.Question[0].Name)
log.Tracef("handling DNS request for domain=%s", qname)
c.mu.RLock()
@ -187,9 +186,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// If handler wants subdomain matching, allow suffix match
// Otherwise require exact match
if entry.MatchSubdomains {
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
} else {
matched = qname == entry.Pattern
matched = strings.EqualFold(qname, entry.Pattern)
}
}

View File

@ -507,5 +507,173 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Test 4: Remove last handler
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
assert.False(t, chain.HasHandlers(testDomain))
}
func TestHandlerChain_CaseSensitivity(t *testing.T) {
tests := []struct {
name string
scenario string
addHandlers []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}
query string
expectedCalls int
}{
{
name: "case insensitive exact match",
scenario: "handler registered lowercase, query uppercase",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"example.com.", nbdns.PriorityDefault, false, true},
},
query: "EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "case insensitive wildcard match",
scenario: "handler registered mixed case wildcard, query different case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"*.Example.Com.", nbdns.PriorityDefault, false, true},
},
query: "sub.EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "multiple handlers different case same domain",
scenario: "second handler should replace first despite case difference",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityDefault, false, true},
},
query: "ExAmPlE.cOm.",
expectedCalls: 1,
},
{
name: "subdomain matching case insensitive",
scenario: "handler with MatchSubdomains true should match regardless of case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"example.com.", nbdns.PriorityDefault, true, true},
},
query: "SUB.EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "root zone case insensitive",
scenario: "root zone handler should match regardless of case",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{".", nbdns.PriorityDefault, false, true},
},
query: "EXAMPLE.COM.",
expectedCalls: 1,
},
{
name: "multiple handlers different priority",
scenario: "should call higher priority handler despite case differences",
addHandlers: []struct {
pattern string
priority int
subdomains bool
shouldMatch bool
}{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityMatchDomain, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
},
query: "example.com.",
expectedCalls: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chain := nbdns.NewHandlerChain()
handlerCalls := make(map[string]bool) // track which patterns were called
// Add handlers according to test case
for _, h := range tt.addHandlers {
var handler dns.Handler
pattern := h.pattern // capture pattern for closure
if h.subdomains {
subHandler := &nbdns.MockSubdomainHandler{
Subdomains: true,
}
if h.shouldMatch {
subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
handlerCalls[pattern] = true
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
handler = subHandler
} else {
mockHandler := &nbdns.MockHandler{}
if h.shouldMatch {
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
handlerCalls[pattern] = true
w := args.Get(0).(dns.ResponseWriter)
r := args.Get(1).(*dns.Msg)
resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeSuccess)
assert.NoError(t, w.WriteMsg(resp))
}).Once()
}
handler = mockHandler
}
chain.AddHandler(pattern, handler, h.priority, nil)
}
// Execute request
r := new(dns.Msg)
r.SetQuestion(tt.query, dns.TypeA)
chain.ServeDNS(&mockResponseWriter{}, r)
// Verify each handler was called exactly as expected
for _, h := range tt.addHandlers {
wasCalled := handlerCalls[h.pattern]
assert.Equal(t, h.shouldMatch, wasCalled,
"Handler for pattern %q was %s when it should%s have been",
h.pattern,
map[bool]string{true: "called", false: "not called"}[wasCalled],
map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch])
}
// Verify total number of calls
assert.Equal(t, tt.expectedCalls, len(handlerCalls),
"Wrong number of total handler calls")
})
}
}

View File

@ -83,7 +83,7 @@ func (h *Manager) allowDNSFirewall() error {
IsRange: false,
Values: []int{ListenPort},
}
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err

View File

@ -406,13 +406,9 @@ func (e *Engine) Start() error {
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
if err != nil {
log.Errorf("failed creating firewall manager: %s", err)
}
if e.firewall != nil && e.firewall.IsServerRouteSupported() {
err = e.routeManager.EnableServerRouter(e.firewall)
if err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
} else if e.firewall != nil {
if err := e.initFirewall(err); err != nil {
return err
}
}
@ -455,6 +451,41 @@ func (e *Engine) Start() error {
return nil
}
func (e *Engine) initFirewall(error) error {
if e.firewall.IsServerRouteSupported() {
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
e.close()
return fmt.Errorf("enable server router: %w", err)
}
}
if e.rpManager == nil || !e.config.RosenpassEnabled {
return nil
}
rosenpassPort := e.rpManager.GetAddress().Port
port := manager.Port{Values: []int{rosenpassPort}}
// this rule is static and will be torn down on engine down by the firewall manager
if _, err := e.firewall.AddPeerFiltering(
net.IP{0, 0, 0, 0},
manager.ProtocolUDP,
nil,
&port,
manager.RuleDirectionIN,
manager.ActionAccept,
"",
"",
); err != nil {
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
return nil
}
log.Infof("rosenpass interface traffic allowed on port %d", rosenpassPort)
return nil
}
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {

View File

@ -139,10 +139,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
s.mutex.Lock()
defer s.mutex.Unlock()
if s.logFile == "console" {
return nil, fmt.Errorf("log file is set to console, cannot create debug bundle")
}
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
if err != nil {
return nil, fmt.Errorf("create zip file: %w", err)
@ -185,17 +181,7 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
}
if req.GetSystemInfo() {
if err := s.addRoutes(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add routes to debug bundle: %v", err)
}
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add interfaces to debug bundle: %v", err)
}
if err := s.addFirewallRules(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add firewall rules to debug bundle: %v", err)
}
s.addSystemInfo(req, anonymizer, archive)
}
if err := s.addNetworkMap(req, anonymizer, archive); err != nil {
@ -206,8 +192,10 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
log.Errorf("Failed to add state file to debug bundle: %v", err)
}
if err := s.addLogfile(req, anonymizer, archive); err != nil {
return fmt.Errorf("add log file: %w", err)
if s.logFile != "console" {
if err := s.addLogfile(req, anonymizer, archive); err != nil {
return fmt.Errorf("add log file: %w", err)
}
}
if err := archive.Close(); err != nil {
@ -216,6 +204,20 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
return nil
}
func (s *Server) addSystemInfo(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) {
if err := s.addRoutes(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add routes to debug bundle: %v", err)
}
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add interfaces to debug bundle: %v", err)
}
if err := s.addFirewallRules(req, anonymizer, archive); err != nil {
log.Errorf("Failed to add firewall rules to debug bundle: %v", err)
}
}
func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error {
if req.GetAnonymize() {
readmeReader := strings.NewReader(readmeContent)

View File

@ -1251,6 +1251,12 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
// and propagates changes to peers if group propagation is enabled.
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
if claim, exists := claims.Raw[jwtclaims.IsToken]; exists {
if isToken, ok := claim.(bool); ok && isToken {
return nil
}
}
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return err

View File

@ -2730,6 +2730,19 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
t.Run("skip sync for token auth type", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
})
t.Run("empty jwt groups", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
@ -2823,7 +2836,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.Len(t, user.AutoGroups, 1, "new group should be added")
})
t.Run("remove all JWT groups", func(t *testing.T) {
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user1",
Raw: jwt.MapClaims{"groups": []interface{}{}},
@ -2834,7 +2847,20 @@ func TestAccount_SetJWTGroups(t *testing.T) {
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present")
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
})
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
claims := jwtclaims.AuthorizationClaims{
UserId: "user2",
Raw: jwt.MapClaims{},
}
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
})
}
@ -3038,9 +3064,9 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
minMsPerOpCICD float64
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 1, 3, 3, 10},
{"Small", 50, 5, 1, 3, 3, 11},
{"Medium", 500, 100, 7, 13, 10, 70},
{"Large", 5000, 200, 65, 80, 60, 200},
{"Large", 5000, 200, 65, 80, 60, 220},
{"Small single", 50, 10, 1, 3, 3, 70},
{"Medium single", 500, 10, 7, 13, 10, 26},
{"Large 5", 5000, 15, 65, 80, 60, 200},
@ -3180,7 +3206,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
maxMsPerOpCICD float64
}{
{"Small", 50, 5, 107, 120, 107, 160},
{"Medium", 500, 100, 105, 140, 105, 190},
{"Medium", 500, 100, 105, 140, 105, 220},
{"Large", 5000, 200, 180, 220, 180, 350},
{"Small single", 50, 10, 107, 120, 105, 160},
{"Medium single", 500, 10, 105, 140, 105, 170},

View File

@ -474,6 +474,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if len(group.Resources) > 0 {
return &GroupLinkError{"network resource", group.Resources[0].ID}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
@ -529,7 +533,10 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
}
for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
isLinked := slices.Contains(r.Groups, groupID) ||
slices.Contains(r.PeerGroups, groupID) ||
slices.Contains(r.AccessControlGroups, groupID)
if isLinked {
return true, r
}
}

View File

@ -725,10 +725,6 @@ components:
PolicyRuleMinimum:
type: object
properties:
id:
description: Policy rule ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
name:
description: Policy rule name identifier
type: string
@ -790,6 +786,31 @@ components:
- end
PolicyRuleUpdate:
allOf:
- $ref: '#/components/schemas/PolicyRuleMinimum'
- type: object
properties:
id:
description: Policy rule ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
sources:
description: Policy rule source group IDs
type: array
items:
type: string
example: "ch8i4ug6lnn4g9hqv797"
destinations:
description: Policy rule destination group IDs
type: array
items:
type: string
example: "ch8i4ug6lnn4g9h7v7m0"
required:
- sources
- destinations
PolicyRuleCreate:
allOf:
- $ref: '#/components/schemas/PolicyRuleMinimum'
- type: object
@ -817,6 +838,10 @@ components:
- $ref: '#/components/schemas/PolicyRuleMinimum'
- type: object
properties:
id:
description: Policy rule ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
sources:
description: Policy rule source group IDs
type: array
@ -836,10 +861,6 @@ components:
PolicyMinimum:
type: object
properties:
id:
description: Policy ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
name:
description: Policy name identifier
type: string
@ -854,7 +875,6 @@ components:
example: true
required:
- name
- description
- enabled
PolicyUpdate:
allOf:
@ -874,11 +894,33 @@ components:
$ref: '#/components/schemas/PolicyRuleUpdate'
required:
- rules
PolicyCreate:
allOf:
- $ref: '#/components/schemas/PolicyMinimum'
- type: object
properties:
source_posture_checks:
description: Posture checks ID's applied to policy source groups
type: array
items:
type: string
example: "chacdk86lnnboviihd70"
rules:
description: Policy rule object for policy UI editor
type: array
items:
$ref: '#/components/schemas/PolicyRuleUpdate'
required:
- rules
Policy:
allOf:
- $ref: '#/components/schemas/PolicyMinimum'
- type: object
properties:
id:
description: Policy ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
source_posture_checks:
description: Posture checks ID's applied to policy source groups
type: array
@ -2463,7 +2505,7 @@ paths:
content:
'application/json':
schema:
$ref: '#/components/schemas/PolicyUpdate'
$ref: '#/components/schemas/PolicyCreate'
responses:
'200':
description: A Policy object

View File

@ -879,7 +879,7 @@ type PersonalAccessTokenRequest struct {
// Policy defines model for Policy.
type Policy struct {
// Description Policy friendly description
Description string `json:"description"`
Description *string `json:"description,omitempty"`
// Enabled Policy status
Enabled bool `json:"enabled"`
@ -897,16 +897,31 @@ type Policy struct {
SourcePostureChecks []string `json:"source_posture_checks"`
}
// PolicyMinimum defines model for PolicyMinimum.
type PolicyMinimum struct {
// PolicyCreate defines model for PolicyCreate.
type PolicyCreate struct {
// Description Policy friendly description
Description string `json:"description"`
Description *string `json:"description,omitempty"`
// Enabled Policy status
Enabled bool `json:"enabled"`
// Id Policy ID
Id *string `json:"id,omitempty"`
// Name Policy name identifier
Name string `json:"name"`
// Rules Policy rule object for policy UI editor
Rules []PolicyRuleUpdate `json:"rules"`
// SourcePostureChecks Posture checks ID's applied to policy source groups
SourcePostureChecks *[]string `json:"source_posture_checks,omitempty"`
}
// PolicyMinimum defines model for PolicyMinimum.
type PolicyMinimum struct {
// Description Policy friendly description
Description *string `json:"description,omitempty"`
// Enabled Policy status
Enabled bool `json:"enabled"`
// Name Policy name identifier
Name string `json:"name"`
@ -970,9 +985,6 @@ type PolicyRuleMinimum struct {
// Enabled Policy rule status
Enabled bool `json:"enabled"`
// Id Policy rule ID
Id *string `json:"id,omitempty"`
// Name Policy rule name identifier
Name string `json:"name"`
@ -1039,14 +1051,11 @@ type PolicyRuleUpdateProtocol string
// PolicyUpdate defines model for PolicyUpdate.
type PolicyUpdate struct {
// Description Policy friendly description
Description string `json:"description"`
Description *string `json:"description,omitempty"`
// Enabled Policy status
Enabled bool `json:"enabled"`
// Id Policy ID
Id *string `json:"id,omitempty"`
// Name Policy name identifier
Name string `json:"name"`
@ -1473,7 +1482,7 @@ type PutApiPeersPeerIdJSONRequestBody = PeerRequest
type PostApiPoliciesJSONRequestBody = PolicyUpdate
// PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType.
type PutApiPoliciesPolicyIdJSONRequestBody = PolicyUpdate
type PutApiPoliciesPolicyIdJSONRequestBody = PolicyCreate
// PostApiPostureChecksJSONRequestBody defines body for PostApiPostureChecks for application/json ContentType.
type PostApiPostureChecksJSONRequestBody = PostureCheckUpdate

View File

@ -133,16 +133,21 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
return
}
description := ""
if req.Description != nil {
description = *req.Description
}
policy := &types.Policy{
ID: policyID,
AccountID: accountID,
Name: req.Name,
Enabled: req.Enabled,
Description: req.Description,
Description: description,
}
for _, rule := range req.Rules {
var ruleID string
if rule.Id != nil {
if rule.Id != nil && policyID != "" {
ruleID = *rule.Id
}
@ -370,7 +375,7 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
ap := &api.Policy{
Id: &policy.ID,
Name: policy.Name,
Description: policy.Description,
Description: &policy.Description,
Enabled: policy.Enabled,
SourcePostureChecks: policy.SourcePostureChecks,
}

View File

@ -154,6 +154,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
func TestPoliciesWritePolicy(t *testing.T) {
str := func(s string) *string { return &s }
emptyString := ""
tt := []struct {
name string
expectedStatus int
@ -184,8 +185,9 @@ func TestPoliciesWritePolicy(t *testing.T) {
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPolicy: &api.Policy{
Id: str("id-was-set"),
Name: "Default POSTed Policy",
Id: str("id-was-set"),
Name: "Default POSTed Policy",
Description: &emptyString,
Rules: []api.PolicyRule{
{
Id: str("id-was-set"),
@ -232,8 +234,9 @@ func TestPoliciesWritePolicy(t *testing.T) {
expectedStatus: http.StatusOK,
expectedBody: true,
expectedPolicy: &api.Policy{
Id: str("id-existed"),
Name: "Default POSTed Policy",
Id: str("id-existed"),
Name: "Default POSTed Policy",
Description: &emptyString,
Rules: []api.PolicyRule{
{
Id: str("id-existed"),

View File

@ -175,6 +175,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
claimMaps[jwtclaims.IsToken] = true
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information.

View File

@ -22,6 +22,8 @@ const (
LastLoginSuffix = "nb_last_login"
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
Invited = "nb_invited"
// IsToken claim indicates that auth type from the user is a token
IsToken = "is_token"
)
// ExtractClaims Extract function type

View File

@ -195,6 +195,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
groups int
routes int
routesWithRGGroups int
networks int
networkResources int
networkRouters int
networkRoutersWithPG int
nameservers int
uiClient int
version string
@ -219,6 +223,16 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
}
groups += len(account.Groups)
networks += len(account.Networks)
networkResources += len(account.NetworkResources)
networkRouters += len(account.NetworkRouters)
for _, router := range account.NetworkRouters {
if len(router.PeerGroups) > 0 {
networkRoutersWithPG++
}
}
routes += len(account.Routes)
for _, route := range account.Routes {
if len(route.PeerGroups) > 0 {
@ -312,6 +326,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks
metricsProperties["posture_checks"] = postureChecks
metricsProperties["groups"] = groups
metricsProperties["networks"] = networks
metricsProperties["network_resources"] = networkResources
metricsProperties["network_routers"] = networkRouters
metricsProperties["network_routers_with_groups"] = networkRoutersWithPG
metricsProperties["routes"] = routes
metricsProperties["routes_with_routing_groups"] = routesWithRGGroups
metricsProperties["nameservers"] = nameservers

View File

@ -5,6 +5,9 @@ import (
"testing"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
@ -172,6 +175,31 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
},
},
},
Networks: []*networkTypes.Network{
{
ID: "1",
AccountID: "1",
},
},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: "1",
AccountID: "1",
NetworkID: "1",
},
{
ID: "2",
AccountID: "1",
NetworkID: "1",
},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{
ID: "1",
AccountID: "1",
NetworkID: "1",
},
},
},
}
}
@ -200,6 +228,15 @@ func TestGenerateProperties(t *testing.T) {
if properties["routes"] != 2 {
t.Errorf("expected 2 routes, got %d", properties["routes"])
}
if properties["networks"] != 1 {
t.Errorf("expected 1 networks, got %d", properties["networks"])
}
if properties["network_resources"] != 2 {
t.Errorf("expected 2 network_resources, got %d", properties["network_resources"])
}
if properties["network_routers"] != 1 {
t.Errorf("expected 1 network_routers, got %d", properties["network_routers"])
}
if properties["rules"] != 4 {
t.Errorf("expected 4 rules, got %d", properties["rules"])
}

View File

@ -111,6 +111,7 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network
NetID: route.NetID(n.Name),
Description: n.Description,
Peer: peer.Key,
PeerID: peer.ID,
PeerGroups: nil,
Masquerade: router.Masquerade,
Metric: router.Metric,

View File

@ -932,11 +932,11 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}{
{"Small", 50, 5, 90, 120, 90, 120},
{"Medium", 500, 100, 110, 150, 120, 260},
{"Large", 5000, 200, 800, 1390, 2500, 4600},
{"Large", 5000, 200, 800, 1700, 2500, 5000},
{"Small single", 50, 10, 90, 120, 90, 120},
{"Medium single", 500, 10, 110, 170, 120, 200},
{"Large 5", 5000, 15, 1300, 2100, 5000, 7000},
{"Extra Large", 2000, 2000, 1300, 2100, 4000, 6000},
{"Large 5", 5000, 15, 1300, 2100, 4900, 7000},
{"Extra Large", 2000, 2000, 1300, 2400, 4000, 6400},
}
log.SetOutput(io.Discard)

View File

@ -74,6 +74,19 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
"peerH",
},
},
"GroupWorkstations": {
ID: "GroupWorkstations",
Name: "GroupWorkstations",
Peers: []string{
"peerB",
"peerA",
"peerD",
"peerE",
"peerF",
"peerG",
"peerH",
},
},
"GroupSwarm": {
ID: "GroupSwarm",
Name: "swarm",
@ -127,7 +140,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
Action: types.PolicyTrafficActionAccept,
Sources: []string{
"GroupSwarm",
"GroupAll",
"GroupWorkstations",
},
Destinations: []string{
"GroupSwarm",
@ -159,6 +172,8 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
assert.Contains(t, peers, account.Peers["peerD"])
assert.Contains(t, peers, account.Peers["peerE"])
assert.Contains(t, peers, account.Peers["peerF"])
assert.Contains(t, peers, account.Peers["peerG"])
assert.Contains(t, peers, account.Peers["peerH"])
epectedFirewallRules := []*types.FirewallRule{
{
@ -189,21 +204,6 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
},
{
PeerIP: "100.65.62.5",
Direction: types.FirewallRuleDirectionOUT,
@ -280,10 +280,16 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
},
}
assert.Len(t, firewallRules, len(epectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
for _, rule := range firewallRules {
contains := false
for _, expectedRule := range epectedFirewallRules {
if rule.IsEqual(expectedRule) {
contains = true
break
}
}
assert.True(t, contains, "rule not found in expected rules %#v", rule)
}
})
}

View File

@ -364,7 +364,7 @@ func toProtocolRoute(route *route.Route) *proto.Route {
}
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
protoRoutes := make([]*proto.Route, 0)
protoRoutes := make([]*proto.Route, 0, len(routes))
for _, r := range routes {
protoRoutes = append(protoRoutes, toProtocolRoute(r))
}

View File

@ -303,55 +303,47 @@ func (a *Account) GetPeerNetworkMap(
return nm
}
func (a *Account) addNetworksRoutingPeers(networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers []string) []*nbpeer.Peer {
missingPeers := map[string]struct{}{}
for _, r := range networkResourcesRoutes {
if r.Peer == peer.Key {
continue
}
func (a *Account) addNetworksRoutingPeers(
networkResourcesRoutes []*route.Route,
peer *nbpeer.Peer,
peersToConnect []*nbpeer.Peer,
expiredPeers []*nbpeer.Peer,
isRouter bool,
sourcePeers map[string]struct{},
) []*nbpeer.Peer {
missing := true
for _, p := range slices.Concat(peersToConnect, expiredPeers) {
if r.Peer == p.Key {
missing = false
break
}
}
if missing {
missingPeers[r.Peer] = struct{}{}
}
networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
for _, r := range networkResourcesRoutes {
networkRoutesPeers[r.PeerID] = struct{}{}
}
if isRouter {
for _, s := range sourcePeers {
if s == peer.ID {
continue
}
delete(sourcePeers, peer.ID)
missing := true
for _, p := range slices.Concat(peersToConnect, expiredPeers) {
if s == p.ID {
missing = false
break
}
}
if missing {
p, ok := a.Peers[s]
if ok {
missingPeers[p.Key] = struct{}{}
}
}
for _, existingPeer := range peersToConnect {
delete(sourcePeers, existingPeer.ID)
delete(networkRoutesPeers, existingPeer.ID)
}
for _, expPeer := range expiredPeers {
delete(sourcePeers, expPeer.ID)
delete(networkRoutesPeers, expPeer.ID)
}
missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
if isRouter {
for p := range sourcePeers {
missingPeers[p] = struct{}{}
}
}
for p := range networkRoutesPeers {
missingPeers[p] = struct{}{}
}
for p := range missingPeers {
for _, p2 := range a.Peers {
if p2.Key == p {
peersToConnect = append(peersToConnect, p2)
break
}
if missingPeer := a.Peers[p]; missingPeer != nil {
peersToConnect = append(peersToConnect, missingPeer)
}
}
return peersToConnect
}
@ -1045,37 +1037,32 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
peerInGroups := false
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
for _, g := range groups {
group, ok := a.Groups[g]
if !ok {
uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups)
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
for _, p := range uniquePeerIDs {
peer, ok := a.Peers[p]
if !ok || peer == nil {
continue
}
for _, p := range group.Peers {
peer, ok := a.Peers[p]
if !ok || peer == nil {
continue
}
// validate the peer based on policy posture checks applied
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peer)
// validate the peer based on policy posture checks applied
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peer)
}
return filteredPeers, peerInGroups
}
@ -1151,7 +1138,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
continue
}
rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap)
rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap)
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
fwRules = append(fwRules, rules...)
}
@ -1159,7 +1146,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
return fwRules
}
func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
distPeersWithPolicy := make(map[string]struct{})
for _, id := range rule.Sources {
group := a.Groups[id]
@ -1173,7 +1160,7 @@ func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeer
}
_, distPeer := distributionPeers[pID]
_, valid := validatedPeersMap[pID]
if distPeer && valid {
if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) {
distPeersWithPolicy[pID] = struct{}{}
}
}
@ -1271,7 +1258,11 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer
distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups)
rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers)
routesFirewallRules = append(routesFirewallRules, rules...)
for _, rule := range rules {
if len(rule.SourceRanges) > 0 {
routesFirewallRules = append(routesFirewallRules, rule)
}
}
}
return routesFirewallRules
@ -1303,10 +1294,10 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
}
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) {
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool
var routes []*route.Route
var allSourcePeers []string
allSourcePeers := make(map[string]struct{}, len(a.Peers))
for _, resource := range a.NetworkResources {
var addSourcePeers bool
@ -1319,23 +1310,22 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
}
}
addedResourceRoute := false
for _, policy := range resourcePolicies[resource.ID] {
for _, sourceGroup := range policy.SourceGroups() {
group := a.GetGroup(sourceGroup)
if group == nil {
log.WithContext(ctx).Warnf("policy %s has source group %s that doesn't exist under account %s, will continue map generation without it", policy.ID, sourceGroup, a.Id)
continue
peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
if addSourcePeers {
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
allSourcePeers[pID] = struct{}{}
}
// routing peer should be able to connect with all source peers
if addSourcePeers {
allSourcePeers = append(allSourcePeers, group.Peers...)
} else if slices.Contains(group.Peers, peerID) {
// add routes for the resource if the peer is in the distribution group
for peerId, router := range networkRoutingPeers {
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
}
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
// add routes for the resource if the peer is in the distribution group
for peerId, router := range networkRoutingPeers {
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
}
addedResourceRoute = true
}
if addedResourceRoute {
break
}
}
}
@ -1343,6 +1333,42 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
return isRoutingPeer, routes, allSourcePeers
}
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
var dest []string
for _, peerID := range inputPeers {
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
dest = append(dest, peerID)
}
}
return dest
}
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
for _, groupID := range groups {
group := a.GetGroup(groupID)
if group == nil {
log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id)
continue
}
if group.IsGroupAll() || len(groups) == 1 {
return group.Peers
}
for _, peerID := range group.Peers {
peerIDs[peerID] = struct{}{}
}
}
ids := make([]string, 0, len(peerIDs))
for peerID := range peerIDs {
ids = append(ids, peerID)
}
return ids
}
// getNetworkResources filters and returns a list of network resources associated with the given network ID.
func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource {
var resources []*resourceTypes.NetworkResource

View File

@ -1,14 +1,20 @@
package types
import (
"context"
"net"
"net/netip"
"slices"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route"
)
@ -310,19 +316,19 @@ func Test_GetResourcePoliciesMap(t *testing.T) {
func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) {
account := setupTestAccount()
peer := &nbpeer.Peer{Key: "peer1"}
peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"}
networkResourcesRoutes := []*route.Route{
{Peer: "peer2Key"},
{Peer: "peer3Key"},
{Peer: "peer2Key", PeerID: "peer2"},
{Peer: "peer3Key", PeerID: "peer3"},
}
peersToConnect := []*nbpeer.Peer{
{Key: "peer2Key"},
{Key: "peer2Key", ID: "peer2"},
}
expiredPeers := []*nbpeer.Peer{
{Key: "peer4Key"},
{Key: "peer4Key", ID: "peer4"},
}
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{})
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 2)
require.Equal(t, "peer2Key", result[0].Key)
require.Equal(t, "peer3Key", result[1].Key)
@ -339,7 +345,7 @@ func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) {
}
expiredPeers := []*nbpeer.Peer{}
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{})
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 1)
require.Equal(t, "peer2Key", result[0].Key)
}
@ -358,7 +364,7 @@ func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) {
{Key: "peer3Key"},
}
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{})
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 1)
require.Equal(t, "peer2Key", result[0].Key)
}
@ -370,6 +376,382 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) {
peersToConnect := []*nbpeer.Peer{}
expiredPeers := []*nbpeer.Peer{}
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{})
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
require.Len(t, result, 0)
}
const (
accID = "accountID"
network1ID = "network1ID"
group1ID = "group1"
accNetResourcePeer1ID = "peer1"
accNetResourcePeer2ID = "peer2"
accNetResourceRouter1ID = "router1"
accNetResource1ID = "resource1ID"
accNetResourceRestrictPostureCheckID = "restrictPostureCheck"
accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck"
accNetResourceLockedPostureCheckID = "lockedPostureCheck"
accNetResourceLinuxPostureCheckID = "linuxPostureCheck"
)
var (
accNetResourcePeer1IP = net.IP{192, 168, 1, 1}
accNetResourcePeer2IP = net.IP{192, 168, 1, 2}
accNetResourceRouter1IP = net.IP{192, 168, 1, 3}
accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}}
)
func getBasicAccountsWithResource() *Account {
return &Account{
Id: accID,
Peers: map[string]*nbpeer.Peer{
accNetResourcePeer1ID: {
ID: accNetResourcePeer1ID,
AccountID: accID,
Key: "peer1Key",
IP: accNetResourcePeer1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
accNetResourcePeer2ID: {
ID: accNetResourcePeer2ID,
AccountID: accID,
Key: "peer2Key",
IP: accNetResourcePeer2IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "windows",
WtVersion: "0.34.1",
KernelVersion: "4.4.0",
},
},
accNetResourceRouter1ID: {
ID: accNetResourceRouter1ID,
AccountID: accID,
Key: "router1Key",
IP: accNetResourceRouter1IP,
Meta: nbpeer.PeerSystemMeta{
GoOS: "linux",
WtVersion: "0.35.1",
KernelVersion: "4.4.0",
},
},
},
Groups: map[string]*Group{
group1ID: {
ID: group1ID,
Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID},
},
},
Networks: []*networkTypes.Network{
{
ID: network1ID,
AccountID: accID,
Name: "network1",
},
},
NetworkRouters: []*routerTypes.NetworkRouter{
{
ID: accNetResourceRouter1ID,
NetworkID: network1ID,
AccountID: accID,
Peer: accNetResourceRouter1ID,
PeerGroups: []string{},
Masquerade: false,
Metric: 100,
},
},
NetworkResources: []*resourceTypes.NetworkResource{
{
ID: accNetResource1ID,
AccountID: accID,
NetworkID: network1ID,
Address: "10.10.10.0/24",
Prefix: netip.MustParsePrefix("10.10.10.0/24"),
Type: resourceTypes.NetworkResourceType("subnet"),
},
},
Policies: []*Policy{
{
ID: "policy1ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "rule1ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"80"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: nil,
},
},
PostureChecks: []*posture.Checks{
{
ID: accNetResourceRestrictPostureCheckID,
Name: accNetResourceRestrictPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.35.0",
},
},
},
{
ID: accNetResourceRelaxedPostureCheckID,
Name: accNetResourceRelaxedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.0.1",
},
},
},
{
ID: accNetResourceLockedPostureCheckID,
Name: accNetResourceLockedPostureCheckID,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "7.7.7",
},
},
},
{
ID: accNetResourceLinuxPostureCheckID,
Name: accNetResourceLinuxPostureCheckID,
Checks: posture.ChecksDefinition{
OSVersionCheck: &posture.OSVersionCheck{
Linux: &posture.MinKernelVersionCheck{
MinKernelVersion: "0.0.0"},
},
},
},
},
}
}
func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// all peers should match the policy
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should not match any peer
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 0, "expected rules count don't match")
}
func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// should allow peer1 to match the policy
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
// should allow peer1 and peer2 to match the policy
newPolicy := &Policy{
ID: "policy2ID",
AccountID: accID,
Enabled: true,
Rules: []*PolicyRule{
{
ID: "policy2ID",
Enabled: true,
Sources: []string{group1ID},
DestinationResource: Resource{
ID: accNetResource1ID,
Type: "Host",
},
Protocol: PolicyRuleProtocolTCP,
Ports: []string{"22"},
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID},
}
account.Policies = append(account.Policies, newPolicy)
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 2, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
assert.Equal(t, uint16(22), rules[1].Port, "should have port 22")
assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp")
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String())
}
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String())
}
}
func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) {
account := getBasicAccountsWithResource()
// two posture checks should match only the peers that match both checks
policy := account.Policies[0]
policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID}
// validate for peer1
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate for peer2
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.False(t, isRouter, "expected router status")
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
// validate routes for router1
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
assert.True(t, isRouter, "should be router")
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
// validate rules for router1
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
assert.Len(t, rules, 1, "expected rules count don't match")
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
}
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
}
}

View File

@ -35,6 +35,15 @@ type FirewallRule struct {
Port string
}
// IsEqual checks if two firewall rules are equal.
func (r *FirewallRule) IsEqual(other *FirewallRule) bool {
return r.PeerIP == other.PeerIP &&
r.Direction == other.Direction &&
r.Action == other.Action &&
r.Protocol == other.Protocol &&
r.Port == other.Port
}
// generateRouteFirewallRules generates a list of firewall rules for a given route.
func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
rulesExists := make(map[string]struct{})

View File

@ -117,9 +117,20 @@ func (p *Policy) RuleGroups() []string {
// SourceGroups returns a slice of all unique source groups referenced in the policy's rules.
func (p *Policy) SourceGroups() []string {
groups := make([]string, 0)
for _, rule := range p.Rules {
groups = append(groups, rule.Sources...)
if len(p.Rules) == 1 {
return p.Rules[0].Sources
}
return groups
groups := make(map[string]struct{}, len(p.Rules))
for _, rule := range p.Rules {
for _, source := range rule.Sources {
groups[source] = struct{}{}
}
}
groupIDs := make([]string, 0, len(groups))
for groupID := range groups {
groupIDs = append(groupIDs, groupID)
}
return groupIDs
}

View File

@ -95,6 +95,7 @@ type Route struct {
NetID NetID
Description string
Peer string
PeerID string `gorm:"-"`
PeerGroups []string `gorm:"serializer:json"`
NetworkType NetworkType
Masquerade bool
@ -120,6 +121,7 @@ func (r *Route) Copy() *Route {
KeepRoute: r.KeepRoute,
NetworkType: r.NetworkType,
Peer: r.Peer,
PeerID: r.PeerID,
PeerGroups: slices.Clone(r.PeerGroups),
Metric: r.Metric,
Masquerade: r.Masquerade,
@ -146,6 +148,7 @@ func (r *Route) IsEqual(other *Route) bool {
other.KeepRoute == r.KeepRoute &&
other.NetworkType == r.NetworkType &&
other.Peer == r.Peer &&
other.PeerID == r.PeerID &&
other.Metric == r.Metric &&
other.Masquerade == r.Masquerade &&
other.Enabled == r.Enabled &&