mirror of
https://github.com/netbirdio/netbird.git
synced 2025-02-10 07:19:53 +01:00
Merge branch 'main' into feature/mysql-support
This commit is contained in:
commit
a3fe7bea38
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
||||
skip: go.mod,go.sum
|
||||
only_warn: 1
|
||||
golangci:
|
||||
|
@ -197,7 +197,7 @@ func (m *Manager) AllowNetbird() error {
|
||||
}
|
||||
|
||||
_, err := m.AddPeerFiltering(
|
||||
net.ParseIP("0.0.0.0"),
|
||||
net.IP{0, 0, 0, 0},
|
||||
"all",
|
||||
nil,
|
||||
nil,
|
||||
|
@ -10,7 +10,6 @@ import (
|
||||
|
||||
// BaseConnTrack provides common fields and locking for all connection types
|
||||
type BaseConnTrack struct {
|
||||
sync.RWMutex
|
||||
SourceIP net.IP
|
||||
DestIP net.IP
|
||||
SourcePort uint16
|
||||
|
@ -62,6 +62,7 @@ type TCPConnKey struct {
|
||||
type TCPConnTrack struct {
|
||||
BaseConnTrack
|
||||
State TCPState
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
// TCPTracker manages TCP connection states
|
||||
@ -131,36 +132,8 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
return false
|
||||
}
|
||||
|
||||
// Handle new SYN packets
|
||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
t.mutex.Lock()
|
||||
if _, exists := t.connections[key]; !exists {
|
||||
// Use preallocated IPs
|
||||
srcIPCopy := t.ipPool.Get()
|
||||
dstIPCopy := t.ipPool.Get()
|
||||
copyIP(srcIPCopy, dstIP)
|
||||
copyIP(dstIPCopy, srcIP)
|
||||
|
||||
conn := &TCPConnTrack{
|
||||
BaseConnTrack: BaseConnTrack{
|
||||
SourceIP: srcIPCopy,
|
||||
DestIP: dstIPCopy,
|
||||
SourcePort: dstPort,
|
||||
DestPort: srcPort,
|
||||
},
|
||||
State: TCPStateSynReceived,
|
||||
}
|
||||
conn.lastSeen.Store(time.Now().UnixNano())
|
||||
conn.established.Store(false)
|
||||
t.connections[key] = conn
|
||||
}
|
||||
t.mutex.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
// Look up existing connection
|
||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
||||
|
||||
t.mutex.RLock()
|
||||
conn, exists := t.connections[key]
|
||||
t.mutex.RUnlock()
|
||||
@ -172,8 +145,7 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
// Handle RST packets
|
||||
if flags&TCPRst != 0 {
|
||||
conn.Lock()
|
||||
isEstablished := conn.IsEstablished()
|
||||
if isEstablished || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
|
||||
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
|
||||
conn.State = TCPStateClosed
|
||||
conn.SetEstablished(false)
|
||||
conn.Unlock()
|
||||
@ -183,7 +155,6 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
||||
return false
|
||||
}
|
||||
|
||||
// Update state
|
||||
conn.Lock()
|
||||
t.updateState(conn, flags, false)
|
||||
conn.UpdateLastSeen()
|
||||
@ -306,6 +277,11 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||
case TCPStateLastAck:
|
||||
return flags&TCPAck != 0
|
||||
case TCPStateClosed:
|
||||
// Accept retransmitted ACKs in closed state
|
||||
// This is important because the final ACK might be lost
|
||||
// and the peer will retransmit their FIN-ACK
|
||||
return flags&TCPAck != 0
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -125,11 +125,8 @@ func TestTCPStateMachine(t *testing.T) {
|
||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
||||
require.True(t, valid, "RST should be allowed for established connection")
|
||||
|
||||
// Verify connection is closed
|
||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
||||
t.Helper()
|
||||
|
||||
require.False(t, valid, "Data should be blocked after RST")
|
||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||
// The connection will be cleaned up by timeout
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -68,17 +68,16 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||
origPattern := pattern
|
||||
isWildcard := strings.HasPrefix(pattern, "*.")
|
||||
if isWildcard {
|
||||
pattern = pattern[2:]
|
||||
}
|
||||
pattern = dns.Fqdn(pattern)
|
||||
origPattern = dns.Fqdn(origPattern)
|
||||
|
||||
// First remove any existing handler with same original pattern and priority
|
||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority {
|
||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
||||
if c.handlers[i].StopHandler != nil {
|
||||
c.handlers[i].StopHandler.stop()
|
||||
}
|
||||
@ -126,10 +125,10 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
||||
|
||||
pattern = dns.Fqdn(pattern)
|
||||
|
||||
// Find and remove handlers matching both original pattern and priority
|
||||
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||
entry := c.handlers[i]
|
||||
if entry.OrigPattern == pattern && entry.Priority == priority {
|
||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||
if entry.StopHandler != nil {
|
||||
entry.StopHandler.stop()
|
||||
}
|
||||
@ -144,9 +143,9 @@ func (c *HandlerChain) HasHandlers(pattern string) bool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
pattern = dns.Fqdn(pattern)
|
||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
||||
for _, entry := range c.handlers {
|
||||
if entry.Pattern == pattern {
|
||||
if strings.EqualFold(entry.Pattern, pattern) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -158,7 +157,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return
|
||||
}
|
||||
|
||||
qname := r.Question[0].Name
|
||||
qname := strings.ToLower(r.Question[0].Name)
|
||||
log.Tracef("handling DNS request for domain=%s", qname)
|
||||
|
||||
c.mu.RLock()
|
||||
@ -187,9 +186,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
// If handler wants subdomain matching, allow suffix match
|
||||
// Otherwise require exact match
|
||||
if entry.MatchSubdomains {
|
||||
matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
|
||||
} else {
|
||||
matched = qname == entry.Pattern
|
||||
matched = strings.EqualFold(qname, entry.Pattern)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -507,5 +507,173 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||
|
||||
// Test 4: Remove last handler
|
||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||
|
||||
assert.False(t, chain.HasHandlers(testDomain))
|
||||
}
|
||||
|
||||
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scenario string
|
||||
addHandlers []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}
|
||||
query string
|
||||
expectedCalls int
|
||||
}{
|
||||
{
|
||||
name: "case insensitive exact match",
|
||||
scenario: "handler registered lowercase, query uppercase",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||
},
|
||||
query: "EXAMPLE.COM.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "case insensitive wildcard match",
|
||||
scenario: "handler registered mixed case wildcard, query different case",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"*.Example.Com.", nbdns.PriorityDefault, false, true},
|
||||
},
|
||||
query: "sub.EXAMPLE.COM.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple handlers different case same domain",
|
||||
scenario: "second handler should replace first despite case difference",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||
{"example.com.", nbdns.PriorityDefault, false, true},
|
||||
},
|
||||
query: "ExAmPlE.cOm.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "subdomain matching case insensitive",
|
||||
scenario: "handler with MatchSubdomains true should match regardless of case",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"example.com.", nbdns.PriorityDefault, true, true},
|
||||
},
|
||||
query: "SUB.EXAMPLE.COM.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "root zone case insensitive",
|
||||
scenario: "root zone handler should match regardless of case",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{".", nbdns.PriorityDefault, false, true},
|
||||
},
|
||||
query: "EXAMPLE.COM.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple handlers different priority",
|
||||
scenario: "should call higher priority handler despite case differences",
|
||||
addHandlers: []struct {
|
||||
pattern string
|
||||
priority int
|
||||
subdomains bool
|
||||
shouldMatch bool
|
||||
}{
|
||||
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
|
||||
{"example.com.", nbdns.PriorityMatchDomain, false, false},
|
||||
{"Example.Com.", nbdns.PriorityDNSRoute, false, true},
|
||||
},
|
||||
query: "example.com.",
|
||||
expectedCalls: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chain := nbdns.NewHandlerChain()
|
||||
handlerCalls := make(map[string]bool) // track which patterns were called
|
||||
|
||||
// Add handlers according to test case
|
||||
for _, h := range tt.addHandlers {
|
||||
var handler dns.Handler
|
||||
pattern := h.pattern // capture pattern for closure
|
||||
|
||||
if h.subdomains {
|
||||
subHandler := &nbdns.MockSubdomainHandler{
|
||||
Subdomains: true,
|
||||
}
|
||||
if h.shouldMatch {
|
||||
subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
handlerCalls[pattern] = true
|
||||
w := args.Get(0).(dns.ResponseWriter)
|
||||
r := args.Get(1).(*dns.Msg)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeSuccess)
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
}
|
||||
handler = subHandler
|
||||
} else {
|
||||
mockHandler := &nbdns.MockHandler{}
|
||||
if h.shouldMatch {
|
||||
mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
|
||||
handlerCalls[pattern] = true
|
||||
w := args.Get(0).(dns.ResponseWriter)
|
||||
r := args.Get(1).(*dns.Msg)
|
||||
resp := new(dns.Msg)
|
||||
resp.SetRcode(r, dns.RcodeSuccess)
|
||||
assert.NoError(t, w.WriteMsg(resp))
|
||||
}).Once()
|
||||
}
|
||||
handler = mockHandler
|
||||
}
|
||||
|
||||
chain.AddHandler(pattern, handler, h.priority, nil)
|
||||
}
|
||||
|
||||
// Execute request
|
||||
r := new(dns.Msg)
|
||||
r.SetQuestion(tt.query, dns.TypeA)
|
||||
chain.ServeDNS(&mockResponseWriter{}, r)
|
||||
|
||||
// Verify each handler was called exactly as expected
|
||||
for _, h := range tt.addHandlers {
|
||||
wasCalled := handlerCalls[h.pattern]
|
||||
assert.Equal(t, h.shouldMatch, wasCalled,
|
||||
"Handler for pattern %q was %s when it should%s have been",
|
||||
h.pattern,
|
||||
map[bool]string{true: "called", false: "not called"}[wasCalled],
|
||||
map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch])
|
||||
}
|
||||
|
||||
// Verify total number of calls
|
||||
assert.Equal(t, tt.expectedCalls, len(handlerCalls),
|
||||
"Wrong number of total handler calls")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -83,7 +83,7 @@ func (h *Manager) allowDNSFirewall() error {
|
||||
IsRange: false,
|
||||
Values: []int{ListenPort},
|
||||
}
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
|
@ -406,13 +406,9 @@ func (e *Engine) Start() error {
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager)
|
||||
if err != nil {
|
||||
log.Errorf("failed creating firewall manager: %s", err)
|
||||
}
|
||||
|
||||
if e.firewall != nil && e.firewall.IsServerRouteSupported() {
|
||||
err = e.routeManager.EnableServerRouter(e.firewall)
|
||||
if err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("enable server router: %w", err)
|
||||
} else if e.firewall != nil {
|
||||
if err := e.initFirewall(err); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -455,6 +451,41 @@ func (e *Engine) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) initFirewall(error) error {
|
||||
if e.firewall.IsServerRouteSupported() {
|
||||
if err := e.routeManager.EnableServerRouter(e.firewall); err != nil {
|
||||
e.close()
|
||||
return fmt.Errorf("enable server router: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if e.rpManager == nil || !e.config.RosenpassEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
rosenpassPort := e.rpManager.GetAddress().Port
|
||||
port := manager.Port{Values: []int{rosenpassPort}}
|
||||
|
||||
// this rule is static and will be torn down on engine down by the firewall manager
|
||||
if _, err := e.firewall.AddPeerFiltering(
|
||||
net.IP{0, 0, 0, 0},
|
||||
manager.ProtocolUDP,
|
||||
nil,
|
||||
&port,
|
||||
manager.RuleDirectionIN,
|
||||
manager.ActionAccept,
|
||||
"",
|
||||
"",
|
||||
); err != nil {
|
||||
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("rosenpass interface traffic allowed on port %d", rosenpassPort)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// modifyPeers updates peers that have been modified (e.g. IP address has been changed).
|
||||
// It closes the existing connection, removes it from the peerConns map, and creates a new one.
|
||||
func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error {
|
||||
|
@ -139,10 +139,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.logFile == "console" {
|
||||
return nil, fmt.Errorf("log file is set to console, cannot create debug bundle")
|
||||
}
|
||||
|
||||
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create zip file: %w", err)
|
||||
@ -185,17 +181,7 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
|
||||
}
|
||||
|
||||
if req.GetSystemInfo() {
|
||||
if err := s.addRoutes(req, anonymizer, archive); err != nil {
|
||||
log.Errorf("Failed to add routes to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
|
||||
log.Errorf("Failed to add interfaces to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := s.addFirewallRules(req, anonymizer, archive); err != nil {
|
||||
log.Errorf("Failed to add firewall rules to debug bundle: %v", err)
|
||||
}
|
||||
s.addSystemInfo(req, anonymizer, archive)
|
||||
}
|
||||
|
||||
if err := s.addNetworkMap(req, anonymizer, archive); err != nil {
|
||||
@ -206,8 +192,10 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
|
||||
log.Errorf("Failed to add state file to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := s.addLogfile(req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add log file: %w", err)
|
||||
if s.logFile != "console" {
|
||||
if err := s.addLogfile(req, anonymizer, archive); err != nil {
|
||||
return fmt.Errorf("add log file: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := archive.Close(); err != nil {
|
||||
@ -216,6 +204,20 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) addSystemInfo(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) {
|
||||
if err := s.addRoutes(req, anonymizer, archive); err != nil {
|
||||
log.Errorf("Failed to add routes to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := s.addInterfaces(req, anonymizer, archive); err != nil {
|
||||
log.Errorf("Failed to add interfaces to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := s.addFirewallRules(req, anonymizer, archive); err != nil {
|
||||
log.Errorf("Failed to add firewall rules to debug bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error {
|
||||
if req.GetAnonymize() {
|
||||
readmeReader := strings.NewReader(readmeContent)
|
||||
|
@ -1251,6 +1251,12 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai
|
||||
// syncJWTGroups processes the JWT groups for a user, updates the account based on the groups,
|
||||
// and propagates changes to peers if group propagation is enabled.
|
||||
func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error {
|
||||
if claim, exists := claims.Raw[jwtclaims.IsToken]; exists {
|
||||
if isToken, ok := claim.(bool); ok && isToken {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -2730,6 +2730,19 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
|
||||
assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account")
|
||||
|
||||
t.Run("skip sync for token auth type", func(t *testing.T) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user1",
|
||||
Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true},
|
||||
}
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
|
||||
})
|
||||
|
||||
t.Run("empty jwt groups", func(t *testing.T) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user1",
|
||||
@ -2823,7 +2836,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.Len(t, user.AutoGroups, 1, "new group should be added")
|
||||
})
|
||||
|
||||
t.Run("remove all JWT groups", func(t *testing.T) {
|
||||
t.Run("remove all JWT groups when list is empty", func(t *testing.T) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user1",
|
||||
Raw: jwt.MapClaims{"groups": []interface{}{}},
|
||||
@ -2834,7 +2847,20 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
|
||||
assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present")
|
||||
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
|
||||
})
|
||||
|
||||
t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) {
|
||||
claims := jwtclaims.AuthorizationClaims{
|
||||
UserId: "user2",
|
||||
Raw: jwt.MapClaims{},
|
||||
}
|
||||
err = manager.syncJWTGroups(context.Background(), "accountID", claims)
|
||||
assert.NoError(t, err, "unable to sync jwt groups")
|
||||
|
||||
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2")
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
|
||||
})
|
||||
}
|
||||
|
||||
@ -3038,9 +3064,9 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) {
|
||||
minMsPerOpCICD float64
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 1, 3, 3, 10},
|
||||
{"Small", 50, 5, 1, 3, 3, 11},
|
||||
{"Medium", 500, 100, 7, 13, 10, 70},
|
||||
{"Large", 5000, 200, 65, 80, 60, 200},
|
||||
{"Large", 5000, 200, 65, 80, 60, 220},
|
||||
{"Small single", 50, 10, 1, 3, 3, 70},
|
||||
{"Medium single", 500, 10, 7, 13, 10, 26},
|
||||
{"Large 5", 5000, 15, 65, 80, 60, 200},
|
||||
@ -3180,7 +3206,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) {
|
||||
maxMsPerOpCICD float64
|
||||
}{
|
||||
{"Small", 50, 5, 107, 120, 107, 160},
|
||||
{"Medium", 500, 100, 105, 140, 105, 190},
|
||||
{"Medium", 500, 100, 105, 140, 105, 220},
|
||||
{"Large", 5000, 200, 180, 220, 180, 350},
|
||||
{"Small single", 50, 10, 107, 120, 105, 160},
|
||||
{"Medium single", 500, 10, 105, 140, 105, 170},
|
||||
|
@ -474,6 +474,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
|
||||
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
|
||||
}
|
||||
|
||||
if len(group.Resources) > 0 {
|
||||
return &GroupLinkError{"network resource", group.Resources[0].ID}
|
||||
}
|
||||
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
|
||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||
}
|
||||
@ -529,7 +533,10 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
|
||||
isLinked := slices.Contains(r.Groups, groupID) ||
|
||||
slices.Contains(r.PeerGroups, groupID) ||
|
||||
slices.Contains(r.AccessControlGroups, groupID)
|
||||
if isLinked {
|
||||
return true, r
|
||||
}
|
||||
}
|
||||
|
@ -725,10 +725,6 @@ components:
|
||||
PolicyRuleMinimum:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
description: Policy rule ID
|
||||
type: string
|
||||
example: ch8i4ug6lnn4g9hqv7mg
|
||||
name:
|
||||
description: Policy rule name identifier
|
||||
type: string
|
||||
@ -790,6 +786,31 @@ components:
|
||||
- end
|
||||
|
||||
PolicyRuleUpdate:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/PolicyRuleMinimum'
|
||||
- type: object
|
||||
properties:
|
||||
id:
|
||||
description: Policy rule ID
|
||||
type: string
|
||||
example: ch8i4ug6lnn4g9hqv7mg
|
||||
sources:
|
||||
description: Policy rule source group IDs
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: "ch8i4ug6lnn4g9hqv797"
|
||||
destinations:
|
||||
description: Policy rule destination group IDs
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: "ch8i4ug6lnn4g9h7v7m0"
|
||||
required:
|
||||
- sources
|
||||
- destinations
|
||||
|
||||
PolicyRuleCreate:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/PolicyRuleMinimum'
|
||||
- type: object
|
||||
@ -817,6 +838,10 @@ components:
|
||||
- $ref: '#/components/schemas/PolicyRuleMinimum'
|
||||
- type: object
|
||||
properties:
|
||||
id:
|
||||
description: Policy rule ID
|
||||
type: string
|
||||
example: ch8i4ug6lnn4g9hqv7mg
|
||||
sources:
|
||||
description: Policy rule source group IDs
|
||||
type: array
|
||||
@ -836,10 +861,6 @@ components:
|
||||
PolicyMinimum:
|
||||
type: object
|
||||
properties:
|
||||
id:
|
||||
description: Policy ID
|
||||
type: string
|
||||
example: ch8i4ug6lnn4g9hqv7mg
|
||||
name:
|
||||
description: Policy name identifier
|
||||
type: string
|
||||
@ -854,7 +875,6 @@ components:
|
||||
example: true
|
||||
required:
|
||||
- name
|
||||
- description
|
||||
- enabled
|
||||
PolicyUpdate:
|
||||
allOf:
|
||||
@ -874,11 +894,33 @@ components:
|
||||
$ref: '#/components/schemas/PolicyRuleUpdate'
|
||||
required:
|
||||
- rules
|
||||
PolicyCreate:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/PolicyMinimum'
|
||||
- type: object
|
||||
properties:
|
||||
source_posture_checks:
|
||||
description: Posture checks ID's applied to policy source groups
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
example: "chacdk86lnnboviihd70"
|
||||
rules:
|
||||
description: Policy rule object for policy UI editor
|
||||
type: array
|
||||
items:
|
||||
$ref: '#/components/schemas/PolicyRuleUpdate'
|
||||
required:
|
||||
- rules
|
||||
Policy:
|
||||
allOf:
|
||||
- $ref: '#/components/schemas/PolicyMinimum'
|
||||
- type: object
|
||||
properties:
|
||||
id:
|
||||
description: Policy ID
|
||||
type: string
|
||||
example: ch8i4ug6lnn4g9hqv7mg
|
||||
source_posture_checks:
|
||||
description: Posture checks ID's applied to policy source groups
|
||||
type: array
|
||||
@ -2463,7 +2505,7 @@ paths:
|
||||
content:
|
||||
'application/json':
|
||||
schema:
|
||||
$ref: '#/components/schemas/PolicyUpdate'
|
||||
$ref: '#/components/schemas/PolicyCreate'
|
||||
responses:
|
||||
'200':
|
||||
description: A Policy object
|
||||
|
@ -879,7 +879,7 @@ type PersonalAccessTokenRequest struct {
|
||||
// Policy defines model for Policy.
|
||||
type Policy struct {
|
||||
// Description Policy friendly description
|
||||
Description string `json:"description"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
|
||||
// Enabled Policy status
|
||||
Enabled bool `json:"enabled"`
|
||||
@ -897,16 +897,31 @@ type Policy struct {
|
||||
SourcePostureChecks []string `json:"source_posture_checks"`
|
||||
}
|
||||
|
||||
// PolicyMinimum defines model for PolicyMinimum.
|
||||
type PolicyMinimum struct {
|
||||
// PolicyCreate defines model for PolicyCreate.
|
||||
type PolicyCreate struct {
|
||||
// Description Policy friendly description
|
||||
Description string `json:"description"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
|
||||
// Enabled Policy status
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Id Policy ID
|
||||
Id *string `json:"id,omitempty"`
|
||||
// Name Policy name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
// Rules Policy rule object for policy UI editor
|
||||
Rules []PolicyRuleUpdate `json:"rules"`
|
||||
|
||||
// SourcePostureChecks Posture checks ID's applied to policy source groups
|
||||
SourcePostureChecks *[]string `json:"source_posture_checks,omitempty"`
|
||||
}
|
||||
|
||||
// PolicyMinimum defines model for PolicyMinimum.
|
||||
type PolicyMinimum struct {
|
||||
// Description Policy friendly description
|
||||
Description *string `json:"description,omitempty"`
|
||||
|
||||
// Enabled Policy status
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Name Policy name identifier
|
||||
Name string `json:"name"`
|
||||
@ -970,9 +985,6 @@ type PolicyRuleMinimum struct {
|
||||
// Enabled Policy rule status
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Id Policy rule ID
|
||||
Id *string `json:"id,omitempty"`
|
||||
|
||||
// Name Policy rule name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
@ -1039,14 +1051,11 @@ type PolicyRuleUpdateProtocol string
|
||||
// PolicyUpdate defines model for PolicyUpdate.
|
||||
type PolicyUpdate struct {
|
||||
// Description Policy friendly description
|
||||
Description string `json:"description"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
|
||||
// Enabled Policy status
|
||||
Enabled bool `json:"enabled"`
|
||||
|
||||
// Id Policy ID
|
||||
Id *string `json:"id,omitempty"`
|
||||
|
||||
// Name Policy name identifier
|
||||
Name string `json:"name"`
|
||||
|
||||
@ -1473,7 +1482,7 @@ type PutApiPeersPeerIdJSONRequestBody = PeerRequest
|
||||
type PostApiPoliciesJSONRequestBody = PolicyUpdate
|
||||
|
||||
// PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType.
|
||||
type PutApiPoliciesPolicyIdJSONRequestBody = PolicyUpdate
|
||||
type PutApiPoliciesPolicyIdJSONRequestBody = PolicyCreate
|
||||
|
||||
// PostApiPostureChecksJSONRequestBody defines body for PostApiPostureChecks for application/json ContentType.
|
||||
type PostApiPostureChecksJSONRequestBody = PostureCheckUpdate
|
||||
|
@ -133,16 +133,21 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
return
|
||||
}
|
||||
|
||||
description := ""
|
||||
if req.Description != nil {
|
||||
description = *req.Description
|
||||
}
|
||||
|
||||
policy := &types.Policy{
|
||||
ID: policyID,
|
||||
AccountID: accountID,
|
||||
Name: req.Name,
|
||||
Enabled: req.Enabled,
|
||||
Description: req.Description,
|
||||
Description: description,
|
||||
}
|
||||
for _, rule := range req.Rules {
|
||||
var ruleID string
|
||||
if rule.Id != nil {
|
||||
if rule.Id != nil && policyID != "" {
|
||||
ruleID = *rule.Id
|
||||
}
|
||||
|
||||
@ -370,7 +375,7 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy {
|
||||
ap := &api.Policy{
|
||||
Id: &policy.ID,
|
||||
Name: policy.Name,
|
||||
Description: policy.Description,
|
||||
Description: &policy.Description,
|
||||
Enabled: policy.Enabled,
|
||||
SourcePostureChecks: policy.SourcePostureChecks,
|
||||
}
|
||||
|
@ -154,6 +154,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
|
||||
|
||||
func TestPoliciesWritePolicy(t *testing.T) {
|
||||
str := func(s string) *string { return &s }
|
||||
emptyString := ""
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
@ -184,8 +185,9 @@ func TestPoliciesWritePolicy(t *testing.T) {
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedPolicy: &api.Policy{
|
||||
Id: str("id-was-set"),
|
||||
Name: "Default POSTed Policy",
|
||||
Id: str("id-was-set"),
|
||||
Name: "Default POSTed Policy",
|
||||
Description: &emptyString,
|
||||
Rules: []api.PolicyRule{
|
||||
{
|
||||
Id: str("id-was-set"),
|
||||
@ -232,8 +234,9 @@ func TestPoliciesWritePolicy(t *testing.T) {
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedPolicy: &api.Policy{
|
||||
Id: str("id-existed"),
|
||||
Name: "Default POSTed Policy",
|
||||
Id: str("id-existed"),
|
||||
Name: "Default POSTed Policy",
|
||||
Description: &emptyString,
|
||||
Rules: []api.PolicyRule{
|
||||
{
|
||||
Id: str("id-existed"),
|
||||
|
@ -175,6 +175,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ
|
||||
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
|
||||
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
|
||||
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
|
||||
claimMaps[jwtclaims.IsToken] = true
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
|
||||
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
|
||||
// Update the current request with the new context information.
|
||||
|
@ -22,6 +22,8 @@ const (
|
||||
LastLoginSuffix = "nb_last_login"
|
||||
// Invited claim indicates that an incoming JWT is from a user that just accepted an invitation
|
||||
Invited = "nb_invited"
|
||||
// IsToken claim indicates that auth type from the user is a token
|
||||
IsToken = "is_token"
|
||||
)
|
||||
|
||||
// ExtractClaims Extract function type
|
||||
|
@ -195,6 +195,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
groups int
|
||||
routes int
|
||||
routesWithRGGroups int
|
||||
networks int
|
||||
networkResources int
|
||||
networkRouters int
|
||||
networkRoutersWithPG int
|
||||
nameservers int
|
||||
uiClient int
|
||||
version string
|
||||
@ -219,6 +223,16 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
}
|
||||
|
||||
groups += len(account.Groups)
|
||||
networks += len(account.Networks)
|
||||
networkResources += len(account.NetworkResources)
|
||||
|
||||
networkRouters += len(account.NetworkRouters)
|
||||
for _, router := range account.NetworkRouters {
|
||||
if len(router.PeerGroups) > 0 {
|
||||
networkRoutersWithPG++
|
||||
}
|
||||
}
|
||||
|
||||
routes += len(account.Routes)
|
||||
for _, route := range account.Routes {
|
||||
if len(route.PeerGroups) > 0 {
|
||||
@ -312,6 +326,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
||||
metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks
|
||||
metricsProperties["posture_checks"] = postureChecks
|
||||
metricsProperties["groups"] = groups
|
||||
metricsProperties["networks"] = networks
|
||||
metricsProperties["network_resources"] = networkResources
|
||||
metricsProperties["network_routers"] = networkRouters
|
||||
metricsProperties["network_routers_with_groups"] = networkRoutersWithPG
|
||||
metricsProperties["routes"] = routes
|
||||
metricsProperties["routes_with_routing_groups"] = routesWithRGGroups
|
||||
metricsProperties["nameservers"] = nameservers
|
||||
|
@ -5,6 +5,9 @@ import (
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@ -172,6 +175,31 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||
},
|
||||
},
|
||||
},
|
||||
Networks: []*networkTypes.Network{
|
||||
{
|
||||
ID: "1",
|
||||
AccountID: "1",
|
||||
},
|
||||
},
|
||||
NetworkResources: []*resourceTypes.NetworkResource{
|
||||
{
|
||||
ID: "1",
|
||||
AccountID: "1",
|
||||
NetworkID: "1",
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
AccountID: "1",
|
||||
NetworkID: "1",
|
||||
},
|
||||
},
|
||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||
{
|
||||
ID: "1",
|
||||
AccountID: "1",
|
||||
NetworkID: "1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -200,6 +228,15 @@ func TestGenerateProperties(t *testing.T) {
|
||||
if properties["routes"] != 2 {
|
||||
t.Errorf("expected 2 routes, got %d", properties["routes"])
|
||||
}
|
||||
if properties["networks"] != 1 {
|
||||
t.Errorf("expected 1 networks, got %d", properties["networks"])
|
||||
}
|
||||
if properties["network_resources"] != 2 {
|
||||
t.Errorf("expected 2 network_resources, got %d", properties["network_resources"])
|
||||
}
|
||||
if properties["network_routers"] != 1 {
|
||||
t.Errorf("expected 1 network_routers, got %d", properties["network_routers"])
|
||||
}
|
||||
if properties["rules"] != 4 {
|
||||
t.Errorf("expected 4 rules, got %d", properties["rules"])
|
||||
}
|
||||
|
@ -111,6 +111,7 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network
|
||||
NetID: route.NetID(n.Name),
|
||||
Description: n.Description,
|
||||
Peer: peer.Key,
|
||||
PeerID: peer.ID,
|
||||
PeerGroups: nil,
|
||||
Masquerade: router.Masquerade,
|
||||
Metric: router.Metric,
|
||||
|
@ -932,11 +932,11 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
}{
|
||||
{"Small", 50, 5, 90, 120, 90, 120},
|
||||
{"Medium", 500, 100, 110, 150, 120, 260},
|
||||
{"Large", 5000, 200, 800, 1390, 2500, 4600},
|
||||
{"Large", 5000, 200, 800, 1700, 2500, 5000},
|
||||
{"Small single", 50, 10, 90, 120, 90, 120},
|
||||
{"Medium single", 500, 10, 110, 170, 120, 200},
|
||||
{"Large 5", 5000, 15, 1300, 2100, 5000, 7000},
|
||||
{"Extra Large", 2000, 2000, 1300, 2100, 4000, 6000},
|
||||
{"Large 5", 5000, 15, 1300, 2100, 4900, 7000},
|
||||
{"Extra Large", 2000, 2000, 1300, 2400, 4000, 6400},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
|
@ -74,6 +74,19 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
"peerH",
|
||||
},
|
||||
},
|
||||
"GroupWorkstations": {
|
||||
ID: "GroupWorkstations",
|
||||
Name: "GroupWorkstations",
|
||||
Peers: []string{
|
||||
"peerB",
|
||||
"peerA",
|
||||
"peerD",
|
||||
"peerE",
|
||||
"peerF",
|
||||
"peerG",
|
||||
"peerH",
|
||||
},
|
||||
},
|
||||
"GroupSwarm": {
|
||||
ID: "GroupSwarm",
|
||||
Name: "swarm",
|
||||
@ -127,7 +140,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
Sources: []string{
|
||||
"GroupSwarm",
|
||||
"GroupAll",
|
||||
"GroupWorkstations",
|
||||
},
|
||||
Destinations: []string{
|
||||
"GroupSwarm",
|
||||
@ -159,6 +172,8 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
assert.Contains(t, peers, account.Peers["peerD"])
|
||||
assert.Contains(t, peers, account.Peers["peerE"])
|
||||
assert.Contains(t, peers, account.Peers["peerF"])
|
||||
assert.Contains(t, peers, account.Peers["peerG"])
|
||||
assert.Contains(t, peers, account.Peers["peerH"])
|
||||
|
||||
epectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
@ -189,21 +204,6 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
},
|
||||
|
||||
{
|
||||
PeerIP: "100.65.62.5",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
@ -280,10 +280,16 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
assert.Len(t, firewallRules, len(epectedFirewallRules))
|
||||
slices.SortFunc(epectedFirewallRules, sortFunc())
|
||||
slices.SortFunc(firewallRules, sortFunc())
|
||||
for i := range firewallRules {
|
||||
assert.Equal(t, epectedFirewallRules[i], firewallRules[i])
|
||||
|
||||
for _, rule := range firewallRules {
|
||||
contains := false
|
||||
for _, expectedRule := range epectedFirewallRules {
|
||||
if rule.IsEqual(expectedRule) {
|
||||
contains = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, contains, "rule not found in expected rules %#v", rule)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -364,7 +364,7 @@ func toProtocolRoute(route *route.Route) *proto.Route {
|
||||
}
|
||||
|
||||
func toProtocolRoutes(routes []*route.Route) []*proto.Route {
|
||||
protoRoutes := make([]*proto.Route, 0)
|
||||
protoRoutes := make([]*proto.Route, 0, len(routes))
|
||||
for _, r := range routes {
|
||||
protoRoutes = append(protoRoutes, toProtocolRoute(r))
|
||||
}
|
||||
|
@ -303,55 +303,47 @@ func (a *Account) GetPeerNetworkMap(
|
||||
return nm
|
||||
}
|
||||
|
||||
func (a *Account) addNetworksRoutingPeers(networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers []string) []*nbpeer.Peer {
|
||||
missingPeers := map[string]struct{}{}
|
||||
for _, r := range networkResourcesRoutes {
|
||||
if r.Peer == peer.Key {
|
||||
continue
|
||||
}
|
||||
func (a *Account) addNetworksRoutingPeers(
|
||||
networkResourcesRoutes []*route.Route,
|
||||
peer *nbpeer.Peer,
|
||||
peersToConnect []*nbpeer.Peer,
|
||||
expiredPeers []*nbpeer.Peer,
|
||||
isRouter bool,
|
||||
sourcePeers map[string]struct{},
|
||||
) []*nbpeer.Peer {
|
||||
|
||||
missing := true
|
||||
for _, p := range slices.Concat(peersToConnect, expiredPeers) {
|
||||
if r.Peer == p.Key {
|
||||
missing = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if missing {
|
||||
missingPeers[r.Peer] = struct{}{}
|
||||
}
|
||||
networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
|
||||
for _, r := range networkResourcesRoutes {
|
||||
networkRoutesPeers[r.PeerID] = struct{}{}
|
||||
}
|
||||
|
||||
if isRouter {
|
||||
for _, s := range sourcePeers {
|
||||
if s == peer.ID {
|
||||
continue
|
||||
}
|
||||
delete(sourcePeers, peer.ID)
|
||||
|
||||
missing := true
|
||||
for _, p := range slices.Concat(peersToConnect, expiredPeers) {
|
||||
if s == p.ID {
|
||||
missing = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if missing {
|
||||
p, ok := a.Peers[s]
|
||||
if ok {
|
||||
missingPeers[p.Key] = struct{}{}
|
||||
}
|
||||
}
|
||||
for _, existingPeer := range peersToConnect {
|
||||
delete(sourcePeers, existingPeer.ID)
|
||||
delete(networkRoutesPeers, existingPeer.ID)
|
||||
}
|
||||
for _, expPeer := range expiredPeers {
|
||||
delete(sourcePeers, expPeer.ID)
|
||||
delete(networkRoutesPeers, expPeer.ID)
|
||||
}
|
||||
|
||||
missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
|
||||
if isRouter {
|
||||
for p := range sourcePeers {
|
||||
missingPeers[p] = struct{}{}
|
||||
}
|
||||
}
|
||||
for p := range networkRoutesPeers {
|
||||
missingPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
for p := range missingPeers {
|
||||
for _, p2 := range a.Peers {
|
||||
if p2.Key == p {
|
||||
peersToConnect = append(peersToConnect, p2)
|
||||
break
|
||||
}
|
||||
if missingPeer := a.Peers[p]; missingPeer != nil {
|
||||
peersToConnect = append(peersToConnect, missingPeer)
|
||||
}
|
||||
}
|
||||
|
||||
return peersToConnect
|
||||
}
|
||||
|
||||
@ -1045,37 +1037,32 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
||||
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
||||
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
peerInGroups := false
|
||||
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
|
||||
for _, g := range groups {
|
||||
group, ok := a.Groups[g]
|
||||
if !ok {
|
||||
uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups)
|
||||
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
|
||||
for _, p := range uniquePeerIDs {
|
||||
peer, ok := a.Peers[p]
|
||||
if !ok || peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, p := range group.Peers {
|
||||
peer, ok := a.Peers[p]
|
||||
if !ok || peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// validate the peer based on policy posture checks applied
|
||||
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
if !isValid {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peer.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if peer.ID == peerID {
|
||||
peerInGroups = true
|
||||
continue
|
||||
}
|
||||
|
||||
filteredPeers = append(filteredPeers, peer)
|
||||
// validate the peer based on policy posture checks applied
|
||||
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
if !isValid {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := validatedPeersMap[peer.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if peer.ID == peerID {
|
||||
peerInGroups = true
|
||||
continue
|
||||
}
|
||||
|
||||
filteredPeers = append(filteredPeers, peer)
|
||||
}
|
||||
|
||||
return filteredPeers, peerInGroups
|
||||
}
|
||||
|
||||
@ -1151,7 +1138,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
|
||||
continue
|
||||
}
|
||||
|
||||
rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap)
|
||||
rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap)
|
||||
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
|
||||
fwRules = append(fwRules, rules...)
|
||||
}
|
||||
@ -1159,7 +1146,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli
|
||||
return fwRules
|
||||
}
|
||||
|
||||
func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
|
||||
func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer {
|
||||
distPeersWithPolicy := make(map[string]struct{})
|
||||
for _, id := range rule.Sources {
|
||||
group := a.Groups[id]
|
||||
@ -1173,7 +1160,7 @@ func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeer
|
||||
}
|
||||
_, distPeer := distributionPeers[pID]
|
||||
_, valid := validatedPeersMap[pID]
|
||||
if distPeer && valid {
|
||||
if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) {
|
||||
distPeersWithPolicy[pID] = struct{}{}
|
||||
}
|
||||
}
|
||||
@ -1271,7 +1258,11 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer
|
||||
distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups)
|
||||
|
||||
rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers)
|
||||
routesFirewallRules = append(routesFirewallRules, rules...)
|
||||
for _, rule := range rules {
|
||||
if len(rule.SourceRanges) > 0 {
|
||||
routesFirewallRules = append(routesFirewallRules, rule)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return routesFirewallRules
|
||||
@ -1303,10 +1294,10 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy {
|
||||
}
|
||||
|
||||
// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers.
|
||||
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) {
|
||||
func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) {
|
||||
var isRoutingPeer bool
|
||||
var routes []*route.Route
|
||||
var allSourcePeers []string
|
||||
allSourcePeers := make(map[string]struct{}, len(a.Peers))
|
||||
|
||||
for _, resource := range a.NetworkResources {
|
||||
var addSourcePeers bool
|
||||
@ -1319,23 +1310,22 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
|
||||
}
|
||||
}
|
||||
|
||||
addedResourceRoute := false
|
||||
for _, policy := range resourcePolicies[resource.ID] {
|
||||
for _, sourceGroup := range policy.SourceGroups() {
|
||||
group := a.GetGroup(sourceGroup)
|
||||
if group == nil {
|
||||
log.WithContext(ctx).Warnf("policy %s has source group %s that doesn't exist under account %s, will continue map generation without it", policy.ID, sourceGroup, a.Id)
|
||||
continue
|
||||
peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
|
||||
if addSourcePeers {
|
||||
for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) {
|
||||
allSourcePeers[pID] = struct{}{}
|
||||
}
|
||||
|
||||
// routing peer should be able to connect with all source peers
|
||||
if addSourcePeers {
|
||||
allSourcePeers = append(allSourcePeers, group.Peers...)
|
||||
} else if slices.Contains(group.Peers, peerID) {
|
||||
// add routes for the resource if the peer is in the distribution group
|
||||
for peerId, router := range networkRoutingPeers {
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
|
||||
}
|
||||
} else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) {
|
||||
// add routes for the resource if the peer is in the distribution group
|
||||
for peerId, router := range networkRoutingPeers {
|
||||
routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...)
|
||||
}
|
||||
addedResourceRoute = true
|
||||
}
|
||||
if addedResourceRoute {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1343,6 +1333,42 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st
|
||||
return isRoutingPeer, routes, allSourcePeers
|
||||
}
|
||||
|
||||
func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
|
||||
var dest []string
|
||||
for _, peerID := range inputPeers {
|
||||
if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) {
|
||||
dest = append(dest, peerID)
|
||||
}
|
||||
}
|
||||
return dest
|
||||
}
|
||||
|
||||
func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string {
|
||||
peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity
|
||||
for _, groupID := range groups {
|
||||
group := a.GetGroup(groupID)
|
||||
if group == nil {
|
||||
log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id)
|
||||
continue
|
||||
}
|
||||
|
||||
if group.IsGroupAll() || len(groups) == 1 {
|
||||
return group.Peers
|
||||
}
|
||||
|
||||
for _, peerID := range group.Peers {
|
||||
peerIDs[peerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(peerIDs))
|
||||
for peerID := range peerIDs {
|
||||
ids = append(ids, peerID)
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
// getNetworkResources filters and returns a list of network resources associated with the given network ID.
|
||||
func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource {
|
||||
var resources []*resourceTypes.NetworkResource
|
||||
|
@ -1,14 +1,20 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@ -310,19 +316,19 @@ func Test_GetResourcePoliciesMap(t *testing.T) {
|
||||
|
||||
func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) {
|
||||
account := setupTestAccount()
|
||||
peer := &nbpeer.Peer{Key: "peer1"}
|
||||
peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"}
|
||||
networkResourcesRoutes := []*route.Route{
|
||||
{Peer: "peer2Key"},
|
||||
{Peer: "peer3Key"},
|
||||
{Peer: "peer2Key", PeerID: "peer2"},
|
||||
{Peer: "peer3Key", PeerID: "peer3"},
|
||||
}
|
||||
peersToConnect := []*nbpeer.Peer{
|
||||
{Key: "peer2Key"},
|
||||
{Key: "peer2Key", ID: "peer2"},
|
||||
}
|
||||
expiredPeers := []*nbpeer.Peer{
|
||||
{Key: "peer4Key"},
|
||||
{Key: "peer4Key", ID: "peer4"},
|
||||
}
|
||||
|
||||
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{})
|
||||
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
|
||||
require.Len(t, result, 2)
|
||||
require.Equal(t, "peer2Key", result[0].Key)
|
||||
require.Equal(t, "peer3Key", result[1].Key)
|
||||
@ -339,7 +345,7 @@ func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) {
|
||||
}
|
||||
expiredPeers := []*nbpeer.Peer{}
|
||||
|
||||
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{})
|
||||
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
|
||||
require.Len(t, result, 1)
|
||||
require.Equal(t, "peer2Key", result[0].Key)
|
||||
}
|
||||
@ -358,7 +364,7 @@ func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) {
|
||||
{Key: "peer3Key"},
|
||||
}
|
||||
|
||||
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{})
|
||||
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
|
||||
require.Len(t, result, 1)
|
||||
require.Equal(t, "peer2Key", result[0].Key)
|
||||
}
|
||||
@ -370,6 +376,382 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) {
|
||||
peersToConnect := []*nbpeer.Peer{}
|
||||
expiredPeers := []*nbpeer.Peer{}
|
||||
|
||||
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{})
|
||||
result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{})
|
||||
require.Len(t, result, 0)
|
||||
}
|
||||
|
||||
const (
|
||||
accID = "accountID"
|
||||
network1ID = "network1ID"
|
||||
group1ID = "group1"
|
||||
accNetResourcePeer1ID = "peer1"
|
||||
accNetResourcePeer2ID = "peer2"
|
||||
accNetResourceRouter1ID = "router1"
|
||||
accNetResource1ID = "resource1ID"
|
||||
accNetResourceRestrictPostureCheckID = "restrictPostureCheck"
|
||||
accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck"
|
||||
accNetResourceLockedPostureCheckID = "lockedPostureCheck"
|
||||
accNetResourceLinuxPostureCheckID = "linuxPostureCheck"
|
||||
)
|
||||
|
||||
var (
|
||||
accNetResourcePeer1IP = net.IP{192, 168, 1, 1}
|
||||
accNetResourcePeer2IP = net.IP{192, 168, 1, 2}
|
||||
accNetResourceRouter1IP = net.IP{192, 168, 1, 3}
|
||||
accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}}
|
||||
)
|
||||
|
||||
func getBasicAccountsWithResource() *Account {
|
||||
return &Account{
|
||||
Id: accID,
|
||||
Peers: map[string]*nbpeer.Peer{
|
||||
accNetResourcePeer1ID: {
|
||||
ID: accNetResourcePeer1ID,
|
||||
AccountID: accID,
|
||||
Key: "peer1Key",
|
||||
IP: accNetResourcePeer1IP,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
WtVersion: "0.35.1",
|
||||
KernelVersion: "4.4.0",
|
||||
},
|
||||
},
|
||||
accNetResourcePeer2ID: {
|
||||
ID: accNetResourcePeer2ID,
|
||||
AccountID: accID,
|
||||
Key: "peer2Key",
|
||||
IP: accNetResourcePeer2IP,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
GoOS: "windows",
|
||||
WtVersion: "0.34.1",
|
||||
KernelVersion: "4.4.0",
|
||||
},
|
||||
},
|
||||
accNetResourceRouter1ID: {
|
||||
ID: accNetResourceRouter1ID,
|
||||
AccountID: accID,
|
||||
Key: "router1Key",
|
||||
IP: accNetResourceRouter1IP,
|
||||
Meta: nbpeer.PeerSystemMeta{
|
||||
GoOS: "linux",
|
||||
WtVersion: "0.35.1",
|
||||
KernelVersion: "4.4.0",
|
||||
},
|
||||
},
|
||||
},
|
||||
Groups: map[string]*Group{
|
||||
group1ID: {
|
||||
ID: group1ID,
|
||||
Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID},
|
||||
},
|
||||
},
|
||||
Networks: []*networkTypes.Network{
|
||||
{
|
||||
ID: network1ID,
|
||||
AccountID: accID,
|
||||
Name: "network1",
|
||||
},
|
||||
},
|
||||
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||
{
|
||||
ID: accNetResourceRouter1ID,
|
||||
NetworkID: network1ID,
|
||||
AccountID: accID,
|
||||
Peer: accNetResourceRouter1ID,
|
||||
PeerGroups: []string{},
|
||||
Masquerade: false,
|
||||
Metric: 100,
|
||||
},
|
||||
},
|
||||
NetworkResources: []*resourceTypes.NetworkResource{
|
||||
{
|
||||
ID: accNetResource1ID,
|
||||
AccountID: accID,
|
||||
NetworkID: network1ID,
|
||||
Address: "10.10.10.0/24",
|
||||
Prefix: netip.MustParsePrefix("10.10.10.0/24"),
|
||||
Type: resourceTypes.NetworkResourceType("subnet"),
|
||||
},
|
||||
},
|
||||
Policies: []*Policy{
|
||||
{
|
||||
ID: "policy1ID",
|
||||
AccountID: accID,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "rule1ID",
|
||||
Enabled: true,
|
||||
Sources: []string{group1ID},
|
||||
DestinationResource: Resource{
|
||||
ID: accNetResource1ID,
|
||||
Type: "Host",
|
||||
},
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Ports: []string{"80"},
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: nil,
|
||||
},
|
||||
},
|
||||
PostureChecks: []*posture.Checks{
|
||||
{
|
||||
ID: accNetResourceRestrictPostureCheckID,
|
||||
Name: accNetResourceRestrictPostureCheckID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.35.0",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: accNetResourceRelaxedPostureCheckID,
|
||||
Name: accNetResourceRelaxedPostureCheckID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: accNetResourceLockedPostureCheckID,
|
||||
Name: accNetResourceLockedPostureCheckID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "7.7.7",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: accNetResourceLinuxPostureCheckID,
|
||||
Name: accNetResourceLinuxPostureCheckID,
|
||||
Checks: posture.ChecksDefinition{
|
||||
OSVersionCheck: &posture.OSVersionCheck{
|
||||
Linux: &posture.MinKernelVersionCheck{
|
||||
MinKernelVersion: "0.0.0"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// all peers should match the policy
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 1, "expected rules count don't match")
|
||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// should allow peer1 to match the policy
|
||||
policy := account.Policies[0]
|
||||
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 1, "expected rules count don't match")
|
||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// should not match any peer
|
||||
policy := account.Policies[0]
|
||||
policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID}
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 0, "expected rules count don't match")
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// should allow peer1 to match the policy
|
||||
policy := account.Policies[0]
|
||||
policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID}
|
||||
|
||||
// should allow peer1 and peer2 to match the policy
|
||||
newPolicy := &Policy{
|
||||
ID: "policy2ID",
|
||||
AccountID: accID,
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: "policy2ID",
|
||||
Enabled: true,
|
||||
Sources: []string{group1ID},
|
||||
DestinationResource: Resource{
|
||||
ID: accNetResource1ID,
|
||||
Type: "Host",
|
||||
},
|
||||
Protocol: PolicyRuleProtocolTCP,
|
||||
Ports: []string{"22"},
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID},
|
||||
}
|
||||
|
||||
account.Policies = append(account.Policies, newPolicy)
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 2, "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 2, "expected rules count don't match")
|
||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
|
||||
assert.Equal(t, uint16(22), rules[1].Port, "should have port 22")
|
||||
assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) {
|
||||
account := getBasicAccountsWithResource()
|
||||
|
||||
// two posture checks should match only the peers that match both checks
|
||||
policy := account.Policies[0]
|
||||
policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID}
|
||||
|
||||
// validate for peer1
|
||||
isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate for peer2
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.False(t, isRouter, "expected router status")
|
||||
assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 0, "expected source peers don't match")
|
||||
|
||||
// validate routes for router1
|
||||
isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap())
|
||||
assert.True(t, isRouter, "should be router")
|
||||
assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match")
|
||||
assert.Len(t, sourcePeers, 1, "expected source peers don't match")
|
||||
assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match")
|
||||
|
||||
// validate rules for router1
|
||||
rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap())
|
||||
assert.Len(t, rules, 1, "expected rules count don't match")
|
||||
assert.Equal(t, uint16(80), rules[0].Port, "should have port 80")
|
||||
assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp")
|
||||
if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") {
|
||||
t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String())
|
||||
}
|
||||
if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") {
|
||||
t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String())
|
||||
}
|
||||
}
|
||||
|
@ -35,6 +35,15 @@ type FirewallRule struct {
|
||||
Port string
|
||||
}
|
||||
|
||||
// IsEqual checks if two firewall rules are equal.
|
||||
func (r *FirewallRule) IsEqual(other *FirewallRule) bool {
|
||||
return r.PeerIP == other.PeerIP &&
|
||||
r.Direction == other.Direction &&
|
||||
r.Action == other.Action &&
|
||||
r.Protocol == other.Protocol &&
|
||||
r.Port == other.Port
|
||||
}
|
||||
|
||||
// generateRouteFirewallRules generates a list of firewall rules for a given route.
|
||||
func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule {
|
||||
rulesExists := make(map[string]struct{})
|
||||
|
@ -117,9 +117,20 @@ func (p *Policy) RuleGroups() []string {
|
||||
|
||||
// SourceGroups returns a slice of all unique source groups referenced in the policy's rules.
|
||||
func (p *Policy) SourceGroups() []string {
|
||||
groups := make([]string, 0)
|
||||
for _, rule := range p.Rules {
|
||||
groups = append(groups, rule.Sources...)
|
||||
if len(p.Rules) == 1 {
|
||||
return p.Rules[0].Sources
|
||||
}
|
||||
return groups
|
||||
groups := make(map[string]struct{}, len(p.Rules))
|
||||
for _, rule := range p.Rules {
|
||||
for _, source := range rule.Sources {
|
||||
groups[source] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
groupIDs := make([]string, 0, len(groups))
|
||||
for groupID := range groups {
|
||||
groupIDs = append(groupIDs, groupID)
|
||||
}
|
||||
|
||||
return groupIDs
|
||||
}
|
||||
|
@ -95,6 +95,7 @@ type Route struct {
|
||||
NetID NetID
|
||||
Description string
|
||||
Peer string
|
||||
PeerID string `gorm:"-"`
|
||||
PeerGroups []string `gorm:"serializer:json"`
|
||||
NetworkType NetworkType
|
||||
Masquerade bool
|
||||
@ -120,6 +121,7 @@ func (r *Route) Copy() *Route {
|
||||
KeepRoute: r.KeepRoute,
|
||||
NetworkType: r.NetworkType,
|
||||
Peer: r.Peer,
|
||||
PeerID: r.PeerID,
|
||||
PeerGroups: slices.Clone(r.PeerGroups),
|
||||
Metric: r.Metric,
|
||||
Masquerade: r.Masquerade,
|
||||
@ -146,6 +148,7 @@ func (r *Route) IsEqual(other *Route) bool {
|
||||
other.KeepRoute == r.KeepRoute &&
|
||||
other.NetworkType == r.NetworkType &&
|
||||
other.Peer == r.Peer &&
|
||||
other.PeerID == r.PeerID &&
|
||||
other.Metric == r.Metric &&
|
||||
other.Masquerade == r.Masquerade &&
|
||||
other.Enabled == r.Enabled &&
|
||||
|
Loading…
Reference in New Issue
Block a user