From 39329e12a1f54cc1770a010279fa4902ccd51693 Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Wed, 13 Nov 2024 13:46:00 +0100
Subject: [PATCH 01/39] [client] Improve state write timeout and abort work
early on timeout (#2882)
* Improve state write timeout and abort work early on timeout
* Don't block on initial persist state
---
client/firewall/iptables/manager_linux.go | 8 +++++---
client/firewall/nftables/manager_linux.go | 8 +++++---
client/internal/config.go | 6 +++---
client/internal/dns/server.go | 10 +++++++---
client/internal/statemanager/manager.go | 20 +++++---------------
client/internal/statemanager/path.go | 22 +++++-----------------
management/server/file_store.go | 2 +-
util/file.go | 23 ++++++++++++++++++-----
util/file_test.go | 7 ++++---
9 files changed, 53 insertions(+), 53 deletions(-)
diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go
index a59bd2c60..adb8f20ef 100644
--- a/client/firewall/iptables/manager_linux.go
+++ b/client/firewall/iptables/manager_linux.go
@@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
// persist early to ensure cleanup of chains
- if err := stateManager.PersistState(context.Background()); err != nil {
- log.Errorf("failed to persist state: %v", err)
- }
+ go func() {
+ if err := stateManager.PersistState(context.Background()); err != nil {
+ log.Errorf("failed to persist state: %v", err)
+ }
+ }()
return nil
}
diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go
index ea8912f27..3f8fac249 100644
--- a/client/firewall/nftables/manager_linux.go
+++ b/client/firewall/nftables/manager_linux.go
@@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
}
// persist early
- if err := stateManager.PersistState(context.Background()); err != nil {
- log.Errorf("failed to persist state: %v", err)
- }
+ go func() {
+ if err := stateManager.PersistState(context.Background()); err != nil {
+ log.Errorf("failed to persist state: %v", err)
+ }
+ }()
return nil
}
diff --git a/client/internal/config.go b/client/internal/config.go
index ee54c6380..ce87835cd 100644
--- a/client/internal/config.go
+++ b/client/internal/config.go
@@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
if err != nil {
return nil, err
}
- err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg)
+ err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
return cfg, err
}
@@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
// WriteOutConfig write put the prepared config to the given path
func WriteOutConfig(path string, config *Config) error {
- return util.WriteJson(path, config)
+ return util.WriteJson(context.Background(), path, config)
}
// createNewConfig creates a new config generating a new Wireguard key and saving to file
@@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) {
}
if updated {
- if err := util.WriteJson(input.ConfigPath, config); err != nil {
+ if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err
}
}
diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go
index 929e1e60c..6c4dccae7 100644
--- a/client/internal/dns/server.go
+++ b/client/internal/dns/server.go
@@ -326,9 +326,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
// persist dns state right away
ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
defer cancel()
- if err := s.stateManager.PersistState(ctx); err != nil {
- log.Errorf("Failed to persist dns state: %v", err)
- }
+
+ // don't block
+ go func() {
+ if err := s.stateManager.PersistState(ctx); err != nil {
+ log.Errorf("Failed to persist dns state: %v", err)
+ }
+ }()
if s.searchDomainNotifier != nil {
s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains())
diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go
index a5a14f807..580ccdfc7 100644
--- a/client/internal/statemanager/manager.go
+++ b/client/internal/statemanager/manager.go
@@ -16,6 +16,7 @@ import (
"golang.org/x/exp/maps"
nberrors "github.com/netbirdio/netbird/client/errors"
+ "github.com/netbirdio/netbird/util"
)
// State interface defines the methods that all state types must implement
@@ -178,25 +179,14 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil
}
- ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
+ ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
done := make(chan error, 1)
+ start := time.Now()
go func() {
- data, err := json.MarshalIndent(m.states, "", " ")
- if err != nil {
- done <- fmt.Errorf("marshal states: %w", err)
- return
- }
-
- // nolint:gosec
- if err := os.WriteFile(m.filePath, data, 0640); err != nil {
- done <- fmt.Errorf("write state file: %w", err)
- return
- }
-
- done <- nil
+ done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states)
}()
select {
@@ -208,7 +198,7 @@ func (m *Manager) PersistState(ctx context.Context) error {
}
}
- log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty))
+ log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start))
clear(m.dirty)
diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go
index 96d6a9f12..6cfd79a12 100644
--- a/client/internal/statemanager/path.go
+++ b/client/internal/statemanager/path.go
@@ -4,32 +4,20 @@ import (
"os"
"path/filepath"
"runtime"
-
- log "github.com/sirupsen/logrus"
)
// GetDefaultStatePath returns the path to the state file based on the operating system
-// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist.
+// It returns an empty string if the path cannot be determined.
func GetDefaultStatePath() string {
- var path string
-
switch runtime.GOOS {
case "windows":
- path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
+ return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json")
case "darwin", "linux":
- path = "/var/lib/netbird/state.json"
+ return "/var/lib/netbird/state.json"
case "freebsd", "openbsd", "netbsd", "dragonfly":
- path = "/var/db/netbird/state.json"
- // ios/android don't need state
- default:
- return ""
+ return "/var/db/netbird/state.json"
}
- dir := filepath.Dir(path)
- if err := os.MkdirAll(dir, 0755); err != nil {
- log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err)
- return ""
- }
+ return ""
- return path
}
diff --git a/management/server/file_store.go b/management/server/file_store.go
index 561e133ce..f375fb990 100644
--- a/management/server/file_store.go
+++ b/management/server/file_store.go
@@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) {
// It is recommended to call it with locking FileStore.mux
func (s *FileStore) persist(ctx context.Context, file string) error {
start := time.Now()
- err := util.WriteJson(file, s)
+ err := util.WriteJson(context.Background(), file, s)
if err != nil {
return err
}
diff --git a/util/file.go b/util/file.go
index ecaecd222..4641cc1b8 100644
--- a/util/file.go
+++ b/util/file.go
@@ -15,7 +15,7 @@ import (
)
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
-func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
+func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
@@ -26,18 +26,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error {
return err
}
- return writeJson(file, obj, configDir, configFileName)
+ return writeJson(ctx, file, obj, configDir, configFileName)
}
// WriteJson writes JSON config object to a file creating parent directories if required
// The output JSON is pretty-formatted
-func WriteJson(file string, obj interface{}) error {
+func WriteJson(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
if err != nil {
return err
}
- return writeJson(file, obj, configDir, configFileName)
+ return writeJson(ctx, file, obj, configDir, configFileName)
}
// DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file
@@ -79,7 +79,11 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
return nil
}
-func writeJson(file string, obj interface{}, configDir string, configFileName string) error {
+func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error {
+ // Check context before expensive operations
+ if ctx.Err() != nil {
+ return ctx.Err()
+ }
// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
@@ -87,6 +91,10 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st
return err
}
+ if ctx.Err() != nil {
+ return ctx.Err()
+ }
+
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil {
return err
@@ -111,6 +119,11 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st
return err
}
+ // Check context again
+ if ctx.Err() != nil {
+ return ctx.Err()
+ }
+
err = os.Rename(tempFileName, file)
if err != nil {
return err
diff --git a/util/file_test.go b/util/file_test.go
index 566d8eda6..f8c9dfabb 100644
--- a/util/file_test.go
+++ b/util/file_test.go
@@ -1,6 +1,7 @@
package util
import (
+ "context"
"crypto/md5"
"encoding/hex"
"io"
@@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
tmpDir := t.TempDir()
- err := WriteJson(tmpDir+"/testconfig.json", tt.config)
+ err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config)
require.NoError(t, err)
read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{})
@@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) {
src := tmpDir + "/copytest_src"
dst := tmpDir + "/copytest_dst"
- err := WriteJson(src, tt.srcContent)
+ err := WriteJson(context.Background(), src, tt.srcContent)
require.NoError(t, err)
err = CopyFileContents(src, dst)
@@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) {
_ = os.Remove(cfgFile)
}()
- err := WriteJson(cfgFile, tt.config)
+ err := WriteJson(context.Background(), cfgFile, tt.config)
require.NoError(t, err)
read, err := ReadJson(cfgFile, &TestConfig{})
From ed047ec9dda048120edf4f074162a27136ac3cd6 Mon Sep 17 00:00:00 2001
From: bcmmbaga
Date: Wed, 13 Nov 2024 16:16:30 +0300
Subject: [PATCH 02/39] Add account locking and merge group deletion methods
Signed-off-by: bcmmbaga
---
management/server/group.go | 66 ++++++++++------------------------
management/server/sql_store.go | 2 +-
2 files changed, 20 insertions(+), 48 deletions(-)
diff --git a/management/server/group.go b/management/server/group.go
index 57960e7f9..154a33b13 100644
--- a/management/server/group.go
+++ b/management/server/group.go
@@ -215,48 +215,9 @@ func difference(a, b []string) []string {
// DeleteGroup object of the peers.
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
- user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
- if err != nil {
- return err
- }
-
- if user.AccountID != accountID {
- return status.NewUserNotPartOfAccountError()
- }
-
- if user.IsRegularUser() {
- return status.NewAdminPermissionError()
- }
-
- var group *nbgroup.Group
-
- err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
- group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
- if err != nil {
- return err
- }
-
- if group.IsGroupAll() {
- return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
- }
-
- if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil {
- return err
- }
-
- if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
- return err
- }
-
- return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID)
- })
- if err != nil {
- return err
- }
-
- am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta())
-
- return nil
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ defer unlock()
+ return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
}
// DeleteGroups deletes groups from an account.
@@ -285,13 +246,14 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
for _, groupID := range groupIDs {
- group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
+ group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
if err != nil {
+ allErrors = errors.Join(allErrors, err)
continue
}
if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
- allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
+ allErrors = errors.Join(allErrors, err)
continue
}
@@ -318,12 +280,15 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ defer unlock()
+
var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
- group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
+ group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
@@ -356,12 +321,15 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ defer unlock()
+
var group *nbgroup.Group
var updateAccountPeers bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
- group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
+ group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
if err != nil {
return err
}
@@ -430,13 +398,17 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.
if group.Issued == nbgroup.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
- return status.Errorf(status.NotFound, "user not found")
+ return err
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
+ if group.IsGroupAll() {
+ return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
+ }
+
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index 7c741d35c..0ebda6440 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -1278,7 +1278,7 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
- return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error)
+ return status.Errorf(status.Internal, "failed to delete groups from store")
}
return nil
From a4d905ffe77881b682a4798d5564b89860404a0a Mon Sep 17 00:00:00 2001
From: bcmmbaga
Date: Wed, 13 Nov 2024 16:56:22 +0300
Subject: [PATCH 03/39] Fix tests
Signed-off-by: bcmmbaga
---
management/server/group_test.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/management/server/group_test.go b/management/server/group_test.go
index 89184e819..59094a23e 100644
--- a/management/server/group_test.go
+++ b/management/server/group_test.go
@@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
{
name: "delete non-existent group",
groupIDs: []string{"non-existent-group"},
- expectedDeleted: []string{"non-existent-group"},
+ expectedReasons: []string{"group: non-existent-group not found"},
},
{
name: "delete multiple groups with mixed results",
From b48afd92fdc2080cd0b4bf7bd4c046684b76338a Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Wed, 13 Nov 2024 15:02:51 +0100
Subject: [PATCH 04/39] [relay-server] Always close ws conn when work thread
exit (#2879)
Close ws conn when work thread exit
---
relay/server/peer.go | 17 ++++++++++++-----
1 file changed, 12 insertions(+), 5 deletions(-)
diff --git a/relay/server/peer.go b/relay/server/peer.go
index c909c35d5..f65fb786a 100644
--- a/relay/server/peer.go
+++ b/relay/server/peer.go
@@ -16,6 +16,8 @@ import (
const (
bufferSize = 8820
+
+ errCloseConn = "failed to close connection to peer: %s"
)
// Peer represents a peer connection
@@ -46,6 +48,12 @@ func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) *
// It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle
// the message accordingly.
func (p *Peer) Work() {
+ defer func() {
+ if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
+ p.log.Errorf(errCloseConn, err)
+ }
+ }()
+
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -97,7 +105,7 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc *
case messages.MsgTypeClose:
p.log.Infof("peer exited gracefully")
if err := p.conn.Close(); err != nil {
- log.Errorf("failed to close connection to peer: %s", err)
+ log.Errorf(errCloseConn, err)
}
default:
p.log.Warnf("received unexpected message type: %s", msgType)
@@ -121,9 +129,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) {
p.log.Errorf("failed to send close message to peer: %s", p.String())
}
- err = p.conn.Close()
- if err != nil {
- p.log.Errorf("failed to close connection to peer: %s", err)
+ if err := p.conn.Close(); err != nil {
+ p.log.Errorf(errCloseConn, err)
}
}
@@ -132,7 +139,7 @@ func (p *Peer) Close() {
defer p.connMu.Unlock()
if err := p.conn.Close(); err != nil {
- p.log.Errorf("failed to close connection to peer: %s", err)
+ p.log.Errorf(errCloseConn, err)
}
}
From 6886691213aa622b1f0f7440c83a3a4c4831cf3d Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Wed, 13 Nov 2024 15:21:33 +0100
Subject: [PATCH 05/39] Update route calculation tests (#2884)
- Add two new test cases for p2p and relay routes with same latency
- Add extra statuses generation
---
client/internal/routemanager/client_test.go | 98 +++++++++++++++++++++
1 file changed, 98 insertions(+)
diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client_test.go
index 583156e4d..56fcf1613 100644
--- a/client/internal/routemanager/client_test.go
+++ b/client/internal/routemanager/client_test.go
@@ -1,6 +1,7 @@
package routemanager
import (
+ "fmt"
"net/netip"
"testing"
"time"
@@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
currentRoute: "route1",
expectedRouteID: "route1",
},
+ {
+ name: "relayed routes with latency 0 should maintain previous choice",
+ statuses: map[route.ID]routerPeerStatus{
+ "route1": {
+ connected: true,
+ relayed: true,
+ latency: 0 * time.Millisecond,
+ },
+ "route2": {
+ connected: true,
+ relayed: true,
+ latency: 0 * time.Millisecond,
+ },
+ },
+ existingRoutes: map[route.ID]*route.Route{
+ "route1": {
+ ID: "route1",
+ Metric: route.MaxMetric,
+ Peer: "peer1",
+ },
+ "route2": {
+ ID: "route2",
+ Metric: route.MaxMetric,
+ Peer: "peer2",
+ },
+ },
+ currentRoute: "route1",
+ expectedRouteID: "route1",
+ },
+ {
+ name: "p2p routes with latency 0 should maintain previous choice",
+ statuses: map[route.ID]routerPeerStatus{
+ "route1": {
+ connected: true,
+ relayed: false,
+ latency: 0 * time.Millisecond,
+ },
+ "route2": {
+ connected: true,
+ relayed: false,
+ latency: 0 * time.Millisecond,
+ },
+ },
+ existingRoutes: map[route.ID]*route.Route{
+ "route1": {
+ ID: "route1",
+ Metric: route.MaxMetric,
+ Peer: "peer1",
+ },
+ "route2": {
+ ID: "route2",
+ Metric: route.MaxMetric,
+ Peer: "peer2",
+ },
+ },
+ currentRoute: "route1",
+ expectedRouteID: "route1",
+ },
{
name: "current route with bad score should be changed to route with better score",
statuses: map[route.ID]routerPeerStatus{
@@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) {
},
}
+ // fill the test data with random routes
+ for _, tc := range testCases {
+ for i := 0; i < 50; i++ {
+ dummyRoute := &route.Route{
+ ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)),
+ Metric: route.MinMetric,
+ Peer: fmt.Sprintf("dummy_p1_%d", i),
+ }
+ tc.existingRoutes[dummyRoute.ID] = dummyRoute
+ }
+ for i := 0; i < 50; i++ {
+ dummyRoute := &route.Route{
+ ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)),
+ Metric: route.MinMetric,
+ Peer: fmt.Sprintf("dummy_p1_%d", i),
+ }
+ tc.existingRoutes[dummyRoute.ID] = dummyRoute
+ }
+
+ for i := 0; i < 50; i++ {
+ id := route.ID(fmt.Sprintf("dummy_p1_%d", i))
+ dummyStatus := routerPeerStatus{
+ connected: false,
+ relayed: true,
+ latency: 0,
+ }
+ tc.statuses[id] = dummyStatus
+ }
+ for i := 0; i < 50; i++ {
+ id := route.ID(fmt.Sprintf("dummy_p2_%d", i))
+ dummyStatus := routerPeerStatus{
+ connected: false,
+ relayed: true,
+ latency: 0,
+ }
+ tc.statuses[id] = dummyStatus
+ }
+ }
+
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
currentRoute := &route.Route{
From be78efbd429c0b7180efe36fe4a826436e5a675b Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Thu, 14 Nov 2024 20:15:16 +0100
Subject: [PATCH 06/39] [client] Handle panic on nil wg interface (#2891)
---
client/internal/engine.go | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/client/internal/engine.go b/client/internal/engine.go
index 190d795cd..0f3a5d28a 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -38,7 +38,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
-
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
@@ -171,7 +170,7 @@ type Engine struct {
relayManager *relayClient.Manager
stateManager *statemanager.Manager
- srWatcher *guard.SRWatcher
+ srWatcher *guard.SRWatcher
}
// Peer is an instance of the Connection Peer
@@ -641,6 +640,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
}
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
+ if e.wgInterface == nil {
+ return errors.New("wireguard interface is not initialized")
+ }
+
if e.wgInterface.Address().String() != conf.Address {
oldAddr := e.wgInterface.Address().String()
log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address)
From 44e799c687ed1fd5e6a658aff3a06ac1594cec69 Mon Sep 17 00:00:00 2001
From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Fri, 15 Nov 2024 11:16:16 +0100
Subject: [PATCH 07/39] [management] Fix limited peer view groups (#2894)
---
management/server/group.go | 12 ++++--------
management/server/http/peers_handler.go | 20 ++++++++++++++++----
2 files changed, 20 insertions(+), 12 deletions(-)
diff --git a/management/server/group.go b/management/server/group.go
index b2ec88cc0..7b4f07948 100644
--- a/management/server/group.go
+++ b/management/server/group.go
@@ -6,11 +6,12 @@ import (
"fmt"
"slices"
- nbdns "github.com/netbirdio/netbird/dns"
- "github.com/netbirdio/netbird/route"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
+ nbdns "github.com/netbirdio/netbird/dns"
+ "github.com/netbirdio/netbird/route"
+
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status"
@@ -27,17 +28,12 @@ func (e *GroupLinkError) Error() string {
// CheckGroupPermissions validates if a user has the necessary permissions to view groups
func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error {
- settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
- if err != nil {
- return err
- }
-
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
- if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID {
+ if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "groups are blocked for users")
}
diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go
index a5856a0e4..f5027cd77 100644
--- a/management/server/http/peers_handler.go
+++ b/management/server/http/peers_handler.go
@@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain()
- respBody := make([]*api.PeerBatch, 0, len(account.Peers))
- for _, peer := range account.Peers {
+ peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID)
+ if err != nil {
+ util.WriteError(r.Context(), err, w)
+ return
+ }
+
+ groupsMap := map[string]*nbgroup.Group{}
+ groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
+ for _, group := range groups {
+ groupsMap[group.ID] = group
+ }
+
+ respBody := make([]*api.PeerBatch, 0, len(peers))
+ for _, peer := range peers {
peerToReturn, err := h.checkPeerStatus(peer)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
- groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
+ groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
}
@@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee
}
func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum {
- var groupsInfo []api.GroupMinimum
+ groupsInfo := []api.GroupMinimum{}
groupsChecked := make(map[string]struct{})
for _, group := range groups {
_, ok := groupsChecked[group.ID]
From 4ef3890bf757f149da7bce3588f62bfbf85b9d8c Mon Sep 17 00:00:00 2001
From: bcmmbaga
Date: Fri, 15 Nov 2024 17:48:00 +0300
Subject: [PATCH 08/39] Fix typo
Signed-off-by: bcmmbaga
---
management/server/dns.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/management/server/dns.go b/management/server/dns.go
index be7caea4e..8df211b0b 100644
--- a/management/server/dns.go
+++ b/management/server/dns.go
@@ -161,7 +161,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return nil
}
-// prepareGroupEvents prepares a list of event functions to be stored.
+// prepareDNSSettingsEvents prepares a list of event functions to be stored.
func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
var eventsToStore []func()
From 4aee3c9e33d6c733d38ac2b6e6287ea78e5ac591 Mon Sep 17 00:00:00 2001
From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Fri, 15 Nov 2024 16:59:03 +0100
Subject: [PATCH 09/39] [client/management] add peer lock to peer meta update
and fix isEqual func (#2840)
---
client/internal/engine.go | 12 +++++
client/internal/engine_test.go | 93 ++++++++++++++++++++++++++++++++++
management/server/account.go | 5 +-
management/server/peer.go | 3 ++
4 files changed, 112 insertions(+), 1 deletion(-)
diff --git a/client/internal/engine.go b/client/internal/engine.go
index 0f3a5d28a..d4a3a561a 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -11,6 +11,7 @@ import (
"reflect"
"runtime"
"slices"
+ "sort"
"strings"
"sync"
"sync/atomic"
@@ -1484,6 +1485,17 @@ func (e *Engine) stopDNSServer() {
// isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
+ for _, check := range checks {
+ sort.Slice(check.Files, func(i, j int) bool {
+ return check.Files[i] < check.Files[j]
+ })
+ }
+ for _, oCheck := range oChecks {
+ sort.Slice(oCheck.Files, func(i, j int) bool {
+ return oCheck.Files[i] < oCheck.Files[j]
+ })
+ }
+
return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool {
return slices.Equal(checks.Files, oChecks.Files)
})
diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go
index 0018af6df..b6c6186ea 100644
--- a/client/internal/engine_test.go
+++ b/client/internal/engine_test.go
@@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) {
}
}
+func Test_CheckFilesEqual(t *testing.T) {
+ testCases := []struct {
+ name string
+ inputChecks1 []*mgmtProto.Checks
+ inputChecks2 []*mgmtProto.Checks
+ expectedBool bool
+ }{
+ {
+ name: "Equal Files In Equal Order Should Return True",
+ inputChecks1: []*mgmtProto.Checks{
+ {
+ Files: []string{
+ "testfile1",
+ "testfile2",
+ },
+ },
+ },
+ inputChecks2: []*mgmtProto.Checks{
+ {
+ Files: []string{
+ "testfile1",
+ "testfile2",
+ },
+ },
+ },
+ expectedBool: true,
+ },
+ {
+ name: "Equal Files In Reverse Order Should Return True",
+ inputChecks1: []*mgmtProto.Checks{
+ {
+ Files: []string{
+ "testfile1",
+ "testfile2",
+ },
+ },
+ },
+ inputChecks2: []*mgmtProto.Checks{
+ {
+ Files: []string{
+ "testfile2",
+ "testfile1",
+ },
+ },
+ },
+ expectedBool: true,
+ },
+ {
+ name: "Unequal Files Should Return False",
+ inputChecks1: []*mgmtProto.Checks{
+ {
+ Files: []string{
+ "testfile1",
+ "testfile2",
+ },
+ },
+ },
+ inputChecks2: []*mgmtProto.Checks{
+ {
+ Files: []string{
+ "testfile1",
+ "testfile3",
+ },
+ },
+ },
+ expectedBool: false,
+ },
+ {
+ name: "Compared With Empty Should Return False",
+ inputChecks1: []*mgmtProto.Checks{
+ {
+ Files: []string{
+ "testfile1",
+ "testfile2",
+ },
+ },
+ },
+ inputChecks2: []*mgmtProto.Checks{
+ {
+ Files: []string{},
+ },
+ },
+ expectedBool: false,
+ },
+ }
+ for _, testCase := range testCases {
+ t.Run(testCase.name, func(t *testing.T) {
+ result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2)
+ assert.Equal(t, testCase.expectedBool, result, "result should match expected bool")
+ })
+ }
+}
+
func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
diff --git a/management/server/account.go b/management/server/account.go
index bf6039229..4afadb4e9 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -2319,7 +2319,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
if err != nil {
- log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
+ log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}
return nil
@@ -2335,6 +2335,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlock()
+ unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
+ defer unlockPeer()
+
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
diff --git a/management/server/peer.go b/management/server/peer.go
index 9784650de..87b5b0e4e 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -168,6 +168,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
account.UpdatePeer(peer)
+ log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected)
+
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
if err != nil {
return false, fmt.Errorf("failed to save peer status: %w", err)
@@ -657,6 +659,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
+ log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
From d9b691b8a56bb8bdd11f02364e0dbf59ae0580bc Mon Sep 17 00:00:00 2001
From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Fri, 15 Nov 2024 17:00:06 +0100
Subject: [PATCH 10/39] [management] Limit the setup-key update operation
(#2841)
---
management/server/http/api/openapi.yml | 25 ---------------
management/server/http/api/types.gen.go | 15 ---------
management/server/http/setupkeys_handler.go | 6 ----
management/server/setupkey.go | 12 ++++---
management/server/setupkey_test.go | 35 ++++++++++++++++++---
5 files changed, 38 insertions(+), 55 deletions(-)
diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml
index 9b4592ccf..bfb375277 100644
--- a/management/server/http/api/openapi.yml
+++ b/management/server/http/api/openapi.yml
@@ -521,19 +521,6 @@ components:
SetupKeyRequest:
type: object
properties:
- name:
- description: Setup Key name
- type: string
- example: Default key
- type:
- description: Setup key type, one-off for single time usage and reusable
- type: string
- example: reusable
- expires_in:
- description: Expiration time in seconds, 0 will mean the key never expires
- type: integer
- minimum: 0
- example: 86400
revoked:
description: Setup key revocation status
type: boolean
@@ -544,21 +531,9 @@ components:
items:
type: string
example: "ch8i4ug6lnn4g9hqv7m0"
- usage_limit:
- description: A number of times this key can be used. The value of 0 indicates the unlimited usage.
- type: integer
- example: 0
- ephemeral:
- description: Indicate that the peer will be ephemeral or not
- type: boolean
- example: true
required:
- - name
- - type
- - expires_in
- revoked
- auto_groups
- - usage_limit
CreateSetupKeyRequest:
type: object
properties:
diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go
index c1ef1ba21..f219c4574 100644
--- a/management/server/http/api/types.gen.go
+++ b/management/server/http/api/types.gen.go
@@ -1098,23 +1098,8 @@ type SetupKeyRequest struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
AutoGroups []string `json:"auto_groups"`
- // Ephemeral Indicate that the peer will be ephemeral or not
- Ephemeral *bool `json:"ephemeral,omitempty"`
-
- // ExpiresIn Expiration time in seconds, 0 will mean the key never expires
- ExpiresIn int `json:"expires_in"`
-
- // Name Setup Key name
- Name string `json:"name"`
-
// Revoked Setup key revocation status
Revoked bool `json:"revoked"`
-
- // Type Setup key type, one-off for single time usage and reusable
- Type string `json:"type"`
-
- // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
- UsageLimit int `json:"usage_limit"`
}
// User defines model for User.
diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go
index 31859f59b..9ba5977bb 100644
--- a/management/server/http/setupkeys_handler.go
+++ b/management/server/http/setupkeys_handler.go
@@ -137,11 +137,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
return
}
- if req.Name == "" {
- util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w)
- return
- }
-
if req.AutoGroups == nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w)
return
@@ -150,7 +145,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request
newKey := &server.SetupKey{}
newKey.AutoGroups = req.AutoGroups
newKey.Revoked = req.Revoked
- newKey.Name = req.Name
newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
diff --git a/management/server/setupkey.go b/management/server/setupkey.go
index 554c66ba4..960532bf9 100644
--- a/management/server/setupkey.go
+++ b/management/server/setupkey.go
@@ -12,9 +12,10 @@ import (
"unicode/utf8"
"github.com/google/uuid"
+ log "github.com/sirupsen/logrus"
+
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/status"
- log "github.com/sirupsen/logrus"
)
const (
@@ -276,7 +277,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
// SaveSetupKey saves the provided SetupKey to the database overriding the existing one.
// Due to the unique nature of a SetupKey certain properties must not be overwritten
// (e.g. the key itself, creation date, ID, etc).
-// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key.
+// These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key.
func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) {
if keyToSave == nil {
return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil")
@@ -312,9 +313,12 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return err
}
- // only auto groups, revoked status, and name can be updated for now
+ if oldKey.Revoked && !keyToSave.Revoked {
+ return status.Errorf(status.InvalidArgument, "can't un-revoke a revoked setup key")
+ }
+
+ // only auto groups, revoked status (from false to true) can be updated
newKey = oldKey.Copy()
- newKey.Name = keyToSave.Name
newKey.AutoGroups = keyToSave.AutoGroups
newKey.Revoked = keyToSave.Revoked
newKey.UpdatedAt = time.Now().UTC()
diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go
index 2ed8aef95..94ed022fa 100644
--- a/management/server/setupkey_test.go
+++ b/management/server/setupkey_test.go
@@ -56,11 +56,9 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
}
autoGroups := []string{"group_1", "group_2"}
- newKeyName := "my-new-test-key"
revoked := true
newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
Id: key.Id,
- Name: newKeyName,
Revoked: revoked,
AutoGroups: autoGroups,
}, userID)
@@ -68,7 +66,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
t.Fatal(err)
}
- assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
+ assertKey(t, newKey, keyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt,
key.Id, time.Now().UTC(), autoGroups, true)
// check the corresponding events that should have been generated
@@ -76,7 +74,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
assert.NotNil(t, ev)
assert.Equal(t, account.Id, ev.AccountID)
- assert.Equal(t, newKeyName, ev.Meta["name"])
+ assert.Equal(t, keyName, ev.Meta["name"])
assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"]))
assert.NotEmpty(t, ev.Meta["key"])
assert.Equal(t, userID, ev.InitiatorID)
@@ -89,7 +87,6 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
autoGroups = append(autoGroups, groupAll.ID)
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
Id: key.Id,
- Name: newKeyName,
Revoked: revoked,
AutoGroups: autoGroups,
}, userID)
@@ -449,3 +446,31 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
}
})
}
+
+func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) {
+ manager, err := createManager(t)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ userID := "testingUser"
+ account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
+ assert.NoError(t, err)
+
+ // revoke the key
+ updateKey := key.Copy()
+ updateKey.Revoked = true
+ _, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
+ assert.NoError(t, err)
+
+ // re-activate revoked key
+ updateKey.Revoked = false
+ _, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID)
+ assert.Error(t, err, "should not allow to update revoked key")
+
+}
From 51c1ec283cb9d9dacc9ec18ab6d98b64d954d362 Mon Sep 17 00:00:00 2001
From: bcmmbaga
Date: Fri, 15 Nov 2024 19:34:57 +0300
Subject: [PATCH 11/39] Add locks and remove log
Signed-off-by: bcmmbaga
---
management/server/posture_checks.go | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go
index d7b5a79a2..59e726c41 100644
--- a/management/server/posture_checks.go
+++ b/management/server/posture_checks.go
@@ -9,7 +9,6 @@ import (
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid"
- log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
)
@@ -32,6 +31,9 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
// SavePostureChecks saves a posture check.
func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ defer unlock()
+
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
@@ -85,6 +87,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
// DeletePostureChecks deletes a posture check by ID.
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ defer unlock()
+
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
@@ -267,7 +272,6 @@ func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountI
for _, sourceGroup := range rule.Sources {
group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup)
if err != nil {
- log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err)
return false, fmt.Errorf("failed to check peer in policy source group: %w", err)
}
From 12f442439a64a4b76e8b9c7e93c0aa5de74d213c Mon Sep 17 00:00:00 2001
From: Bethuel Mmbaga
Date: Fri, 15 Nov 2024 20:09:32 +0300
Subject: [PATCH 12/39] [management] Refactor group to use store methods
(#2867)
* Refactor setup key handling to use store methods
Signed-off-by: bcmmbaga
* add lock to get account groups
Signed-off-by: bcmmbaga
* add check for regular user
Signed-off-by: bcmmbaga
* get only required groups for auto-group validation
Signed-off-by: bcmmbaga
* add account lock and return auto groups map on validation
Signed-off-by: bcmmbaga
* refactor account peers update
Signed-off-by: bcmmbaga
* Refactor groups to use store methods
Signed-off-by: bcmmbaga
* refactor GetGroupByID and add NewGroupNotFoundError
Signed-off-by: bcmmbaga
* fix tests
Signed-off-by: bcmmbaga
* Add AddPeer and RemovePeer methods to Group struct
Signed-off-by: bcmmbaga
* Preserve store engine in SqlStore transactions
Signed-off-by: bcmmbaga
* Run groups ops in transaction
Signed-off-by: bcmmbaga
* fix missing group removed from setup key activity
Signed-off-by: bcmmbaga
* fix merge
Signed-off-by: bcmmbaga
* fix merge
Signed-off-by: bcmmbaga
* fix sonar
Signed-off-by: bcmmbaga
* Change setup key log level to debug for missing group
Signed-off-by: bcmmbaga
* Retrieve modified peers once for group events
Signed-off-by: bcmmbaga
* Add tests
Signed-off-by: bcmmbaga
* Add account locking and merge group deletion methods
Signed-off-by: bcmmbaga
* Fix tests
Signed-off-by: bcmmbaga
---------
Signed-off-by: bcmmbaga
---
management/server/account.go | 29 +-
management/server/account_test.go | 12 +-
management/server/dns.go | 2 +-
management/server/group.go | 514 ++++++++++--------
management/server/group/group.go | 27 +
management/server/group/group_test.go | 90 +++
management/server/group_test.go | 2 +-
management/server/integrated_validator.go | 27 +-
management/server/mock_server/account_mock.go | 9 -
management/server/nameserver.go | 6 +-
management/server/peer.go | 41 +-
management/server/peer_test.go | 2 +-
management/server/policy.go | 4 +-
management/server/posture_checks.go | 2 +-
management/server/route.go | 6 +-
management/server/route_test.go | 2 +-
management/server/setupkey.go | 6 +-
management/server/sql_store.go | 120 +++-
management/server/sql_store_test.go | 285 +++++++++-
management/server/status/error.go | 11 +-
management/server/store.go | 8 +-
management/server/user.go | 8 +-
22 files changed, 878 insertions(+), 335 deletions(-)
create mode 100644 management/server/group/group_test.go
diff --git a/management/server/account.go b/management/server/account.go
index 4afadb4e9..1bd8a99a9 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -110,7 +110,6 @@ type AccountManager interface {
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
- ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
@@ -1435,7 +1434,7 @@ func isNil(i idp.Manager) bool {
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) {
- accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
+ accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
@@ -2083,7 +2082,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error saving groups: %w", err)
}
- if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
+ if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}
}
@@ -2101,7 +2100,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
for _, g := range addNewGroups {
- group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
+ group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
@@ -2114,7 +2113,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
for _, g := range removeOldGroups {
- group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
+ group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
@@ -2127,14 +2126,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
if settings.GroupsPropagationEnabled {
- account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
+ removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups)
if err != nil {
- return status.NewGetAccountError(err)
+ return err
}
- if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) {
+ newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups)
+ if err != nil {
+ return err
+ }
+
+ if removedGroupAffectsPeers || newGroupsAffectsPeers {
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
}
@@ -2401,12 +2405,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context,
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
- updatedAccount, err := am.Store.GetAccount(ctx, accountID)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
- return
- }
- am.updateAccountPeers(ctx, updatedAccount)
+ am.updateAccountPeers(ctx, accountID)
}
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
diff --git a/management/server/account_test.go b/management/server/account_test.go
index fdf004a3b..97e0d45f0 100644
--- a/management/server/account_test.go
+++ b/management/server/account_test.go
@@ -1413,11 +1413,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
- group := group.Group{
+ err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
- }
+ })
+
+ require.NoError(t, err, "failed to save group")
policy := Policy{
Enabled: true,
@@ -1460,7 +1462,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return
}
- if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
+ if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil {
t.Errorf("delete group: %v", err)
return
}
@@ -2714,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
- group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
+ group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
@@ -2734,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
- group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
+ group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
diff --git a/management/server/dns.go b/management/server/dns.go
index 256b8b125..4551be5ab 100644
--- a/management/server/dns.go
+++ b/management/server/dns.go
@@ -146,7 +146,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
}
if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
return nil
diff --git a/management/server/group.go b/management/server/group.go
index 7b4f07948..a36213f04 100644
--- a/management/server/group.go
+++ b/management/server/group.go
@@ -33,8 +33,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco
return err
}
- if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
- return status.Errorf(status.PermissionDenied, "groups are blocked for users")
+ if user.AccountID != accountID {
+ return status.NewUserNotPartOfAccountError()
+ }
+
+ if user.IsRegularUser() {
+ return status.NewAdminPermissionError()
}
return nil
@@ -45,8 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
-
- return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
+ return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
}
// GetAllGroups returns all groups in an account
@@ -54,13 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
-
return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
- return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
+ return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
}
// SaveGroup object of the peers
@@ -73,79 +75,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI
// SaveGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
-func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error {
- account, err := am.Store.GetAccount(ctx, accountID)
+func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error {
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
+ if user.AccountID != accountID {
+ return status.NewUserNotPartOfAccountError()
+ }
+
+ if user.IsRegularUser() {
+ return status.NewAdminPermissionError()
+ }
+
var eventsToStore []func()
+ var groupsToSave []*nbgroup.Group
+ var updateAccountPeers bool
- for _, newGroup := range newGroups {
- if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
- return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
- }
-
- if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
- existingGroup, err := account.FindGroupByName(newGroup.Name)
- if err != nil {
- s, ok := status.FromError(err)
- if !ok || s.ErrorType != status.NotFound {
- return err
- }
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ groupIDs := make([]string, 0, len(groups))
+ for _, newGroup := range groups {
+ if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
+ return err
}
- // Avoid duplicate groups only for the API issued groups.
- // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
- if existingGroup != nil {
- return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
- }
+ newGroup.AccountID = accountID
+ groupsToSave = append(groupsToSave, newGroup)
+ groupIDs = append(groupIDs, newGroup.ID)
- newGroup.ID = xid.New().String()
+ events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
+ eventsToStore = append(eventsToStore, events...)
}
- for _, peerID := range newGroup.Peers {
- if account.Peers[peerID] == nil {
- return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
- }
+ updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs)
+ if err != nil {
+ return err
}
- oldGroup := account.Groups[newGroup.ID]
- account.Groups[newGroup.ID] = newGroup
+ if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
+ return err
+ }
- events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account)
- eventsToStore = append(eventsToStore, events...)
- }
-
- newGroupIDs := make([]string, 0, len(newGroups))
- for _, newGroup := range newGroups {
- newGroupIDs = append(newGroupIDs, newGroup.ID)
- }
-
- account.Network.IncSerial()
- if err = am.Store.SaveAccount(ctx, account); err != nil {
+ return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave)
+ })
+ if err != nil {
return err
}
- if areGroupChangesAffectPeers(account, newGroupIDs) {
- am.updateAccountPeers(ctx, account)
- }
-
for _, storeEvent := range eventsToStore {
storeEvent()
}
+ if updateAccountPeers {
+ am.updateAccountPeers(ctx, accountID)
+ }
+
return nil
}
// prepareGroupEvents prepares a list of event functions to be stored.
-func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() {
+func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() {
var eventsToStore []func()
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
- if oldGroup != nil {
+ oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
+ if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
} else {
@@ -155,35 +152,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
})
}
- for _, p := range addedPeers {
- peer := account.Peers[p]
- if peer == nil {
- log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
+ modifiedPeers := slices.Concat(addedPeers, removedPeers)
+ peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers)
+ if err != nil {
+ log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
+ return nil
+ }
+
+ for _, peerID := range addedPeers {
+ peer, ok := peers[peerID]
+ if !ok {
+ log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID)
continue
}
- peerCopy := peer // copy to avoid closure issues
+
eventsToStore = append(eventsToStore, func() {
- am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer,
- map[string]any{
- "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
- "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
- })
+ meta := map[string]any{
+ "group": newGroup.Name, "group_id": newGroup.ID,
+ "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
+ }
+ am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta)
})
}
- for _, p := range removedPeers {
- peer := account.Peers[p]
- if peer == nil {
- log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID)
+ for _, peerID := range removedPeers {
+ peer, ok := peers[peerID]
+ if !ok {
+ log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID)
continue
}
- peerCopy := peer // copy to avoid closure issues
+
eventsToStore = append(eventsToStore, func() {
- am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer,
- map[string]any{
- "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(),
- "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()),
- })
+ meta := map[string]any{
+ "group": newGroup.Name, "group_id": newGroup.ID,
+ "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
+ }
+ am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta)
})
}
@@ -206,42 +210,10 @@ func difference(a, b []string) []string {
}
// DeleteGroup object of the peers.
-func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
+func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error {
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountId)
- if err != nil {
- return err
- }
-
- group, ok := account.Groups[groupID]
- if !ok {
- return nil
- }
-
- allGroup, err := account.GetGroupAll()
- if err != nil {
- return err
- }
-
- if allGroup.ID == groupID {
- return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
- }
-
- if err = validateDeleteGroup(account, group, userId); err != nil {
- return err
- }
- delete(account.Groups, groupID)
-
- account.Network.IncSerial()
- if err = am.Store.SaveAccount(ctx, account); err != nil {
- return err
- }
-
- am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
-
- return nil
+ return am.DeleteGroups(ctx, accountID, userID, []string{groupID})
}
// DeleteGroups deletes groups from an account.
@@ -250,93 +222,94 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
//
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
// Errors are collected and returned at the end.
-func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
- account, err := am.Store.GetAccount(ctx, accountId)
+func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error {
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
+ if user.AccountID != accountID {
+ return status.NewUserNotPartOfAccountError()
+ }
+
+ if user.IsRegularUser() {
+ return status.NewAdminPermissionError()
+ }
+
var allErrors error
+ var groupIDsToDelete []string
+ var deletedGroups []*nbgroup.Group
- deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
- for _, groupID := range groupIDs {
- group, ok := account.Groups[groupID]
- if !ok {
- continue
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ for _, groupID := range groupIDs {
+ group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID)
+ if err != nil {
+ allErrors = errors.Join(allErrors, err)
+ continue
+ }
+
+ if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil {
+ allErrors = errors.Join(allErrors, err)
+ continue
+ }
+
+ groupIDsToDelete = append(groupIDsToDelete, groupID)
+ deletedGroups = append(deletedGroups, group)
}
- if err := validateDeleteGroup(account, group, userId); err != nil {
- allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
- continue
+ if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
+ return err
}
- delete(account.Groups, groupID)
- deletedGroups = append(deletedGroups, group)
- }
-
- account.Network.IncSerial()
- if err = am.Store.SaveAccount(ctx, account); err != nil {
+ return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete)
+ })
+ if err != nil {
return err
}
- for _, g := range deletedGroups {
- am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
+ for _, group := range deletedGroups {
+ am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
}
return allErrors
}
-// ListGroups objects of the peers
-func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
- if err != nil {
- return nil, err
- }
-
- groups := make([]*nbgroup.Group, 0, len(account.Groups))
- for _, item := range account.Groups {
- groups = append(groups, item)
- }
-
- return groups, nil
-}
-
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- account, err := am.Store.GetAccount(ctx, accountID)
+ var group *nbgroup.Group
+ var updateAccountPeers bool
+ var err error
+
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
+ if err != nil {
+ return err
+ }
+
+ if updated := group.AddPeer(peerID); !updated {
+ return nil
+ }
+
+ updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
+ if err != nil {
+ return err
+ }
+
+ if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
+ return err
+ }
+
+ return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
+ })
if err != nil {
return err
}
- group, ok := account.Groups[groupID]
- if !ok {
- return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
- }
-
- add := true
- for _, itemID := range group.Peers {
- if itemID == peerID {
- add = false
- break
- }
- }
- if add {
- group.Peers = append(group.Peers, peerID)
- }
-
- account.Network.IncSerial()
- if err = am.Store.SaveAccount(ctx, account); err != nil {
- return err
- }
-
- if areGroupChangesAffectPeers(account, []string{group.ID}) {
- am.updateAccountPeers(ctx, account)
+ if updateAccountPeers {
+ am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -347,90 +320,162 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- account, err := am.Store.GetAccount(ctx, accountID)
+ var group *nbgroup.Group
+ var updateAccountPeers bool
+ var err error
+
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID)
+ if err != nil {
+ return err
+ }
+
+ if updated := group.RemovePeer(peerID); !updated {
+ return nil
+ }
+
+ updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
+ if err != nil {
+ return err
+ }
+
+ if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
+ return err
+ }
+
+ return transaction.SaveGroup(ctx, LockingStrengthUpdate, group)
+ })
if err != nil {
return err
}
- group, ok := account.Groups[groupID]
- if !ok {
- return status.Errorf(status.NotFound, "group with ID %s not found", groupID)
- }
-
- account.Network.IncSerial()
- for i, itemID := range group.Peers {
- if itemID == peerID {
- group.Peers = append(group.Peers[:i], group.Peers[i+1:]...)
- if err := am.Store.SaveAccount(ctx, account); err != nil {
- return err
- }
- }
- }
-
- if areGroupChangesAffectPeers(account, []string{group.ID}) {
- am.updateAccountPeers(ctx, account)
+ if updateAccountPeers {
+ am.updateAccountPeers(ctx, accountID)
}
return nil
}
-func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
+// validateNewGroup validates the new group for existence and required fields.
+func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error {
+ if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI {
+ return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
+ }
+
+ if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
+ existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
+ if err != nil {
+ if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
+ return err
+ }
+ }
+
+ // Prevent duplicate groups for API-issued groups.
+ // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of.
+ if existingGroup != nil {
+ return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
+ }
+
+ newGroup.ID = xid.New().String()
+ }
+
+ for _, peerID := range newGroup.Peers {
+ _, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
+ if err != nil {
+ return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
+ }
+ }
+
+ return nil
+}
+
+func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == nbgroup.GroupIssuedIntegration {
- executingUser := account.Users[userID]
- if executingUser == nil {
- return status.Errorf(status.NotFound, "user not found")
+ executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID)
+ if err != nil {
+ return err
}
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
- if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
+ if group.IsGroupAll() {
+ return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
+ }
+
+ if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
- if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
+ if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name}
}
- if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
+ if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name}
}
- if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
+ if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name}
}
- if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
+ if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id}
}
- if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
+ return checkGroupLinkedToSettings(ctx, transaction, group)
+}
+
+// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
+func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error {
+ dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID)
+ if err != nil {
+ return err
+ }
+
+ if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
- if account.Settings.Extra != nil {
- if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
- return &GroupLinkError{"integrated validator", group.Name}
- }
+ settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID)
+ if err != nil {
+ return err
+ }
+
+ if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
+ return &GroupLinkError{"integrated validator", group.Name}
}
return nil
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
-func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
+func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) {
+ routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
+ return false, nil
+ }
+
for _, r := range routes {
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
return true, r
}
}
+
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
-func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
+func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) {
+ policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
+ return false, nil
+ }
+
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
@@ -442,7 +487,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
-func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
+func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
+ nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
+ return false, nil
+ }
+
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
@@ -450,11 +501,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou
}
}
}
+
return false, nil
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
-func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
+func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) {
+ setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
+ return false, nil
+ }
+
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
@@ -464,7 +522,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
-func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
+func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) {
+ users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
+ return false, nil
+ }
+
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
@@ -473,6 +537,35 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
return false, nil
}
+// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
+func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
+ if len(groupIDs) == 0 {
+ return false, nil
+ }
+
+ dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return false, err
+ }
+
+ for _, groupID := range groupIDs {
+ if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
+ return true, nil
+ }
+ if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
+ return true, nil
+ }
+ if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
+ return true, nil
+ }
+ if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
+ return true, nil
+ }
+ }
+
+ return false, nil
+}
+
// anyGroupHasPeers checks if any of the given groups in the account have peers.
func anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
@@ -482,22 +575,3 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
}
return false
}
-
-func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool {
- for _, groupID := range groupIDs {
- if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) {
- return true
- }
- if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked {
- return true
- }
- if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked {
- return true
- }
- if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked {
- return true
- }
- }
-
- return false
-}
diff --git a/management/server/group/group.go b/management/server/group/group.go
index e98e5ecc4..24c60d3ce 100644
--- a/management/server/group/group.go
+++ b/management/server/group/group.go
@@ -54,3 +54,30 @@ func (g *Group) HasPeers() bool {
func (g *Group) IsGroupAll() bool {
return g.Name == "All"
}
+
+// AddPeer adds peerID to Peers if not present, returning true if added.
+func (g *Group) AddPeer(peerID string) bool {
+ if peerID == "" {
+ return false
+ }
+
+ for _, itemID := range g.Peers {
+ if itemID == peerID {
+ return false
+ }
+ }
+
+ g.Peers = append(g.Peers, peerID)
+ return true
+}
+
+// RemovePeer removes peerID from Peers if present, returning true if removed.
+func (g *Group) RemovePeer(peerID string) bool {
+ for i, itemID := range g.Peers {
+ if itemID == peerID {
+ g.Peers = append(g.Peers[:i], g.Peers[i+1:]...)
+ return true
+ }
+ }
+ return false
+}
diff --git a/management/server/group/group_test.go b/management/server/group/group_test.go
new file mode 100644
index 000000000..cb002f8d9
--- /dev/null
+++ b/management/server/group/group_test.go
@@ -0,0 +1,90 @@
+package group
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+)
+
+func TestAddPeer(t *testing.T) {
+ t.Run("add new peer to empty slice", func(t *testing.T) {
+ group := &Group{Peers: []string{}}
+ peerID := "peer1"
+ assert.True(t, group.AddPeer(peerID))
+ assert.Contains(t, group.Peers, peerID)
+ })
+
+ t.Run("add new peer to nil slice", func(t *testing.T) {
+ group := &Group{Peers: nil}
+ peerID := "peer1"
+ assert.True(t, group.AddPeer(peerID))
+ assert.Contains(t, group.Peers, peerID)
+ })
+
+ t.Run("add new peer to non-empty slice", func(t *testing.T) {
+ group := &Group{Peers: []string{"peer1", "peer2"}}
+ peerID := "peer3"
+ assert.True(t, group.AddPeer(peerID))
+ assert.Contains(t, group.Peers, peerID)
+ })
+
+ t.Run("add duplicate peer", func(t *testing.T) {
+ group := &Group{Peers: []string{"peer1", "peer2"}}
+ peerID := "peer1"
+ assert.False(t, group.AddPeer(peerID))
+ assert.Equal(t, 2, len(group.Peers))
+ })
+
+ t.Run("add empty peer", func(t *testing.T) {
+ group := &Group{Peers: []string{"peer1", "peer2"}}
+ peerID := ""
+ assert.False(t, group.AddPeer(peerID))
+ assert.Equal(t, 2, len(group.Peers))
+ })
+}
+
+func TestRemovePeer(t *testing.T) {
+ t.Run("remove existing peer from slice", func(t *testing.T) {
+ group := &Group{Peers: []string{"peer1", "peer2", "peer3"}}
+ peerID := "peer2"
+ assert.True(t, group.RemovePeer(peerID))
+ assert.NotContains(t, group.Peers, peerID)
+ assert.Equal(t, 2, len(group.Peers))
+ })
+
+ t.Run("remove peer from empty slice", func(t *testing.T) {
+ group := &Group{Peers: []string{}}
+ peerID := "peer1"
+ assert.False(t, group.RemovePeer(peerID))
+ assert.Equal(t, 0, len(group.Peers))
+ })
+
+ t.Run("remove peer from nil slice", func(t *testing.T) {
+ group := &Group{Peers: nil}
+ peerID := "peer1"
+ assert.False(t, group.RemovePeer(peerID))
+ assert.Nil(t, group.Peers)
+ })
+
+ t.Run("remove non-existent peer", func(t *testing.T) {
+ group := &Group{Peers: []string{"peer1", "peer2"}}
+ peerID := "peer3"
+ assert.False(t, group.RemovePeer(peerID))
+ assert.Equal(t, 2, len(group.Peers))
+ })
+
+ t.Run("remove peer from single-item slice", func(t *testing.T) {
+ group := &Group{Peers: []string{"peer1"}}
+ peerID := "peer1"
+ assert.True(t, group.RemovePeer(peerID))
+ assert.Equal(t, 0, len(group.Peers))
+ assert.NotContains(t, group.Peers, peerID)
+ })
+
+ t.Run("remove empty peer", func(t *testing.T) {
+ group := &Group{Peers: []string{"peer1", "peer2"}}
+ peerID := ""
+ assert.False(t, group.RemovePeer(peerID))
+ assert.Equal(t, 2, len(group.Peers))
+ })
+}
diff --git a/management/server/group_test.go b/management/server/group_test.go
index 89184e819..59094a23e 100644
--- a/management/server/group_test.go
+++ b/management/server/group_test.go
@@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
{
name: "delete non-existent group",
groupIDs: []string{"non-existent-group"},
- expectedDeleted: []string{"non-existent-group"},
+ expectedReasons: []string{"group: non-existent-group not found"},
},
{
name: "delete multiple groups with mixed results",
diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go
index 99e6b204c..0c70b702a 100644
--- a/management/server/integrated_validator.go
+++ b/management/server/integrated_validator.go
@@ -52,25 +52,22 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
return am.Store.SaveAccount(ctx, a)
}
-func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) {
- if len(groups) == 0 {
+func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
+ if len(groupIDs) == 0 {
return true, nil
}
- accountsGroups, err := am.ListGroups(ctx, accountId)
- if err != nil {
- return false, err
- }
- for _, group := range groups {
- var found bool
- for _, accountGroup := range accountsGroups {
- if accountGroup.ID == group {
- found = true
- break
+
+ err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ for _, groupID := range groupIDs {
+ _, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
+ if err != nil {
+ return err
}
}
- if !found {
- return false, nil
- }
+ return nil
+ })
+ if err != nil {
+ return false, err
}
return true, nil
diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go
index d7139bb2a..aa6a47b15 100644
--- a/management/server/mock_server/account_mock.go
+++ b/management/server/mock_server/account_mock.go
@@ -45,7 +45,6 @@ type MockAccountManager struct {
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
- ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
@@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
}
-// ListGroups mock implementation of ListGroups from server.AccountManager interface
-func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
- if am.ListGroupsFunc != nil {
- return am.ListGroupsFunc(ctx, accountID)
- }
- return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented")
-}
-
// GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface
func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
if am.GroupAddPeerFunc != nil {
diff --git a/management/server/nameserver.go b/management/server/nameserver.go
index 5ebd263dc..957008714 100644
--- a/management/server/nameserver.go
+++ b/management/server/nameserver.go
@@ -71,7 +71,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
}
if anyGroupHasPeers(account, newNSGroup.Groups) {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
@@ -106,7 +106,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
}
if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
@@ -136,7 +136,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
}
if anyGroupHasPeers(account, nsGroup.Groups) {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
diff --git a/management/server/peer.go b/management/server/peer.go
index 87b5b0e4e..1405dead8 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -133,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, account.Id)
}
return nil
@@ -271,7 +271,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
if peerLabelUpdated || requiresPeerUpdates {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
return peer, nil
@@ -335,7 +335,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
- updateAccountPeers := isPeerInActiveGroup(account, peerID)
+ updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID)
+ if err != nil {
+ return err
+ }
err = am.deletePeers(ctx, account, []string{peerID}, userID)
if err != nil {
@@ -348,7 +351,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
}
if updateAccountPeers {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -555,7 +558,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed to add peer to account: %w", err)
}
- err = transaction.IncrementNetworkSerial(ctx, accountID)
+ err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
@@ -598,10 +601,15 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
}
-
groupsToAdd = append(groupsToAdd, allGroup.ID)
- if areGroupChangesAffectPeers(account, groupsToAdd) {
- am.updateAccountPeers(ctx, account)
+
+ newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ if newGroupsAffectsPeers {
+ am.updateAccountPeers(ctx, accountID)
}
approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -666,7 +674,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
}
if sync.UpdateAccountPeers {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, account.Id)
}
}
@@ -685,7 +693,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
}
if isStatusChanged {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, account.Id)
}
validPeersMap, err := am.GetValidatedPeers(account)
@@ -816,7 +824,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
}
if updateRemotePeers || isStatusChanged {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
@@ -979,7 +987,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
-func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
+func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
start := time.Now()
defer func() {
if am.metrics != nil {
@@ -987,6 +995,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
}
}()
+ account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err)
+ return
+ }
peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account)
@@ -1033,12 +1046,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
-func isPeerInActiveGroup(account *Account, peerID string) bool {
+func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) {
peerGroupIDs := make([]string, 0)
for _, group := range account.Groups {
if slices.Contains(group.Peers, peerID) {
peerGroupIDs = append(peerGroupIDs, group.ID)
}
}
- return areGroupChangesAffectPeers(account, peerGroupIDs)
+ return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs)
}
diff --git a/management/server/peer_test.go b/management/server/peer_test.go
index 78885ea1b..4e2dcb2c3 100644
--- a/management/server/peer_test.go
+++ b/management/server/peer_test.go
@@ -877,7 +877,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
start := time.Now()
for i := 0; i < b.N; i++ {
- manager.updateAccountPeers(ctx, account)
+ manager.updateAccountPeers(ctx, account.Id)
}
duration := time.Since(start)
diff --git a/management/server/policy.go b/management/server/policy.go
index 43a925f88..8a5733f01 100644
--- a/management/server/policy.go
+++ b/management/server/policy.go
@@ -377,7 +377,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -406,7 +406,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
if anyGroupHasPeers(account, policy.ruleGroups()) {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
return nil
diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go
index 2dccd8f59..096cff3f5 100644
--- a/management/server/posture_checks.go
+++ b/management/server/posture_checks.go
@@ -69,7 +69,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
return nil
diff --git a/management/server/route.go b/management/server/route.go
index 1cf00b37c..dcf2cb0d3 100644
--- a/management/server/route.go
+++ b/management/server/route.go
@@ -238,7 +238,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
}
if isRouteChangeAffectPeers(account, &newRoute) {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
@@ -324,7 +324,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
}
if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
@@ -356,7 +356,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if isRouteChangeAffectPeers(account, routy) {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
return nil
diff --git a/management/server/route_test.go b/management/server/route_test.go
index 4893e19b9..5c848f68c 100644
--- a/management/server/route_test.go
+++ b/management/server/route_test.go
@@ -1091,7 +1091,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
- groups, err := am.ListGroups(context.Background(), account.Id)
+ groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id)
require.NoError(t, err)
var groupHA1, groupHA2 *nbgroup.Group
for _, group := range groups {
diff --git a/management/server/setupkey.go b/management/server/setupkey.go
index 960532bf9..cae0dfecb 100644
--- a/management/server/setupkey.go
+++ b/management/server/setupkey.go
@@ -453,14 +453,14 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
if err != nil {
- log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err)
+ log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
return nil
}
for _, g := range removedGroups {
group, ok := groups[g]
if !ok {
- log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err)
+ log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g)
continue
}
@@ -473,7 +473,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
for _, g := range addedGroups {
group, ok := groups[g]
if !ok {
- log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err)
+ log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g)
continue
}
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index 9dd3e778d..0ebda6440 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -33,12 +33,13 @@ import (
)
const (
- storeSqliteFileName = "store.db"
- idQueryCondition = "id = ?"
- keyQueryCondition = "key = ?"
- accountAndIDQueryCondition = "account_id = ? and id = ?"
- accountIDCondition = "account_id = ?"
- peerNotFoundFMT = "peer %s not found"
+ storeSqliteFileName = "store.db"
+ idQueryCondition = "id = ?"
+ keyQueryCondition = "key = ?"
+ accountAndIDQueryCondition = "account_id = ? and id = ?"
+ accountAndIDsQueryCondition = "account_id = ? AND id IN ?"
+ accountIDCondition = "account_id = ?"
+ peerNotFoundFMT = "peer %s not found"
)
// SqlStore represents an account storage backed by a Sql DB persisted to disk
@@ -555,9 +556,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil
}
-func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
+func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
var users []*User
- result := s.db.Find(&users, accountIDCondition, accountID)
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@@ -857,7 +858,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID)
}
-
return status.NewGetUserFromStoreError()
}
user.LastLogin = lastLogin
@@ -1045,7 +1045,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return status.Errorf(status.NotFound, "group not found for account")
+ return status.NewGroupNotFoundError(groupID)
}
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
@@ -1079,10 +1079,45 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
return nil
}
-func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
- result := s.db.Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
+// GetPeerByID retrieves a peer by its ID and account ID.
+func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
+ var peer *nbpeer.Peer
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
+ First(&peer, accountAndIDQueryCondition, accountID, peerID)
if result.Error != nil {
- return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.Errorf(status.NotFound, "peer not found")
+ }
+ log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "failed to get peer from store")
+ }
+
+ return peer, nil
+}
+
+// GetPeersByIDs retrieves peers by their IDs and account ID.
+func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) {
+ var peers []*nbpeer.Peer
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs)
+ if result.Error != nil {
+ log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store")
+ }
+
+ peersMap := make(map[string]*nbpeer.Peer)
+ for _, peer := range peers {
+ peersMap[peer.ID] = peer
+ }
+
+ return peersMap, nil
+}
+
+func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
+ Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
+ if result.Error != nil {
+ log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
+ return status.Errorf(status.Internal, "failed to increment network serial count in store")
}
return nil
}
@@ -1103,7 +1138,8 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
func (s *SqlStore) withTx(tx *gorm.DB) Store {
return &SqlStore{
- db: tx,
+ db: tx,
+ storeEngine: s.storeEngine,
}
}
@@ -1155,12 +1191,22 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
}
// GetGroupByID retrieves a group by ID and account ID.
-func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
- return getRecordByID[nbgroup.Group](s.db.Preload(clause.Associations), lockStrength, groupID, accountID)
+func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) {
+ var group *nbgroup.Group
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID)
+ if err := result.Error; err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ return nil, status.NewGroupNotFoundError(groupID)
+ }
+ log.WithContext(ctx).Errorf("failed to get group from store: %s", err)
+ return nil, status.Errorf(status.Internal, "failed to get group from store")
+ }
+
+ return group, nil
}
// GetGroupByName retrieves a group by name and account ID.
-func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
+func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) {
var group nbgroup.Group
// TODO: This fix is accepted for now, but if we need to handle this more frequently
@@ -1172,12 +1218,13 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
query = query.Order("json_array_length(peers) DESC")
}
- result := query.First(&group, "name = ? and account_id = ?", groupName, accountID)
+ result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return nil, status.Errorf(status.NotFound, "group not found")
+ return nil, status.NewGroupNotFoundError(groupName)
}
- return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error)
+ log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error)
+ return nil, status.Errorf(status.Internal, "failed to get group by name from store")
}
return &group, nil
}
@@ -1185,7 +1232,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
// GetGroupsByIDs retrieves groups by their IDs and account ID.
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) {
var groups []*nbgroup.Group
- result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs)
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
@@ -1203,11 +1250,40 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group)
if result.Error != nil {
- return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error)
+ log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error)
+ return status.Errorf(status.Internal, "failed to save group to store")
}
return nil
}
+// DeleteGroup deletes a group from the database.
+func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
+ Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID)
+ if err := result.Error; err != nil {
+ log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error)
+ return status.Errorf(status.Internal, "failed to delete group from store")
+ }
+
+ if result.RowsAffected == 0 {
+ return status.NewGroupNotFoundError(groupID)
+ }
+
+ return nil
+}
+
+// DeleteGroups deletes groups from the database.
+func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
+ result := s.db.Clauses(clause.Locking{Strength: string(strength)}).
+ Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs)
+ if result.Error != nil {
+ log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error)
+ return status.Errorf(status.Internal, "failed to delete groups from store")
+ }
+
+ return nil
+}
+
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID)
diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go
index 3f3b2a453..114da1ee6 100644
--- a/management/server/sql_store_test.go
+++ b/management/server/sql_store_test.go
@@ -14,11 +14,10 @@ import (
"time"
"github.com/google/uuid"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
route2 "github.com/netbirdio/netbird/route"
@@ -1181,7 +1180,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Fatal("failed to save group")
return err
}
- group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
+ group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
if err != nil {
t.Fatal("failed to get group")
return err
@@ -1201,7 +1200,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) {
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
account, err := store.GetAccount(context.Background(), accountID)
require.NoError(t, err)
- users, err := store.GetAccountUsers(context.Background(), accountID)
+ users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID)
require.NoError(t, err)
require.Len(t, users, len(account.Users))
}
@@ -1260,9 +1259,9 @@ func TestSqlite_GetGroupByName(t *testing.T) {
}
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
- group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID)
+ group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All")
require.NoError(t, err)
- require.Equal(t, "All", group.Name)
+ require.True(t, group.IsGroupAll())
}
func Test_DeleteSetupKeySuccessfully(t *testing.T) {
@@ -1293,3 +1292,275 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) {
err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID)
require.Error(t, err)
}
+
+func TestSqlStore_GetGroupsByIDs(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ tests := []struct {
+ name string
+ groupIDs []string
+ expectedCount int
+ }{
+ {
+ name: "retrieve existing groups by existing IDs",
+ groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
+ expectedCount: 2,
+ },
+ {
+ name: "empty group IDs list",
+ groupIDs: []string{},
+ expectedCount: 0,
+ },
+ {
+ name: "non-existing group IDs",
+ groupIDs: []string{"nonexistent1", "nonexistent2"},
+ expectedCount: 0,
+ },
+ {
+ name: "mixed existing and non-existing group IDs",
+ groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"},
+ expectedCount: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs)
+ require.NoError(t, err)
+ require.Len(t, groups, tt.expectedCount)
+ })
+ }
+}
+
+func TestSqlStore_SaveGroup(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ group := &nbgroup.Group{
+ ID: "group-id",
+ AccountID: accountID,
+ Issued: "api",
+ Peers: []string{"peer1", "peer2"},
+ }
+ err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group)
+ require.NoError(t, err)
+
+ savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id")
+ require.NoError(t, err)
+ require.Equal(t, savedGroup, group)
+}
+
+func TestSqlStore_SaveGroups(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ groups := []*nbgroup.Group{
+ {
+ ID: "group-1",
+ AccountID: accountID,
+ Issued: "api",
+ Peers: []string{"peer1", "peer2"},
+ },
+ {
+ ID: "group-2",
+ AccountID: accountID,
+ Issued: "integration",
+ Peers: []string{"peer3", "peer4"},
+ },
+ }
+ err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups)
+ require.NoError(t, err)
+}
+
+func TestSqlStore_DeleteGroup(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ tests := []struct {
+ name string
+ groupID string
+ expectError bool
+ }{
+ {
+ name: "delete existing group",
+ groupID: "cfefqs706sqkneg59g4g",
+ expectError: false,
+ },
+ {
+ name: "delete non-existing group",
+ groupID: "non-existing-group-id",
+ expectError: true,
+ },
+ {
+ name: "delete with empty group ID",
+ groupID: "",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID)
+ if tt.expectError {
+ require.Error(t, err)
+ sErr, ok := status.FromError(err)
+ require.True(t, ok)
+ require.Equal(t, sErr.Type(), status.NotFound)
+ } else {
+ require.NoError(t, err)
+
+ group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID)
+ require.Error(t, err)
+ require.Nil(t, group)
+ }
+ })
+ }
+}
+
+func TestSqlStore_DeleteGroups(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ tests := []struct {
+ name string
+ groupIDs []string
+ expectError bool
+ }{
+ {
+ name: "delete multiple existing groups",
+ groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"},
+ expectError: false,
+ },
+ {
+ name: "delete non-existing groups",
+ groupIDs: []string{"non-existing-id-1", "non-existing-id-2"},
+ expectError: false,
+ },
+ {
+ name: "delete with empty group IDs list",
+ groupIDs: []string{},
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs)
+ if tt.expectError {
+ require.Error(t, err)
+ } else {
+ require.NoError(t, err)
+
+ for _, groupID := range tt.groupIDs {
+ group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID)
+ require.Error(t, err)
+ require.Nil(t, group)
+ }
+ }
+ })
+ }
+}
+
+func TestSqlStore_GetPeerByID(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ tests := []struct {
+ name string
+ peerID string
+ expectError bool
+ }{
+ {
+ name: "retrieve existing peer",
+ peerID: "cfefqs706sqkneg59g4g",
+ expectError: false,
+ },
+ {
+ name: "retrieve non-existing peer",
+ peerID: "non-existing",
+ expectError: true,
+ },
+ {
+ name: "retrieve with empty peer ID",
+ peerID: "",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID)
+ if tt.expectError {
+ require.Error(t, err)
+ sErr, ok := status.FromError(err)
+ require.True(t, ok)
+ require.Equal(t, sErr.Type(), status.NotFound)
+ require.Nil(t, peer)
+ } else {
+ require.NoError(t, err)
+ require.NotNil(t, peer)
+ require.Equal(t, tt.peerID, peer.ID)
+ }
+ })
+ }
+}
+
+func TestSqlStore_GetPeersByIDs(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ tests := []struct {
+ name string
+ peerIDs []string
+ expectedCount int
+ }{
+ {
+ name: "retrieve existing peers by existing IDs",
+ peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"},
+ expectedCount: 2,
+ },
+ {
+ name: "empty peer IDs list",
+ peerIDs: []string{},
+ expectedCount: 0,
+ },
+ {
+ name: "non-existing peer IDs",
+ peerIDs: []string{"nonexistent1", "nonexistent2"},
+ expectedCount: 0,
+ },
+ {
+ name: "mixed existing and non-existing peer IDs",
+ peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"},
+ expectedCount: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs)
+ require.NoError(t, err)
+ require.Len(t, peers, tt.expectedCount)
+ })
+ }
+}
diff --git a/management/server/status/error.go b/management/server/status/error.go
index f1f3f16e6..8b6d0077b 100644
--- a/management/server/status/error.go
+++ b/management/server/status/error.go
@@ -3,7 +3,6 @@ package status
import (
"errors"
"fmt"
- "time"
)
const (
@@ -126,11 +125,6 @@ func NewAdminPermissionError() error {
return Errorf(PermissionDenied, "admin role required to perform this action")
}
-// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
-func NewStoreContextCanceledError(duration time.Duration) error {
- return Errorf(Internal, "store access: context canceled after %v", duration)
-}
-
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
func NewInvalidKeyIDError() error {
return Errorf(InvalidArgument, "invalid key ID")
@@ -140,3 +134,8 @@ func NewInvalidKeyIDError() error {
func NewGetAccountError(err error) error {
return Errorf(Internal, "error getting account: %s", err)
}
+
+// NewGroupNotFoundError creates a new Error with NotFound type for a missing group
+func NewGroupNotFoundError(groupID string) error {
+ return Errorf(NotFound, "group: %s not found", groupID)
+}
diff --git a/management/server/store.go b/management/server/store.go
index c7c066629..71b0d457b 100644
--- a/management/server/store.go
+++ b/management/server/store.go
@@ -62,7 +62,7 @@ type Store interface {
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
- GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
+ GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
@@ -76,6 +76,8 @@ type Store interface {
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
+ DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
+ DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
@@ -90,6 +92,8 @@ type Store interface {
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
+ GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
+ GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
@@ -108,7 +112,7 @@ type Store interface {
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
- IncrementNetworkSerial(ctx context.Context, accountId string) error
+ IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string
diff --git a/management/server/user.go b/management/server/user.go
index 5e0d9d034..74062112a 100644
--- a/management/server/user.go
+++ b/management/server/user.go
@@ -494,7 +494,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
if updateAccountPeers {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, account.Id)
}
return nil
@@ -835,7 +835,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, account.Id)
}
for _, storeEvent := range eventsToStore {
@@ -1132,7 +1132,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, account.Id)
}
return nil
}
@@ -1240,7 +1240,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
}
if updateAccountPeers {
- am.updateAccountPeers(ctx, account)
+ am.updateAccountPeers(ctx, accountID)
}
for targetUserID, meta := range deletedUsersMeta {
From a1c5287b7c7a6152cb3cff319dea278ba1cfd3c1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=C4=B0smail?=
Date: Fri, 15 Nov 2024 20:21:27 +0300
Subject: [PATCH 13/39] Fix the Inactivity Expiration problem. (#2865)
---
management/server/account.go | 29 +++++++++++++++++------------
1 file changed, 17 insertions(+), 12 deletions(-)
diff --git a/management/server/account.go b/management/server/account.go
index 1bd8a99a9..95c93a22b 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -1186,20 +1186,25 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
- if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
- event := activity.AccountPeerInactivityExpirationEnabled
- if !newSettings.PeerInactivityExpirationEnabled {
- event = activity.AccountPeerInactivityExpirationDisabled
- am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
- } else {
+
+ if newSettings.PeerInactivityExpirationEnabled {
+ if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
+ oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
+
+ am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
}
- am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
- }
-
- if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
- am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
- am.checkAndSchedulePeerInactivityExpiration(ctx, account)
+ } else {
+ if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
+ event := activity.AccountPeerInactivityExpirationEnabled
+ if !newSettings.PeerInactivityExpirationEnabled {
+ event = activity.AccountPeerInactivityExpirationDisabled
+ am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
+ } else {
+ am.checkAndSchedulePeerInactivityExpiration(ctx, account)
+ }
+ am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
+ }
}
return nil
From 121dfda915cbaa17e8a16af1f96a71927fc0ece1 Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Fri, 15 Nov 2024 20:05:26 +0100
Subject: [PATCH 14/39] [client] Fix state manager race conditions (#2890)
---
client/internal/dns/server.go | 20 +++----
client/internal/engine.go | 2 +-
.../routemanager/refcounter/refcounter.go | 58 +++++++++----------
.../internal/routemanager/systemops/state.go | 23 ++++----
.../systemops/systemops_generic.go | 20 +------
client/internal/statemanager/manager.go | 40 +++++++++----
util/file.go | 55 +++++++++++++-----
7 files changed, 118 insertions(+), 100 deletions(-)
diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go
index 6c4dccae7..f0277319c 100644
--- a/client/internal/dns/server.go
+++ b/client/internal/dns/server.go
@@ -7,7 +7,6 @@ import (
"runtime"
"strings"
"sync"
- "time"
"github.com/miekg/dns"
"github.com/mitchellh/hashstructure/v2"
@@ -323,13 +322,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
log.Error(err)
}
- // persist dns state right away
- ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
- defer cancel()
-
- // don't block
go func() {
- if err := s.stateManager.PersistState(ctx); err != nil {
+ // persist dns state right away
+ if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}
}()
@@ -537,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks(
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
}
- // persist dns state right away
- ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second)
- defer cancel()
- if err := s.stateManager.PersistState(ctx); err != nil {
- l.Errorf("Failed to persist dns state: %v", err)
- }
+ go func() {
+ if err := s.stateManager.PersistState(s.ctx); err != nil {
+ l.Errorf("Failed to persist dns state: %v", err)
+ }
+ }()
if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 {
s.addHostRootZone()
diff --git a/client/internal/engine.go b/client/internal/engine.go
index d4a3a561a..1c912220c 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -297,7 +297,7 @@ func (e *Engine) Stop() error {
if err := e.stateManager.Stop(ctx); err != nil {
return fmt.Errorf("failed to stop state manager: %w", err)
}
- if err := e.stateManager.PersistState(ctx); err != nil {
+ if err := e.stateManager.PersistState(context.Background()); err != nil {
log.Errorf("failed to persist state: %v", err)
}
diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go
index 0e230ef40..f2f0a169d 100644
--- a/client/internal/routemanager/refcounter/refcounter.go
+++ b/client/internal/routemanager/refcounter/refcounter.go
@@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error
type Counter[Key comparable, I, O any] struct {
// refCountMap keeps track of the reference Ref for keys
refCountMap map[Key]Ref[O]
- refCountMu sync.Mutex
+ mu sync.Mutex
// idMap keeps track of the keys associated with an ID for removal
idMap map[string][]Key
- idMu sync.Mutex
add AddFunc[Key, I, O]
remove RemoveFunc[Key, O]
}
@@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key
func (rm *Counter[Key, I, O]) LoadData(
existingCounter *Counter[Key, I, O],
) {
- rm.refCountMu.Lock()
- defer rm.refCountMu.Unlock()
- rm.idMu.Lock()
- defer rm.idMu.Unlock()
+ rm.mu.Lock()
+ defer rm.mu.Unlock()
rm.refCountMap = existingCounter.refCountMap
rm.idMap = existingCounter.idMap
@@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData(
// Get retrieves the current reference count and associated data for a key.
// If the key doesn't exist, it returns a zero value Ref and false.
func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
- rm.refCountMu.Lock()
- defer rm.refCountMu.Unlock()
+ rm.mu.Lock()
+ defer rm.mu.Unlock()
ref, ok := rm.refCountMap[key]
return ref, ok
@@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) {
// Increment increments the reference count for the given key.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
- rm.refCountMu.Lock()
- defer rm.refCountMu.Unlock()
+ rm.mu.Lock()
+ defer rm.mu.Unlock()
+ return rm.increment(key, in)
+}
+
+func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) {
ref := rm.refCountMap[key]
logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out)
@@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) {
// IncrementWithID increments the reference count for the given key and groups it under the given ID.
// If this is the first reference to the key, the AddFunc is called.
func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) {
- rm.idMu.Lock()
- defer rm.idMu.Unlock()
+ rm.mu.Lock()
+ defer rm.mu.Unlock()
- ref, err := rm.Increment(key, in)
+ ref, err := rm.increment(key, in)
if err != nil {
return ref, fmt.Errorf("with ID: %w", err)
}
@@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O],
// Decrement decrements the reference count for the given key.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
- rm.refCountMu.Lock()
- defer rm.refCountMu.Unlock()
+ rm.mu.Lock()
+ defer rm.mu.Unlock()
+ return rm.decrement(key)
+}
+func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) {
ref, ok := rm.refCountMap[key]
if !ok {
logCallerF("No reference found for key %v", key)
@@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) {
// DecrementWithID decrements the reference count for all keys associated with the given ID.
// If the reference count reaches 0, the RemoveFunc is called.
func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
- rm.idMu.Lock()
- defer rm.idMu.Unlock()
+ rm.mu.Lock()
+ defer rm.mu.Unlock()
var merr *multierror.Error
for _, key := range rm.idMap[id] {
- if _, err := rm.Decrement(key); err != nil {
+ if _, err := rm.decrement(key); err != nil {
merr = multierror.Append(merr, err)
}
}
@@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error {
// Flush removes all references and calls RemoveFunc for each key.
func (rm *Counter[Key, I, O]) Flush() error {
- rm.refCountMu.Lock()
- defer rm.refCountMu.Unlock()
- rm.idMu.Lock()
- defer rm.idMu.Unlock()
+ rm.mu.Lock()
+ defer rm.mu.Unlock()
var merr *multierror.Error
for key := range rm.refCountMap {
@@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error {
// Clear removes all references without calling RemoveFunc.
func (rm *Counter[Key, I, O]) Clear() {
- rm.refCountMu.Lock()
- defer rm.refCountMu.Unlock()
- rm.idMu.Lock()
- defer rm.idMu.Unlock()
+ rm.mu.Lock()
+ defer rm.mu.Unlock()
clear(rm.refCountMap)
clear(rm.idMap)
@@ -217,10 +217,8 @@ func (rm *Counter[Key, I, O]) Clear() {
// MarshalJSON implements the json.Marshaler interface for Counter.
func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) {
- rm.refCountMu.Lock()
- defer rm.refCountMu.Unlock()
- rm.idMu.Lock()
- defer rm.idMu.Unlock()
+ rm.mu.Lock()
+ defer rm.mu.Unlock()
return json.Marshal(struct {
RefCountMap map[Key]Ref[O] `json:"refCountMap"`
diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go
index 425908922..8e158711e 100644
--- a/client/internal/routemanager/systemops/state.go
+++ b/client/internal/routemanager/systemops/state.go
@@ -2,31 +2,28 @@ package systemops
import (
"net/netip"
- "sync"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
)
-type ShutdownState struct {
- Counter *ExclusionCounter `json:"counter,omitempty"`
- mu sync.RWMutex
-}
+type ShutdownState ExclusionCounter
func (s *ShutdownState) Name() string {
return "route_state"
}
func (s *ShutdownState) Cleanup() error {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- if s.Counter == nil {
- return nil
- }
-
sysops := NewSysOps(nil, nil)
sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable)
- sysops.refCounter.LoadData(s.Counter)
+ sysops.refCounter.LoadData((*ExclusionCounter)(s))
return sysops.refCounter.Flush()
}
+
+func (s *ShutdownState) MarshalJSON() ([]byte, error) {
+ return (*ExclusionCounter)(s).MarshalJSON()
+}
+
+func (s *ShutdownState) UnmarshalJSON(data []byte) error {
+ return (*ExclusionCounter)(s).UnmarshalJSON(data)
+}
diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go
index 4ff34aa51..f8b3ebbb8 100644
--- a/client/internal/routemanager/systemops/systemops_generic.go
+++ b/client/internal/routemanager/systemops/systemops_generic.go
@@ -62,7 +62,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return nexthop, err
},
func(prefix netip.Prefix, nexthop Nexthop) error {
- // remove from state even if we have trouble removing it from the route table
+ // update state even if we have trouble removing it from the route table
// it could be already gone
r.updateState(stateManager)
@@ -75,12 +75,9 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return r.setupHooks(initAddresses)
}
+// updateState updates state on every change so it will be persisted regularly
func (r *SysOps) updateState(stateManager *statemanager.Manager) {
- state := getState(stateManager)
-
- state.Counter = r.refCounter
-
- if err := stateManager.UpdateState(state); err != nil {
+ if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil {
log.Errorf("failed to update state: %v", err)
}
}
@@ -532,14 +529,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P
// Return true if the longest matching prefix is from vpnRoutes
return isVpn, longestPrefix
}
-
-func getState(stateManager *statemanager.Manager) *ShutdownState {
- var shutdownState *ShutdownState
- if state := stateManager.GetState(shutdownState); state != nil {
- shutdownState = state.(*ShutdownState)
- } else {
- shutdownState = &ShutdownState{}
- }
-
- return shutdownState
-}
diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go
index 580ccdfc7..da6dd022f 100644
--- a/client/internal/statemanager/manager.go
+++ b/client/internal/statemanager/manager.go
@@ -74,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
- if m.cancel != nil {
- m.cancel()
+ if m.cancel == nil {
+ return nil
+ }
+ m.cancel()
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-m.done:
- return nil
- }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-m.done:
}
return nil
@@ -179,14 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error {
return nil
}
+ bs, err := marshalWithPanicRecovery(m.states)
+ if err != nil {
+ return fmt.Errorf("marshal states: %w", err)
+ }
+
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
done := make(chan error, 1)
-
start := time.Now()
go func() {
- done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states)
+ done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs)
}()
select {
@@ -286,3 +290,19 @@ func (m *Manager) PerformCleanup() error {
return nberrors.FormatErrorOrNil(merr)
}
+
+func marshalWithPanicRecovery(v any) ([]byte, error) {
+ var bs []byte
+ var err error
+
+ func() {
+ defer func() {
+ if r := recover(); r != nil {
+ err = fmt.Errorf("panic during marshal: %v", r)
+ }
+ }()
+ bs, err = json.Marshal(v)
+ }()
+
+ return bs, err
+}
diff --git a/util/file.go b/util/file.go
index 4641cc1b8..f7de7ede2 100644
--- a/util/file.go
+++ b/util/file.go
@@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
+ "errors"
"fmt"
"io"
"os"
@@ -14,6 +15,19 @@ import (
log "github.com/sirupsen/logrus"
)
+func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error {
+ configDir, configFileName, err := prepareConfigFileDir(file)
+ if err != nil {
+ return fmt.Errorf("prepare config file dir: %w", err)
+ }
+
+ if err = EnforcePermission(file); err != nil {
+ return fmt.Errorf("enforce permission: %w", err)
+ }
+
+ return writeBytes(ctx, file, err, configDir, configFileName, bs)
+}
+
// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory
func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error {
configDir, configFileName, err := prepareConfigFileDir(file)
@@ -82,29 +96,44 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error {
func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error {
// Check context before expensive operations
if ctx.Err() != nil {
- return ctx.Err()
+ return fmt.Errorf("write json start: %w", ctx.Err())
}
// make it pretty
bs, err := json.MarshalIndent(obj, "", " ")
if err != nil {
- return err
+ return fmt.Errorf("marshal: %w", err)
}
+ return writeBytes(ctx, file, err, configDir, configFileName, bs)
+}
+
+func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error {
if ctx.Err() != nil {
- return ctx.Err()
+ return fmt.Errorf("write bytes start: %w", ctx.Err())
}
tempFile, err := os.CreateTemp(configDir, ".*"+configFileName)
if err != nil {
- return err
+ return fmt.Errorf("create temp: %w", err)
}
tempFileName := tempFile.Name()
- // closing file ops as windows doesn't allow to move it
- err = tempFile.Close()
+
+ if deadline, ok := ctx.Deadline(); ok {
+ if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) {
+ log.Warnf("failed to set deadline: %v", err)
+ }
+ }
+
+ _, err = tempFile.Write(bs)
if err != nil {
- return err
+ _ = tempFile.Close()
+ return fmt.Errorf("write: %w", err)
+ }
+
+ if err = tempFile.Close(); err != nil {
+ return fmt.Errorf("close %s: %w", tempFileName, err)
}
defer func() {
@@ -114,19 +143,13 @@ func writeJson(ctx context.Context, file string, obj interface{}, configDir stri
}
}()
- err = os.WriteFile(tempFileName, bs, 0600)
- if err != nil {
- return err
- }
-
// Check context again
if ctx.Err() != nil {
- return ctx.Err()
+ return fmt.Errorf("after temp file: %w", ctx.Err())
}
- err = os.Rename(tempFileName, file)
- if err != nil {
- return err
+ if err = os.Rename(tempFileName, file); err != nil {
+ return fmt.Errorf("move %s to %s: %w", tempFileName, file, err)
}
return nil
From 582bb587140789884c8905e7ae8c8c3e732077dc Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Fri, 15 Nov 2024 22:55:33 +0100
Subject: [PATCH 15/39] Move state updates outside the refcounter (#2897)
---
.../systemops/systemops_generic.go | 18 +++++++-----------
1 file changed, 7 insertions(+), 11 deletions(-)
diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go
index f8b3ebbb8..3038c3ec5 100644
--- a/client/internal/routemanager/systemops/systemops_generic.go
+++ b/client/internal/routemanager/systemops/systemops_generic.go
@@ -57,22 +57,14 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana
return nexthop, refcounter.ErrIgnore
}
- r.updateState(stateManager)
-
return nexthop, err
},
- func(prefix netip.Prefix, nexthop Nexthop) error {
- // update state even if we have trouble removing it from the route table
- // it could be already gone
- r.updateState(stateManager)
-
- return r.removeFromRouteTable(prefix, nexthop)
- },
+ r.removeFromRouteTable,
)
r.refCounter = refCounter
- return r.setupHooks(initAddresses)
+ return r.setupHooks(initAddresses, stateManager)
}
// updateState updates state on every change so it will be persisted regularly
@@ -333,7 +325,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface)
return r.removeFromRouteTable(prefix, nextHop)
}
-func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
+func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) {
beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error {
prefix, err := util.GetPrefixFromIP(ip)
if err != nil {
@@ -344,6 +336,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
return fmt.Errorf("adding route reference: %v", err)
}
+ r.updateState(stateManager)
+
return nil
}
afterHook := func(connID nbnet.ConnectionID) error {
@@ -351,6 +345,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re
return fmt.Errorf("remove route reference: %w", err)
}
+ r.updateState(stateManager)
+
return nil
}
From a7d5c522033beef9174898b5e43a907b282f7341 Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Fri, 15 Nov 2024 22:59:49 +0100
Subject: [PATCH 16/39] Fix error state race on mgmt connection error (#2892)
---
client/internal/connect.go | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/client/internal/connect.go b/client/internal/connect.go
index dff44f1d2..f76aa066b 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -157,7 +157,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
engineCtx, cancel := context.WithCancel(c.ctx)
defer func() {
- c.statusRecorder.MarkManagementDisconnected(state.err)
+ _, err := state.Status()
+ c.statusRecorder.MarkManagementDisconnected(err)
c.statusRecorder.CleanLocalPeerState()
cancel()
}()
From ec543f89fb819b4aae28850b370f0c06f05f7f96 Mon Sep 17 00:00:00 2001
From: Kursat Aktas
Date: Sat, 16 Nov 2024 17:45:31 +0300
Subject: [PATCH 17/39] Introducing NetBird Guru on Gurubase.io (#2778)
---
README.md | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/README.md b/README.md
index 270c9ad87..a2d7f3897 100644
--- a/README.md
+++ b/README.md
@@ -19,6 +19,10 @@
+
+
+
+
From 65a94f695f09063d505f49a6b2496f7a2e4d48b4 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Mon, 18 Nov 2024 12:55:02 +0100
Subject: [PATCH 18/39] use google domain for tests (#2902)
---
client/internal/dns/server_test.go | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go
index 21f1f1b7d..eab9f4ecb 100644
--- a/client/internal/dns/server_test.go
+++ b/client/internal/dns/server_test.go
@@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
Port: 53,
},
},
- Domains: []string{"customdomain.com"},
+ Domains: []string{"google.com"},
Primary: false,
},
},
@@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) {
if ips[0] != zoneRecords[0].RData {
t.Fatalf("invalid zone record: %v", err)
}
- _, err = resolver.LookupHost(context.Background(), "customdomain.com")
+ _, err = resolver.LookupHost(context.Background(), "google.com")
if err != nil {
t.Errorf("failed to resolve: %s", err)
}
From ec6438e643c3ef0b0388c05c2fad5beed99fb4fe Mon Sep 17 00:00:00 2001
From: bcmmbaga
Date: Mon, 18 Nov 2024 17:12:13 +0300
Subject: [PATCH 19/39] Use update strength and simplify check
Signed-off-by: bcmmbaga
---
management/server/policy.go | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/management/server/policy.go b/management/server/policy.go
index 6dcb96316..693ae2872 100644
--- a/management/server/policy.go
+++ b/management/server/policy.go
@@ -435,7 +435,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
- policy, err = transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
+ policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID)
if err != nil {
return err
}
@@ -502,8 +502,6 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account
if hasPeers {
return true, nil
}
-
- return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
}
return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
From 78fab877c07ee50aae95ed037218ec41f0bf2489 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Mon, 18 Nov 2024 15:31:53 +0100
Subject: [PATCH 20/39] [misc] Update signing pipeline version (#2900)
---
.github/workflows/release.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 14e383a27..183cdb02c 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -9,7 +9,7 @@ on:
pull_request:
env:
- SIGN_PIPE_VER: "v0.0.16"
+ SIGN_PIPE_VER: "v0.0.17"
GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
From df98c67ac8100ac75995fecfaa32e76691db36c0 Mon Sep 17 00:00:00 2001
From: bcmmbaga
Date: Mon, 18 Nov 2024 18:46:52 +0300
Subject: [PATCH 21/39] prevent changing ruleID when not empty
Signed-off-by: bcmmbaga
---
management/server/http/policies_handler.go | 7 ++++++-
management/server/sql_store_test.go | 2 ++
2 files changed, 8 insertions(+), 1 deletion(-)
diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go
index 8255e4896..ca256a183 100644
--- a/management/server/http/policies_handler.go
+++ b/management/server/http/policies_handler.go
@@ -128,8 +128,13 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
Description: req.Description,
}
for _, rule := range req.Rules {
+ ruleID := policyID // TODO: when policy can contain multiple rules, need refactor
+ if rule.Id != nil {
+ ruleID = *rule.Id
+ }
+
pr := server.PolicyRule{
- ID: policyID, // TODO: when policy can contain multiple rules, need refactor
+ ID: ruleID,
PolicyID: policyID,
Name: rule.Name,
Destinations: rule.Destinations,
diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go
index 8931008d7..c05793fc6 100644
--- a/management/server/sql_store_test.go
+++ b/management/server/sql_store_test.go
@@ -1832,6 +1832,8 @@ func TestSqlStore_SavePolicy(t *testing.T) {
policy.Enabled = false
policy.Description = "policy"
+ policy.Rules[0].Sources = []string{"group"}
+ policy.Rules[0].Ports = []string{"80", "443"}
err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
require.NoError(t, err)
From b60e2c32614615d5f7c44d95794338ade30d9287 Mon Sep 17 00:00:00 2001
From: bcmmbaga
Date: Mon, 18 Nov 2024 22:48:38 +0300
Subject: [PATCH 22/39] prevent duplicate rules during updates
Signed-off-by: bcmmbaga
---
management/server/http/policies_handler.go | 2 +-
management/server/policy.go | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go
index ca256a183..eff9092d4 100644
--- a/management/server/http/policies_handler.go
+++ b/management/server/http/policies_handler.go
@@ -128,7 +128,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
Description: req.Description,
}
for _, rule := range req.Rules {
- ruleID := policyID // TODO: when policy can contain multiple rules, need refactor
+ var ruleID string
if rule.Id != nil {
ruleID = *rule.Id
}
diff --git a/management/server/policy.go b/management/server/policy.go
index 693ae2872..2d3abc3f1 100644
--- a/management/server/policy.go
+++ b/management/server/policy.go
@@ -532,7 +532,7 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po
for i, rule := range policy.Rules {
ruleCopy := rule.Copy()
if ruleCopy.ID == "" {
- ruleCopy.ID = xid.New().String()
+ ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor
ruleCopy.PolicyID = policy.ID
}
From 52ea2e84e9aa3c03fe43c5098ab50c94ff2e0818 Mon Sep 17 00:00:00 2001
From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Tue, 19 Nov 2024 00:04:50 +0100
Subject: [PATCH 23/39] [management] Add transaction metrics and exclude
getAccount time from peers update (#2904)
---
management/server/peer.go | 11 ++++++-----
management/server/sql_store.go | 11 ++++++++++-
management/server/telemetry/store_metrics.go | 12 ++++++++++++
3 files changed, 28 insertions(+), 6 deletions(-)
diff --git a/management/server/peer.go b/management/server/peer.go
index 1405dead8..8c45e45c9 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -988,6 +988,12 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
+ account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err)
+ return
+ }
+
start := time.Now()
defer func() {
if am.metrics != nil {
@@ -995,11 +1001,6 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
}
}()
- account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
- if err != nil {
- log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err)
- return
- }
peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account)
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index 0ebda6440..278f5443d 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -1123,6 +1123,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength Lock
}
func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error {
+ startTime := time.Now()
tx := s.db.Begin()
if tx.Error != nil {
return tx.Error
@@ -1133,7 +1134,15 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor
tx.Rollback()
return err
}
- return tx.Commit().Error
+
+ err = tx.Commit().Error
+
+ log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime))
+ if s.metrics != nil {
+ s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime))
+ }
+
+ return err
}
func (s *SqlStore) withTx(tx *gorm.DB) Store {
diff --git a/management/server/telemetry/store_metrics.go b/management/server/telemetry/store_metrics.go
index b038c3d36..bb3745b5a 100644
--- a/management/server/telemetry/store_metrics.go
+++ b/management/server/telemetry/store_metrics.go
@@ -13,6 +13,7 @@ type StoreMetrics struct {
globalLockAcquisitionDurationMs metric.Int64Histogram
persistenceDurationMicro metric.Int64Histogram
persistenceDurationMs metric.Int64Histogram
+ transactionDurationMs metric.Int64Histogram
ctx context.Context
}
@@ -40,11 +41,17 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er
return nil, err
}
+ transactionDurationMs, err := meter.Int64Histogram("management.store.transaction.duration.ms")
+ if err != nil {
+ return nil, err
+ }
+
return &StoreMetrics{
globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro,
globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs,
persistenceDurationMicro: persistenceDurationMicro,
persistenceDurationMs: persistenceDurationMs,
+ transactionDurationMs: transactionDurationMs,
ctx: ctx,
}, nil
}
@@ -60,3 +67,8 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) {
metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds())
metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds())
}
+
+// CountTransactionDuration counts the duration of a store persistence operation
+func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) {
+ metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds())
+}
From eb5d0569ae0ce829a312e13ab3e9757b9cdf019f Mon Sep 17 00:00:00 2001
From: "Krzysztof Nazarewski (kdn)"
Date: Tue, 19 Nov 2024 14:14:58 +0100
Subject: [PATCH 24/39] [client] Add NB_SKIP_SOCKET_MARK & fix crash instead of
returing an error (#2899)
* dialer: fix crash instead of returning error
* add NB_SKIP_SOCKET_MARK
---
.../routemanager/systemops/systemops_linux.go | 2 +-
util/grpc/dialer.go | 9 +++++++--
util/net/dialer_nonios.go | 2 +-
util/net/net_linux.go | 12 ++++++++++++
4 files changed, 21 insertions(+), 4 deletions(-)
diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go
index 0124fd95e..71a0f26ae 100644
--- a/client/internal/routemanager/systemops/systemops_linux.go
+++ b/client/internal/routemanager/systemops/systemops_linux.go
@@ -55,7 +55,7 @@ type ruleParams struct {
// isLegacy determines whether to use the legacy routing setup
func isLegacy() bool {
- return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled()
+ return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true"
}
// setIsLegacy sets the legacy routing setup
diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go
index 57ab8fd55..4fbffe342 100644
--- a/util/grpc/dialer.go
+++ b/util/grpc/dialer.go
@@ -3,6 +3,9 @@ package grpc
import (
"context"
"crypto/tls"
+ "fmt"
+ "google.golang.org/grpc/codes"
+ "google.golang.org/grpc/status"
"net"
"os/user"
"runtime"
@@ -23,20 +26,22 @@ func WithCustomDialer() grpc.DialOption {
if runtime.GOOS == "linux" {
currentUser, err := user.Current()
if err != nil {
- log.Fatalf("failed to get current user: %v", err)
+ return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err)
}
// the custom dialer requires root permissions which are not required for use cases run as non-root
if currentUser.Uid != "0" {
+ log.Debug("Not running as root, using standard dialer")
dialer := &net.Dialer{}
return dialer.DialContext(ctx, "tcp", addr)
}
}
+ log.Debug("Using nbnet.NewDialer()")
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
if err != nil {
log.Errorf("Failed to dial: %s", err)
- return nil, err
+ return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
}
return conn, nil
})
diff --git a/util/net/dialer_nonios.go b/util/net/dialer_nonios.go
index 4032a75c0..34004a368 100644
--- a/util/net/dialer_nonios.go
+++ b/util/net/dialer_nonios.go
@@ -69,7 +69,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
conn, err := d.Dialer.DialContext(ctx, network, address)
if err != nil {
- return nil, fmt.Errorf("dial: %w", err)
+ return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
}
// Wrap the connection in Conn to handle Close with hooks
diff --git a/util/net/net_linux.go b/util/net/net_linux.go
index 954545eb5..98f49af8d 100644
--- a/util/net/net_linux.go
+++ b/util/net/net_linux.go
@@ -4,9 +4,14 @@ package net
import (
"fmt"
+ "os"
"syscall"
+
+ log "github.com/sirupsen/logrus"
)
+const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK"
+
// SetSocketMark sets the SO_MARK option on the given socket connection
func SetSocketMark(conn syscall.Conn) error {
sysconn, err := conn.SyscallConn()
@@ -36,6 +41,13 @@ func SetRawSocketMark(conn syscall.RawConn) error {
func SetSocketOpt(fd int) error {
if CustomRoutingDisabled() {
+ log.Infof("Custom routing is disabled, skipping SO_MARK")
+ return nil
+ }
+
+ // Check for the new environment variable
+ if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" {
+ log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK")
return nil
}
From 5dd6a08ea6926cb1cb87a3301524b9031f4db61a Mon Sep 17 00:00:00 2001
From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Tue, 19 Nov 2024 17:25:49 +0100
Subject: [PATCH 25/39] link peer meta update back to account object (#2911)
---
management/server/peer.go | 1 +
1 file changed, 1 insertion(+)
diff --git a/management/server/peer.go b/management/server/peer.go
index 8c45e45c9..901e4815d 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -667,6 +667,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
+ account.Peers[peer.ID] = peer
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil {
From f66bbcc54c65f0856679fed1b50298e97c0ec7af Mon Sep 17 00:00:00 2001
From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Tue, 19 Nov 2024 18:13:26 +0100
Subject: [PATCH 26/39] [management] Add metric for peer meta update (#2913)
---
management/server/peer.go | 2 ++
.../server/telemetry/accountmanager_metrics.go | 12 ++++++++++++
2 files changed, 14 insertions(+)
diff --git a/management/server/peer.go b/management/server/peer.go
index 901e4815d..beb833dba 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -667,6 +667,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
+ am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
account.Peers[peer.ID] = peer
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
err = am.Store.SavePeer(ctx, account.Id, peer)
@@ -801,6 +802,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
updated := peer.UpdateMetaIfNew(login.Meta)
if updated {
+ am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
shouldStorePeer = true
}
diff --git a/management/server/telemetry/accountmanager_metrics.go b/management/server/telemetry/accountmanager_metrics.go
index e4bb4e3c3..4a5a31e2d 100644
--- a/management/server/telemetry/accountmanager_metrics.go
+++ b/management/server/telemetry/accountmanager_metrics.go
@@ -13,6 +13,7 @@ type AccountManagerMetrics struct {
updateAccountPeersDurationMs metric.Float64Histogram
getPeerNetworkMapDurationMs metric.Float64Histogram
networkMapObjectCount metric.Int64Histogram
+ peerMetaUpdateCount metric.Int64Counter
}
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
@@ -44,11 +45,17 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
return nil, err
}
+ peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1"))
+ if err != nil {
+ return nil, err
+ }
+
return &AccountManagerMetrics{
ctx: ctx,
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
networkMapObjectCount: networkMapObjectCount,
+ peerMetaUpdateCount: peerMetaUpdateCount,
}, nil
}
@@ -67,3 +74,8 @@ func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration ti
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
metrics.networkMapObjectCount.Record(metrics.ctx, count)
}
+
+// CountPeerMetUpdate counts the number of peer meta updates
+func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
+ metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
+}
From aa575d6f445e74f34f8353a4c413adc209c56f4b Mon Sep 17 00:00:00 2001
From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Thu, 21 Nov 2024 15:10:34 +0100
Subject: [PATCH 27/39] [management] Add activity events to group propagation
flow (#2916)
---
management/server/account.go | 38 ++++++++++-
management/server/activity/codes.go | 6 ++
management/server/user.go | 97 ++++++++++++++++++++++-------
3 files changed, 116 insertions(+), 25 deletions(-)
diff --git a/management/server/account.go b/management/server/account.go
index 95c93a22b..0ab123655 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -965,7 +965,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro
}
// UserGroupsAddToPeers adds groups to all peers of user
-func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
+func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string {
+ groupUpdates := make(map[string][]string)
+
userPeers := make(map[string]struct{})
for pid, peer := range a.Peers {
if peer.UserID == userID {
@@ -979,6 +981,8 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
continue
}
+ oldPeers := group.Peers
+
groupPeers := make(map[string]struct{})
for _, pid := range group.Peers {
groupPeers[pid] = struct{}{}
@@ -992,16 +996,25 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
+
+ groupUpdates[gid] = difference(group.Peers, oldPeers)
}
+
+ return groupUpdates
}
// UserGroupsRemoveFromPeers removes groups from all peers of user
-func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
+func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string {
+ groupUpdates := make(map[string][]string)
+
for _, gid := range groups {
group, ok := a.Groups[gid]
if !ok || group.Name == "All" {
continue
}
+
+ oldPeers := group.Peers
+
update := make([]string, 0, len(group.Peers))
for _, pid := range group.Peers {
peer, ok := a.Peers[pid]
@@ -1013,7 +1026,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
}
}
group.Peers = update
+ groupUpdates[gid] = difference(oldPeers, group.Peers)
}
+
+ return groupUpdates
}
// BuildManager creates a new DefaultAccountManager with a provided Store
@@ -1175,6 +1191,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, err
}
+ err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
+ if err != nil {
+ return nil, fmt.Errorf("groups propagation failed: %w", err)
+ }
+
updatedAccount := account.UpdateSettings(newSettings)
err = am.Store.SaveAccount(ctx, account)
@@ -1185,6 +1206,19 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return updatedAccount, nil
}
+func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error {
+ if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
+ if newSettings.GroupsPropagationEnabled {
+ am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
+ // Todo: retroactively add user groups to all peers
+ } else {
+ am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
+ }
+ }
+
+ return nil
+}
+
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go
index 603260dbc..4c57d65fb 100644
--- a/management/server/activity/codes.go
+++ b/management/server/activity/codes.go
@@ -148,6 +148,9 @@ const (
AccountPeerInactivityExpirationDurationUpdated Activity = 67
SetupKeyDeleted Activity = 68
+
+ UserGroupPropagationEnabled Activity = 69
+ UserGroupPropagationDisabled Activity = 70
)
var activityMap = map[Activity]Code{
@@ -222,6 +225,9 @@ var activityMap = map[Activity]Code{
AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"},
AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"},
SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"},
+
+ UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"},
+ UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"},
}
// StringCode returns a string code of the activity
diff --git a/management/server/user.go b/management/server/user.go
index 74062112a..edb5e6fd3 100644
--- a/management/server/user.go
+++ b/management/server/user.go
@@ -805,15 +805,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
expiredPeers = append(expiredPeers, blockedPeers...)
}
+ peerGroupsAdded := make(map[string][]string)
+ peerGroupsRemoved := make(map[string][]string)
if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled {
removedGroups := difference(oldUser.AutoGroups, update.AutoGroups)
// need force update all auto groups in any case they will not be duplicated
- account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
- account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
+ peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...)
+ peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...)
}
- events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
- eventsToStore = append(eventsToStore, events...)
+ userUpdateEvents := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole)
+ eventsToStore = append(eventsToStore, userUpdateEvents...)
+
+ userGroupsEvents := am.prepareUserGroupsEvents(ctx, initiatorUser.Id, oldUser, newUser, account, peerGroupsAdded, peerGroupsRemoved)
+ eventsToStore = append(eventsToStore, userGroupsEvents...)
updatedUserInfo, err := getUserInfo(ctx, am, newUser, account)
if err != nil {
@@ -872,32 +877,78 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in
})
}
+ return eventsToStore
+}
+
+func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() {
+ var eventsToStore []func()
if newUser.AutoGroups != nil {
removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups)
addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups)
- for _, g := range removedGroups {
- group := account.GetGroup(g)
- if group != nil {
- eventsToStore = append(eventsToStore, func() {
- am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
- map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
- })
- } else {
- log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
- }
- }
- for _, g := range addedGroups {
- group := account.GetGroup(g)
- if group != nil {
- eventsToStore = append(eventsToStore, func() {
- am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
- map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
- })
- }
+ removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved)
+ eventsToStore = append(eventsToStore, removedEvents...)
+
+ addedEvents := am.handleGroupAddedToUser(ctx, initiatorUserID, oldUser, newUser, account, addedGroups, peerGroupsAdded)
+ eventsToStore = append(eventsToStore, addedEvents...)
+ }
+ return eventsToStore
+}
+
+func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() {
+ var eventsToStore []func()
+ for _, g := range addedGroups {
+ group := account.GetGroup(g)
+ if group != nil {
+ eventsToStore = append(eventsToStore, func() {
+ am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser,
+ map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
+ })
}
}
+ for groupID, peerIDs := range peerGroupsAdded {
+ group := account.GetGroup(groupID)
+ for _, peerID := range peerIDs {
+ peer := account.GetPeer(peerID)
+ eventsToStore = append(eventsToStore, func() {
+ meta := map[string]any{
+ "group": group.Name, "group_id": group.ID,
+ "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
+ }
+ am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupAddedToPeer, meta)
+ })
+ }
+ }
+ return eventsToStore
+}
+func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() {
+ var eventsToStore []func()
+ for _, g := range removedGroups {
+ group := account.GetGroup(g)
+ if group != nil {
+ eventsToStore = append(eventsToStore, func() {
+ am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser,
+ map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName})
+ })
+
+ } else {
+ log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id)
+ }
+ }
+ for groupID, peerIDs := range peerGroupsRemoved {
+ group := account.GetGroup(groupID)
+ for _, peerID := range peerIDs {
+ peer := account.GetPeer(peerID)
+ eventsToStore = append(eventsToStore, func() {
+ meta := map[string]any{
+ "group": group.Name, "group_id": group.ID,
+ "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()),
+ }
+ am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupRemovedFromPeer, meta)
+ })
+ }
+ }
return eventsToStore
}
From 1bbabf70b057c2384a077742b9e6760161e153aa Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Thu, 21 Nov 2024 16:53:37 +0100
Subject: [PATCH 28/39] [client] Fix allow netbird rule verdict (#2925)
* Fix allow netbird rule verdict
* Fix chain name
---
client/firewall/nftables/manager_linux.go | 8 +++++---
1 file changed, 5 insertions(+), 3 deletions(-)
diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go
index 3f8fac249..8e1aa0d80 100644
--- a/client/firewall/nftables/manager_linux.go
+++ b/client/firewall/nftables/manager_linux.go
@@ -199,7 +199,7 @@ func (m *Manager) AllowNetbird() error {
var chain *nftables.Chain
for _, c := range chains {
- if c.Table.Name == tableNameFilter && c.Name == chainNameForward {
+ if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
chain = c
break
}
@@ -276,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error {
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains {
- if c.Table.Name == "filter" && c.Name == "INPUT" {
+ if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
@@ -351,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
Register: 1,
Data: ifname(m.wgIface.Name()),
},
- &expr.Verdict{},
+ &expr.Verdict{
+ Kind: expr.VerdictAccept,
+ },
},
UserData: []byte(allowNetbirdInputRuleID),
}
From 9db1932664557da94ff64bfc03f5acdcf30667ea Mon Sep 17 00:00:00 2001
From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com>
Date: Fri, 22 Nov 2024 10:15:51 +0100
Subject: [PATCH 29/39] [management] Fix getSetupKey call (#2927)
---
management/server/http/api/openapi.yml | 30 +++++++--
management/server/http/api/types.gen.go | 89 ++++++++++++++++++++++++-
management/server/setupkey.go | 2 +-
management/server/setupkey_test.go | 43 ++++++++----
4 files changed, 144 insertions(+), 20 deletions(-)
diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml
index bfb375277..2e084f6e4 100644
--- a/management/server/http/api/openapi.yml
+++ b/management/server/http/api/openapi.yml
@@ -439,17 +439,13 @@ components:
example: 5
required:
- accessible_peers_count
- SetupKey:
+ SetupKeyBase:
type: object
properties:
id:
description: Setup Key ID
type: string
example: 2531583362
- key:
- description: Setup Key value
- type: string
- example: A616097E-FCF0-48FA-9354-CA4A61142761
name:
description: Setup key name identifier
type: string
@@ -518,6 +514,28 @@ components:
- updated_at
- usage_limit
- ephemeral
+ SetupKeyClear:
+ allOf:
+ - $ref: '#/components/schemas/SetupKeyBase'
+ - type: object
+ properties:
+ key:
+ description: Setup Key as plain text
+ type: string
+ example: A616097E-FCF0-48FA-9354-CA4A61142761
+ required:
+ - key
+ SetupKey:
+ allOf:
+ - $ref: '#/components/schemas/SetupKeyBase'
+ - type: object
+ properties:
+ key:
+ description: Setup Key as secret
+ type: string
+ example: A6160****
+ required:
+ - key
SetupKeyRequest:
type: object
properties:
@@ -1918,7 +1936,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/SetupKey'
+ $ref: '#/components/schemas/SetupKeyClear'
'400':
"$ref": "#/components/responses/bad_request"
'401':
diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go
index f219c4574..321395d25 100644
--- a/management/server/http/api/types.gen.go
+++ b/management/server/http/api/types.gen.go
@@ -1062,7 +1062,94 @@ type SetupKey struct {
// Id Setup Key ID
Id string `json:"id"`
- // Key Setup Key value
+ // Key Setup Key as secret
+ Key string `json:"key"`
+
+ // LastUsed Setup key last usage date
+ LastUsed time.Time `json:"last_used"`
+
+ // Name Setup key name identifier
+ Name string `json:"name"`
+
+ // Revoked Setup key revocation status
+ Revoked bool `json:"revoked"`
+
+ // State Setup key status, "valid", "overused","expired" or "revoked"
+ State string `json:"state"`
+
+ // Type Setup key type, one-off for single time usage and reusable
+ Type string `json:"type"`
+
+ // UpdatedAt Setup key last update date
+ UpdatedAt time.Time `json:"updated_at"`
+
+ // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
+ UsageLimit int `json:"usage_limit"`
+
+ // UsedTimes Usage count of setup key
+ UsedTimes int `json:"used_times"`
+
+ // Valid Setup key validity status
+ Valid bool `json:"valid"`
+}
+
+// SetupKeyBase defines model for SetupKeyBase.
+type SetupKeyBase struct {
+ // AutoGroups List of group IDs to auto-assign to peers registered with this key
+ AutoGroups []string `json:"auto_groups"`
+
+ // Ephemeral Indicate that the peer will be ephemeral or not
+ Ephemeral bool `json:"ephemeral"`
+
+ // Expires Setup Key expiration date
+ Expires time.Time `json:"expires"`
+
+ // Id Setup Key ID
+ Id string `json:"id"`
+
+ // LastUsed Setup key last usage date
+ LastUsed time.Time `json:"last_used"`
+
+ // Name Setup key name identifier
+ Name string `json:"name"`
+
+ // Revoked Setup key revocation status
+ Revoked bool `json:"revoked"`
+
+ // State Setup key status, "valid", "overused","expired" or "revoked"
+ State string `json:"state"`
+
+ // Type Setup key type, one-off for single time usage and reusable
+ Type string `json:"type"`
+
+ // UpdatedAt Setup key last update date
+ UpdatedAt time.Time `json:"updated_at"`
+
+ // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage.
+ UsageLimit int `json:"usage_limit"`
+
+ // UsedTimes Usage count of setup key
+ UsedTimes int `json:"used_times"`
+
+ // Valid Setup key validity status
+ Valid bool `json:"valid"`
+}
+
+// SetupKeyClear defines model for SetupKeyClear.
+type SetupKeyClear struct {
+ // AutoGroups List of group IDs to auto-assign to peers registered with this key
+ AutoGroups []string `json:"auto_groups"`
+
+ // Ephemeral Indicate that the peer will be ephemeral or not
+ Ephemeral bool `json:"ephemeral"`
+
+ // Expires Setup Key expiration date
+ Expires time.Time `json:"expires"`
+
+ // Id Setup Key ID
+ Id string `json:"id"`
+
+ // Key Setup Key as plain text
Key string `json:"key"`
// LastUsed Setup key last usage date
diff --git a/management/server/setupkey.go b/management/server/setupkey.go
index cae0dfecb..ef431d3ad 100644
--- a/management/server/setupkey.go
+++ b/management/server/setupkey.go
@@ -379,7 +379,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, status.NewAdminPermissionError()
}
- setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
+ setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
if err != nil {
return nil, err
}
diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go
index 94ed022fa..7c8200706 100644
--- a/management/server/setupkey_test.go
+++ b/management/server/setupkey_test.go
@@ -210,22 +210,41 @@ func TestGetSetupKeys(t *testing.T) {
t.Fatal(err)
}
- err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
- ID: "group_1",
- Name: "group_name_1",
- Peers: []string{},
- })
+ plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false)
if err != nil {
t.Fatal(err)
}
- err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
- ID: "group_2",
- Name: "group_name_2",
- Peers: []string{},
- })
- if err != nil {
- t.Fatal(err)
+ type testCase struct {
+ name string
+ keyId string
+ expectedFailure bool
+ }
+
+ testCase1 := testCase{
+ name: "Should get existing Setup Key",
+ keyId: plainKey.Id,
+ expectedFailure: false,
+ }
+ testCase2 := testCase{
+ name: "Should fail to get non-existent Setup Key",
+ keyId: "some key",
+ expectedFailure: true,
+ }
+
+ for _, tCase := range []testCase{testCase1, testCase2} {
+ t.Run(tCase.name, func(t *testing.T) {
+ key, err := manager.GetSetupKey(context.Background(), account.Id, userID, tCase.keyId)
+
+ if tCase.expectedFailure {
+ if err == nil {
+ t.Fatal("expected to fail")
+ }
+ return
+ }
+
+ assert.NotEqual(t, plainKey.Key, key.Key)
+ })
}
}
From 2a5cb1649402d42f588d374f6a775c62e92a5522 Mon Sep 17 00:00:00 2001
From: Zoltan Papp
Date: Fri, 22 Nov 2024 18:12:34 +0100
Subject: [PATCH 30/39] [relay] Refactor initial Relay connection (#2800)
Can support firewalls with restricted WS rules
allow to run engine without Relay servers
keep up to date Relay address changes
---
client/internal/connect.go | 3 +-
client/internal/engine.go | 10 ++-
client/internal/peer/status.go | 20 ++---
client/internal/peer/worker_ice.go | 14 +--
relay/client/client.go | 6 +-
relay/client/client_test.go | 2 +-
relay/client/guard.go | 91 +++++++++++++++----
relay/client/manager.go | 140 +++++++++++++++++++----------
relay/client/picker.go | 16 ++--
relay/client/picker_test.go | 5 +-
10 files changed, 211 insertions(+), 96 deletions(-)
diff --git a/client/internal/connect.go b/client/internal/connect.go
index f76aa066b..8c2ad4aa1 100644
--- a/client/internal/connect.go
+++ b/client/internal/connect.go
@@ -232,6 +232,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
relayURLs, token := parseRelayInfo(loginResp)
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String())
+ c.statusRecorder.SetRelayMgr(relayManager)
if len(relayURLs) > 0 {
if token != nil {
if err := relayManager.UpdateToken(token); err != nil {
@@ -242,9 +243,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold
log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", "))
if err = relayManager.Serve(); err != nil {
log.Error(err)
- return wrapErr(err)
}
- c.statusRecorder.SetRelayMgr(relayManager)
}
peerConfig := loginResp.GetPeerConfig()
diff --git a/client/internal/engine.go b/client/internal/engine.go
index 1c912220c..dc4499e17 100644
--- a/client/internal/engine.go
+++ b/client/internal/engine.go
@@ -538,6 +538,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
relayMsg := wCfg.GetRelay()
if relayMsg != nil {
+ // when we receive token we expect valid address list too
c := &auth.Token{
Payload: relayMsg.GetTokenPayload(),
Signature: relayMsg.GetTokenSignature(),
@@ -546,9 +547,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
log.Errorf("failed to update relay token: %v", err)
return fmt.Errorf("update relay token: %w", err)
}
+
+ e.relayManager.UpdateServerURLs(relayMsg.Urls)
+
+ // Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
+ // We can ignore all errors because the guard will manage the reconnection retries.
+ _ = e.relayManager.Serve()
+ } else {
+ e.relayManager.UpdateServerURLs(nil)
}
- // todo update relay address in the relay manager
// todo update signal
}
diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go
index 0444dc60b..74e2ee82c 100644
--- a/client/internal/peer/status.go
+++ b/client/internal/peer/status.go
@@ -676,25 +676,23 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
// extend the list of stun, turn servers with relay address
relayStates := slices.Clone(d.relayStates)
- var relayState relay.ProbeResult
-
// if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
if err != nil {
// TODO add their status
- if errors.Is(err, relayClient.ErrRelayClientNotConnected) {
- for _, r := range d.relayMgr.ServerURLs() {
- relayStates = append(relayStates, relay.ProbeResult{
- URI: r,
- })
- }
- return relayStates
+ for _, r := range d.relayMgr.ServerURLs() {
+ relayStates = append(relayStates, relay.ProbeResult{
+ URI: r,
+ Err: err,
+ })
}
- relayState.Err = err
+ return relayStates
}
- relayState.URI = instanceAddr
+ relayState := relay.ProbeResult{
+ URI: instanceAddr,
+ }
return append(relayStates, relayState)
}
diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go
index 4c67cb781..7ce4797c3 100644
--- a/client/internal/peer/worker_ice.go
+++ b/client/internal/peer/worker_ice.go
@@ -46,8 +46,6 @@ type WorkerICE struct {
hasRelayOnLocally bool
conn WorkerICECallbacks
- selectedPriority ConnPriority
-
agent *ice.Agent
muxAgent sync.Mutex
@@ -95,10 +93,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
var preferredCandidateTypes []ice.CandidateType
if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" {
- w.selectedPriority = connPriorityICEP2P
preferredCandidateTypes = icemaker.CandidateTypesP2P()
} else {
- w.selectedPriority = connPriorityICETurn
preferredCandidateTypes = icemaker.CandidateTypes()
}
@@ -159,7 +155,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
RelayedOnLocal: isRelayCandidate(pair.Local),
}
w.log.Debugf("on ICE conn read to use ready")
- go w.conn.OnConnReady(w.selectedPriority, ci)
+ go w.conn.OnConnReady(selectedPriority(pair), ci)
}
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
@@ -394,3 +390,11 @@ func isRelayed(pair *ice.CandidatePair) bool {
}
return false
}
+
+func selectedPriority(pair *ice.CandidatePair) ConnPriority {
+ if isRelayed(pair) {
+ return connPriorityICETurn
+ } else {
+ return connPriorityICEP2P
+ }
+}
diff --git a/relay/client/client.go b/relay/client/client.go
index 154c1787f..db5252f50 100644
--- a/relay/client/client.go
+++ b/relay/client/client.go
@@ -140,7 +140,7 @@ type Client struct {
instanceURL *RelayAddr
muInstanceURL sync.Mutex
- onDisconnectListener func()
+ onDisconnectListener func(string)
onConnectedListener func()
listenerMutex sync.Mutex
}
@@ -233,7 +233,7 @@ func (c *Client) ServerInstanceURL() (string, error) {
}
// SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed.
-func (c *Client) SetOnDisconnectListener(fn func()) {
+func (c *Client) SetOnDisconnectListener(fn func(string)) {
c.listenerMutex.Lock()
defer c.listenerMutex.Unlock()
c.onDisconnectListener = fn
@@ -554,7 +554,7 @@ func (c *Client) notifyDisconnected() {
if c.onDisconnectListener == nil {
return
}
- go c.onDisconnectListener()
+ go c.onDisconnectListener(c.connectionURL)
}
func (c *Client) notifyConnected() {
diff --git a/relay/client/client_test.go b/relay/client/client_test.go
index ef28203e9..7ddfba4c6 100644
--- a/relay/client/client_test.go
+++ b/relay/client/client_test.go
@@ -551,7 +551,7 @@ func TestCloseByServer(t *testing.T) {
}
disconnected := make(chan struct{})
- relayClient.SetOnDisconnectListener(func() {
+ relayClient.SetOnDisconnectListener(func(_ string) {
log.Infof("client disconnected")
close(disconnected)
})
diff --git a/relay/client/guard.go b/relay/client/guard.go
index d6b6b0da5..b971363a8 100644
--- a/relay/client/guard.go
+++ b/relay/client/guard.go
@@ -4,65 +4,120 @@ import (
"context"
"time"
+ "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
)
var (
- reconnectingTimeout = 5 * time.Second
+ reconnectingTimeout = 60 * time.Second
)
// Guard manage the reconnection tries to the Relay server in case of disconnection event.
type Guard struct {
- ctx context.Context
- relayClient *Client
+ // OnNewRelayClient is a channel that is used to notify the relay client about a new relay client instance.
+ OnNewRelayClient chan *Client
+ serverPicker *ServerPicker
}
// NewGuard creates a new guard for the relay client.
-func NewGuard(context context.Context, relayClient *Client) *Guard {
+func NewGuard(sp *ServerPicker) *Guard {
g := &Guard{
- ctx: context,
- relayClient: relayClient,
+ OnNewRelayClient: make(chan *Client, 1),
+ serverPicker: sp,
}
return g
}
-// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection
+// StartReconnectTrys is called when the relay client is disconnected from the relay server.
+// It attempts to reconnect to the relay server. The function first tries a quick reconnect
+// to the same server that was used before, if the server URL is still valid. If the quick
+// reconnect fails, it starts a ticker to periodically attempt server picking until it
+// succeeds or the context is done.
+//
+// Parameters:
+// - ctx: The context to control the lifecycle of the reconnection attempts.
+// - relayClient: The relay client instance that was disconnected.
// todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent
-func (g *Guard) OnDisconnected() {
- if g.quickReconnect() {
+func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) {
+ if relayClient == nil {
+ goto RETRY
+ }
+ if g.isServerURLStillValid(relayClient) && g.quickReconnect(ctx, relayClient) {
return
}
- ticker := time.NewTicker(reconnectingTimeout)
+RETRY:
+ ticker := exponentTicker(ctx)
defer ticker.Stop()
for {
select {
case <-ticker.C:
- err := g.relayClient.Connect()
- if err != nil {
- log.Errorf("failed to reconnect to relay server: %s", err)
+ if err := g.retry(ctx); err != nil {
+ log.Errorf("failed to pick new Relay server: %s", err)
continue
}
return
- case <-g.ctx.Done():
+ case <-ctx.Done():
return
}
}
}
-func (g *Guard) quickReconnect() bool {
- ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond)
+func (g *Guard) retry(ctx context.Context) error {
+ log.Infof("try to pick up a new Relay server")
+ relayClient, err := g.serverPicker.PickServer(ctx)
+ if err != nil {
+ return err
+ }
+
+ // prevent to work with a deprecated Relay client instance
+ g.drainRelayClientChan()
+
+ g.OnNewRelayClient <- relayClient
+ return nil
+}
+
+func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool {
+ ctx, cancel := context.WithTimeout(parentCtx, 1500*time.Millisecond)
defer cancel()
<-ctx.Done()
- if g.ctx.Err() != nil {
+ if parentCtx.Err() != nil {
return false
}
+ log.Infof("try to reconnect to Relay server: %s", rc.connectionURL)
- if err := g.relayClient.Connect(); err != nil {
+ if err := rc.Connect(); err != nil {
log.Errorf("failed to reconnect to relay server: %s", err)
return false
}
return true
}
+
+func (g *Guard) drainRelayClientChan() {
+ select {
+ case <-g.OnNewRelayClient:
+ default:
+ }
+}
+
+func (g *Guard) isServerURLStillValid(rc *Client) bool {
+ for _, url := range g.serverPicker.ServerURLs.Load().([]string) {
+ if url == rc.connectionURL {
+ return true
+ }
+ }
+ return false
+}
+
+func exponentTicker(ctx context.Context) *backoff.Ticker {
+ bo := backoff.WithContext(&backoff.ExponentialBackOff{
+ InitialInterval: 2 * time.Second,
+ Multiplier: 2,
+ MaxInterval: reconnectingTimeout,
+ Clock: backoff.SystemClock,
+ }, ctx)
+
+ return backoff.NewTicker(bo)
+}
diff --git a/relay/client/manager.go b/relay/client/manager.go
index b14a7701b..d847bb879 100644
--- a/relay/client/manager.go
+++ b/relay/client/manager.go
@@ -57,12 +57,15 @@ type ManagerService interface {
// relay servers will be closed if there is no active connection. Periodically the manager will check if there is any
// unused relay connection and close it.
type Manager struct {
- ctx context.Context
- serverURLs []string
- peerID string
- tokenStore *relayAuth.TokenStore
+ ctx context.Context
+ peerID string
+ running bool
+ tokenStore *relayAuth.TokenStore
+ serverPicker *ServerPicker
- relayClient *Client
+ relayClient *Client
+ // the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable
+ relayClientMu sync.Mutex
reconnectGuard *Guard
relayClients map[string]*RelayTrack
@@ -76,48 +79,54 @@ type Manager struct {
// NewManager creates a new manager instance.
// The serverURL address can be empty. In this case, the manager will not serve.
func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager {
- return &Manager{
- ctx: ctx,
- serverURLs: serverURLs,
- peerID: peerID,
- tokenStore: &relayAuth.TokenStore{},
+ tokenStore := &relayAuth.TokenStore{}
+
+ m := &Manager{
+ ctx: ctx,
+ peerID: peerID,
+ tokenStore: tokenStore,
+ serverPicker: &ServerPicker{
+ TokenStore: tokenStore,
+ PeerID: peerID,
+ },
relayClients: make(map[string]*RelayTrack),
onDisconnectedListeners: make(map[string]*list.List),
}
+ m.serverPicker.ServerURLs.Store(serverURLs)
+ m.reconnectGuard = NewGuard(m.serverPicker)
+ return m
}
-// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for
-// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection.
+// Serve starts the manager, attempting to establish a connection with the relay server.
+// If the connection fails, it will keep trying to reconnect in the background.
+// Additionally, it starts a cleanup loop to remove unused relay connections.
+// The manager will automatically reconnect to the relay server in case of disconnection.
func (m *Manager) Serve() error {
- if m.relayClient != nil {
+ if m.running {
return fmt.Errorf("manager already serving")
}
- log.Debugf("starting relay client manager with %v relay servers", m.serverURLs)
+ m.running = true
+ log.Debugf("starting relay client manager with %v relay servers", m.serverPicker.ServerURLs.Load())
- sp := ServerPicker{
- TokenStore: m.tokenStore,
- PeerID: m.peerID,
- }
-
- client, err := sp.PickServer(m.ctx, m.serverURLs)
+ client, err := m.serverPicker.PickServer(m.ctx)
if err != nil {
- return err
+ go m.reconnectGuard.StartReconnectTrys(m.ctx, nil)
+ } else {
+ m.storeClient(client)
}
- m.relayClient = client
- m.reconnectGuard = NewGuard(m.ctx, m.relayClient)
- m.relayClient.SetOnConnectedListener(m.onServerConnected)
- m.relayClient.SetOnDisconnectListener(func() {
- m.onServerDisconnected(client.connectionURL)
- })
- m.startCleanupLoop()
- return nil
+ go m.listenGuardEvent(m.ctx)
+ go m.startCleanupLoop()
+ return err
}
// OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be
// established via the relay server. If the peer is on a different relay server, the manager will establish a new
// connection to the relay server. It returns back with a net.Conn what represent the remote peer connection.
func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
+ m.relayClientMu.Lock()
+ defer m.relayClientMu.Unlock()
+
if m.relayClient == nil {
return nil, ErrRelayClientNotConnected
}
@@ -146,6 +155,9 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) {
// Ready returns true if the home Relay client is connected to the relay server.
func (m *Manager) Ready() bool {
+ m.relayClientMu.Lock()
+ defer m.relayClientMu.Unlock()
+
if m.relayClient == nil {
return false
}
@@ -159,6 +171,13 @@ func (m *Manager) SetOnReconnectedListener(f func()) {
// AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection
// closed.
func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error {
+ m.relayClientMu.Lock()
+ defer m.relayClientMu.Unlock()
+
+ if m.relayClient == nil {
+ return ErrRelayClientNotConnected
+ }
+
foreign, err := m.isForeignServer(serverAddress)
if err != nil {
return err
@@ -177,6 +196,9 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ
// RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is
// lost. This address will be sent to the target peer to choose the common relay server for the communication.
func (m *Manager) RelayInstanceAddress() (string, error) {
+ m.relayClientMu.Lock()
+ defer m.relayClientMu.Unlock()
+
if m.relayClient == nil {
return "", ErrRelayClientNotConnected
}
@@ -185,13 +207,18 @@ func (m *Manager) RelayInstanceAddress() (string, error) {
// ServerURLs returns the addresses of the relay servers.
func (m *Manager) ServerURLs() []string {
- return m.serverURLs
+ return m.serverPicker.ServerURLs.Load().([]string)
}
// HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with
// Relay service.
func (m *Manager) HasRelayAddress() bool {
- return len(m.serverURLs) > 0
+ return len(m.serverPicker.ServerURLs.Load().([]string)) > 0
+}
+
+func (m *Manager) UpdateServerURLs(serverURLs []string) {
+ log.Infof("update relay server URLs: %v", serverURLs)
+ m.serverPicker.ServerURLs.Store(serverURLs)
}
// UpdateToken updates the token in the token store.
@@ -245,9 +272,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) {
return nil, err
}
// if connection closed then delete the relay client from the list
- relayClient.SetOnDisconnectListener(func() {
- m.onServerDisconnected(serverAddress)
- })
+ relayClient.SetOnDisconnectListener(m.onServerDisconnected)
rt.relayClient = relayClient
rt.Unlock()
@@ -265,14 +290,37 @@ func (m *Manager) onServerConnected() {
go m.onReconnectedListenerFn()
}
+// onServerDisconnected start to reconnection for home server only
func (m *Manager) onServerDisconnected(serverAddress string) {
+ m.relayClientMu.Lock()
if serverAddress == m.relayClient.connectionURL {
- go m.reconnectGuard.OnDisconnected()
+ go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient)
}
+ m.relayClientMu.Unlock()
m.notifyOnDisconnectListeners(serverAddress)
}
+func (m *Manager) listenGuardEvent(ctx context.Context) {
+ for {
+ select {
+ case rc := <-m.reconnectGuard.OnNewRelayClient:
+ m.storeClient(rc)
+ case <-ctx.Done():
+ return
+ }
+ }
+}
+
+func (m *Manager) storeClient(client *Client) {
+ m.relayClientMu.Lock()
+ defer m.relayClientMu.Unlock()
+
+ m.relayClient = client
+ m.relayClient.SetOnConnectedListener(m.onServerConnected)
+ m.relayClient.SetOnDisconnectListener(m.onServerDisconnected)
+}
+
func (m *Manager) isForeignServer(address string) (bool, error) {
rAddr, err := m.relayClient.ServerInstanceURL()
if err != nil {
@@ -282,22 +330,16 @@ func (m *Manager) isForeignServer(address string) (bool, error) {
}
func (m *Manager) startCleanupLoop() {
- if m.ctx.Err() != nil {
- return
- }
-
ticker := time.NewTicker(relayCleanupInterval)
- go func() {
- defer ticker.Stop()
- for {
- select {
- case <-m.ctx.Done():
- return
- case <-ticker.C:
- m.cleanUpUnusedRelays()
- }
+ defer ticker.Stop()
+ for {
+ select {
+ case <-m.ctx.Done():
+ return
+ case <-ticker.C:
+ m.cleanUpUnusedRelays()
}
- }()
+ }
}
func (m *Manager) cleanUpUnusedRelays() {
diff --git a/relay/client/picker.go b/relay/client/picker.go
index 13b0547aa..eb5062dbb 100644
--- a/relay/client/picker.go
+++ b/relay/client/picker.go
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
+ "sync/atomic"
"time"
log "github.com/sirupsen/logrus"
@@ -12,10 +13,13 @@ import (
)
const (
- connectionTimeout = 30 * time.Second
maxConcurrentServers = 7
)
+var (
+ connectionTimeout = 30 * time.Second
+)
+
type connResult struct {
RelayClient *Client
Url string
@@ -24,20 +28,22 @@ type connResult struct {
type ServerPicker struct {
TokenStore *auth.TokenStore
+ ServerURLs atomic.Value
PeerID string
}
-func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) {
+func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) {
ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout)
defer cancel()
- totalServers := len(urls)
+ totalServers := len(sp.ServerURLs.Load().([]string))
connResultChan := make(chan connResult, totalServers)
successChan := make(chan connResult, 1)
concurrentLimiter := make(chan struct{}, maxConcurrentServers)
- for _, url := range urls {
+ log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string))
+ for _, url := range sp.ServerURLs.Load().([]string) {
// todo check if we have a successful connection so we do not need to connect to other servers
concurrentLimiter <- struct{}{}
go func(url string) {
@@ -78,7 +84,7 @@ func (sp *ServerPicker) processConnResults(resultChan chan connResult, successCh
for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ {
cr := <-resultChan
if cr.Err != nil {
- log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
+ log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err)
continue
}
log.Infof("connected to Relay server: %s", cr.Url)
diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go
index 4800e05ba..20a03e64d 100644
--- a/relay/client/picker_test.go
+++ b/relay/client/picker_test.go
@@ -7,16 +7,19 @@ import (
)
func TestServerPicker_UnavailableServers(t *testing.T) {
+ connectionTimeout = 5 * time.Second
+
sp := ServerPicker{
TokenStore: nil,
PeerID: "test",
}
+ sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"})
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
defer cancel()
go func() {
- _, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"})
+ _, err := sp.PickServer(ctx)
if err == nil {
t.Error(err)
}
From 05c4aa7c2cad19f3679bf2548f8dc296dd1e043b Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Fri, 22 Nov 2024 18:50:47 +0100
Subject: [PATCH 31/39] [misc] Renew slack link (#2938)
---
README.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index a2d7f3897..e7925ae09 100644
--- a/README.md
+++ b/README.md
@@ -17,7 +17,7 @@
-
+
@@ -34,7 +34,7 @@
See Documentation
- Join our Slack channel
+ Join our Slack channel
From 56cecf849ea4f6092b0dc9b421126da399bffa95 Mon Sep 17 00:00:00 2001
From: Maycon Santos
Date: Fri, 22 Nov 2024 20:40:30 +0100
Subject: [PATCH 32/39] Import time package (#2940)
---
relay/client/picker_test.go | 1 +
1 file changed, 1 insertion(+)
diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go
index 20a03e64d..28167c5ce 100644
--- a/relay/client/picker_test.go
+++ b/relay/client/picker_test.go
@@ -4,6 +4,7 @@ import (
"context"
"errors"
"testing"
+ "time"
)
func TestServerPicker_UnavailableServers(t *testing.T) {
From 940d0c48c69198803b4cd88125214c5cb666bf2b Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Mon, 25 Nov 2024 15:11:31 +0100
Subject: [PATCH 33/39] [client] Don't return error in userspace mode without
firewall (#2924)
---
client/firewall/uspfilter/uspfilter.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go
index af5dc6733..fb726395b 100644
--- a/client/firewall/uspfilter/uspfilter.go
+++ b/client/firewall/uspfilter/uspfilter.go
@@ -239,7 +239,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
// SetLegacyManagement doesn't need to be implemented for this manager
func (m *Manager) SetLegacyManagement(isLegacy bool) error {
if m.nativeFirewall == nil {
- return errRouteNotSupported
+ return nil
}
return m.nativeFirewall.SetLegacyManagement(isLegacy)
}
From 0ecd5f211850d50dcf7179adbf3ae0246705f0a3 Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Mon, 25 Nov 2024 15:11:56 +0100
Subject: [PATCH 34/39] [client] Test nftables for incompatible iptables rules
(#2948)
---
.../firewall/nftables/manager_linux_test.go | 104 ++++++++++++++++++
1 file changed, 104 insertions(+)
diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go
index 77f4f0306..33fdc4b3d 100644
--- a/client/firewall/nftables/manager_linux_test.go
+++ b/client/firewall/nftables/manager_linux_test.go
@@ -1,9 +1,11 @@
package nftables
import (
+ "bytes"
"fmt"
"net"
"net/netip"
+ "os/exec"
"testing"
"time"
@@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) {
})
}
}
+
+func runIptablesSave(t *testing.T) (string, string) {
+ t.Helper()
+ var stdout, stderr bytes.Buffer
+ cmd := exec.Command("iptables-save")
+ cmd.Stdout = &stdout
+ cmd.Stderr = &stderr
+
+ err := cmd.Run()
+ require.NoError(t, err, "iptables-save failed to run")
+
+ return stdout.String(), stderr.String()
+}
+
+func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
+ t.Helper()
+ // Check for any incompatibility warnings
+ require.NotContains(t,
+ stderr,
+ "incompatible",
+ "iptables-save produced compatibility warning. Full stderr: %s",
+ stderr,
+ )
+
+ // Verify standard tables are present
+ expectedTables := []string{
+ "*filter",
+ "*nat",
+ "*mangle",
+ }
+
+ for _, table := range expectedTables {
+ require.Contains(t,
+ stdout,
+ table,
+ "iptables-save output missing expected table: %s\nFull stdout: %s",
+ table,
+ stdout,
+ )
+ }
+}
+
+func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
+ if check() != NFTABLES {
+ t.Skip("nftables not supported on this system")
+ }
+
+ if _, err := exec.LookPath("iptables-save"); err != nil {
+ t.Skipf("iptables-save not available on this system: %v", err)
+ }
+
+ // First ensure iptables-nft tables exist by running iptables-save
+ stdout, stderr := runIptablesSave(t)
+ verifyIptablesOutput(t, stdout, stderr)
+
+ manager, err := Create(ifaceMock)
+ require.NoError(t, err, "failed to create manager")
+ require.NoError(t, manager.Init(nil))
+
+ t.Cleanup(func() {
+ err := manager.Reset(nil)
+ require.NoError(t, err, "failed to reset manager state")
+
+ // Verify iptables output after reset
+ stdout, stderr := runIptablesSave(t)
+ verifyIptablesOutput(t, stdout, stderr)
+ })
+
+ ip := net.ParseIP("100.96.0.1")
+ _, err = manager.AddPeerFiltering(
+ ip,
+ fw.ProtocolTCP,
+ nil,
+ &fw.Port{Values: []int{80}},
+ fw.RuleDirectionIN,
+ fw.ActionAccept,
+ "",
+ "test rule",
+ )
+ require.NoError(t, err, "failed to add peer filtering rule")
+
+ _, err = manager.AddRouteFiltering(
+ []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
+ netip.MustParsePrefix("10.1.0.0/24"),
+ fw.ProtocolTCP,
+ nil,
+ &fw.Port{Values: []int{443}},
+ fw.ActionAccept,
+ )
+ require.NoError(t, err, "failed to add route filtering rule")
+
+ pair := fw.RouterPair{
+ Source: netip.MustParsePrefix("192.168.1.0/24"),
+ Destination: netip.MustParsePrefix("10.0.0.0/24"),
+ Masquerade: true,
+ }
+ err = manager.AddNatRule(pair)
+ require.NoError(t, err, "failed to add NAT rule")
+
+ stdout, stderr = runIptablesSave(t)
+ verifyIptablesOutput(t, stdout, stderr)
+}
From f1625b32bdd6c1fe0abb73756835947b99d1a97f Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Mon, 25 Nov 2024 15:12:16 +0100
Subject: [PATCH 35/39] [client] Set up sysctl and routing table name only if
routing rules are available (#2933)
---
.../routemanager/systemops/systemops_linux.go | 22 +++++++++----------
1 file changed, 11 insertions(+), 11 deletions(-)
diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go
index 71a0f26ae..ac4fd5c71 100644
--- a/client/internal/routemanager/systemops/systemops_linux.go
+++ b/client/internal/routemanager/systemops/systemops_linux.go
@@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
return r.setupRefCounter(initAddresses, stateManager)
}
- if err = addRoutingTableName(); err != nil {
- log.Errorf("Error adding routing table name: %v", err)
- }
-
- originalValues, err := sysctl.Setup(r.wgInterface)
- if err != nil {
- log.Errorf("Error setting up sysctl: %v", err)
- sysctlFailed = true
- }
- originalSysctl = originalValues
-
defer func() {
if err != nil {
if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil {
@@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager
}
}
+ if err = addRoutingTableName(); err != nil {
+ log.Errorf("Error adding routing table name: %v", err)
+ }
+
+ originalValues, err := sysctl.Setup(r.wgInterface)
+ if err != nil {
+ log.Errorf("Error setting up sysctl: %v", err)
+ sysctlFailed = true
+ }
+ originalSysctl = originalValues
+
return nil, nil, nil
}
From 9810386937edd1665109479f745e6faaf21db840 Mon Sep 17 00:00:00 2001
From: Viktor Liu <17948409+lixmal@users.noreply.github.com>
Date: Mon, 25 Nov 2024 15:19:56 +0100
Subject: [PATCH 36/39] [client] Allow routing to fallback to exclusion routes
if rules are not supported (#2909)
---
client/internal/routemanager/systemops/systemops_linux.go | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go
index ac4fd5c71..1d629d6e9 100644
--- a/client/internal/routemanager/systemops/systemops_linux.go
+++ b/client/internal/routemanager/systemops/systemops_linux.go
@@ -450,7 +450,7 @@ func addRule(params ruleParams) error {
rule.Invert = params.invert
rule.SuppressPrefixlen = params.suppressPrefix
- if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) {
+ if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) {
return fmt.Errorf("add routing rule: %w", err)
}
@@ -467,7 +467,7 @@ func removeRule(params ruleParams) error {
rule.Priority = params.priority
rule.SuppressPrefixlen = params.suppressPrefix
- if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) {
+ if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) {
return fmt.Errorf("remove routing rule: %w", err)
}
From ca12bc6953b8ba0e8645b90dc2ef8d50c8c38a01 Mon Sep 17 00:00:00 2001
From: Bethuel Mmbaga
Date: Mon, 25 Nov 2024 18:26:24 +0300
Subject: [PATCH 37/39] [management] Refactor posture check to use store
methods (#2874)
---
management/server/account.go | 2 +-
management/server/dns.go | 2 +-
management/server/group.go | 19 +-
.../server/http/posture_checks_handler.go | 3 +-
.../http/posture_checks_handler_test.go | 6 +-
management/server/mock_server/account_mock.go | 6 +-
management/server/nameserver.go | 10 +-
management/server/peer.go | 25 +-
management/server/policy.go | 6 +-
management/server/posture/checks.go | 6 -
management/server/posture_checks.go | 337 +++++++++++-------
management/server/posture_checks_test.go | 221 +++++++-----
management/server/route.go | 10 +-
management/server/sql_store.go | 51 ++-
management/server/sql_store_test.go | 135 +++++++
management/server/status/error.go | 5 +
management/server/store.go | 4 +-
management/server/testdata/extended-store.sql | 1 +
18 files changed, 589 insertions(+), 260 deletions(-)
diff --git a/management/server/account.go b/management/server/account.go
index 0ab123655..9fb56c855 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -139,7 +139,7 @@ type AccountManager interface {
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
- SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
+ SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManager() idp.Manager
diff --git a/management/server/dns.go b/management/server/dns.go
index 4551be5ab..e52be6016 100644
--- a/management/server/dns.go
+++ b/management/server/dns.go
@@ -145,7 +145,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
}
- if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) {
+ if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) {
am.updateAccountPeers(ctx, accountID)
}
diff --git a/management/server/group.go b/management/server/group.go
index a36213f04..7b307cf1a 100644
--- a/management/server/group.go
+++ b/management/server/group.go
@@ -566,8 +566,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI
return false, nil
}
-// anyGroupHasPeers checks if any of the given groups in the account have peers.
-func anyGroupHasPeers(account *Account, groupIDs []string) bool {
+func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
@@ -575,3 +574,19 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool {
}
return false
}
+
+// anyGroupHasPeers checks if any of the given groups in the account have peers.
+func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) {
+ groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs)
+ if err != nil {
+ return false, err
+ }
+
+ for _, group := range groups {
+ if group.HasPeers() {
+ return true, nil
+ }
+ }
+
+ return false, nil
+}
diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go
index 1d020e9bc..2c8204292 100644
--- a/management/server/http/posture_checks_handler.go
+++ b/management/server/http/posture_checks_handler.go
@@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.
return
}
- if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil {
+ postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks)
+ if err != nil {
util.WriteError(r.Context(), err, w)
return
}
diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go
index 02f0f0d83..f400cec81 100644
--- a/management/server/http/posture_checks_handler_test.go
+++ b/management/server/http/posture_checks_handler_test.go
@@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
}
return p, nil
},
- SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error {
+ SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil {
- return status.Errorf(status.InvalidArgument, err.Error()) //nolint
+ return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint
}
- return nil
+ return postureChecks, nil
},
DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error {
_, ok := testPostureChecks[postureChecksID]
diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go
index aa6a47b15..673ed33bb 100644
--- a/management/server/mock_server/account_mock.go
+++ b/management/server/mock_server/account_mock.go
@@ -96,7 +96,7 @@ type MockAccountManager struct {
HasConnectedChannelFunc func(peerID string) bool
GetExternalCacheManagerFunc func() server.ExternalCacheManager
GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
- SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
+ SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error)
DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManagerFunc func() idp.Manager
@@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p
}
// SavePostureChecks mocks SavePostureChecks of the AccountManager interface
-func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
+func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
if am.SavePostureChecksFunc != nil {
return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks)
}
- return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
+ return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented")
}
// DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface
diff --git a/management/server/nameserver.go b/management/server/nameserver.go
index 957008714..9119a3dec 100644
--- a/management/server/nameserver.go
+++ b/management/server/nameserver.go
@@ -70,7 +70,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
return nil, err
}
- if anyGroupHasPeers(account, newNSGroup.Groups) {
+ if am.anyGroupHasPeers(account, newNSGroup.Groups) {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
@@ -105,7 +105,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}
- if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
+ if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
@@ -135,7 +135,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err
}
- if anyGroupHasPeers(account, nsGroup.Groups) {
+ if am.anyGroupHasPeers(account, nsGroup.Groups) {
am.updateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
@@ -279,9 +279,9 @@ func validateDomain(domain string) error {
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
-func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
+func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false
}
- return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups)
+ return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups)
}
diff --git a/management/server/peer.go b/management/server/peer.go
index beb833dba..dcb47af3b 100644
--- a/management/server/peer.go
+++ b/management/server/peer.go
@@ -617,7 +617,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err
}
- postureChecks := am.getPeerPostureChecks(account, newPeer)
+ postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
@@ -702,7 +706,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
}
- postureChecks = am.getPeerPostureChecks(account, peer)
+
+ postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
+ if err != nil {
+ return nil, nil, nil, err
+ }
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
@@ -876,7 +884,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
if err != nil {
return nil, nil, nil, err
}
- postureChecks = am.getPeerPostureChecks(account, peer)
+
+ postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
+ if err != nil {
+ return nil, nil, nil, err
+ }
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
@@ -1030,7 +1042,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
defer wg.Done()
defer func() { <-semaphore }()
- postureChecks := am.getPeerPostureChecks(account, p)
+ postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID)
+ if err != nil {
+ log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err)
+ return
+ }
+
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
diff --git a/management/server/policy.go b/management/server/policy.go
index 8a5733f01..c7872591d 100644
--- a/management/server/policy.go
+++ b/management/server/policy.go
@@ -405,7 +405,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
- if anyGroupHasPeers(account, policy.ruleGroups()) {
+ if am.anyGroupHasPeers(account, policy.ruleGroups()) {
am.updateAccountPeers(ctx, accountID)
}
@@ -469,7 +469,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
if !policyToSave.Enabled && !oldPolicy.Enabled {
return false, nil
}
- updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups())
+ updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups())
return updateAccountPeers, nil
}
@@ -477,7 +477,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli
// Add the new policy to the account
account.Policies = append(account.Policies, policyToSave)
- return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
+ return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
}
func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go
index f2739dddf..b2f308d76 100644
--- a/management/server/posture/checks.go
+++ b/management/server/posture/checks.go
@@ -7,8 +7,6 @@ import (
"regexp"
"github.com/hashicorp/go-version"
- "github.com/rs/xid"
-
"github.com/netbirdio/netbird/management/server/http/api"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/status"
@@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh
}
func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) {
- if postureChecksID == "" {
- postureChecksID = xid.New().String()
- }
-
postureChecks := Checks{
ID: postureChecksID,
Name: name,
diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go
index 096cff3f5..59e726c41 100644
--- a/management/server/posture_checks.go
+++ b/management/server/posture_checks.go
@@ -2,16 +2,14 @@ package server
import (
"context"
+ "fmt"
"slices"
"github.com/netbirdio/netbird/management/server/activity"
- nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
-)
-
-const (
- errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks"
+ "github.com/rs/xid"
+ "golang.org/x/exp/maps"
)
func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) {
@@ -20,219 +18,284 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, err
}
- if !user.HasAdminPower() || user.AccountID != accountID {
- return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
- }
-
- return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
-}
-
-func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
- if err != nil {
- return err
- }
-
- user, err := account.FindUser(userID)
- if err != nil {
- return err
+ if user.AccountID != accountID {
+ return nil, status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
- return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
+ return nil, status.NewAdminPermissionError()
}
- if err := postureChecks.Validate(); err != nil {
- return status.Errorf(status.InvalidArgument, err.Error()) //nolint
+ return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
+}
+
+// SavePostureChecks saves a posture check.
+func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) {
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ defer unlock()
+
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
+ if err != nil {
+ return nil, err
}
- exists, uniqName := am.savePostureChecks(account, postureChecks)
-
- // we do not allow create new posture checks with non uniq name
- if !exists && !uniqName {
- return status.Errorf(status.PreconditionFailed, "Posture check name should be unique")
+ if user.AccountID != accountID {
+ return nil, status.NewUserNotPartOfAccountError()
}
- action := activity.PostureCheckCreated
- if exists {
- action = activity.PostureCheckUpdated
- account.Network.IncSerial()
+ if !user.HasAdminPower() {
+ return nil, status.NewAdminPermissionError()
}
- if err = am.Store.SaveAccount(ctx, account); err != nil {
- return err
+ var updateAccountPeers bool
+ var isUpdate = postureChecks.ID != ""
+ var action = activity.PostureCheckCreated
+
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
+ return err
+ }
+
+ if isUpdate {
+ updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
+ if err != nil {
+ return err
+ }
+
+ if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
+ return err
+ }
+
+ action = activity.PostureCheckUpdated
+ }
+
+ postureChecks.AccountID = accountID
+ return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks)
+ })
+ if err != nil {
+ return nil, err
}
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
- if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) {
+ if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
- return nil
+ return postureChecks, nil
}
+// DeletePostureChecks deletes a posture check by ID.
func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return err
}
- user, err := account.FindUser(userID)
- if err != nil {
- return err
+ if user.AccountID != accountID {
+ return status.NewUserNotPartOfAccountError()
}
if !user.HasAdminPower() {
- return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
+ return status.NewAdminPermissionError()
}
- postureChecks, err := am.deletePostureChecks(account, postureChecksID)
+ var postureChecks *posture.Checks
+
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
+ if err != nil {
+ return err
+ }
+
+ if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil {
+ return err
+ }
+
+ if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
+ return err
+ }
+
+ return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID)
+ })
if err != nil {
return err
}
- if err = am.Store.SaveAccount(ctx, account); err != nil {
- return err
- }
-
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta())
return nil
}
+// ListPostureChecks returns a list of posture checks.
func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- if !user.HasAdminPower() || user.AccountID != accountID {
- return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
+ if user.AccountID != accountID {
+ return nil, status.NewUserNotPartOfAccountError()
+ }
+
+ if !user.HasAdminPower() {
+ return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
}
-func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) {
- uniqName = true
- for i, p := range account.PostureChecks {
- if !exists && p.ID == postureChecks.ID {
- account.PostureChecks[i] = postureChecks
- exists = true
- }
- if p.Name == postureChecks.Name {
- uniqName = false
- }
- }
- if !exists {
- account.PostureChecks = append(account.PostureChecks, postureChecks)
- }
- return
-}
-
-func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) {
- postureChecksIdx := -1
- for i, postureChecks := range account.PostureChecks {
- if postureChecks.ID == postureChecksID {
- postureChecksIdx = i
- break
- }
- }
- if postureChecksIdx < 0 {
- return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID)
- }
-
- // Check if posture check is linked to any policy
- if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked {
- return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name)
- }
-
- postureChecks := account.PostureChecks[postureChecksIdx]
- account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...)
-
- return postureChecks, nil
-}
-
// getPeerPostureChecks returns the posture checks applied for a given peer.
-func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks {
- peerPostureChecks := make(map[string]posture.Checks)
+func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) {
+ peerPostureChecks := make(map[string]*posture.Checks)
- if len(account.PostureChecks) == 0 {
+ err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return err
+ }
+
+ if len(postureChecks) == 0 {
+ return nil
+ }
+
+ policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return err
+ }
+
+ for _, policy := range policies {
+ if !policy.Enabled {
+ continue
+ }
+
+ if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ return maps.Values(peerPostureChecks), nil
+}
+
+// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
+func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) {
+ policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return false, err
+ }
+
+ for _, policy := range policies {
+ if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
+ hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups())
+ if err != nil {
+ return false, err
+ }
+
+ if hasPeers {
+ return true, nil
+ }
+ }
+ }
+
+ return false, nil
+}
+
+// validatePostureChecks validates the posture checks.
+func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error {
+ if err := postureChecks.Validate(); err != nil {
+ return status.Errorf(status.InvalidArgument, err.Error()) //nolint
+ }
+
+ // If the posture check already has an ID, verify its existence in the store.
+ if postureChecks.ID != "" {
+ if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
+ return err
+ }
return nil
}
- for _, policy := range account.Policies {
- if !policy.Enabled {
- continue
- }
+ // For new posture checks, ensure no duplicates by name.
+ checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return err
+ }
- if isPeerInPolicySourceGroups(peer.ID, account, policy) {
- addPolicyPostureChecks(account, policy, peerPostureChecks)
+ for _, check := range checks {
+ if check.Name == postureChecks.Name && check.ID != postureChecks.ID {
+ return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name)
}
}
- postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks))
- for _, check := range peerPostureChecks {
- checkCopy := check
- postureChecksList = append(postureChecksList, &checkCopy)
+ postureChecks.ID = xid.New().String()
+
+ return nil
+}
+
+// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups.
+func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error {
+ isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy)
+ if err != nil {
+ return err
}
- return postureChecksList
+ if !isInGroup {
+ return nil
+ }
+
+ for _, sourcePostureCheckID := range policy.SourcePostureChecks {
+ postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID)
+ if err != nil {
+ return err
+ }
+ peerPostureChecks[sourcePostureCheckID] = postureCheck
+ }
+
+ return nil
}
// isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups.
-func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool {
+func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
for _, sourceGroup := range rule.Sources {
- group, ok := account.Groups[sourceGroup]
- if ok && slices.Contains(group.Peers, peerID) {
- return true
+ group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup)
+ if err != nil {
+ return false, fmt.Errorf("failed to check peer in policy source group: %w", err)
+ }
+
+ if slices.Contains(group.Peers, peerID) {
+ return true, nil
}
}
}
- return false
-}
-
-func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) {
- for _, sourcePostureCheckID := range policy.SourcePostureChecks {
- for _, postureCheck := range account.PostureChecks {
- if postureCheck.ID == sourcePostureCheckID {
- peerPostureChecks[sourcePostureCheckID] = *postureCheck
- }
- }
- }
-}
-
-func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) {
- for _, policy := range account.Policies {
- if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
- return true, policy
- }
- }
return false, nil
}
-// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers.
-func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool {
- if !exists {
- return false
+// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
+func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error {
+ policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
+ if err != nil {
+ return err
}
- isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID)
- if !isLinked {
- return false
+ for _, policy := range policies {
+ if slices.Contains(policy.SourcePostureChecks, postureChecksID) {
+ return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name)
+ }
}
- return anyGroupHasPeers(account, linkedPolicy.ruleGroups())
+
+ return nil
}
diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go
index c63538b9d..3c5c5fc79 100644
--- a/management/server/posture_checks_test.go
+++ b/management/server/posture_checks_test.go
@@ -7,6 +7,7 @@ import (
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/group"
@@ -16,7 +17,6 @@ import (
const (
adminUserID = "adminUserID"
regularUserID = "regularUserID"
- postureCheckID = "existing-id"
postureCheckName = "Existing check"
)
@@ -33,7 +33,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
t.Run("Generic posture check flow", func(t *testing.T) {
// regular users can not create checks
- err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
+ _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{})
assert.Error(t, err)
// regular users cannot list check
@@ -41,8 +41,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Error(t, err)
// should be possible to create posture check with uniq name
- err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
- ID: postureCheckID,
+ postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
Name: postureCheckName,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
@@ -58,8 +57,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Len(t, checks, 1)
// should not be possible to create posture check with non uniq name
- err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
- ID: "new-id",
+ _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
Name: postureCheckName,
Checks: posture.ChecksDefinition{
GeoLocationCheck: &posture.GeoLocationCheck{
@@ -74,23 +72,20 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) {
assert.Error(t, err)
// admins can update posture checks
- err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{
- ID: postureCheckID,
- Name: postureCheckName,
- Checks: posture.ChecksDefinition{
- NBVersionCheck: &posture.NBVersionCheck{
- MinVersion: "0.27.0",
- },
+ postureCheck.Checks = posture.ChecksDefinition{
+ NBVersionCheck: &posture.NBVersionCheck{
+ MinVersion: "0.27.0",
},
- })
+ }
+ _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck)
assert.NoError(t, err)
// users should not be able to delete posture checks
- err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID)
+ err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID)
assert.Error(t, err)
// admin should be able to delete posture checks
- err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID)
+ err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID)
assert.NoError(t, err)
checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID)
assert.NoError(t, err)
@@ -150,9 +145,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
- postureCheck := posture.Checks{
- ID: "postureCheck",
- Name: "postureCheck",
+ postureCheckA := &posture.Checks{
+ Name: "postureCheckA",
+ AccountID: account.Id,
+ Checks: posture.ChecksDefinition{
+ ProcessCheck: &posture.ProcessCheck{
+ Processes: []posture.Process{
+ {LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"},
+ },
+ },
+ },
+ }
+ postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA)
+ require.NoError(t, err)
+
+ postureCheckB := &posture.Checks{
+ Name: "postureCheckB",
AccountID: account.Id,
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
@@ -169,7 +177,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
- err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
+ postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@@ -187,12 +195,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
- postureCheck.Checks = posture.ChecksDefinition{
+ postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
- err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
+ _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@@ -215,7 +223,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
- SourcePostureChecks: []string{postureCheck.ID},
+ SourcePostureChecks: []string{postureCheckB.ID},
}
// Linking posture check to policy should trigger update account peers and send peer update
@@ -238,7 +246,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked posture checks should update account peers and send peer update
t.Run("updating linked to posture check with peers", func(t *testing.T) {
- postureCheck.Checks = posture.ChecksDefinition{
+ postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
@@ -255,7 +263,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
- err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
+ _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@@ -293,7 +301,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
- err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID)
+ err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID)
assert.NoError(t, err)
select {
@@ -303,7 +311,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}
})
- err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
+ _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
@@ -321,7 +329,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
- SourcePostureChecks: []string{postureCheck.ID},
+ SourcePostureChecks: []string{postureCheckB.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
assert.NoError(t, err)
@@ -332,12 +340,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
- postureCheck.Checks = posture.ChecksDefinition{
+ postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
- err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
+ _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@@ -367,7 +375,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
- SourcePostureChecks: []string{postureCheck.ID},
+ SourcePostureChecks: []string{postureCheckB.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
@@ -379,12 +387,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
- postureCheck.Checks = posture.ChecksDefinition{
+ postureCheckB.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
- err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
+ _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@@ -409,7 +417,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
- SourcePostureChecks: []string{postureCheck.ID},
+ SourcePostureChecks: []string{postureCheckB.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err)
@@ -420,7 +428,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
- postureCheck.Checks = posture.ChecksDefinition{
+ postureCheckB.Checks = posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{
@@ -429,7 +437,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
}
- err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
+ _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB)
assert.NoError(t, err)
select {
@@ -440,80 +448,123 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
})
}
-func TestArePostureCheckChangesAffectingPeers(t *testing.T) {
- account := &Account{
- Policies: []*Policy{
- {
- ID: "policyA",
- Rules: []*PolicyRule{
- {
- Enabled: true,
- Sources: []string{"groupA"},
- Destinations: []string{"groupA"},
- },
- },
- SourcePostureChecks: []string{"checkA"},
- },
- },
- Groups: map[string]*group.Group{
- "groupA": {
- ID: "groupA",
- Peers: []string{"peer1"},
- },
- "groupB": {
- ID: "groupB",
- Peers: []string{},
- },
- },
- PostureChecks: []*posture.Checks{
- {
- ID: "checkA",
- },
- {
- ID: "checkB",
- },
- },
+func TestArePostureCheckChangesAffectPeers(t *testing.T) {
+ manager, err := createManager(t)
+ require.NoError(t, err, "failed to create account manager")
+
+ account, err := initTestPostureChecksAccount(manager)
+ require.NoError(t, err, "failed to init testing account")
+
+ groupA := &group.Group{
+ ID: "groupA",
+ AccountID: account.Id,
+ Peers: []string{"peer1"},
}
+ groupB := &group.Group{
+ ID: "groupB",
+ AccountID: account.Id,
+ Peers: []string{},
+ }
+ err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB})
+ require.NoError(t, err, "failed to save groups")
+
+ postureCheckA := &posture.Checks{
+ Name: "checkA",
+ AccountID: account.Id,
+ Checks: posture.ChecksDefinition{
+ NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
+ },
+ }
+ postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA)
+ require.NoError(t, err, "failed to save postureCheckA")
+
+ postureCheckB := &posture.Checks{
+ Name: "checkB",
+ AccountID: account.Id,
+ Checks: posture.ChecksDefinition{
+ NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"},
+ },
+ }
+ postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB)
+ require.NoError(t, err, "failed to save postureCheckB")
+
+ policy := &Policy{
+ ID: "policyA",
+ AccountID: account.Id,
+ Rules: []*PolicyRule{
+ {
+ ID: "ruleA",
+ PolicyID: "policyA",
+ Enabled: true,
+ Sources: []string{"groupA"},
+ Destinations: []string{"groupA"},
+ },
+ },
+ SourcePostureChecks: []string{postureCheckA.ID},
+ }
+
+ err = manager.SavePolicy(context.Background(), account.Id, userID, policy, false)
+ require.NoError(t, err, "failed to save policy")
+
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
- result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
+ result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
+ require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
- result := arePostureCheckChangesAffectingPeers(account, "checkB", true)
+ result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
+ require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check does not exist", func(t *testing.T) {
- result := arePostureCheckChangesAffectingPeers(account, "unknown", false)
+ result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown")
+ require.NoError(t, err)
assert.False(t, result)
})
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
- account.Policies[0].Rules[0].Sources = []string{"groupB"}
- account.Policies[0].Rules[0].Destinations = []string{"groupA"}
- result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
+ policy.Rules[0].Sources = []string{"groupB"}
+ policy.Rules[0].Destinations = []string{"groupA"}
+ err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
+ require.NoError(t, err, "failed to update policy")
+
+ result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
+ require.NoError(t, err)
assert.True(t, result)
})
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
- account.Policies[0].Rules[0].Sources = []string{"groupA"}
- account.Policies[0].Rules[0].Destinations = []string{"groupB"}
- result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
+ policy.Rules[0].Sources = []string{"groupA"}
+ policy.Rules[0].Destinations = []string{"groupB"}
+ err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
+ require.NoError(t, err, "failed to update policy")
+
+ result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
+ require.NoError(t, err)
assert.True(t, result)
})
- t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
- account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"}
- account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"}
- result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
+ t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
+ groupA.Peers = []string{}
+ err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA)
+ require.NoError(t, err, "failed to save groups")
+
+ result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
+ require.NoError(t, err)
assert.False(t, result)
})
- t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
- account.Groups["groupA"].Peers = []string{}
- result := arePostureCheckChangesAffectingPeers(account, "checkA", true)
+ t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
+ policy.Rules[0].Sources = []string{"nonExistentGroup"}
+ policy.Rules[0].Destinations = []string{"nonExistentGroup"}
+ err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
+ require.NoError(t, err, "failed to update policy")
+
+ result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
+ require.NoError(t, err)
assert.False(t, result)
})
}
diff --git a/management/server/route.go b/management/server/route.go
index dcf2cb0d3..ecb562645 100644
--- a/management/server/route.go
+++ b/management/server/route.go
@@ -237,7 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, err
}
- if isRouteChangeAffectPeers(account, &newRoute) {
+ if am.isRouteChangeAffectPeers(account, &newRoute) {
am.updateAccountPeers(ctx, accountID)
}
@@ -323,7 +323,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
- if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) {
+ if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
am.updateAccountPeers(ctx, accountID)
}
@@ -355,7 +355,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
- if isRouteChangeAffectPeers(account, routy) {
+ if am.isRouteChangeAffectPeers(account, routy) {
am.updateAccountPeers(ctx, accountID)
}
@@ -651,6 +651,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo {
// isRouteChangeAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers
-func isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
- return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
+func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool {
+ return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
}
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index 278f5443d..47c17bb92 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -1305,12 +1305,57 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
- return getRecords[*posture.Checks](s.db, lockStrength, accountID)
+ var postureChecks []*posture.Checks
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID)
+ if result.Error != nil {
+ log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "failed to get posture checks from store")
+ }
+
+ return postureChecks, nil
}
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
-func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
- return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID)
+func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) {
+ var postureCheck *posture.Checks
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
+ First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID)
+ if result.Error != nil {
+ if errors.Is(result.Error, gorm.ErrRecordNotFound) {
+ return nil, status.NewPostureChecksNotFoundError(postureChecksID)
+ }
+ log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "failed to get posture check from store")
+ }
+
+ return postureCheck, nil
+}
+
+// SavePostureChecks saves a posture checks to the database.
+func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
+ if result.Error != nil {
+ log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error)
+ return status.Errorf(status.Internal, "failed to save posture checks to store")
+ }
+
+ return nil
+}
+
+// DeletePostureChecks deletes a posture checks from the database.
+func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error {
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
+ Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
+ if result.Error != nil {
+ log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error)
+ return status.Errorf(status.Internal, "failed to delete posture checks from store")
+ }
+
+ if result.RowsAffected == 0 {
+ return status.NewPostureChecksNotFoundError(postureChecksID)
+ }
+
+ return nil
}
// GetAccountRoutes retrieves network routes for an account.
diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go
index 114da1ee6..de939e8d0 100644
--- a/management/server/sql_store_test.go
+++ b/management/server/sql_store_test.go
@@ -16,6 +16,7 @@ import (
"github.com/google/uuid"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
+ "github.com/netbirdio/netbird/management/server/posture"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -1564,3 +1565,137 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) {
})
}
}
+
+func TestSqlStore_GetPostureChecksByID(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ tests := []struct {
+ name string
+ postureChecksID string
+ expectError bool
+ }{
+ {
+ name: "retrieve existing posture checks",
+ postureChecksID: "csplshq7qv948l48f7t0",
+ expectError: false,
+ },
+ {
+ name: "retrieve non-existing posture checks",
+ postureChecksID: "non-existing",
+ expectError: true,
+ },
+ {
+ name: "retrieve with empty posture checks ID",
+ postureChecksID: "",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID)
+ if tt.expectError {
+ require.Error(t, err)
+ sErr, ok := status.FromError(err)
+ require.True(t, ok)
+ require.Equal(t, sErr.Type(), status.NotFound)
+ require.Nil(t, postureChecks)
+ } else {
+ require.NoError(t, err)
+ require.NotNil(t, postureChecks)
+ require.Equal(t, tt.postureChecksID, postureChecks.ID)
+ }
+ })
+ }
+}
+
+func TestSqlStore_SavePostureChecks(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ postureChecks := &posture.Checks{
+ ID: "posture-checks-id",
+ AccountID: accountID,
+ Checks: posture.ChecksDefinition{
+ NBVersionCheck: &posture.NBVersionCheck{
+ MinVersion: "0.31.0",
+ },
+ OSVersionCheck: &posture.OSVersionCheck{
+ Ios: &posture.MinVersionCheck{
+ MinVersion: "13.0.1",
+ },
+ Linux: &posture.MinKernelVersionCheck{
+ MinKernelVersion: "5.3.3-dev",
+ },
+ },
+ GeoLocationCheck: &posture.GeoLocationCheck{
+ Locations: []posture.Location{
+ {
+ CountryCode: "DE",
+ CityName: "Berlin",
+ },
+ },
+ Action: posture.CheckActionAllow,
+ },
+ },
+ }
+ err = store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureChecks)
+ require.NoError(t, err)
+
+ savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, "posture-checks-id")
+ require.NoError(t, err)
+ require.Equal(t, savePostureChecks, postureChecks)
+}
+
+func TestSqlStore_DeletePostureChecks(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ tests := []struct {
+ name string
+ postureChecksID string
+ expectError bool
+ }{
+ {
+ name: "delete existing posture checks",
+ postureChecksID: "csplshq7qv948l48f7t0",
+ expectError: false,
+ },
+ {
+ name: "delete non-existing posture checks",
+ postureChecksID: "non-existing-posture-checks-id",
+ expectError: true,
+ },
+ {
+ name: "delete with empty posture checks ID",
+ postureChecksID: "",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err = store.DeletePostureChecks(context.Background(), LockingStrengthUpdate, accountID, tt.postureChecksID)
+ if tt.expectError {
+ require.Error(t, err)
+ sErr, ok := status.FromError(err)
+ require.True(t, ok)
+ require.Equal(t, sErr.Type(), status.NotFound)
+ } else {
+ require.NoError(t, err)
+ group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID)
+ require.Error(t, err)
+ require.Nil(t, group)
+ }
+ })
+ }
+}
diff --git a/management/server/status/error.go b/management/server/status/error.go
index 8b6d0077b..44391e1f1 100644
--- a/management/server/status/error.go
+++ b/management/server/status/error.go
@@ -139,3 +139,8 @@ func NewGetAccountError(err error) error {
func NewGroupNotFoundError(groupID string) error {
return Errorf(NotFound, "group: %s not found", groupID)
}
+
+// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks
+func NewPostureChecksNotFoundError(postureChecksID string) error {
+ return Errorf(NotFound, "posture checks: %s not found", postureChecksID)
+}
diff --git a/management/server/store.go b/management/server/store.go
index 71b0d457b..03b5821e7 100644
--- a/management/server/store.go
+++ b/management/server/store.go
@@ -84,7 +84,9 @@ type Store interface {
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
- GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
+ GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
+ SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
+ DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql
index b522741e7..1646ff4da 100644
--- a/management/server/testdata/extended-store.sql
+++ b/management/server/testdata/extended-store.sql
@@ -34,4 +34,5 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003'
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,'');
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
+INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
INSERT INTO installations VALUES(1,'');
From f118d81d3219b78b689206523e4159fcb495fa12 Mon Sep 17 00:00:00 2001
From: Bethuel Mmbaga
Date: Tue, 26 Nov 2024 12:46:05 +0300
Subject: [PATCH 38/39] [management] Refactor policy to use store methods
(#2878)
---
management/server/account.go | 2 +-
management/server/account_test.go | 57 ++--
management/server/group_test.go | 5 +-
management/server/http/policies_handler.go | 26 +-
.../server/http/policies_handler_test.go | 4 +-
management/server/mock_server/account_mock.go | 8 +-
management/server/peer_test.go | 31 +-
management/server/policy.go | 319 +++++++++++-------
management/server/policy_test.go | 165 +++------
management/server/posture_checks_test.go | 43 +--
management/server/route_test.go | 3 +-
management/server/setupkey_test.go | 5 +-
management/server/sql_store.go | 82 ++++-
management/server/sql_store_test.go | 158 +++++++++
management/server/status/error.go | 5 +
management/server/store.go | 6 +-
management/server/testdata/extended-store.sql | 1 +
management/server/user_test.go | 5 +-
18 files changed, 576 insertions(+), 349 deletions(-)
diff --git a/management/server/account.go b/management/server/account.go
index 9fb56c855..fbe6fcc1a 100644
--- a/management/server/account.go
+++ b/management/server/account.go
@@ -113,7 +113,7 @@ type AccountManager interface {
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
- SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error
+ SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error)
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error)
GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error)
diff --git a/management/server/account_test.go b/management/server/account_test.go
index 97e0d45f0..c8c2d5941 100644
--- a/management/server/account_test.go
+++ b/management/server/account_test.go
@@ -1238,8 +1238,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
return
}
- policy := Policy{
- ID: "policy",
+ _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -1250,8 +1249,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
- }
- err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
+ })
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -1320,19 +1318,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
- policy := Policy{
- Enabled: true,
- Rules: []*PolicyRule{
- {
- Enabled: true,
- Sources: []string{"groupA"},
- Destinations: []string{"groupA"},
- Bidirectional: true,
- Action: PolicyTrafficActionAccept,
- },
- },
- }
-
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
@@ -1345,7 +1330,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
}
}()
- if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
+ _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
+ Enabled: true,
+ Rules: []*PolicyRule{
+ {
+ Enabled: true,
+ Sources: []string{"groupA"},
+ Destinations: []string{"groupA"},
+ Bidirectional: true,
+ Action: PolicyTrafficActionAccept,
+ },
+ },
+ })
+ if err != nil {
t.Errorf("delete default rule: %v", err)
return
}
@@ -1366,7 +1363,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
return
}
- policy := Policy{
+ _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -1377,9 +1374,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
- }
-
- if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
+ })
+ if err != nil {
t.Errorf("save policy: %v", err)
return
}
@@ -1421,7 +1417,12 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
require.NoError(t, err, "failed to save group")
- policy := Policy{
+ if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
+ t.Errorf("delete default rule: %v", err)
+ return
+ }
+
+ policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -1432,14 +1433,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
- }
-
- if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
- t.Errorf("delete default rule: %v", err)
- return
- }
-
- if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil {
+ })
+ if err != nil {
t.Errorf("save policy: %v", err)
return
}
diff --git a/management/server/group_test.go b/management/server/group_test.go
index 59094a23e..ec017fc57 100644
--- a/management/server/group_test.go
+++ b/management/server/group_test.go
@@ -500,8 +500,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
})
// adding a group to policy
- err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
- ID: "policy",
+ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -512,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
- }, false)
+ })
assert.NoError(t, err)
// Saving a group linked to policy should update account peers and send peer update
diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go
index 73f3803b5..eff9092d4 100644
--- a/management/server/http/policies_handler.go
+++ b/management/server/http/policies_handler.go
@@ -6,10 +6,8 @@ import (
"strconv"
"github.com/gorilla/mux"
- nbgroup "github.com/netbirdio/netbird/management/server/group"
- "github.com/rs/xid"
-
"github.com/netbirdio/netbird/management/server"
+ nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -122,21 +120,22 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return
}
- isUpdate := policyID != ""
-
- if policyID == "" {
- policyID = xid.New().String()
- }
-
- policy := server.Policy{
+ policy := &server.Policy{
ID: policyID,
+ AccountID: accountID,
Name: req.Name,
Enabled: req.Enabled,
Description: req.Description,
}
for _, rule := range req.Rules {
+ var ruleID string
+ if rule.Id != nil {
+ ruleID = *rule.Id
+ }
+
pr := server.PolicyRule{
- ID: policyID, // TODO: when policy can contain multiple rules, need refactor
+ ID: ruleID,
+ PolicyID: policyID,
Name: rule.Name,
Destinations: rule.Destinations,
Sources: rule.Sources,
@@ -225,7 +224,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
policy.SourcePostureChecks = *req.SourcePostureChecks
}
- if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
+ policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy)
+ if err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return
}
- resp := toPolicyResponse(allGroups, &policy)
+ resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return
diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go
index 228ebcbce..f8a897eb2 100644
--- a/management/server/http/policies_handler_test.go
+++ b/management/server/http/policies_handler_test.go
@@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
}
return policy, nil
},
- SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
+ SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set"
}
- return nil
+ return policy, nil
},
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil
diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go
index 673ed33bb..46a4fbc1f 100644
--- a/management/server/mock_server/account_mock.go
+++ b/management/server/mock_server/account_mock.go
@@ -49,7 +49,7 @@ type MockAccountManager struct {
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
- SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error
+ SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error)
DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error
ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error)
GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error)
@@ -386,11 +386,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID
}
// SavePolicy mock implementation of SavePolicy from server.AccountManager interface
-func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error {
+func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) {
if am.SavePolicyFunc != nil {
- return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate)
+ return am.SavePolicyFunc(ctx, accountID, userID, policy)
}
- return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
+ return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented")
}
// DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface
diff --git a/management/server/peer_test.go b/management/server/peer_test.go
index 4e2dcb2c3..e410fa892 100644
--- a/management/server/peer_test.go
+++ b/management/server/peer_test.go
@@ -283,14 +283,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
var (
group1 nbgroup.Group
group2 nbgroup.Group
- policy Policy
)
group1.ID = xid.New().String()
group2.ID = xid.New().String()
group1.Name = "src"
group2.Name = "dst"
- policy.ID = xid.New().String()
group1.Peers = append(group1.Peers, peer1.ID)
group2.Peers = append(group2.Peers, peer2.ID)
@@ -305,18 +303,20 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
return
}
- policy.Name = "test"
- policy.Enabled = true
- policy.Rules = []*PolicyRule{
- {
- Enabled: true,
- Sources: []string{group1.ID},
- Destinations: []string{group2.ID},
- Bidirectional: true,
- Action: PolicyTrafficActionAccept,
+ policy := &Policy{
+ Name: "test",
+ Enabled: true,
+ Rules: []*PolicyRule{
+ {
+ Enabled: true,
+ Sources: []string{group1.ID},
+ Destinations: []string{group2.ID},
+ Bidirectional: true,
+ Action: PolicyTrafficActionAccept,
+ },
},
}
- err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
+ policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
@@ -364,7 +364,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) {
}
policy.Enabled = false
- err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
+ _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
if err != nil {
t.Errorf("expecting rule to be added, got failure %v", err)
return
@@ -1445,8 +1445,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
// Adding peer to group linked with policy should update account peers and send peer update
t.Run("adding peer to group linked with policy", func(t *testing.T) {
- err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
- ID: "policy",
+ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -1457,7 +1456,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
- }, false)
+ })
require.NoError(t, err)
done := make(chan struct{})
diff --git a/management/server/policy.go b/management/server/policy.go
index c7872591d..2d3abc3f1 100644
--- a/management/server/policy.go
+++ b/management/server/policy.go
@@ -3,13 +3,13 @@ package server
import (
"context"
_ "embed"
- "slices"
"strconv"
"strings"
+ "github.com/netbirdio/netbird/management/proto"
+ "github.com/rs/xid"
log "github.com/sirupsen/logrus"
- "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -125,6 +125,7 @@ type PolicyRule struct {
func (pm *PolicyRule) Copy() *PolicyRule {
rule := &PolicyRule{
ID: pm.ID,
+ PolicyID: pm.PolicyID,
Name: pm.Name,
Description: pm.Description,
Enabled: pm.Enabled,
@@ -171,6 +172,7 @@ type Policy struct {
func (p *Policy) Copy() *Policy {
c := &Policy{
ID: p.ID,
+ AccountID: p.AccountID,
Name: p.Name,
Description: p.Description,
Enabled: p.Enabled,
@@ -343,157 +345,207 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, err
}
- if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
- return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
+ if user.AccountID != accountID {
+ return nil, status.NewUserNotPartOfAccountError()
}
- return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
+ if user.IsRegularUser() {
+ return nil, status.NewAdminPermissionError()
+ }
+
+ return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
}
// SavePolicy in the store
-func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error {
+func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
- account, err := am.Store.GetAccount(ctx, accountID)
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
- return err
+ return nil, err
}
- updateAccountPeers, err := am.savePolicy(account, policy, isUpdate)
+ if user.AccountID != accountID {
+ return nil, status.NewUserNotPartOfAccountError()
+ }
+
+ if user.IsRegularUser() {
+ return nil, status.NewAdminPermissionError()
+ }
+
+ var isUpdate = policy.ID != ""
+ var updateAccountPeers bool
+ var action = activity.PolicyAdded
+
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ if err = validatePolicy(ctx, transaction, accountID, policy); err != nil {
+ return err
+ }
+
+ updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate)
+ if err != nil {
+ return err
+ }
+
+ if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
+ return err
+ }
+
+ saveFunc := transaction.CreatePolicy
+ if isUpdate {
+ action = activity.PolicyUpdated
+ saveFunc = transaction.SavePolicy
+ }
+
+ return saveFunc(ctx, LockingStrengthUpdate, policy)
+ })
if err != nil {
- return err
+ return nil, err
}
- account.Network.IncSerial()
- if err = am.Store.SaveAccount(ctx, account); err != nil {
- return err
- }
-
- action := activity.PolicyAdded
- if isUpdate {
- action = activity.PolicyUpdated
- }
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
+ return policy, nil
+}
+
+// DeletePolicy from the store
+func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
+ unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
+ defer unlock()
+
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
+ if err != nil {
+ return err
+ }
+
+ if user.AccountID != accountID {
+ return status.NewUserNotPartOfAccountError()
+ }
+
+ if user.IsRegularUser() {
+ return status.NewAdminPermissionError()
+ }
+
+ var policy *Policy
+ var updateAccountPeers bool
+
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID)
+ if err != nil {
+ return err
+ }
+
+ updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false)
+ if err != nil {
+ return err
+ }
+
+ if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
+ return err
+ }
+
+ return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID)
+ })
+ if err != nil {
+ return err
+ }
+
+ am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
+
+ if updateAccountPeers {
+ am.updateAccountPeers(ctx, accountID)
+ }
+
return nil
}
-// DeletePolicy from the store
-func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
- if err != nil {
- return err
- }
-
- policy, err := am.deletePolicy(account, policyID)
- if err != nil {
- return err
- }
-
- account.Network.IncSerial()
- if err = am.Store.SaveAccount(ctx, account); err != nil {
- return err
- }
-
- am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
-
- if am.anyGroupHasPeers(account, policy.ruleGroups()) {
- am.updateAccountPeers(ctx, accountID)
- }
-
- return nil
-}
-
-// ListPolicies from the store
+// ListPolicies from the store.
func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
- if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
- return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
+ if user.AccountID != accountID {
+ return nil, status.NewUserNotPartOfAccountError()
+ }
+
+ if user.IsRegularUser() {
+ return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID)
}
-func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) {
- policyIdx := -1
- for i, policy := range account.Policies {
- if policy.ID == policyID {
- policyIdx = i
- break
- }
- }
- if policyIdx < 0 {
- return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID)
- }
-
- policy := account.Policies[policyIdx]
- account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...)
- return policy, nil
-}
-
-// savePolicy saves or updates a policy in the given account.
-// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy.
-func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) {
- for index, rule := range policyToSave.Rules {
- rule.Sources = filterValidGroupIDs(account, rule.Sources)
- rule.Destinations = filterValidGroupIDs(account, rule.Destinations)
- policyToSave.Rules[index] = rule
- }
-
- if policyToSave.SourcePostureChecks != nil {
- policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks)
- }
-
+// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
+func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) {
if isUpdate {
- policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID })
- if policyIdx < 0 {
- return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID)
+ existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
+ if err != nil {
+ return false, err
}
- oldPolicy := account.Policies[policyIdx]
- // Update the existing policy
- account.Policies[policyIdx] = policyToSave
-
- if !policyToSave.Enabled && !oldPolicy.Enabled {
+ if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
- updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups())
- return updateAccountPeers, nil
- }
+ hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups())
+ if err != nil {
+ return false, err
+ }
- // Add the new policy to the account
- account.Policies = append(account.Policies, policyToSave)
-
- return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil
-}
-
-func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
- result := make([]*proto.FirewallRule, len(rules))
- for i := range rules {
- rule := rules[i]
-
- result[i] = &proto.FirewallRule{
- PeerIP: rule.PeerIP,
- Direction: getProtoDirection(rule.Direction),
- Action: getProtoAction(rule.Action),
- Protocol: getProtoProtocol(rule.Protocol),
- Port: rule.Port,
+ if hasPeers {
+ return true, nil
}
}
- return result
+
+ return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups())
+}
+
+// validatePolicy validates the policy and its rules.
+func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error {
+ if policy.ID != "" {
+ _, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID)
+ if err != nil {
+ return err
+ }
+ } else {
+ policy.ID = xid.New().String()
+ policy.AccountID = accountID
+ }
+
+ groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups())
+ if err != nil {
+ return err
+ }
+
+ postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks)
+ if err != nil {
+ return err
+ }
+
+ for i, rule := range policy.Rules {
+ ruleCopy := rule.Copy()
+ if ruleCopy.ID == "" {
+ ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor
+ ruleCopy.PolicyID = policy.ID
+ }
+
+ ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources)
+ ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations)
+ policy.Rules[i] = ruleCopy
+ }
+
+ if policy.SourcePostureChecks != nil {
+ policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks)
+ }
+
+ return nil
}
// getAllPeersFromGroups for given peer ID and list of groups
@@ -574,27 +626,42 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
return nil
}
-// filterValidPostureChecks filters and returns the posture check IDs from the given list
-// that are valid within the provided account.
-func filterValidPostureChecks(account *Account, postureChecksIds []string) []string {
- result := make([]string, 0, len(postureChecksIds))
+// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list.
+func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string {
+ validIDs := make([]string, 0, len(postureChecksIds))
for _, id := range postureChecksIds {
- for _, postureCheck := range account.PostureChecks {
- if id == postureCheck.ID {
- result = append(result, id)
- continue
- }
+ if _, exists := postureChecks[id]; exists {
+ validIDs = append(validIDs, id)
}
}
- return result
+
+ return validIDs
}
-// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map.
-func filterValidGroupIDs(account *Account, groupIDs []string) []string {
- result := make([]string, 0, len(groupIDs))
- for _, groupID := range groupIDs {
- if _, exists := account.Groups[groupID]; exists {
- result = append(result, groupID)
+// getValidGroupIDs filters and returns only the valid group IDs from the provided list.
+func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string {
+ validIDs := make([]string, 0, len(groupIDs))
+ for _, id := range groupIDs {
+ if _, exists := groups[id]; exists {
+ validIDs = append(validIDs, id)
+ }
+ }
+
+ return validIDs
+}
+
+// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules.
+func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule {
+ result := make([]*proto.FirewallRule, len(rules))
+ for i := range rules {
+ rule := rules[i]
+
+ result[i] = &proto.FirewallRule{
+ PeerIP: rule.PeerIP,
+ Direction: getProtoDirection(rule.Direction),
+ Action: getProtoAction(rule.Action),
+ Protocol: getProtoProtocol(rule.Protocol),
+ Port: rule.Port,
}
}
return result
diff --git a/management/server/policy_test.go b/management/server/policy_test.go
index e7f0f9cd2..62d80f46e 100644
--- a/management/server/policy_test.go
+++ b/management/server/policy_test.go
@@ -7,7 +7,6 @@ import (
"testing"
"time"
- "github.com/rs/xid"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/slices"
@@ -859,14 +858,23 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
+ var policyWithGroupRulesNoPeers *Policy
+ var policyWithDestinationPeersOnly *Policy
+ var policyWithSourceAndDestinationPeers *Policy
+
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
t.Run("saving policy with rule groups with no peers", func(t *testing.T) {
- policy := Policy{
- ID: "policy-rule-groups-no-peers",
- Enabled: true,
+ done := make(chan struct{})
+ go func() {
+ peerShouldNotReceiveUpdate(t, updMsg)
+ close(done)
+ }()
+
+ policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
+ AccountID: account.Id,
+ Enabled: true,
Rules: []*PolicyRule{
{
- ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupC"},
@@ -874,15 +882,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
- }
-
- done := make(chan struct{})
- go func() {
- peerShouldNotReceiveUpdate(t, updMsg)
- close(done)
- }()
-
- err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
+ })
assert.NoError(t, err)
select {
@@ -895,12 +895,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with source group containing peers, but destination group without peers should
// update account's peers and send peer update
t.Run("saving policy where source has peers but destination does not", func(t *testing.T) {
- policy := Policy{
- ID: "policy-source-has-peers-destination-none",
- Enabled: true,
+ done := make(chan struct{})
+ go func() {
+ peerShouldReceiveUpdate(t, updMsg)
+ close(done)
+ }()
+
+ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
+ AccountID: account.Id,
+ Enabled: true,
Rules: []*PolicyRule{
{
- ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupB"},
@@ -909,15 +914,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
- }
-
- done := make(chan struct{})
- go func() {
- peerShouldReceiveUpdate(t, updMsg)
- close(done)
- }()
-
- err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
+ })
assert.NoError(t, err)
select {
@@ -930,13 +927,18 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with destination group containing peers, but source group without peers should
// update account's peers and send peer update
t.Run("saving policy where destination has peers but source does not", func(t *testing.T) {
- policy := Policy{
- ID: "policy-destination-has-peers-source-none",
- Enabled: true,
+ done := make(chan struct{})
+ go func() {
+ peerShouldReceiveUpdate(t, updMsg)
+ close(done)
+ }()
+
+ policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
+ AccountID: account.Id,
+ Enabled: true,
Rules: []*PolicyRule{
{
- ID: xid.New().String(),
- Enabled: false,
+ Enabled: true,
Sources: []string{"groupC"},
Destinations: []string{"groupD"},
Bidirectional: true,
@@ -944,15 +946,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
- }
-
- done := make(chan struct{})
- go func() {
- peerShouldReceiveUpdate(t, updMsg)
- close(done)
- }()
-
- err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
+ })
assert.NoError(t, err)
select {
@@ -965,12 +959,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Saving policy with destination and source groups containing peers should update account's peers
// and send peer update
t.Run("saving policy with source and destination groups with peers", func(t *testing.T) {
- policy := Policy{
- ID: "policy-source-destination-peers",
- Enabled: true,
+ done := make(chan struct{})
+ go func() {
+ peerShouldReceiveUpdate(t, updMsg)
+ close(done)
+ }()
+
+ policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
+ AccountID: account.Id,
+ Enabled: true,
Rules: []*PolicyRule{
{
- ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
@@ -978,15 +977,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
Action: PolicyTrafficActionAccept,
},
},
- }
-
- done := make(chan struct{})
- go func() {
- peerShouldReceiveUpdate(t, updMsg)
- close(done)
- }()
-
- err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
+ })
assert.NoError(t, err)
select {
@@ -999,28 +990,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Disabling policy with destination and source groups containing peers should update account's peers
// and send peer update
t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) {
- policy := Policy{
- ID: "policy-source-destination-peers",
- Enabled: false,
- Rules: []*PolicyRule{
- {
- ID: xid.New().String(),
- Enabled: true,
- Sources: []string{"groupA"},
- Destinations: []string{"groupD"},
- Bidirectional: true,
- Action: PolicyTrafficActionAccept,
- },
- },
- }
-
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
- err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
+ policyWithSourceAndDestinationPeers.Enabled = false
+ policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
assert.NoError(t, err)
select {
@@ -1033,29 +1010,15 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Updating disabled policy with destination and source groups containing peers should not update account's peers
// or send peer update
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
- policy := Policy{
- ID: "policy-source-destination-peers",
- Description: "updated description",
- Enabled: false,
- Rules: []*PolicyRule{
- {
- ID: xid.New().String(),
- Enabled: true,
- Sources: []string{"groupA"},
- Destinations: []string{"groupA"},
- Bidirectional: true,
- Action: PolicyTrafficActionAccept,
- },
- },
- }
-
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
- err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
+ policyWithSourceAndDestinationPeers.Description = "updated description"
+ policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"}
+ policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
assert.NoError(t, err)
select {
@@ -1068,28 +1031,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Enabling policy with destination and source groups containing peers should update account's peers
// and send peer update
t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) {
- policy := Policy{
- ID: "policy-source-destination-peers",
- Enabled: true,
- Rules: []*PolicyRule{
- {
- ID: xid.New().String(),
- Enabled: true,
- Sources: []string{"groupA"},
- Destinations: []string{"groupD"},
- Bidirectional: true,
- Action: PolicyTrafficActionAccept,
- },
- },
- }
-
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
- err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
+ policyWithSourceAndDestinationPeers.Enabled = true
+ policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers)
assert.NoError(t, err)
select {
@@ -1101,15 +1050,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Deleting policy should trigger account peers update and send peer update
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
- policyID := "policy-source-destination-peers"
-
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
- err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
+ err := manager.DeletePolicy(context.Background(), account.Id, policyWithSourceAndDestinationPeers.ID, userID)
assert.NoError(t, err)
select {
@@ -1123,14 +1070,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Deleting policy with destination group containing peers, but source group without peers should
// update account's peers and send peer update
t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) {
- policyID := "policy-destination-has-peers-source-none"
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
- err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
+ err := manager.DeletePolicy(context.Background(), account.Id, policyWithDestinationPeersOnly.ID, userID)
assert.NoError(t, err)
select {
@@ -1142,14 +1088,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Deleting policy with no peers in groups should not update account's peers and not send peer update
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
- policyID := "policy-rule-groups-no-peers"
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
- err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID)
+ err := manager.DeletePolicy(context.Background(), account.Id, policyWithGroupRulesNoPeers.ID, userID)
assert.NoError(t, err)
select {
diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go
index 3c5c5fc79..93e5741cf 100644
--- a/management/server/posture_checks_test.go
+++ b/management/server/posture_checks_test.go
@@ -5,7 +5,6 @@ import (
"testing"
"time"
- "github.com/rs/xid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -210,12 +209,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}
})
- policy := Policy{
- ID: "policyA",
+ policy := &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
- ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
@@ -234,7 +231,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
close(done)
}()
- err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
+ policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
assert.NoError(t, err)
select {
@@ -282,8 +279,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}()
policy.SourcePostureChecks = []string{}
-
- err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
+ _, err := manager.SavePolicy(context.Background(), account.Id, userID, policy)
assert.NoError(t, err)
select {
@@ -316,12 +312,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update
t.Run("updating linked posture check to policy with no peers", func(t *testing.T) {
- policy = Policy{
- ID: "policyB",
+ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
- ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupC"},
@@ -330,8 +324,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
SourcePostureChecks: []string{postureCheckB.ID},
- }
- err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
+ })
assert.NoError(t, err)
done := make(chan struct{})
@@ -362,12 +355,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
})
- policy = Policy{
- ID: "policyB",
+
+ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
- ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupB"},
Destinations: []string{"groupA"},
@@ -376,9 +368,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
SourcePostureChecks: []string{postureCheckB.ID},
- }
-
- err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
+ })
assert.NoError(t, err)
done := make(chan struct{})
@@ -405,8 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
// Updating linked client posture check to policy where source has peers but destination does not,
// should trigger account peers update and send peer update
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
- policy = Policy{
- ID: "policyB",
+ _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -418,8 +407,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
},
},
SourcePostureChecks: []string{postureCheckB.ID},
- }
- err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
+ })
assert.NoError(t, err)
done := make(chan struct{})
@@ -490,12 +478,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
require.NoError(t, err, "failed to save postureCheckB")
policy := &Policy{
- ID: "policyA",
AccountID: account.Id,
Rules: []*PolicyRule{
{
- ID: "ruleA",
- PolicyID: "policyA",
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupA"},
@@ -504,7 +489,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
SourcePostureChecks: []string{postureCheckA.ID},
}
- err = manager.SavePolicy(context.Background(), account.Id, userID, policy, false)
+ policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
require.NoError(t, err, "failed to save policy")
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
@@ -528,7 +513,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
policy.Rules[0].Sources = []string{"groupB"}
policy.Rules[0].Destinations = []string{"groupA"}
- err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
+ _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
@@ -539,7 +524,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
policy.Rules[0].Sources = []string{"groupA"}
policy.Rules[0].Destinations = []string{"groupB"}
- err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
+ _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
@@ -560,7 +545,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
policy.Rules[0].Sources = []string{"nonExistentGroup"}
policy.Rules[0].Destinations = []string{"nonExistentGroup"}
- err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true)
+ _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
diff --git a/management/server/route_test.go b/management/server/route_test.go
index 5c848f68c..108f791e0 100644
--- a/management/server/route_test.go
+++ b/management/server/route_test.go
@@ -1214,12 +1214,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) {
defaultRule := rules[0]
newPolicy := defaultRule.Copy()
- newPolicy.ID = xid.New().String()
newPolicy.Name = "peer1 only"
newPolicy.Rules[0].Sources = []string{newGroup.ID}
newPolicy.Rules[0].Destinations = []string{newGroup.ID}
- err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false)
+ _, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy)
require.NoError(t, err)
err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID)
diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go
index 7c8200706..614547c60 100644
--- a/management/server/setupkey_test.go
+++ b/management/server/setupkey_test.go
@@ -406,8 +406,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
})
assert.NoError(t, err)
- policy := Policy{
- ID: "policy",
+ policy := &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -419,7 +418,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
},
},
}
- err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
+ _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index 47c17bb92..9a24857d1 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -1243,8 +1243,8 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren
var groups []*nbgroup.Group
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil {
- log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error)
- return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store")
+ log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store")
}
groupsMap := make(map[string]*nbgroup.Group)
@@ -1295,12 +1295,67 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
- return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID)
+ var policies []*Policy
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
+ Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
+ if err := result.Error; err != nil {
+ log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "failed to get policies from store")
+ }
+
+ return policies, nil
}
// GetPolicyByID retrieves a policy by its ID and account ID.
-func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
- return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID)
+func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) {
+ var policy *Policy
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
+ First(&policy, accountAndIDQueryCondition, accountID, policyID)
+ if err := result.Error; err != nil {
+ if errors.Is(err, gorm.ErrRecordNotFound) {
+ return nil, status.NewPolicyNotFoundError(policyID)
+ }
+ log.WithContext(ctx).Errorf("failed to get policy from store: %s", err)
+ return nil, status.Errorf(status.Internal, "failed to get policy from store")
+ }
+
+ return policy, nil
+}
+
+func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy)
+ if result.Error != nil {
+ log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error)
+ return status.Errorf(status.Internal, "failed to create policy in store")
+ }
+
+ return nil
+}
+
+// SavePolicy saves a policy to the database.
+func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
+ result := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
+ Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy)
+ if err := result.Error; err != nil {
+ log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
+ return status.Errorf(status.Internal, "failed to save policy to store")
+ }
+ return nil
+}
+
+func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error {
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
+ Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID)
+ if err := result.Error; err != nil {
+ log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err)
+ return status.Errorf(status.Internal, "failed to delete policy from store")
+ }
+
+ if result.RowsAffected == 0 {
+ return status.NewPolicyNotFoundError(policyID)
+ }
+
+ return nil
}
// GetAccountPostureChecks retrieves posture checks for an account.
@@ -1331,6 +1386,23 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
return postureCheck, nil
}
+// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
+func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
+ var postureChecks []*posture.Checks
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
+ if result.Error != nil {
+ log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error)
+ return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store")
+ }
+
+ postureChecksMap := make(map[string]*posture.Checks)
+ for _, postureCheck := range postureChecks {
+ postureChecksMap[postureCheck.ID] = postureCheck
+ }
+
+ return postureChecksMap, nil
+}
+
// SavePostureChecks saves a posture checks to the database.
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go
index de939e8d0..c05793fc6 100644
--- a/management/server/sql_store_test.go
+++ b/management/server/sql_store_test.go
@@ -1612,6 +1612,49 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) {
}
}
+func TestSqlStore_GetPostureChecksByIDs(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ tests := []struct {
+ name string
+ postureCheckIDs []string
+ expectedCount int
+ }{
+ {
+ name: "retrieve existing posture checks by existing IDs",
+ postureCheckIDs: []string{"csplshq7qv948l48f7t0", "cspnllq7qv95uq1r4k90"},
+ expectedCount: 2,
+ },
+ {
+ name: "empty posture check IDs list",
+ postureCheckIDs: []string{},
+ expectedCount: 0,
+ },
+ {
+ name: "non-existing posture check IDs",
+ postureCheckIDs: []string{"nonexistent1", "nonexistent2"},
+ expectedCount: 0,
+ },
+ {
+ name: "mixed existing and non-existing posture check IDs",
+ postureCheckIDs: []string{"cspnllq7qv95uq1r4k90", "nonexistent"},
+ expectedCount: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ groups, err := store.GetPostureChecksByIDs(context.Background(), LockingStrengthShare, accountID, tt.postureCheckIDs)
+ require.NoError(t, err)
+ require.Len(t, groups, tt.expectedCount)
+ })
+ }
+}
+
func TestSqlStore_SavePostureChecks(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
@@ -1699,3 +1742,118 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) {
})
}
}
+
+func TestSqlStore_GetPolicyByID(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ tests := []struct {
+ name string
+ policyID string
+ expectError bool
+ }{
+ {
+ name: "retrieve existing policy",
+ policyID: "cs1tnh0hhcjnqoiuebf0",
+ expectError: false,
+ },
+ {
+ name: "retrieve non-existing policy checks",
+ policyID: "non-existing",
+ expectError: true,
+ },
+ {
+ name: "retrieve with empty policy ID",
+ policyID: "",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, tt.policyID)
+ if tt.expectError {
+ require.Error(t, err)
+ sErr, ok := status.FromError(err)
+ require.True(t, ok)
+ require.Equal(t, sErr.Type(), status.NotFound)
+ require.Nil(t, policy)
+ } else {
+ require.NoError(t, err)
+ require.NotNil(t, policy)
+ require.Equal(t, tt.policyID, policy.ID)
+ }
+ })
+ }
+}
+
+func TestSqlStore_CreatePolicy(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ policy := &Policy{
+ ID: "policy-id",
+ AccountID: accountID,
+ Enabled: true,
+ Rules: []*PolicyRule{
+ {
+ Enabled: true,
+ Sources: []string{"groupA"},
+ Destinations: []string{"groupC"},
+ Bidirectional: true,
+ Action: PolicyTrafficActionAccept,
+ },
+ },
+ }
+ err = store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy)
+ require.NoError(t, err)
+
+ savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID)
+ require.NoError(t, err)
+ require.Equal(t, savePolicy, policy)
+
+}
+
+func TestSqlStore_SavePolicy(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ policyID := "cs1tnh0hhcjnqoiuebf0"
+
+ policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID)
+ require.NoError(t, err)
+
+ policy.Enabled = false
+ policy.Description = "policy"
+ policy.Rules[0].Sources = []string{"group"}
+ policy.Rules[0].Ports = []string{"80", "443"}
+ err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy)
+ require.NoError(t, err)
+
+ savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID)
+ require.NoError(t, err)
+ require.Equal(t, savePolicy, policy)
+}
+
+func TestSqlStore_DeletePolicy(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+ policyID := "cs1tnh0hhcjnqoiuebf0"
+
+ err = store.DeletePolicy(context.Background(), LockingStrengthShare, accountID, policyID)
+ require.NoError(t, err)
+
+ policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID)
+ require.Error(t, err)
+ require.Nil(t, policy)
+}
diff --git a/management/server/status/error.go b/management/server/status/error.go
index 44391e1f1..0fff53559 100644
--- a/management/server/status/error.go
+++ b/management/server/status/error.go
@@ -144,3 +144,8 @@ func NewGroupNotFoundError(groupID string) error {
func NewPostureChecksNotFoundError(postureChecksID string) error {
return Errorf(NotFound, "posture checks: %s not found", postureChecksID)
}
+
+// NewPolicyNotFoundError creates a new Error with NotFound type for a missing policy
+func NewPolicyNotFoundError(policyID string) error {
+ return Errorf(NotFound, "policy: %s not found", policyID)
+}
diff --git a/management/server/store.go b/management/server/store.go
index 03b5821e7..ba61d552d 100644
--- a/management/server/store.go
+++ b/management/server/store.go
@@ -80,11 +80,15 @@ type Store interface {
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
- GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
+ GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
+ CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
+ SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
+ DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
+ GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql
index 1646ff4da..37db27316 100644
--- a/management/server/testdata/extended-store.sql
+++ b/management/server/testdata/extended-store.sql
@@ -35,4 +35,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-3465
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,'');
INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,'');
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
+INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
INSERT INTO installations VALUES(1,'');
diff --git a/management/server/user_test.go b/management/server/user_test.go
index d4f560a54..498017afa 100644
--- a/management/server/user_test.go
+++ b/management/server/user_test.go
@@ -1279,8 +1279,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
})
require.NoError(t, err)
- policy := Policy{
- ID: "policy",
+ policy := &Policy{
Enabled: true,
Rules: []*PolicyRule{
{
@@ -1292,7 +1291,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
},
},
}
- err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false)
+ _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy)
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
From 0e48a772ff8de32cb6710f8d2f8c36a6bd468551 Mon Sep 17 00:00:00 2001
From: Bethuel Mmbaga
Date: Tue, 26 Nov 2024 15:43:05 +0300
Subject: [PATCH 39/39] [management] Refactor DNS settings to use store methods
(#2883)
---
management/server/dns.go | 164 ++++++++++++++++++++--------
management/server/sql_store.go | 21 +++-
management/server/sql_store_test.go | 63 +++++++++++
management/server/store.go | 1 +
4 files changed, 204 insertions(+), 45 deletions(-)
diff --git a/management/server/dns.go b/management/server/dns.go
index e52be6016..8df211b0b 100644
--- a/management/server/dns.go
+++ b/management/server/dns.go
@@ -3,6 +3,7 @@ package server
import (
"context"
"fmt"
+ "slices"
"strconv"
"sync"
@@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, err
}
- if !user.IsAdminOrServiceUser() || user.AccountID != accountID {
- return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings")
+ if user.AccountID != accountID {
+ return nil, status.NewUserNotPartOfAccountError()
+ }
+
+ if user.IsRegularUser() {
+ return nil, status.NewAdminPermissionError()
}
return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID)
@@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
// SaveDNSSettings validates a user role and updates the account's DNS settings
func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error {
- unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
- defer unlock()
-
- account, err := am.Store.GetAccount(ctx, accountID)
- if err != nil {
- return err
- }
-
- user, err := account.FindUser(userID)
- if err != nil {
- return err
- }
-
- if !user.HasAdminPower() {
- return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings")
- }
-
if dnsSettingsToSave == nil {
return status.Errorf(status.InvalidArgument, "the dns settings provided are nil")
}
- if len(dnsSettingsToSave.DisabledManagementGroups) != 0 {
- err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups)
- if err != nil {
- return err
- }
- }
-
- oldSettings := account.DNSSettings.Copy()
- account.DNSSettings = dnsSettingsToSave.Copy()
-
- addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
- removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
-
- account.Network.IncSerial()
- if err = am.Store.SaveAccount(ctx, account); err != nil {
+ user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
+ if err != nil {
return err
}
- for _, id := range addedGroups {
- group := account.GetGroup(id)
- meta := map[string]any{"group": group.Name, "group_id": group.ID}
- am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
+ if user.AccountID != accountID {
+ return status.NewUserNotPartOfAccountError()
}
- for _, id := range removedGroups {
- group := account.GetGroup(id)
- meta := map[string]any{"group": group.Name, "group_id": group.ID}
- am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
+ if !user.HasAdminPower() {
+ return status.NewAdminPermissionError()
}
- if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) {
+ var updateAccountPeers bool
+ var eventsToStore []func()
+
+ err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
+ if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
+ return err
+ }
+
+ oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID)
+ if err != nil {
+ return err
+ }
+
+ addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
+ removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
+
+ updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
+ if err != nil {
+ return err
+ }
+
+ events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
+ eventsToStore = append(eventsToStore, events...)
+
+ if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
+ return err
+ }
+
+ return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave)
+ })
+ if err != nil {
+ return err
+ }
+
+ for _, storeEvent := range eventsToStore {
+ storeEvent()
+ }
+
+ if updateAccountPeers {
am.updateAccountPeers(ctx, accountID)
}
return nil
}
+// prepareDNSSettingsEvents prepares a list of event functions to be stored.
+func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() {
+ var eventsToStore []func()
+
+ modifiedGroups := slices.Concat(addedGroups, removedGroups)
+ groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups)
+ if err != nil {
+ log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
+ return nil
+ }
+
+ for _, groupID := range addedGroups {
+ group, ok := groups[groupID]
+ if !ok {
+ log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID)
+ continue
+ }
+
+ eventsToStore = append(eventsToStore, func() {
+ meta := map[string]any{"group": group.Name, "group_id": group.ID}
+ am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta)
+ })
+
+ }
+
+ for _, groupID := range removedGroups {
+ group, ok := groups[groupID]
+ if !ok {
+ log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID)
+ continue
+ }
+
+ eventsToStore = append(eventsToStore, func() {
+ meta := map[string]any{"group": group.Name, "group_id": group.ID}
+ am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta)
+ })
+ }
+
+ return eventsToStore
+}
+
+// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
+func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
+ hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups)
+ if err != nil {
+ return false, err
+ }
+
+ if hasPeers {
+ return true, nil
+ }
+
+ return anyGroupHasPeers(ctx, transaction, accountID, removedGroups)
+}
+
+// validateDNSSettings validates the DNS settings.
+func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error {
+ if len(settings.DisabledManagementGroups) == 0 {
+ return nil
+ }
+
+ groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups)
+ if err != nil {
+ return err
+ }
+
+ return validateGroups(settings.DisabledManagementGroups, groups)
+}
+
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
protoUpdate := &proto.DNSConfig{
diff --git a/management/server/sql_store.go b/management/server/sql_store.go
index 9a24857d1..f58ceb1ad 100644
--- a/management/server/sql_store.go
+++ b/management/server/sql_store.go
@@ -1162,9 +1162,10 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
First(&accountDNSSettings, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
- return nil, status.Errorf(status.NotFound, "dns settings not found")
+ return nil, status.NewAccountNotFoundError(accountID)
}
- return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
+ log.WithContext(ctx).Errorf("failed to get dns settings from store: %v", result.Error)
+ return nil, status.Errorf(status.Internal, "failed to get dns settings from store")
}
return &accountDNSSettings.DNSSettings, nil
}
@@ -1537,3 +1538,19 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a
}
return &record, nil
}
+
+// SaveDNSSettings saves the DNS settings to the store.
+func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
+ result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
+ Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
+ if result.Error != nil {
+ log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error)
+ return status.Errorf(status.Internal, "failed to save dns settings to store")
+ }
+
+ if result.RowsAffected == 0 {
+ return status.NewAccountNotFoundError(accountID)
+ }
+
+ return nil
+}
diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go
index c05793fc6..df5294d73 100644
--- a/management/server/sql_store_test.go
+++ b/management/server/sql_store_test.go
@@ -1857,3 +1857,66 @@ func TestSqlStore_DeletePolicy(t *testing.T) {
require.Error(t, err)
require.Nil(t, policy)
}
+
+func TestSqlStore_GetDNSSettings(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ tests := []struct {
+ name string
+ accountID string
+ expectError bool
+ }{
+ {
+ name: "retrieve existing account dns settings",
+ accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
+ expectError: false,
+ },
+ {
+ name: "retrieve non-existing account dns settings",
+ accountID: "non-existing",
+ expectError: true,
+ },
+ {
+ name: "retrieve dns settings with empty account ID",
+ accountID: "",
+ expectError: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, tt.accountID)
+ if tt.expectError {
+ require.Error(t, err)
+ sErr, ok := status.FromError(err)
+ require.True(t, ok)
+ require.Equal(t, sErr.Type(), status.NotFound)
+ require.Nil(t, dnsSettings)
+ } else {
+ require.NoError(t, err)
+ require.NotNil(t, dnsSettings)
+ }
+ })
+ }
+}
+
+func TestSqlStore_SaveDNSSettings(t *testing.T) {
+ store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir())
+ t.Cleanup(cleanup)
+ require.NoError(t, err)
+
+ accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
+
+ dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
+ require.NoError(t, err)
+
+ dnsSettings.DisabledManagementGroups = []string{"groupA", "groupB"}
+ err = store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, dnsSettings)
+ require.NoError(t, err)
+
+ saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID)
+ require.NoError(t, err)
+ require.Equal(t, saveDNSSettings, dnsSettings)
+}
diff --git a/management/server/store.go b/management/server/store.go
index ba61d552d..cca014b52 100644
--- a/management/server/store.go
+++ b/management/server/store.go
@@ -59,6 +59,7 @@ type Store interface {
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
+ SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)