From 39329e12a1f54cc1770a010279fa4902ccd51693 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 13 Nov 2024 13:46:00 +0100 Subject: [PATCH 01/39] [client] Improve state write timeout and abort work early on timeout (#2882) * Improve state write timeout and abort work early on timeout * Don't block on initial persist state --- client/firewall/iptables/manager_linux.go | 8 +++++--- client/firewall/nftables/manager_linux.go | 8 +++++--- client/internal/config.go | 6 +++--- client/internal/dns/server.go | 10 +++++++--- client/internal/statemanager/manager.go | 20 +++++--------------- client/internal/statemanager/path.go | 22 +++++----------------- management/server/file_store.go | 2 +- util/file.go | 23 ++++++++++++++++++----- util/file_test.go | 7 ++++--- 9 files changed, 53 insertions(+), 53 deletions(-) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index a59bd2c60..adb8f20ef 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } // persist early to ensure cleanup of chains - if err := stateManager.PersistState(context.Background()); err != nil { - log.Errorf("failed to persist state: %v", err) - } + go func() { + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + }() return nil } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index ea8912f27..3f8fac249 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } // persist early - if err := stateManager.PersistState(context.Background()); err != nil { - log.Errorf("failed to persist state: %v", err) - } + go func() { + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + }() return nil } diff --git a/client/internal/config.go b/client/internal/config.go index ee54c6380..ce87835cd 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { if err != nil { return nil, err } - err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) + err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg) return cfg, err } @@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) { // WriteOutConfig write put the prepared config to the given path func WriteOutConfig(path string, config *Config) error { - return util.WriteJson(path, config) + return util.WriteJson(context.Background(), path, config) } // createNewConfig creates a new config generating a new Wireguard key and saving to file @@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) { } if updated { - if err := util.WriteJson(input.ConfigPath, config); err != nil { + if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil { return nil, err } } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 929e1e60c..6c4dccae7 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -326,9 +326,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // persist dns state right away ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) defer cancel() - if err := s.stateManager.PersistState(ctx); err != nil { - log.Errorf("Failed to persist dns state: %v", err) - } + + // don't block + go func() { + if err := s.stateManager.PersistState(ctx); err != nil { + log.Errorf("Failed to persist dns state: %v", err) + } + }() if s.searchDomainNotifier != nil { s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index a5a14f807..580ccdfc7 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -16,6 +16,7 @@ import ( "golang.org/x/exp/maps" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/util" ) // State interface defines the methods that all state types must implement @@ -178,25 +179,14 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() done := make(chan error, 1) + start := time.Now() go func() { - data, err := json.MarshalIndent(m.states, "", " ") - if err != nil { - done <- fmt.Errorf("marshal states: %w", err) - return - } - - // nolint:gosec - if err := os.WriteFile(m.filePath, data, 0640); err != nil { - done <- fmt.Errorf("write state file: %w", err) - return - } - - done <- nil + done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states) }() select { @@ -208,7 +198,7 @@ func (m *Manager) PersistState(ctx context.Context) error { } } - log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty)) + log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start)) clear(m.dirty) diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go index 96d6a9f12..6cfd79a12 100644 --- a/client/internal/statemanager/path.go +++ b/client/internal/statemanager/path.go @@ -4,32 +4,20 @@ import ( "os" "path/filepath" "runtime" - - log "github.com/sirupsen/logrus" ) // GetDefaultStatePath returns the path to the state file based on the operating system -// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist. +// It returns an empty string if the path cannot be determined. func GetDefaultStatePath() string { - var path string - switch runtime.GOOS { case "windows": - path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") + return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") case "darwin", "linux": - path = "/var/lib/netbird/state.json" + return "/var/lib/netbird/state.json" case "freebsd", "openbsd", "netbsd", "dragonfly": - path = "/var/db/netbird/state.json" - // ios/android don't need state - default: - return "" + return "/var/db/netbird/state.json" } - dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0755); err != nil { - log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) - return "" - } + return "" - return path } diff --git a/management/server/file_store.go b/management/server/file_store.go index 561e133ce..f375fb990 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { // It is recommended to call it with locking FileStore.mux func (s *FileStore) persist(ctx context.Context, file string) error { start := time.Now() - err := util.WriteJson(file, s) + err := util.WriteJson(context.Background(), file, s) if err != nil { return err } diff --git a/util/file.go b/util/file.go index ecaecd222..4641cc1b8 100644 --- a/util/file.go +++ b/util/file.go @@ -15,7 +15,7 @@ import ( ) // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory -func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { +func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err @@ -26,18 +26,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { return err } - return writeJson(file, obj, configDir, configFileName) + return writeJson(ctx, file, obj, configDir, configFileName) } // WriteJson writes JSON config object to a file creating parent directories if required // The output JSON is pretty-formatted -func WriteJson(file string, obj interface{}) error { +func WriteJson(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err } - return writeJson(file, obj, configDir, configFileName) + return writeJson(ctx, file, obj, configDir, configFileName) } // DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file @@ -79,7 +79,11 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { return nil } -func writeJson(file string, obj interface{}, configDir string, configFileName string) error { +func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error { + // Check context before expensive operations + if ctx.Err() != nil { + return ctx.Err() + } // make it pretty bs, err := json.MarshalIndent(obj, "", " ") @@ -87,6 +91,10 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st return err } + if ctx.Err() != nil { + return ctx.Err() + } + tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) if err != nil { return err @@ -111,6 +119,11 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st return err } + // Check context again + if ctx.Err() != nil { + return ctx.Err() + } + err = os.Rename(tempFileName, file) if err != nil { return err diff --git a/util/file_test.go b/util/file_test.go index 566d8eda6..f8c9dfabb 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -1,6 +1,7 @@ package util import ( + "context" "crypto/md5" "encoding/hex" "io" @@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tmpDir := t.TempDir() - err := WriteJson(tmpDir+"/testconfig.json", tt.config) + err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config) require.NoError(t, err) read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) @@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) { src := tmpDir + "/copytest_src" dst := tmpDir + "/copytest_dst" - err := WriteJson(src, tt.srcContent) + err := WriteJson(context.Background(), src, tt.srcContent) require.NoError(t, err) err = CopyFileContents(src, dst) @@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) { _ = os.Remove(cfgFile) }() - err := WriteJson(cfgFile, tt.config) + err := WriteJson(context.Background(), cfgFile, tt.config) require.NoError(t, err) read, err := ReadJson(cfgFile, &TestConfig{}) From ed047ec9dda048120edf4f074162a27136ac3cd6 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 13 Nov 2024 16:16:30 +0300 Subject: [PATCH 02/39] Add account locking and merge group deletion methods Signed-off-by: bcmmbaga --- management/server/group.go | 66 ++++++++++------------------------ management/server/sql_store.go | 2 +- 2 files changed, 20 insertions(+), 48 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index 57960e7f9..154a33b13 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -215,48 +215,9 @@ func difference(a, b []string) []string { // DeleteGroup object of the peers. func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) - if err != nil { - return err - } - - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() - } - - if user.IsRegularUser() { - return status.NewAdminPermissionError() - } - - var group *nbgroup.Group - - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } - - if group.IsGroupAll() { - return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") - } - - if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil { - return err - } - - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { - return err - } - - return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID) - }) - if err != nil { - return err - } - - am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta()) - - return nil + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + return am.DeleteGroups(ctx, accountID, userID, []string{groupID}) } // DeleteGroups deletes groups from an account. @@ -285,13 +246,14 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { for _, groupID := range groupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID) if err != nil { + allErrors = errors.Join(allErrors, err) continue } if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { - allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) + allErrors = errors.Join(allErrors, err) continue } @@ -318,12 +280,15 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var group *nbgroup.Group var updateAccountPeers bool var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -356,12 +321,15 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var group *nbgroup.Group var updateAccountPeers bool var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -430,13 +398,17 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. if group.Issued == nbgroup.GroupIssuedIntegration { executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return status.Errorf(status.NotFound, "user not found") + return err } if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") } } + if group.IsGroupAll() { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 7c741d35c..0ebda6440 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1278,7 +1278,7 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) - return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete groups from store") } return nil From a4d905ffe77881b682a4798d5564b89860404a0a Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 13 Nov 2024 16:56:22 +0300 Subject: [PATCH 03/39] Fix tests Signed-off-by: bcmmbaga --- management/server/group_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/group_test.go b/management/server/group_test.go index 89184e819..59094a23e 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { { name: "delete non-existent group", groupIDs: []string{"non-existent-group"}, - expectedDeleted: []string{"non-existent-group"}, + expectedReasons: []string{"group: non-existent-group not found"}, }, { name: "delete multiple groups with mixed results", From b48afd92fdc2080cd0b4bf7bd4c046684b76338a Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 13 Nov 2024 15:02:51 +0100 Subject: [PATCH 04/39] [relay-server] Always close ws conn when work thread exit (#2879) Close ws conn when work thread exit --- relay/server/peer.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/relay/server/peer.go b/relay/server/peer.go index c909c35d5..f65fb786a 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -16,6 +16,8 @@ import ( const ( bufferSize = 8820 + + errCloseConn = "failed to close connection to peer: %s" ) // Peer represents a peer connection @@ -46,6 +48,12 @@ func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) * // It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle // the message accordingly. func (p *Peer) Work() { + defer func() { + if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + p.log.Errorf(errCloseConn, err) + } + }() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -97,7 +105,7 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc * case messages.MsgTypeClose: p.log.Infof("peer exited gracefully") if err := p.conn.Close(); err != nil { - log.Errorf("failed to close connection to peer: %s", err) + log.Errorf(errCloseConn, err) } default: p.log.Warnf("received unexpected message type: %s", msgType) @@ -121,9 +129,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) { p.log.Errorf("failed to send close message to peer: %s", p.String()) } - err = p.conn.Close() - if err != nil { - p.log.Errorf("failed to close connection to peer: %s", err) + if err := p.conn.Close(); err != nil { + p.log.Errorf(errCloseConn, err) } } @@ -132,7 +139,7 @@ func (p *Peer) Close() { defer p.connMu.Unlock() if err := p.conn.Close(); err != nil { - p.log.Errorf("failed to close connection to peer: %s", err) + p.log.Errorf(errCloseConn, err) } } From 6886691213aa622b1f0f7440c83a3a4c4831cf3d Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 13 Nov 2024 15:21:33 +0100 Subject: [PATCH 05/39] Update route calculation tests (#2884) - Add two new test cases for p2p and relay routes with same latency - Add extra statuses generation --- client/internal/routemanager/client_test.go | 98 +++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client_test.go index 583156e4d..56fcf1613 100644 --- a/client/internal/routemanager/client_test.go +++ b/client/internal/routemanager/client_test.go @@ -1,6 +1,7 @@ package routemanager import ( + "fmt" "net/netip" "testing" "time" @@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) { currentRoute: "route1", expectedRouteID: "route1", }, + { + name: "relayed routes with latency 0 should maintain previous choice", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + connected: true, + relayed: true, + latency: 0 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: true, + latency: 0 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "route1", + expectedRouteID: "route1", + }, + { + name: "p2p routes with latency 0 should maintain previous choice", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + connected: true, + relayed: false, + latency: 0 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: false, + latency: 0 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "route1", + expectedRouteID: "route1", + }, { name: "current route with bad score should be changed to route with better score", statuses: map[route.ID]routerPeerStatus{ @@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, } + // fill the test data with random routes + for _, tc := range testCases { + for i := 0; i < 50; i++ { + dummyRoute := &route.Route{ + ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)), + Metric: route.MinMetric, + Peer: fmt.Sprintf("dummy_p1_%d", i), + } + tc.existingRoutes[dummyRoute.ID] = dummyRoute + } + for i := 0; i < 50; i++ { + dummyRoute := &route.Route{ + ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)), + Metric: route.MinMetric, + Peer: fmt.Sprintf("dummy_p1_%d", i), + } + tc.existingRoutes[dummyRoute.ID] = dummyRoute + } + + for i := 0; i < 50; i++ { + id := route.ID(fmt.Sprintf("dummy_p1_%d", i)) + dummyStatus := routerPeerStatus{ + connected: false, + relayed: true, + latency: 0, + } + tc.statuses[id] = dummyStatus + } + for i := 0; i < 50; i++ { + id := route.ID(fmt.Sprintf("dummy_p2_%d", i)) + dummyStatus := routerPeerStatus{ + connected: false, + relayed: true, + latency: 0, + } + tc.statuses[id] = dummyStatus + } + } + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { currentRoute := &route.Route{ From be78efbd429c0b7180efe36fe4a826436e5a675b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 14 Nov 2024 20:15:16 +0100 Subject: [PATCH 06/39] [client] Handle panic on nil wg interface (#2891) --- client/internal/engine.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 190d795cd..0f3a5d28a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -38,7 +38,6 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" - nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" @@ -171,7 +170,7 @@ type Engine struct { relayManager *relayClient.Manager stateManager *statemanager.Manager - srWatcher *guard.SRWatcher + srWatcher *guard.SRWatcher } // Peer is an instance of the Connection Peer @@ -641,6 +640,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { } func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { + if e.wgInterface == nil { + return errors.New("wireguard interface is not initialized") + } + if e.wgInterface.Address().String() != conf.Address { oldAddr := e.wgInterface.Address().String() log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address) From 44e799c687ed1fd5e6a658aff3a06ac1594cec69 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 15 Nov 2024 11:16:16 +0100 Subject: [PATCH 07/39] [management] Fix limited peer view groups (#2894) --- management/server/group.go | 12 ++++-------- management/server/http/peers_handler.go | 20 ++++++++++++++++---- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index b2ec88cc0..7b4f07948 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -6,11 +6,12 @@ import ( "fmt" "slices" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/route" "github.com/rs/xid" log "github.com/sirupsen/logrus" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" @@ -27,17 +28,12 @@ func (e *GroupLinkError) Error() string { // CheckGroupPermissions validates if a user has the necessary permissions to view groups func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return err - } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return status.Errorf(status.PermissionDenied, "groups are blocked for users") } diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index a5856a0e4..f5027cd77 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - respBody := make([]*api.PeerBatch, 0, len(account.Peers)) - for _, peer := range account.Peers { + peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groupsMap := map[string]*nbgroup.Group{} + groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + for _, group := range groups { + groupsMap[group.ID] = group + } + + respBody := make([]*api.PeerBatch, 0, len(peers)) + for _, peer := range peers { peerToReturn, err := h.checkPeerStatus(peer) if err != nil { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } @@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { - var groupsInfo []api.GroupMinimum + groupsInfo := []api.GroupMinimum{} groupsChecked := make(map[string]struct{}) for _, group := range groups { _, ok := groupsChecked[group.ID] From 4ef3890bf757f149da7bce3588f62bfbf85b9d8c Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 15 Nov 2024 17:48:00 +0300 Subject: [PATCH 08/39] Fix typo Signed-off-by: bcmmbaga --- management/server/dns.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/dns.go b/management/server/dns.go index be7caea4e..8df211b0b 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -161,7 +161,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID return nil } -// prepareGroupEvents prepares a list of event functions to be stored. +// prepareDNSSettingsEvents prepares a list of event functions to be stored. func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() { var eventsToStore []func() From 4aee3c9e33d6c733d38ac2b6e6287ea78e5ac591 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 15 Nov 2024 16:59:03 +0100 Subject: [PATCH 09/39] [client/management] add peer lock to peer meta update and fix isEqual func (#2840) --- client/internal/engine.go | 12 +++++ client/internal/engine_test.go | 93 ++++++++++++++++++++++++++++++++++ management/server/account.go | 5 +- management/server/peer.go | 3 ++ 4 files changed, 112 insertions(+), 1 deletion(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 0f3a5d28a..d4a3a561a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -11,6 +11,7 @@ import ( "reflect" "runtime" "slices" + "sort" "strings" "sync" "sync/atomic" @@ -1484,6 +1485,17 @@ func (e *Engine) stopDNSServer() { // isChecksEqual checks if two slices of checks are equal. func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { + for _, check := range checks { + sort.Slice(check.Files, func(i, j int) bool { + return check.Files[i] < check.Files[j] + }) + } + for _, oCheck := range oChecks { + sort.Slice(oCheck.Files, func(i, j int) bool { + return oCheck.Files[i] < oCheck.Files[j] + }) + } + return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool { return slices.Equal(checks.Files, oChecks.Files) }) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 0018af6df..b6c6186ea 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) { } } +func Test_CheckFilesEqual(t *testing.T) { + testCases := []struct { + name string + inputChecks1 []*mgmtProto.Checks + inputChecks2 []*mgmtProto.Checks + expectedBool bool + }{ + { + name: "Equal Files In Equal Order Should Return True", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + expectedBool: true, + }, + { + name: "Equal Files In Reverse Order Should Return True", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile2", + "testfile1", + }, + }, + }, + expectedBool: true, + }, + { + name: "Unequal Files Should Return False", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile3", + }, + }, + }, + expectedBool: false, + }, + { + name: "Compared With Empty Should Return False", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{}, + }, + }, + expectedBool: false, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2) + assert.Equal(t, testCase.expectedBool, result, "result should match expected bool") + }) + } +} + func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { diff --git a/management/server/account.go b/management/server/account.go index bf6039229..4afadb4e9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2319,7 +2319,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account) if err != nil { - log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) + log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } return nil @@ -2335,6 +2335,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st unlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer unlock() + unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer unlockPeer() + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err diff --git a/management/server/peer.go b/management/server/peer.go index 9784650de..87b5b0e4e 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -168,6 +168,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context account.UpdatePeer(peer) + log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected) + err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) if err != nil { return false, fmt.Errorf("failed to save peer status: %w", err) @@ -657,6 +659,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac updated := peer.UpdateMetaIfNew(sync.Meta) if updated { + log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) err = am.Store.SavePeer(ctx, account.Id, peer) if err != nil { return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err) From d9b691b8a56bb8bdd11f02364e0dbf59ae0580bc Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 15 Nov 2024 17:00:06 +0100 Subject: [PATCH 10/39] [management] Limit the setup-key update operation (#2841) --- management/server/http/api/openapi.yml | 25 --------------- management/server/http/api/types.gen.go | 15 --------- management/server/http/setupkeys_handler.go | 6 ---- management/server/setupkey.go | 12 ++++--- management/server/setupkey_test.go | 35 ++++++++++++++++++--- 5 files changed, 38 insertions(+), 55 deletions(-) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 9b4592ccf..bfb375277 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -521,19 +521,6 @@ components: SetupKeyRequest: type: object properties: - name: - description: Setup Key name - type: string - example: Default key - type: - description: Setup key type, one-off for single time usage and reusable - type: string - example: reusable - expires_in: - description: Expiration time in seconds, 0 will mean the key never expires - type: integer - minimum: 0 - example: 86400 revoked: description: Setup key revocation status type: boolean @@ -544,21 +531,9 @@ components: items: type: string example: "ch8i4ug6lnn4g9hqv7m0" - usage_limit: - description: A number of times this key can be used. The value of 0 indicates the unlimited usage. - type: integer - example: 0 - ephemeral: - description: Indicate that the peer will be ephemeral or not - type: boolean - example: true required: - - name - - type - - expires_in - revoked - auto_groups - - usage_limit CreateSetupKeyRequest: type: object properties: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index c1ef1ba21..f219c4574 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -1098,23 +1098,8 @@ type SetupKeyRequest struct { // AutoGroups List of group IDs to auto-assign to peers registered with this key AutoGroups []string `json:"auto_groups"` - // Ephemeral Indicate that the peer will be ephemeral or not - Ephemeral *bool `json:"ephemeral,omitempty"` - - // ExpiresIn Expiration time in seconds, 0 will mean the key never expires - ExpiresIn int `json:"expires_in"` - - // Name Setup Key name - Name string `json:"name"` - // Revoked Setup key revocation status Revoked bool `json:"revoked"` - - // Type Setup key type, one-off for single time usage and reusable - Type string `json:"type"` - - // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage. - UsageLimit int `json:"usage_limit"` } // User defines model for User. diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 31859f59b..9ba5977bb 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -137,11 +137,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request return } - if req.Name == "" { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w) - return - } - if req.AutoGroups == nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w) return @@ -150,7 +145,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request newKey := &server.SetupKey{} newKey.AutoGroups = req.AutoGroups newKey.Revoked = req.Revoked - newKey.Name = req.Name newKey.Id = keyID newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 554c66ba4..960532bf9 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -12,9 +12,10 @@ import ( "unicode/utf8" "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" ) const ( @@ -276,7 +277,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s // SaveSetupKey saves the provided SetupKey to the database overriding the existing one. // Due to the unique nature of a SetupKey certain properties must not be overwritten // (e.g. the key itself, creation date, ID, etc). -// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. +// These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key. func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { if keyToSave == nil { return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") @@ -312,9 +313,12 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return err } - // only auto groups, revoked status, and name can be updated for now + if oldKey.Revoked && !keyToSave.Revoked { + return status.Errorf(status.InvalidArgument, "can't un-revoke a revoked setup key") + } + + // only auto groups, revoked status (from false to true) can be updated newKey = oldKey.Copy() - newKey.Name = keyToSave.Name newKey.AutoGroups = keyToSave.AutoGroups newKey.Revoked = keyToSave.Revoked newKey.UpdatedAt = time.Now().UTC() diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 2ed8aef95..94ed022fa 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -56,11 +56,9 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } autoGroups := []string{"group_1", "group_2"} - newKeyName := "my-new-test-key" revoked := true newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ Id: key.Id, - Name: newKeyName, Revoked: revoked, AutoGroups: autoGroups, }, userID) @@ -68,7 +66,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt, + assertKey(t, newKey, keyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt, key.Id, time.Now().UTC(), autoGroups, true) // check the corresponding events that should have been generated @@ -76,7 +74,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { assert.NotNil(t, ev) assert.Equal(t, account.Id, ev.AccountID) - assert.Equal(t, newKeyName, ev.Meta["name"]) + assert.Equal(t, keyName, ev.Meta["name"]) assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"])) assert.NotEmpty(t, ev.Meta["key"]) assert.Equal(t, userID, ev.InitiatorID) @@ -89,7 +87,6 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { autoGroups = append(autoGroups, groupAll.ID) _, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ Id: key.Id, - Name: newKeyName, Revoked: revoked, AutoGroups: autoGroups, }, userID) @@ -449,3 +446,31 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { } }) } + +func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + userID := "testingUser" + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + if err != nil { + t.Fatal(err) + } + + key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false) + assert.NoError(t, err) + + // revoke the key + updateKey := key.Copy() + updateKey.Revoked = true + _, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID) + assert.NoError(t, err) + + // re-activate revoked key + updateKey.Revoked = false + _, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID) + assert.Error(t, err, "should not allow to update revoked key") + +} From 51c1ec283cb9d9dacc9ec18ab6d98b64d954d362 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 15 Nov 2024 19:34:57 +0300 Subject: [PATCH 11/39] Add locks and remove log Signed-off-by: bcmmbaga --- management/server/posture_checks.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index d7b5a79a2..59e726c41 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -9,7 +9,6 @@ import ( "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/rs/xid" - log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" ) @@ -32,6 +31,9 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID // SavePostureChecks saves a posture check. func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err @@ -85,6 +87,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI // DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err @@ -267,7 +272,6 @@ func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountI for _, sourceGroup := range rule.Sources { group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup) if err != nil { - log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err) return false, fmt.Errorf("failed to check peer in policy source group: %w", err) } From 12f442439a64a4b76e8b9c7e93c0aa5de74d213c Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 15 Nov 2024 20:09:32 +0300 Subject: [PATCH 12/39] [management] Refactor group to use store methods (#2867) * Refactor setup key handling to use store methods Signed-off-by: bcmmbaga * add lock to get account groups Signed-off-by: bcmmbaga * add check for regular user Signed-off-by: bcmmbaga * get only required groups for auto-group validation Signed-off-by: bcmmbaga * add account lock and return auto groups map on validation Signed-off-by: bcmmbaga * refactor account peers update Signed-off-by: bcmmbaga * Refactor groups to use store methods Signed-off-by: bcmmbaga * refactor GetGroupByID and add NewGroupNotFoundError Signed-off-by: bcmmbaga * fix tests Signed-off-by: bcmmbaga * Add AddPeer and RemovePeer methods to Group struct Signed-off-by: bcmmbaga * Preserve store engine in SqlStore transactions Signed-off-by: bcmmbaga * Run groups ops in transaction Signed-off-by: bcmmbaga * fix missing group removed from setup key activity Signed-off-by: bcmmbaga * fix merge Signed-off-by: bcmmbaga * fix merge Signed-off-by: bcmmbaga * fix sonar Signed-off-by: bcmmbaga * Change setup key log level to debug for missing group Signed-off-by: bcmmbaga * Retrieve modified peers once for group events Signed-off-by: bcmmbaga * Add tests Signed-off-by: bcmmbaga * Add account locking and merge group deletion methods Signed-off-by: bcmmbaga * Fix tests Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/account.go | 29 +- management/server/account_test.go | 12 +- management/server/dns.go | 2 +- management/server/group.go | 514 ++++++++++-------- management/server/group/group.go | 27 + management/server/group/group_test.go | 90 +++ management/server/group_test.go | 2 +- management/server/integrated_validator.go | 27 +- management/server/mock_server/account_mock.go | 9 - management/server/nameserver.go | 6 +- management/server/peer.go | 41 +- management/server/peer_test.go | 2 +- management/server/policy.go | 4 +- management/server/posture_checks.go | 2 +- management/server/route.go | 6 +- management/server/route_test.go | 2 +- management/server/setupkey.go | 6 +- management/server/sql_store.go | 120 +++- management/server/sql_store_test.go | 285 +++++++++- management/server/status/error.go | 11 +- management/server/store.go | 8 +- management/server/user.go | 8 +- 22 files changed, 878 insertions(+), 335 deletions(-) create mode 100644 management/server/group/group_test.go diff --git a/management/server/account.go b/management/server/account.go index 4afadb4e9..1bd8a99a9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -110,7 +110,6 @@ type AccountManager interface { SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error - ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) @@ -1435,7 +1434,7 @@ func isNil(i idp.Manager) bool { // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { - accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) if err != nil { return err } @@ -2083,7 +2082,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error saving groups: %w", err) } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } @@ -2101,7 +2100,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2114,7 +2113,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2127,14 +2126,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.GroupsPropagationEnabled { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups) if err != nil { - return status.NewGetAccountError(err) + return err } - if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups) + if err != nil { + return err + } + + if removedGroupAffectsPeers || newGroupsAffectsPeers { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } } @@ -2401,12 +2405,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) - updatedAccount, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) - return - } - am.updateAccountPeers(ctx, updatedAccount) + am.updateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { diff --git a/management/server/account_test.go b/management/server/account_test.go index fdf004a3b..97e0d45f0 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1413,11 +1413,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - group := group.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - } + }) + + require.NoError(t, err, "failed to save group") policy := Policy{ Enabled: true, @@ -1460,7 +1462,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } - if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { + if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil { t.Errorf("delete group: %v", err) return } @@ -2714,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 0) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) @@ -2734,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) diff --git a/management/server/dns.go b/management/server/dns.go index 256b8b125..4551be5ab 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -146,7 +146,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/group.go b/management/server/group.go index 7b4f07948..a36213f04 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -33,8 +33,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco return err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "groups are blocked for users") + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewAdminPermissionError() } return nil @@ -45,8 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - - return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) } // GetAllGroups returns all groups in an account @@ -54,13 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) + return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName) } // SaveGroup object of the peers @@ -73,79 +75,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI // SaveGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { - account, err := am.Store.GetAccount(ctx, accountID) +func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var eventsToStore []func() + var groupsToSave []*nbgroup.Group + var updateAccountPeers bool - for _, newGroup := range newGroups { - if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { - return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) - } - - if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := account.FindGroupByName(newGroup.Name) - if err != nil { - s, ok := status.FromError(err) - if !ok || s.ErrorType != status.NotFound { - return err - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err } - // Avoid duplicate groups only for the API issued groups. - // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. - if existingGroup != nil { - return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) - } + newGroup.AccountID = accountID + groupsToSave = append(groupsToSave, newGroup) + groupIDs = append(groupIDs, newGroup.ID) - newGroup.ID = xid.New().String() + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) } - for _, peerID := range newGroup.Peers { - if account.Peers[peerID] == nil { - return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) - } + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + if err != nil { + return err } - oldGroup := account.Groups[newGroup.ID] - account.Groups[newGroup.ID] = newGroup + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } - events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) - eventsToStore = append(eventsToStore, events...) - } - - newGroupIDs := make([]string, 0, len(newGroups)) - for _, newGroup := range newGroups { - newGroupIDs = append(newGroupIDs, newGroup.ID) - } - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave) + }) + if err != nil { return err } - if areGroupChangesAffectPeers(account, newGroupIDs) { - am.updateAccountPeers(ctx, account) - } - for _, storeEvent := range eventsToStore { storeEvent() } + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + return nil } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - if oldGroup != nil { + oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) + if err == nil && oldGroup != nil { addedPeers = difference(newGroup.Peers, oldGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers) } else { @@ -155,35 +152,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID }) } - for _, p := range addedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + modifiedPeers := slices.Concat(addedPeers, removedPeers) + peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers) + if err != nil { + log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) + return nil + } + + for _, peerID := range addedPeers { + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID) continue } - peerCopy := peer // copy to avoid closure issues + eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), - "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), - }) + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) }) } - for _, p := range removedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + for _, peerID := range removedPeers { + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID) continue } - peerCopy := peer // copy to avoid closure issues + eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), - "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), - }) + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) }) } @@ -206,42 +210,10 @@ func difference(a, b []string) []string { } // DeleteGroup object of the peers. -func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - - account, err := am.Store.GetAccount(ctx, accountId) - if err != nil { - return err - } - - group, ok := account.Groups[groupID] - if !ok { - return nil - } - - allGroup, err := account.GetGroupAll() - if err != nil { - return err - } - - if allGroup.ID == groupID { - return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") - } - - if err = validateDeleteGroup(account, group, userId); err != nil { - return err - } - delete(account.Groups, groupID) - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) - - return nil + return am.DeleteGroups(ctx, accountID, userID, []string{groupID}) } // DeleteGroups deletes groups from an account. @@ -250,93 +222,94 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use // // If an error occurs while deleting a group, the function skips it and continues deleting other groups. // Errors are collected and returned at the end. -func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { - account, err := am.Store.GetAccount(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var allErrors error + var groupIDsToDelete []string + var deletedGroups []*nbgroup.Group - deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs)) - for _, groupID := range groupIDs { - group, ok := account.Groups[groupID] - if !ok { - continue + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, groupID := range groupIDs { + group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID) + if err != nil { + allErrors = errors.Join(allErrors, err) + continue + } + + if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { + allErrors = errors.Join(allErrors, err) + continue + } + + groupIDsToDelete = append(groupIDsToDelete, groupID) + deletedGroups = append(deletedGroups, group) } - if err := validateDeleteGroup(account, group, userId); err != nil { - allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) - continue + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err } - delete(account.Groups, groupID) - deletedGroups = append(deletedGroups, group) - } - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete) + }) + if err != nil { return err } - for _, g := range deletedGroups { - am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) + for _, group := range deletedGroups { + am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta()) } return allErrors } -// ListGroups objects of the peers -func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - - groups := make([]*nbgroup.Group, 0, len(account.Groups)) - for _, item := range account.Groups { - groups = append(groups, item) - } - - return groups, nil -} - // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + var group *nbgroup.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.AddPeer(peerID); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + }) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } - - add := true - for _, itemID := range group.Peers { - if itemID == peerID { - add = false - break - } - } - if add { - group.Peers = append(group.Peers, peerID) - } - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - if areGroupChangesAffectPeers(account, []string{group.ID}) { - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) } return nil @@ -347,90 +320,162 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + var group *nbgroup.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.RemovePeer(peerID); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + }) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } - - account.Network.IncSerial() - for i, itemID := range group.Peers { - if itemID == peerID { - group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) - if err := am.Store.SaveAccount(ctx, account); err != nil { - return err - } - } - } - - if areGroupChangesAffectPeers(account, []string{group.ID}) { - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) } return nil } -func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { +// validateNewGroup validates the new group for existence and required fields. +func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error { + if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { + return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) + } + + if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { + existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) + if err != nil { + if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { + return err + } + } + + // Prevent duplicate groups for API-issued groups. + // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. + if existingGroup != nil { + return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) + } + + newGroup.ID = xid.New().String() + } + + for _, peerID := range newGroup.Peers { + _, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) + } + } + + return nil +} + +func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser := account.Users[userID] - if executingUser == nil { - return status.Errorf(status.NotFound, "user not found") + executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err } if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") } } - if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { + if group.IsGroupAll() { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } - if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { + if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"name server groups", linkedDns.Name} } - if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { + if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"policy", linkedPolicy.Name} } - if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { + if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"setup key", linkedSetupKey.Name} } - if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { + if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"user", linkedUser.Id} } - if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { + return checkGroupLinkedToSettings(ctx, transaction, group) +} + +// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. +func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error { + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) { return &GroupLinkError{"disabled DNS management groups", group.Name} } - if account.Settings.Extra != nil { - if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { - return &GroupLinkError{"integrated validator", group.Name} - } + settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { + return &GroupLinkError{"integrated validator", group.Name} } return nil } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { +func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) { + routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) + return false, nil + } + for _, r := range routes { if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { return true, r } } + return false, nil } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { +func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) + return false, nil + } + for _, policy := range policies { for _, rule := range policy.Rules { if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { @@ -442,7 +487,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { +func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) + return false, nil + } + for _, dns := range nameServerGroups { for _, g := range dns.Groups { if g == groupID { @@ -450,11 +501,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou } } } + return false, nil } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { +func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) { + setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) + return false, nil + } + for _, setupKey := range setupKeys { if slices.Contains(setupKey.AutoGroups, groupID) { return true, setupKey @@ -464,7 +522,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { +func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) { + users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) + return false, nil + } + for _, user := range users { if slices.Contains(user.AutoGroups, groupID) { return true, user @@ -473,6 +537,35 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { return false, nil } +// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. +func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { + return false, nil + } + + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err + } + + for _, groupID := range groupIDs { + if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) { + return true, nil + } + if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked { + return true, nil + } + if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked { + return true, nil + } + if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked { + return true, nil + } + } + + return false, nil +} + // anyGroupHasPeers checks if any of the given groups in the account have peers. func anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { @@ -482,22 +575,3 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool { } return false } - -func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool { - for _, groupID := range groupIDs { - if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) { - return true - } - if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked { - return true - } - if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked { - return true - } - if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked { - return true - } - } - - return false -} diff --git a/management/server/group/group.go b/management/server/group/group.go index e98e5ecc4..24c60d3ce 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -54,3 +54,30 @@ func (g *Group) HasPeers() bool { func (g *Group) IsGroupAll() bool { return g.Name == "All" } + +// AddPeer adds peerID to Peers if not present, returning true if added. +func (g *Group) AddPeer(peerID string) bool { + if peerID == "" { + return false + } + + for _, itemID := range g.Peers { + if itemID == peerID { + return false + } + } + + g.Peers = append(g.Peers, peerID) + return true +} + +// RemovePeer removes peerID from Peers if present, returning true if removed. +func (g *Group) RemovePeer(peerID string) bool { + for i, itemID := range g.Peers { + if itemID == peerID { + g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) + return true + } + } + return false +} diff --git a/management/server/group/group_test.go b/management/server/group/group_test.go new file mode 100644 index 000000000..cb002f8d9 --- /dev/null +++ b/management/server/group/group_test.go @@ -0,0 +1,90 @@ +package group + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddPeer(t *testing.T) { + t.Run("add new peer to empty slice", func(t *testing.T) { + group := &Group{Peers: []string{}} + peerID := "peer1" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add new peer to nil slice", func(t *testing.T) { + group := &Group{Peers: nil} + peerID := "peer1" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add new peer to non-empty slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer3" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add duplicate peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer1" + assert.False(t, group.AddPeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("add empty peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "" + assert.False(t, group.AddPeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) +} + +func TestRemovePeer(t *testing.T) { + t.Run("remove existing peer from slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2", "peer3"}} + peerID := "peer2" + assert.True(t, group.RemovePeer(peerID)) + assert.NotContains(t, group.Peers, peerID) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("remove peer from empty slice", func(t *testing.T) { + group := &Group{Peers: []string{}} + peerID := "peer1" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 0, len(group.Peers)) + }) + + t.Run("remove peer from nil slice", func(t *testing.T) { + group := &Group{Peers: nil} + peerID := "peer1" + assert.False(t, group.RemovePeer(peerID)) + assert.Nil(t, group.Peers) + }) + + t.Run("remove non-existent peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer3" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("remove peer from single-item slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1"}} + peerID := "peer1" + assert.True(t, group.RemovePeer(peerID)) + assert.Equal(t, 0, len(group.Peers)) + assert.NotContains(t, group.Peers, peerID) + }) + + t.Run("remove empty peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) +} diff --git a/management/server/group_test.go b/management/server/group_test.go index 89184e819..59094a23e 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { { name: "delete non-existent group", groupIDs: []string{"non-existent-group"}, - expectedDeleted: []string{"non-existent-group"}, + expectedReasons: []string{"group: non-existent-group not found"}, }, { name: "delete multiple groups with mixed results", diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 99e6b204c..0c70b702a 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -52,25 +52,22 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con return am.Store.SaveAccount(ctx, a) } -func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { - if len(groups) == 0 { +func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { return true, nil } - accountsGroups, err := am.ListGroups(ctx, accountId) - if err != nil { - return false, err - } - for _, group := range groups { - var found bool - for _, accountGroup := range accountsGroups { - if accountGroup.ID == group { - found = true - break + + err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, groupID := range groupIDs { + _, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + if err != nil { + return err } } - if !found { - return false, nil - } + return nil + }) + if err != nil { + return false, err } return true, nil diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d7139bb2a..aa6a47b15 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -45,7 +45,6 @@ type MockAccountManager struct { SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error - ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error @@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") } -// ListGroups mock implementation of ListGroups from server.AccountManager interface -func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) { - if am.ListGroupsFunc != nil { - return am.ListGroupsFunc(ctx, accountID) - } - return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented") -} - // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { if am.GroupAddPeerFunc != nil { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 5ebd263dc..957008714 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -71,7 +71,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco } if anyGroupHasPeers(account, newNSGroup.Groups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) @@ -106,7 +106,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun } if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) @@ -136,7 +136,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco } if anyGroupHasPeers(account, nsGroup.Groups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) diff --git a/management/server/peer.go b/management/server/peer.go index 87b5b0e4e..1405dead8 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -133,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil @@ -271,7 +271,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if peerLabelUpdated || requiresPeerUpdates { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return peer, nil @@ -335,7 +335,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - updateAccountPeers := isPeerInActiveGroup(account, peerID) + updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID) + if err != nil { + return err + } err = am.deletePeers(ctx, account, []string{peerID}, userID) if err != nil { @@ -348,7 +351,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -555,7 +558,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed to add peer to account: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, accountID) + err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -598,10 +601,15 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if err != nil { return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err) } - groupsToAdd = append(groupsToAdd, allGroup.ID) - if areGroupChangesAffectPeers(account, groupsToAdd) { - am.updateAccountPeers(ctx, account) + + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd) + if err != nil { + return nil, nil, nil, err + } + + if newGroupsAffectsPeers { + am.updateAccountPeers(ctx, accountID) } approvedPeersMap, err := am.GetValidatedPeers(account) @@ -666,7 +674,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if sync.UpdateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } } @@ -685,7 +693,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if isStatusChanged { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } validPeersMap, err := am.GetValidatedPeers(account) @@ -816,7 +824,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } if updateRemotePeers || isStatusChanged { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) @@ -979,7 +987,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { +func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) { start := time.Now() defer func() { if am.metrics != nil { @@ -987,6 +995,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account } }() + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) + return + } peers := account.GetPeers() approvedPeersMap, err := am.GetValidatedPeers(account) @@ -1033,12 +1046,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} { // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(account *Account, peerID string) bool { +func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) { peerGroupIDs := make([]string, 0) for _, group := range account.Groups { if slices.Contains(group.Peers, peerID) { peerGroupIDs = append(peerGroupIDs, group.ID) } } - return areGroupChangesAffectPeers(account, peerGroupIDs) + return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 78885ea1b..4e2dcb2c3 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -877,7 +877,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - manager.updateAccountPeers(ctx, account) + manager.updateAccountPeers(ctx, account.Id) } duration := time.Since(start) diff --git a/management/server/policy.go b/management/server/policy.go index 43a925f88..8a5733f01 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -377,7 +377,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -406,7 +406,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) if anyGroupHasPeers(account, policy.ruleGroups()) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 2dccd8f59..096cff3f5 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -69,7 +69,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/route.go b/management/server/route.go index 1cf00b37c..dcf2cb0d3 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -238,7 +238,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } if isRouteChangeAffectPeers(account, &newRoute) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -324,7 +324,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) @@ -356,7 +356,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) if isRouteChangeAffectPeers(account, routy) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/route_test.go b/management/server/route_test.go index 4893e19b9..5c848f68c 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1091,7 +1091,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.ListGroups(context.Background(), account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) require.NoError(t, err) var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 960532bf9..cae0dfecb 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -453,14 +453,14 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran modifiedGroups := slices.Concat(addedGroups, removedGroups) groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) if err != nil { - log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err) + log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) return nil } for _, g := range removedGroups { group, ok := groups[g] if !ok { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g) continue } @@ -473,7 +473,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran for _, g := range addedGroups { group, ok := groups[g] if !ok { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g) continue } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 9dd3e778d..0ebda6440 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -33,12 +33,13 @@ import ( ) const ( - storeSqliteFileName = "store.db" - idQueryCondition = "id = ?" - keyQueryCondition = "key = ?" - accountAndIDQueryCondition = "account_id = ? and id = ?" - accountIDCondition = "account_id = ?" - peerNotFoundFMT = "peer %s not found" + storeSqliteFileName = "store.db" + idQueryCondition = "id = ?" + keyQueryCondition = "key = ?" + accountAndIDQueryCondition = "account_id = ? and id = ?" + accountAndIDsQueryCondition = "account_id = ? AND id IN ?" + accountIDCondition = "account_id = ?" + peerNotFoundFMT = "peer %s not found" ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -555,9 +556,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } -func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { +func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { var users []*User - result := s.db.Find(&users, accountIDCondition, accountID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -857,7 +858,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewUserNotFoundError(userID) } - return status.NewGetUserFromStoreError() } user.LastLogin = lastLogin @@ -1045,7 +1045,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.Errorf(status.NotFound, "group not found for account") + return status.NewGroupNotFoundError(groupID) } return status.Errorf(status.Internal, "issue finding group: %s", result.Error) @@ -1079,10 +1079,45 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro return nil } -func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - result := s.db.Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) +// GetPeerByID retrieves a peer by its ID and account ID. +func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) { + var peer *nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&peer, accountAndIDQueryCondition, accountID, peerID) if result.Error != nil { - return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "peer not found") + } + log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peer from store") + } + + return peer, nil +} + +// GetPeersByIDs retrieves peers by their IDs and account ID. +func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store") + } + + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + return peersMap, nil +} + +func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) + return status.Errorf(status.Internal, "failed to increment network serial count in store") } return nil } @@ -1103,7 +1138,8 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor func (s *SqlStore) withTx(tx *gorm.DB) Store { return &SqlStore{ - db: tx, + db: tx, + storeEngine: s.storeEngine, } } @@ -1155,12 +1191,22 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength } // GetGroupByID retrieves a group by ID and account ID. -func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { - return getRecordByID[nbgroup.Group](s.db.Preload(clause.Associations), lockStrength, groupID, accountID) +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) { + var group *nbgroup.Group + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewGroupNotFoundError(groupID) + } + log.WithContext(ctx).Errorf("failed to get group from store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get group from store") + } + + return group, nil } // GetGroupByName retrieves a group by name and account ID. -func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) { var group nbgroup.Group // TODO: This fix is accepted for now, but if we need to handle this more frequently @@ -1172,12 +1218,13 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren query = query.Order("json_array_length(peers) DESC") } - result := query.First(&group, "name = ? and account_id = ?", groupName, accountID) + result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "group not found") + return nil, status.NewGroupNotFoundError(groupName) } - return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) + log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get group by name from store") } return &group, nil } @@ -1185,7 +1232,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren // GetGroupsByIDs retrieves groups by their IDs and account ID. func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { var groups []*nbgroup.Group - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") @@ -1203,11 +1250,40 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) if result.Error != nil { - return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) + log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save group to store") } return nil } +// DeleteGroup deletes a group from the database. +func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete group from store") + } + + if result.RowsAffected == 0 { + return status.NewGroupNotFoundError(groupID) + } + + return nil +} + +// DeleteGroups deletes groups from the database. +func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { + result := s.db.Clauses(clause.Locking{Strength: string(strength)}). + Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete groups from store") + } + + return nil +} + // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 3f3b2a453..114da1ee6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -14,11 +14,10 @@ import ( "time" "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" route2 "github.com/netbirdio/netbird/route" @@ -1181,7 +1180,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { t.Fatal("failed to save group") return err } - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID) if err != nil { t.Fatal("failed to get group") return err @@ -1201,7 +1200,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" account, err := store.GetAccount(context.Background(), accountID) require.NoError(t, err) - users, err := store.GetAccountUsers(context.Background(), accountID) + users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err) require.Len(t, users, len(account.Users)) } @@ -1260,9 +1259,9 @@ func TestSqlite_GetGroupByName(t *testing.T) { } accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID) + group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All") require.NoError(t, err) - require.Equal(t, "All", group.Name) + require.True(t, group.IsGroupAll()) } func Test_DeleteSetupKeySuccessfully(t *testing.T) { @@ -1293,3 +1292,275 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID) require.Error(t, err) } + +func TestSqlStore_GetGroupsByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectedCount int + }{ + { + name: "retrieve existing groups by existing IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectedCount: 2, + }, + { + name: "empty group IDs list", + groupIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing group IDs", + groupIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing group IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs) + require.NoError(t, err) + require.Len(t, groups, tt.expectedCount) + }) + } +} + +func TestSqlStore_SaveGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group := &nbgroup.Group{ + ID: "group-id", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + } + err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + require.NoError(t, err) + + savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id") + require.NoError(t, err) + require.Equal(t, savedGroup, group) +} + +func TestSqlStore_SaveGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + groups := []*nbgroup.Group{ + { + ID: "group-1", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + }, + { + ID: "group-2", + AccountID: accountID, + Issued: "integration", + Peers: []string{"peer3", "peer4"}, + }, + } + err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups) + require.NoError(t, err) +} + +func TestSqlStore_DeleteGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupID string + expectError bool + }{ + { + name: "delete existing group", + groupID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "delete non-existing group", + groupID: "non-existing-group-id", + expectError: true, + }, + { + name: "delete with empty group ID", + groupID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + } else { + require.NoError(t, err) + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID) + require.Error(t, err) + require.Nil(t, group) + } + }) + } +} + +func TestSqlStore_DeleteGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectError bool + }{ + { + name: "delete multiple existing groups", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectError: false, + }, + { + name: "delete non-existing groups", + groupIDs: []string{"non-existing-id-1", "non-existing-id-2"}, + expectError: false, + }, + { + name: "delete with empty group IDs list", + groupIDs: []string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + + for _, groupID := range tt.groupIDs { + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.Error(t, err) + require.Nil(t, group) + } + } + }) + } +} + +func TestSqlStore_GetPeerByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerID string + expectError bool + }{ + { + name: "retrieve existing peer", + peerID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "retrieve non-existing peer", + peerID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty peer ID", + peerID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, peer) + } else { + require.NoError(t, err) + require.NotNil(t, peer) + require.Equal(t, tt.peerID, peer.ID) + } + }) + } +} + +func TestSqlStore_GetPeersByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerIDs []string + expectedCount int + }{ + { + name: "retrieve existing peers by existing IDs", + peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"}, + expectedCount: 2, + }, + { + name: "empty peer IDs list", + peerIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing peer IDs", + peerIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing peer IDs", + peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} diff --git a/management/server/status/error.go b/management/server/status/error.go index f1f3f16e6..8b6d0077b 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -3,7 +3,6 @@ package status import ( "errors" "fmt" - "time" ) const ( @@ -126,11 +125,6 @@ func NewAdminPermissionError() error { return Errorf(PermissionDenied, "admin role required to perform this action") } -// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context -func NewStoreContextCanceledError(duration time.Duration) error { - return Errorf(Internal, "store access: context canceled after %v", duration) -} - // NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key func NewInvalidKeyIDError() error { return Errorf(InvalidArgument, "invalid key ID") @@ -140,3 +134,8 @@ func NewInvalidKeyIDError() error { func NewGetAccountError(err error) error { return Errorf(Internal, "error getting account: %s", err) } + +// NewGroupNotFoundError creates a new Error with NotFound type for a missing group +func NewGroupNotFoundError(groupID string) error { + return Errorf(NotFound, "group: %s not found", groupID) +} diff --git a/management/server/store.go b/management/server/store.go index c7c066629..71b0d457b 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -62,7 +62,7 @@ type Store interface { GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) + GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error @@ -76,6 +76,8 @@ type Store interface { GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error + DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error + DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) @@ -90,6 +92,8 @@ type Store interface { AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) + GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) + GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error @@ -108,7 +112,7 @@ type Store interface { GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) - IncrementNetworkSerial(ctx context.Context, accountId string) error + IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetInstallationID() string diff --git a/management/server/user.go b/management/server/user.go index 5e0d9d034..74062112a 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -494,7 +494,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil @@ -835,7 +835,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } for _, storeEvent := range eventsToStore { @@ -1132,7 +1132,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil } @@ -1240,7 +1240,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } for targetUserID, meta := range deletedUsersMeta { From a1c5287b7c7a6152cb3cff319dea278ba1cfd3c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=B0smail?= Date: Fri, 15 Nov 2024 20:21:27 +0300 Subject: [PATCH 13/39] Fix the Inactivity Expiration problem. (#2865) --- management/server/account.go | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 1bd8a99a9..95c93a22b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1186,20 +1186,25 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { - if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { - event := activity.AccountPeerInactivityExpirationEnabled - if !newSettings.PeerInactivityExpirationEnabled { - event = activity.AccountPeerInactivityExpirationDisabled - am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) - } else { + + if newSettings.PeerInactivityExpirationEnabled { + if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { + oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration + + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) am.checkAndSchedulePeerInactivityExpiration(ctx, account) } - am.StoreEvent(ctx, userID, accountID, accountID, event, nil) - } - - if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { - am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } else { + if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { + event := activity.AccountPeerInactivityExpirationEnabled + if !newSettings.PeerInactivityExpirationEnabled { + event = activity.AccountPeerInactivityExpirationDisabled + am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) + } else { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + am.StoreEvent(ctx, userID, accountID, accountID, event, nil) + } } return nil From 121dfda915cbaa17e8a16af1f96a71927fc0ece1 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 15 Nov 2024 20:05:26 +0100 Subject: [PATCH 14/39] [client] Fix state manager race conditions (#2890) --- client/internal/dns/server.go | 20 +++---- client/internal/engine.go | 2 +- .../routemanager/refcounter/refcounter.go | 58 +++++++++---------- .../internal/routemanager/systemops/state.go | 23 ++++---- .../systemops/systemops_generic.go | 20 +------ client/internal/statemanager/manager.go | 40 +++++++++---- util/file.go | 55 +++++++++++++----- 7 files changed, 118 insertions(+), 100 deletions(-) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 6c4dccae7..f0277319c 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -7,7 +7,6 @@ import ( "runtime" "strings" "sync" - "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -323,13 +322,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { log.Error(err) } - // persist dns state right away - ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) - defer cancel() - - // don't block go func() { - if err := s.stateManager.PersistState(ctx); err != nil { + // persist dns state right away + if err := s.stateManager.PersistState(s.ctx); err != nil { log.Errorf("Failed to persist dns state: %v", err) } }() @@ -537,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks( l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } - // persist dns state right away - ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) - defer cancel() - if err := s.stateManager.PersistState(ctx); err != nil { - l.Errorf("Failed to persist dns state: %v", err) - } + go func() { + if err := s.stateManager.PersistState(s.ctx); err != nil { + l.Errorf("Failed to persist dns state: %v", err) + } + }() if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { s.addHostRootZone() diff --git a/client/internal/engine.go b/client/internal/engine.go index d4a3a561a..1c912220c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -297,7 +297,7 @@ func (e *Engine) Stop() error { if err := e.stateManager.Stop(ctx); err != nil { return fmt.Errorf("failed to stop state manager: %w", err) } - if err := e.stateManager.PersistState(ctx); err != nil { + if err := e.stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index 0e230ef40..f2f0a169d 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error type Counter[Key comparable, I, O any] struct { // refCountMap keeps track of the reference Ref for keys refCountMap map[Key]Ref[O] - refCountMu sync.Mutex + mu sync.Mutex // idMap keeps track of the keys associated with an ID for removal idMap map[string][]Key - idMu sync.Mutex add AddFunc[Key, I, O] remove RemoveFunc[Key, O] } @@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key func (rm *Counter[Key, I, O]) LoadData( existingCounter *Counter[Key, I, O], ) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() rm.refCountMap = existingCounter.refCountMap rm.idMap = existingCounter.idMap @@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData( // Get retrieves the current reference count and associated data for a key. // If the key doesn't exist, it returns a zero value Ref and false. func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() ref, ok := rm.refCountMap[key] return ref, ok @@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { // Increment increments the reference count for the given key. // If this is the first reference to the key, the AddFunc is called. func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() + return rm.increment(key, in) +} + +func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) { ref := rm.refCountMap[key] logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out) @@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { // IncrementWithID increments the reference count for the given key and groups it under the given ID. // If this is the first reference to the key, the AddFunc is called. func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) { - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() - ref, err := rm.Increment(key, in) + ref, err := rm.increment(key, in) if err != nil { return ref, fmt.Errorf("with ID: %w", err) } @@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], // Decrement decrements the reference count for the given key. // If the reference count reaches 0, the RemoveFunc is called. func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() + return rm.decrement(key) +} +func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) { ref, ok := rm.refCountMap[key] if !ok { logCallerF("No reference found for key %v", key) @@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { // DecrementWithID decrements the reference count for all keys associated with the given ID. // If the reference count reaches 0, the RemoveFunc is called. func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() var merr *multierror.Error for _, key := range rm.idMap[id] { - if _, err := rm.Decrement(key); err != nil { + if _, err := rm.decrement(key); err != nil { merr = multierror.Append(merr, err) } } @@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { // Flush removes all references and calls RemoveFunc for each key. func (rm *Counter[Key, I, O]) Flush() error { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() var merr *multierror.Error for key := range rm.refCountMap { @@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error { // Clear removes all references without calling RemoveFunc. func (rm *Counter[Key, I, O]) Clear() { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() clear(rm.refCountMap) clear(rm.idMap) @@ -217,10 +217,8 @@ func (rm *Counter[Key, I, O]) Clear() { // MarshalJSON implements the json.Marshaler interface for Counter. func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() return json.Marshal(struct { RefCountMap map[Key]Ref[O] `json:"refCountMap"` diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 425908922..8e158711e 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -2,31 +2,28 @@ package systemops import ( "net/netip" - "sync" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) -type ShutdownState struct { - Counter *ExclusionCounter `json:"counter,omitempty"` - mu sync.RWMutex -} +type ShutdownState ExclusionCounter func (s *ShutdownState) Name() string { return "route_state" } func (s *ShutdownState) Cleanup() error { - s.mu.RLock() - defer s.mu.RUnlock() - - if s.Counter == nil { - return nil - } - sysops := NewSysOps(nil, nil) sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData(s.Counter) + sysops.refCounter.LoadData((*ExclusionCounter)(s)) return sysops.refCounter.Flush() } + +func (s *ShutdownState) MarshalJSON() ([]byte, error) { + return (*ExclusionCounter)(s).MarshalJSON() +} + +func (s *ShutdownState) UnmarshalJSON(data []byte) error { + return (*ExclusionCounter)(s).UnmarshalJSON(data) +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 4ff34aa51..f8b3ebbb8 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -62,7 +62,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return nexthop, err }, func(prefix netip.Prefix, nexthop Nexthop) error { - // remove from state even if we have trouble removing it from the route table + // update state even if we have trouble removing it from the route table // it could be already gone r.updateState(stateManager) @@ -75,12 +75,9 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return r.setupHooks(initAddresses) } +// updateState updates state on every change so it will be persisted regularly func (r *SysOps) updateState(stateManager *statemanager.Manager) { - state := getState(stateManager) - - state.Counter = r.refCounter - - if err := stateManager.UpdateState(state); err != nil { + if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil { log.Errorf("failed to update state: %v", err) } } @@ -532,14 +529,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P // Return true if the longest matching prefix is from vpnRoutes return isVpn, longestPrefix } - -func getState(stateManager *statemanager.Manager) *ShutdownState { - var shutdownState *ShutdownState - if state := stateManager.GetState(shutdownState); state != nil { - shutdownState = state.(*ShutdownState) - } else { - shutdownState = &ShutdownState{} - } - - return shutdownState -} diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 580ccdfc7..da6dd022f 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -74,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error { m.mu.Lock() defer m.mu.Unlock() - if m.cancel != nil { - m.cancel() + if m.cancel == nil { + return nil + } + m.cancel() - select { - case <-ctx.Done(): - return ctx.Err() - case <-m.done: - return nil - } + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: } return nil @@ -179,14 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } + bs, err := marshalWithPanicRecovery(m.states) + if err != nil { + return fmt.Errorf("marshal states: %w", err) + } + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() done := make(chan error, 1) - start := time.Now() go func() { - done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states) + done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs) }() select { @@ -286,3 +290,19 @@ func (m *Manager) PerformCleanup() error { return nberrors.FormatErrorOrNil(merr) } + +func marshalWithPanicRecovery(v any) ([]byte, error) { + var bs []byte + var err error + + func() { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic during marshal: %v", r) + } + }() + bs, err = json.Marshal(v) + }() + + return bs, err +} diff --git a/util/file.go b/util/file.go index 4641cc1b8..f7de7ede2 100644 --- a/util/file.go +++ b/util/file.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -14,6 +15,19 @@ import ( log "github.com/sirupsen/logrus" ) +func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error { + configDir, configFileName, err := prepareConfigFileDir(file) + if err != nil { + return fmt.Errorf("prepare config file dir: %w", err) + } + + if err = EnforcePermission(file); err != nil { + return fmt.Errorf("enforce permission: %w", err) + } + + return writeBytes(ctx, file, err, configDir, configFileName, bs) +} + // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) @@ -82,29 +96,44 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error { // Check context before expensive operations if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("write json start: %w", ctx.Err()) } // make it pretty bs, err := json.MarshalIndent(obj, "", " ") if err != nil { - return err + return fmt.Errorf("marshal: %w", err) } + return writeBytes(ctx, file, err, configDir, configFileName, bs) +} + +func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error { if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("write bytes start: %w", ctx.Err()) } tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) if err != nil { - return err + return fmt.Errorf("create temp: %w", err) } tempFileName := tempFile.Name() - // closing file ops as windows doesn't allow to move it - err = tempFile.Close() + + if deadline, ok := ctx.Deadline(); ok { + if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) { + log.Warnf("failed to set deadline: %v", err) + } + } + + _, err = tempFile.Write(bs) if err != nil { - return err + _ = tempFile.Close() + return fmt.Errorf("write: %w", err) + } + + if err = tempFile.Close(); err != nil { + return fmt.Errorf("close %s: %w", tempFileName, err) } defer func() { @@ -114,19 +143,13 @@ func writeJson(ctx context.Context, file string, obj interface{}, configDir stri } }() - err = os.WriteFile(tempFileName, bs, 0600) - if err != nil { - return err - } - // Check context again if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("after temp file: %w", ctx.Err()) } - err = os.Rename(tempFileName, file) - if err != nil { - return err + if err = os.Rename(tempFileName, file); err != nil { + return fmt.Errorf("move %s to %s: %w", tempFileName, file, err) } return nil From 582bb587140789884c8905e7ae8c8c3e732077dc Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 15 Nov 2024 22:55:33 +0100 Subject: [PATCH 15/39] Move state updates outside the refcounter (#2897) --- .../systemops/systemops_generic.go | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index f8b3ebbb8..3038c3ec5 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -57,22 +57,14 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return nexthop, refcounter.ErrIgnore } - r.updateState(stateManager) - return nexthop, err }, - func(prefix netip.Prefix, nexthop Nexthop) error { - // update state even if we have trouble removing it from the route table - // it could be already gone - r.updateState(stateManager) - - return r.removeFromRouteTable(prefix, nexthop) - }, + r.removeFromRouteTable, ) r.refCounter = refCounter - return r.setupHooks(initAddresses) + return r.setupHooks(initAddresses, stateManager) } // updateState updates state on every change so it will be persisted regularly @@ -333,7 +325,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) return r.removeFromRouteTable(prefix, nextHop) } -func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { prefix, err := util.GetPrefixFromIP(ip) if err != nil { @@ -344,6 +336,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re return fmt.Errorf("adding route reference: %v", err) } + r.updateState(stateManager) + return nil } afterHook := func(connID nbnet.ConnectionID) error { @@ -351,6 +345,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re return fmt.Errorf("remove route reference: %w", err) } + r.updateState(stateManager) + return nil } From a7d5c522033beef9174898b5e43a907b282f7341 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 15 Nov 2024 22:59:49 +0100 Subject: [PATCH 16/39] Fix error state race on mgmt connection error (#2892) --- client/internal/connect.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index dff44f1d2..f76aa066b 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -157,7 +157,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold engineCtx, cancel := context.WithCancel(c.ctx) defer func() { - c.statusRecorder.MarkManagementDisconnected(state.err) + _, err := state.Status() + c.statusRecorder.MarkManagementDisconnected(err) c.statusRecorder.CleanLocalPeerState() cancel() }() From ec543f89fb819b4aae28850b370f0c06f05f7f96 Mon Sep 17 00:00:00 2001 From: Kursat Aktas Date: Sat, 16 Nov 2024 17:45:31 +0300 Subject: [PATCH 17/39] Introducing NetBird Guru on Gurubase.io (#2778) --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 270c9ad87..a2d7f3897 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,10 @@
+ +
+ +

From 65a94f695f09063d505f49a6b2496f7a2e4d48b4 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 18 Nov 2024 12:55:02 +0100 Subject: [PATCH 18/39] use google domain for tests (#2902) --- client/internal/dns/server_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 21f1f1b7d..eab9f4ecb 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) { Port: 53, }, }, - Domains: []string{"customdomain.com"}, + Domains: []string{"google.com"}, Primary: false, }, }, @@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) { if ips[0] != zoneRecords[0].RData { t.Fatalf("invalid zone record: %v", err) } - _, err = resolver.LookupHost(context.Background(), "customdomain.com") + _, err = resolver.LookupHost(context.Background(), "google.com") if err != nil { t.Errorf("failed to resolve: %s", err) } From ec6438e643c3ef0b0388c05c2fad5beed99fb4fe Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 18 Nov 2024 17:12:13 +0300 Subject: [PATCH 19/39] Use update strength and simplify check Signed-off-by: bcmmbaga --- management/server/policy.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/management/server/policy.go b/management/server/policy.go index 6dcb96316..693ae2872 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -435,7 +435,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po var updateAccountPeers bool err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - policy, err = transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) + policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID) if err != nil { return err } @@ -502,8 +502,6 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account if hasPeers { return true, nil } - - return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) } return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) From 78fab877c07ee50aae95ed037218ec41f0bf2489 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 18 Nov 2024 15:31:53 +0100 Subject: [PATCH 20/39] [misc] Update signing pipeline version (#2900) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 14e383a27..183cdb02c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.16" + SIGN_PIPE_VER: "v0.0.17" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" From df98c67ac8100ac75995fecfaa32e76691db36c0 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 18 Nov 2024 18:46:52 +0300 Subject: [PATCH 21/39] prevent changing ruleID when not empty Signed-off-by: bcmmbaga --- management/server/http/policies_handler.go | 7 ++++++- management/server/sql_store_test.go | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 8255e4896..ca256a183 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -128,8 +128,13 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID Description: req.Description, } for _, rule := range req.Rules { + ruleID := policyID // TODO: when policy can contain multiple rules, need refactor + if rule.Id != nil { + ruleID = *rule.Id + } + pr := server.PolicyRule{ - ID: policyID, // TODO: when policy can contain multiple rules, need refactor + ID: ruleID, PolicyID: policyID, Name: rule.Name, Destinations: rule.Destinations, diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 8931008d7..c05793fc6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1832,6 +1832,8 @@ func TestSqlStore_SavePolicy(t *testing.T) { policy.Enabled = false policy.Description = "policy" + policy.Rules[0].Sources = []string{"group"} + policy.Rules[0].Ports = []string{"80", "443"} err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) require.NoError(t, err) From b60e2c32614615d5f7c44d95794338ade30d9287 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 18 Nov 2024 22:48:38 +0300 Subject: [PATCH 22/39] prevent duplicate rules during updates Signed-off-by: bcmmbaga --- management/server/http/policies_handler.go | 2 +- management/server/policy.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index ca256a183..eff9092d4 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -128,7 +128,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID Description: req.Description, } for _, rule := range req.Rules { - ruleID := policyID // TODO: when policy can contain multiple rules, need refactor + var ruleID string if rule.Id != nil { ruleID = *rule.Id } diff --git a/management/server/policy.go b/management/server/policy.go index 693ae2872..2d3abc3f1 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -532,7 +532,7 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po for i, rule := range policy.Rules { ruleCopy := rule.Copy() if ruleCopy.ID == "" { - ruleCopy.ID = xid.New().String() + ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor ruleCopy.PolicyID = policy.ID } From 52ea2e84e9aa3c03fe43c5098ab50c94ff2e0818 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 19 Nov 2024 00:04:50 +0100 Subject: [PATCH 23/39] [management] Add transaction metrics and exclude getAccount time from peers update (#2904) --- management/server/peer.go | 11 ++++++----- management/server/sql_store.go | 11 ++++++++++- management/server/telemetry/store_metrics.go | 12 ++++++++++++ 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 1405dead8..8c45e45c9 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -988,6 +988,12 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) + return + } + start := time.Now() defer func() { if am.metrics != nil { @@ -995,11 +1001,6 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account } }() - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) - return - } peers := account.GetPeers() approvedPeersMap, err := am.GetValidatedPeers(account) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 0ebda6440..278f5443d 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1123,6 +1123,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength Lock } func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { + startTime := time.Now() tx := s.db.Begin() if tx.Error != nil { return tx.Error @@ -1133,7 +1134,15 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor tx.Rollback() return err } - return tx.Commit().Error + + err = tx.Commit().Error + + log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime)) + if s.metrics != nil { + s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime)) + } + + return err } func (s *SqlStore) withTx(tx *gorm.DB) Store { diff --git a/management/server/telemetry/store_metrics.go b/management/server/telemetry/store_metrics.go index b038c3d36..bb3745b5a 100644 --- a/management/server/telemetry/store_metrics.go +++ b/management/server/telemetry/store_metrics.go @@ -13,6 +13,7 @@ type StoreMetrics struct { globalLockAcquisitionDurationMs metric.Int64Histogram persistenceDurationMicro metric.Int64Histogram persistenceDurationMs metric.Int64Histogram + transactionDurationMs metric.Int64Histogram ctx context.Context } @@ -40,11 +41,17 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er return nil, err } + transactionDurationMs, err := meter.Int64Histogram("management.store.transaction.duration.ms") + if err != nil { + return nil, err + } + return &StoreMetrics{ globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro, globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs, persistenceDurationMicro: persistenceDurationMicro, persistenceDurationMs: persistenceDurationMs, + transactionDurationMs: transactionDurationMs, ctx: ctx, }, nil } @@ -60,3 +67,8 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) { metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds()) metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds()) } + +// CountTransactionDuration counts the duration of a store persistence operation +func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) { + metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds()) +} From eb5d0569ae0ce829a312e13ab3e9757b9cdf019f Mon Sep 17 00:00:00 2001 From: "Krzysztof Nazarewski (kdn)" Date: Tue, 19 Nov 2024 14:14:58 +0100 Subject: [PATCH 24/39] [client] Add NB_SKIP_SOCKET_MARK & fix crash instead of returing an error (#2899) * dialer: fix crash instead of returning error * add NB_SKIP_SOCKET_MARK --- .../routemanager/systemops/systemops_linux.go | 2 +- util/grpc/dialer.go | 9 +++++++-- util/net/dialer_nonios.go | 2 +- util/net/net_linux.go | 12 ++++++++++++ 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 0124fd95e..71a0f26ae 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -55,7 +55,7 @@ type ruleParams struct { // isLegacy determines whether to use the legacy routing setup func isLegacy() bool { - return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() + return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true" } // setIsLegacy sets the legacy routing setup diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go index 57ab8fd55..4fbffe342 100644 --- a/util/grpc/dialer.go +++ b/util/grpc/dialer.go @@ -3,6 +3,9 @@ package grpc import ( "context" "crypto/tls" + "fmt" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "net" "os/user" "runtime" @@ -23,20 +26,22 @@ func WithCustomDialer() grpc.DialOption { if runtime.GOOS == "linux" { currentUser, err := user.Current() if err != nil { - log.Fatalf("failed to get current user: %v", err) + return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err) } // the custom dialer requires root permissions which are not required for use cases run as non-root if currentUser.Uid != "0" { + log.Debug("Not running as root, using standard dialer") dialer := &net.Dialer{} return dialer.DialContext(ctx, "tcp", addr) } } + log.Debug("Using nbnet.NewDialer()") conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { log.Errorf("Failed to dial: %s", err) - return nil, err + return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) } return conn, nil }) diff --git a/util/net/dialer_nonios.go b/util/net/dialer_nonios.go index 4032a75c0..34004a368 100644 --- a/util/net/dialer_nonios.go +++ b/util/net/dialer_nonios.go @@ -69,7 +69,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { - return nil, fmt.Errorf("dial: %w", err) + return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) } // Wrap the connection in Conn to handle Close with hooks diff --git a/util/net/net_linux.go b/util/net/net_linux.go index 954545eb5..98f49af8d 100644 --- a/util/net/net_linux.go +++ b/util/net/net_linux.go @@ -4,9 +4,14 @@ package net import ( "fmt" + "os" "syscall" + + log "github.com/sirupsen/logrus" ) +const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK" + // SetSocketMark sets the SO_MARK option on the given socket connection func SetSocketMark(conn syscall.Conn) error { sysconn, err := conn.SyscallConn() @@ -36,6 +41,13 @@ func SetRawSocketMark(conn syscall.RawConn) error { func SetSocketOpt(fd int) error { if CustomRoutingDisabled() { + log.Infof("Custom routing is disabled, skipping SO_MARK") + return nil + } + + // Check for the new environment variable + if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" { + log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK") return nil } From 5dd6a08ea6926cb1cb87a3301524b9031f4db61a Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 19 Nov 2024 17:25:49 +0100 Subject: [PATCH 25/39] link peer meta update back to account object (#2911) --- management/server/peer.go | 1 + 1 file changed, 1 insertion(+) diff --git a/management/server/peer.go b/management/server/peer.go index 8c45e45c9..901e4815d 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -667,6 +667,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac updated := peer.UpdateMetaIfNew(sync.Meta) if updated { + account.Peers[peer.ID] = peer log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) err = am.Store.SavePeer(ctx, account.Id, peer) if err != nil { From f66bbcc54c65f0856679fed1b50298e97c0ec7af Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 19 Nov 2024 18:13:26 +0100 Subject: [PATCH 26/39] [management] Add metric for peer meta update (#2913) --- management/server/peer.go | 2 ++ .../server/telemetry/accountmanager_metrics.go | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/management/server/peer.go b/management/server/peer.go index 901e4815d..beb833dba 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -667,6 +667,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac updated := peer.UpdateMetaIfNew(sync.Meta) if updated { + am.metrics.AccountManagerMetrics().CountPeerMetUpdate() account.Peers[peer.ID] = peer log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) err = am.Store.SavePeer(ctx, account.Id, peer) @@ -801,6 +802,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) updated := peer.UpdateMetaIfNew(login.Meta) if updated { + am.metrics.AccountManagerMetrics().CountPeerMetUpdate() shouldStorePeer = true } diff --git a/management/server/telemetry/accountmanager_metrics.go b/management/server/telemetry/accountmanager_metrics.go index e4bb4e3c3..4a5a31e2d 100644 --- a/management/server/telemetry/accountmanager_metrics.go +++ b/management/server/telemetry/accountmanager_metrics.go @@ -13,6 +13,7 @@ type AccountManagerMetrics struct { updateAccountPeersDurationMs metric.Float64Histogram getPeerNetworkMapDurationMs metric.Float64Histogram networkMapObjectCount metric.Int64Histogram + peerMetaUpdateCount metric.Int64Counter } // NewAccountManagerMetrics creates an instance of AccountManagerMetrics @@ -44,11 +45,17 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account return nil, err } + peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1")) + if err != nil { + return nil, err + } + return &AccountManagerMetrics{ ctx: ctx, getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs, updateAccountPeersDurationMs: updateAccountPeersDurationMs, networkMapObjectCount: networkMapObjectCount, + peerMetaUpdateCount: peerMetaUpdateCount, }, nil } @@ -67,3 +74,8 @@ func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration ti func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) { metrics.networkMapObjectCount.Record(metrics.ctx, count) } + +// CountPeerMetUpdate counts the number of peer meta updates +func (metrics *AccountManagerMetrics) CountPeerMetUpdate() { + metrics.peerMetaUpdateCount.Add(metrics.ctx, 1) +} From aa575d6f445e74f34f8353a4c413adc209c56f4b Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:10:34 +0100 Subject: [PATCH 27/39] [management] Add activity events to group propagation flow (#2916) --- management/server/account.go | 38 ++++++++++- management/server/activity/codes.go | 6 ++ management/server/user.go | 97 ++++++++++++++++++++++------- 3 files changed, 116 insertions(+), 25 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 95c93a22b..0ab123655 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -965,7 +965,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro } // UserGroupsAddToPeers adds groups to all peers of user -func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { +func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + userPeers := make(map[string]struct{}) for pid, peer := range a.Peers { if peer.UserID == userID { @@ -979,6 +981,8 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { continue } + oldPeers := group.Peers + groupPeers := make(map[string]struct{}) for _, pid := range group.Peers { groupPeers[pid] = struct{}{} @@ -992,16 +996,25 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { for pid := range groupPeers { group.Peers = append(group.Peers, pid) } + + groupUpdates[gid] = difference(group.Peers, oldPeers) } + + return groupUpdates } // UserGroupsRemoveFromPeers removes groups from all peers of user -func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { +func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + for _, gid := range groups { group, ok := a.Groups[gid] if !ok || group.Name == "All" { continue } + + oldPeers := group.Peers + update := make([]string, 0, len(group.Peers)) for _, pid := range group.Peers { peer, ok := a.Peers[pid] @@ -1013,7 +1026,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { } } group.Peers = update + groupUpdates[gid] = difference(oldPeers, group.Peers) } + + return groupUpdates } // BuildManager creates a new DefaultAccountManager with a provided Store @@ -1175,6 +1191,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, err } + err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID) + if err != nil { + return nil, fmt.Errorf("groups propagation failed: %w", err) + } + updatedAccount := account.UpdateSettings(newSettings) err = am.Store.SaveAccount(ctx, account) @@ -1185,6 +1206,19 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return updatedAccount, nil } +func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error { + if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled { + if newSettings.GroupsPropagationEnabled { + am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil) + // Todo: retroactively add user groups to all peers + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil) + } + } + + return nil +} + func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 603260dbc..4c57d65fb 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -148,6 +148,9 @@ const ( AccountPeerInactivityExpirationDurationUpdated Activity = 67 SetupKeyDeleted Activity = 68 + + UserGroupPropagationEnabled Activity = 69 + UserGroupPropagationDisabled Activity = 70 ) var activityMap = map[Activity]Code{ @@ -222,6 +225,9 @@ var activityMap = map[Activity]Code{ AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"}, + + UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"}, + UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"}, } // StringCode returns a string code of the activity diff --git a/management/server/user.go b/management/server/user.go index 74062112a..edb5e6fd3 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -805,15 +805,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, expiredPeers = append(expiredPeers, blockedPeers...) } + peerGroupsAdded := make(map[string][]string) + peerGroupsRemoved := make(map[string][]string) if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) // need force update all auto groups in any case they will not be duplicated - account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) - account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) + peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) + peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) } - events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole) - eventsToStore = append(eventsToStore, events...) + userUpdateEvents := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole) + eventsToStore = append(eventsToStore, userUpdateEvents...) + + userGroupsEvents := am.prepareUserGroupsEvents(ctx, initiatorUser.Id, oldUser, newUser, account, peerGroupsAdded, peerGroupsRemoved) + eventsToStore = append(eventsToStore, userGroupsEvents...) updatedUserInfo, err := getUserInfo(ctx, am, newUser, account) if err != nil { @@ -872,32 +877,78 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in }) } + return eventsToStore +} + +func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() { + var eventsToStore []func() if newUser.AutoGroups != nil { removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups) addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) - for _, g := range removedGroups { - group := account.GetGroup(g) - if group != nil { - eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) - }) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) - } - } - for _, g := range addedGroups { - group := account.GetGroup(g) - if group != nil { - eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser, - map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) - }) - } + removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved) + eventsToStore = append(eventsToStore, removedEvents...) + + addedEvents := am.handleGroupAddedToUser(ctx, initiatorUserID, oldUser, newUser, account, addedGroups, peerGroupsAdded) + eventsToStore = append(eventsToStore, addedEvents...) + } + return eventsToStore +} + +func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() { + var eventsToStore []func() + for _, g := range addedGroups { + group := account.GetGroup(g) + if group != nil { + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser, + map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) + }) } } + for groupID, peerIDs := range peerGroupsAdded { + group := account.GetGroup(groupID) + for _, peerID := range peerIDs { + peer := account.GetPeer(peerID) + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } + am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupAddedToPeer, meta) + }) + } + } + return eventsToStore +} +func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() { + var eventsToStore []func() + for _, g := range removedGroups { + group := account.GetGroup(g) + if group != nil { + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser, + map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) + }) + + } else { + log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) + } + } + for groupID, peerIDs := range peerGroupsRemoved { + group := account.GetGroup(groupID) + for _, peerID := range peerIDs { + peer := account.GetPeer(peerID) + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } + am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupRemovedFromPeer, meta) + }) + } + } return eventsToStore } From 1bbabf70b057c2384a077742b9e6760161e153aa Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 21 Nov 2024 16:53:37 +0100 Subject: [PATCH 28/39] [client] Fix allow netbird rule verdict (#2925) * Fix allow netbird rule verdict * Fix chain name --- client/firewall/nftables/manager_linux.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 3f8fac249..8e1aa0d80 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -199,7 +199,7 @@ func (m *Manager) AllowNetbird() error { var chain *nftables.Chain for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameForward { + if c.Table.Name == tableNameFilter && c.Name == chainNameInput { chain = c break } @@ -276,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error { func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { for _, c := range chains { - if c.Table.Name == "filter" && c.Name == "INPUT" { + if c.Table.Name == tableNameFilter && c.Name == chainNameInput { rules, err := m.rConn.GetRules(c.Table, c) if err != nil { log.Errorf("get rules for chain %q: %v", c.Name, err) @@ -351,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Verdict{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, }, UserData: []byte(allowNetbirdInputRuleID), } From 9db1932664557da94ff64bfc03f5acdcf30667ea Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 22 Nov 2024 10:15:51 +0100 Subject: [PATCH 29/39] [management] Fix getSetupKey call (#2927) --- management/server/http/api/openapi.yml | 30 +++++++-- management/server/http/api/types.gen.go | 89 ++++++++++++++++++++++++- management/server/setupkey.go | 2 +- management/server/setupkey_test.go | 43 ++++++++---- 4 files changed, 144 insertions(+), 20 deletions(-) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index bfb375277..2e084f6e4 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -439,17 +439,13 @@ components: example: 5 required: - accessible_peers_count - SetupKey: + SetupKeyBase: type: object properties: id: description: Setup Key ID type: string example: 2531583362 - key: - description: Setup Key value - type: string - example: A616097E-FCF0-48FA-9354-CA4A61142761 name: description: Setup key name identifier type: string @@ -518,6 +514,28 @@ components: - updated_at - usage_limit - ephemeral + SetupKeyClear: + allOf: + - $ref: '#/components/schemas/SetupKeyBase' + - type: object + properties: + key: + description: Setup Key as plain text + type: string + example: A616097E-FCF0-48FA-9354-CA4A61142761 + required: + - key + SetupKey: + allOf: + - $ref: '#/components/schemas/SetupKeyBase' + - type: object + properties: + key: + description: Setup Key as secret + type: string + example: A6160**** + required: + - key SetupKeyRequest: type: object properties: @@ -1918,7 +1936,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/SetupKey' + $ref: '#/components/schemas/SetupKeyClear' '400': "$ref": "#/components/responses/bad_request" '401': diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index f219c4574..321395d25 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -1062,7 +1062,94 @@ type SetupKey struct { // Id Setup Key ID Id string `json:"id"` - // Key Setup Key value + // Key Setup Key as secret + Key string `json:"key"` + + // LastUsed Setup key last usage date + LastUsed time.Time `json:"last_used"` + + // Name Setup key name identifier + Name string `json:"name"` + + // Revoked Setup key revocation status + Revoked bool `json:"revoked"` + + // State Setup key status, "valid", "overused","expired" or "revoked" + State string `json:"state"` + + // Type Setup key type, one-off for single time usage and reusable + Type string `json:"type"` + + // UpdatedAt Setup key last update date + UpdatedAt time.Time `json:"updated_at"` + + // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage. + UsageLimit int `json:"usage_limit"` + + // UsedTimes Usage count of setup key + UsedTimes int `json:"used_times"` + + // Valid Setup key validity status + Valid bool `json:"valid"` +} + +// SetupKeyBase defines model for SetupKeyBase. +type SetupKeyBase struct { + // AutoGroups List of group IDs to auto-assign to peers registered with this key + AutoGroups []string `json:"auto_groups"` + + // Ephemeral Indicate that the peer will be ephemeral or not + Ephemeral bool `json:"ephemeral"` + + // Expires Setup Key expiration date + Expires time.Time `json:"expires"` + + // Id Setup Key ID + Id string `json:"id"` + + // LastUsed Setup key last usage date + LastUsed time.Time `json:"last_used"` + + // Name Setup key name identifier + Name string `json:"name"` + + // Revoked Setup key revocation status + Revoked bool `json:"revoked"` + + // State Setup key status, "valid", "overused","expired" or "revoked" + State string `json:"state"` + + // Type Setup key type, one-off for single time usage and reusable + Type string `json:"type"` + + // UpdatedAt Setup key last update date + UpdatedAt time.Time `json:"updated_at"` + + // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage. + UsageLimit int `json:"usage_limit"` + + // UsedTimes Usage count of setup key + UsedTimes int `json:"used_times"` + + // Valid Setup key validity status + Valid bool `json:"valid"` +} + +// SetupKeyClear defines model for SetupKeyClear. +type SetupKeyClear struct { + // AutoGroups List of group IDs to auto-assign to peers registered with this key + AutoGroups []string `json:"auto_groups"` + + // Ephemeral Indicate that the peer will be ephemeral or not + Ephemeral bool `json:"ephemeral"` + + // Expires Setup Key expiration date + Expires time.Time `json:"expires"` + + // Id Setup Key ID + Id string `json:"id"` + + // Key Setup Key as plain text Key string `json:"key"` // LastUsed Setup key last usage date diff --git a/management/server/setupkey.go b/management/server/setupkey.go index cae0dfecb..ef431d3ad 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -379,7 +379,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, status.NewAdminPermissionError() } - setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) if err != nil { return nil, err } diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 94ed022fa..7c8200706 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -210,22 +210,41 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ - ID: "group_1", - Name: "group_name_1", - Peers: []string{}, - }) + plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ - ID: "group_2", - Name: "group_name_2", - Peers: []string{}, - }) - if err != nil { - t.Fatal(err) + type testCase struct { + name string + keyId string + expectedFailure bool + } + + testCase1 := testCase{ + name: "Should get existing Setup Key", + keyId: plainKey.Id, + expectedFailure: false, + } + testCase2 := testCase{ + name: "Should fail to get non-existent Setup Key", + keyId: "some key", + expectedFailure: true, + } + + for _, tCase := range []testCase{testCase1, testCase2} { + t.Run(tCase.name, func(t *testing.T) { + key, err := manager.GetSetupKey(context.Background(), account.Id, userID, tCase.keyId) + + if tCase.expectedFailure { + if err == nil { + t.Fatal("expected to fail") + } + return + } + + assert.NotEqual(t, plainKey.Key, key.Key) + }) } } From 2a5cb1649402d42f588d374f6a775c62e92a5522 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 22 Nov 2024 18:12:34 +0100 Subject: [PATCH 30/39] [relay] Refactor initial Relay connection (#2800) Can support firewalls with restricted WS rules allow to run engine without Relay servers keep up to date Relay address changes --- client/internal/connect.go | 3 +- client/internal/engine.go | 10 ++- client/internal/peer/status.go | 20 ++--- client/internal/peer/worker_ice.go | 14 +-- relay/client/client.go | 6 +- relay/client/client_test.go | 2 +- relay/client/guard.go | 91 +++++++++++++++---- relay/client/manager.go | 140 +++++++++++++++++++---------- relay/client/picker.go | 16 ++-- relay/client/picker_test.go | 5 +- 10 files changed, 211 insertions(+), 96 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index f76aa066b..8c2ad4aa1 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -232,6 +232,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold relayURLs, token := parseRelayInfo(loginResp) relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String()) + c.statusRecorder.SetRelayMgr(relayManager) if len(relayURLs) > 0 { if token != nil { if err := relayManager.UpdateToken(token); err != nil { @@ -242,9 +243,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", ")) if err = relayManager.Serve(); err != nil { log.Error(err) - return wrapErr(err) } - c.statusRecorder.SetRelayMgr(relayManager) } peerConfig := loginResp.GetPeerConfig() diff --git a/client/internal/engine.go b/client/internal/engine.go index 1c912220c..dc4499e17 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -538,6 +538,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { relayMsg := wCfg.GetRelay() if relayMsg != nil { + // when we receive token we expect valid address list too c := &auth.Token{ Payload: relayMsg.GetTokenPayload(), Signature: relayMsg.GetTokenSignature(), @@ -546,9 +547,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { log.Errorf("failed to update relay token: %v", err) return fmt.Errorf("update relay token: %w", err) } + + e.relayManager.UpdateServerURLs(relayMsg.Urls) + + // Just in case the agent started with an MGM server where the relay was disabled but was later enabled. + // We can ignore all errors because the guard will manage the reconnection retries. + _ = e.relayManager.Serve() + } else { + e.relayManager.UpdateServerURLs(nil) } - // todo update relay address in the relay manager // todo update signal } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 0444dc60b..74e2ee82c 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -676,25 +676,23 @@ func (d *Status) GetRelayStates() []relay.ProbeResult { // extend the list of stun, turn servers with relay address relayStates := slices.Clone(d.relayStates) - var relayState relay.ProbeResult - // if the server connection is not established then we will use the general address // in case of connection we will use the instance specific address instanceAddr, err := d.relayMgr.RelayInstanceAddress() if err != nil { // TODO add their status - if errors.Is(err, relayClient.ErrRelayClientNotConnected) { - for _, r := range d.relayMgr.ServerURLs() { - relayStates = append(relayStates, relay.ProbeResult{ - URI: r, - }) - } - return relayStates + for _, r := range d.relayMgr.ServerURLs() { + relayStates = append(relayStates, relay.ProbeResult{ + URI: r, + Err: err, + }) } - relayState.Err = err + return relayStates } - relayState.URI = instanceAddr + relayState := relay.ProbeResult{ + URI: instanceAddr, + } return append(relayStates, relayState) } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 4c67cb781..7ce4797c3 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -46,8 +46,6 @@ type WorkerICE struct { hasRelayOnLocally bool conn WorkerICECallbacks - selectedPriority ConnPriority - agent *ice.Agent muxAgent sync.Mutex @@ -95,10 +93,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { var preferredCandidateTypes []ice.CandidateType if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { - w.selectedPriority = connPriorityICEP2P preferredCandidateTypes = icemaker.CandidateTypesP2P() } else { - w.selectedPriority = connPriorityICETurn preferredCandidateTypes = icemaker.CandidateTypes() } @@ -159,7 +155,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { RelayedOnLocal: isRelayCandidate(pair.Local), } w.log.Debugf("on ICE conn read to use ready") - go w.conn.OnConnReady(w.selectedPriority, ci) + go w.conn.OnConnReady(selectedPriority(pair), ci) } // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. @@ -394,3 +390,11 @@ func isRelayed(pair *ice.CandidatePair) bool { } return false } + +func selectedPriority(pair *ice.CandidatePair) ConnPriority { + if isRelayed(pair) { + return connPriorityICETurn + } else { + return connPriorityICEP2P + } +} diff --git a/relay/client/client.go b/relay/client/client.go index 154c1787f..db5252f50 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -140,7 +140,7 @@ type Client struct { instanceURL *RelayAddr muInstanceURL sync.Mutex - onDisconnectListener func() + onDisconnectListener func(string) onConnectedListener func() listenerMutex sync.Mutex } @@ -233,7 +233,7 @@ func (c *Client) ServerInstanceURL() (string, error) { } // SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed. -func (c *Client) SetOnDisconnectListener(fn func()) { +func (c *Client) SetOnDisconnectListener(fn func(string)) { c.listenerMutex.Lock() defer c.listenerMutex.Unlock() c.onDisconnectListener = fn @@ -554,7 +554,7 @@ func (c *Client) notifyDisconnected() { if c.onDisconnectListener == nil { return } - go c.onDisconnectListener() + go c.onDisconnectListener(c.connectionURL) } func (c *Client) notifyConnected() { diff --git a/relay/client/client_test.go b/relay/client/client_test.go index ef28203e9..7ddfba4c6 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -551,7 +551,7 @@ func TestCloseByServer(t *testing.T) { } disconnected := make(chan struct{}) - relayClient.SetOnDisconnectListener(func() { + relayClient.SetOnDisconnectListener(func(_ string) { log.Infof("client disconnected") close(disconnected) }) diff --git a/relay/client/guard.go b/relay/client/guard.go index d6b6b0da5..b971363a8 100644 --- a/relay/client/guard.go +++ b/relay/client/guard.go @@ -4,65 +4,120 @@ import ( "context" "time" + "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" ) var ( - reconnectingTimeout = 5 * time.Second + reconnectingTimeout = 60 * time.Second ) // Guard manage the reconnection tries to the Relay server in case of disconnection event. type Guard struct { - ctx context.Context - relayClient *Client + // OnNewRelayClient is a channel that is used to notify the relay client about a new relay client instance. + OnNewRelayClient chan *Client + serverPicker *ServerPicker } // NewGuard creates a new guard for the relay client. -func NewGuard(context context.Context, relayClient *Client) *Guard { +func NewGuard(sp *ServerPicker) *Guard { g := &Guard{ - ctx: context, - relayClient: relayClient, + OnNewRelayClient: make(chan *Client, 1), + serverPicker: sp, } return g } -// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection +// StartReconnectTrys is called when the relay client is disconnected from the relay server. +// It attempts to reconnect to the relay server. The function first tries a quick reconnect +// to the same server that was used before, if the server URL is still valid. If the quick +// reconnect fails, it starts a ticker to periodically attempt server picking until it +// succeeds or the context is done. +// +// Parameters: +// - ctx: The context to control the lifecycle of the reconnection attempts. +// - relayClient: The relay client instance that was disconnected. // todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent -func (g *Guard) OnDisconnected() { - if g.quickReconnect() { +func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) { + if relayClient == nil { + goto RETRY + } + if g.isServerURLStillValid(relayClient) && g.quickReconnect(ctx, relayClient) { return } - ticker := time.NewTicker(reconnectingTimeout) +RETRY: + ticker := exponentTicker(ctx) defer ticker.Stop() for { select { case <-ticker.C: - err := g.relayClient.Connect() - if err != nil { - log.Errorf("failed to reconnect to relay server: %s", err) + if err := g.retry(ctx); err != nil { + log.Errorf("failed to pick new Relay server: %s", err) continue } return - case <-g.ctx.Done(): + case <-ctx.Done(): return } } } -func (g *Guard) quickReconnect() bool { - ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond) +func (g *Guard) retry(ctx context.Context) error { + log.Infof("try to pick up a new Relay server") + relayClient, err := g.serverPicker.PickServer(ctx) + if err != nil { + return err + } + + // prevent to work with a deprecated Relay client instance + g.drainRelayClientChan() + + g.OnNewRelayClient <- relayClient + return nil +} + +func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool { + ctx, cancel := context.WithTimeout(parentCtx, 1500*time.Millisecond) defer cancel() <-ctx.Done() - if g.ctx.Err() != nil { + if parentCtx.Err() != nil { return false } + log.Infof("try to reconnect to Relay server: %s", rc.connectionURL) - if err := g.relayClient.Connect(); err != nil { + if err := rc.Connect(); err != nil { log.Errorf("failed to reconnect to relay server: %s", err) return false } return true } + +func (g *Guard) drainRelayClientChan() { + select { + case <-g.OnNewRelayClient: + default: + } +} + +func (g *Guard) isServerURLStillValid(rc *Client) bool { + for _, url := range g.serverPicker.ServerURLs.Load().([]string) { + if url == rc.connectionURL { + return true + } + } + return false +} + +func exponentTicker(ctx context.Context) *backoff.Ticker { + bo := backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 2 * time.Second, + Multiplier: 2, + MaxInterval: reconnectingTimeout, + Clock: backoff.SystemClock, + }, ctx) + + return backoff.NewTicker(bo) +} diff --git a/relay/client/manager.go b/relay/client/manager.go index b14a7701b..d847bb879 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -57,12 +57,15 @@ type ManagerService interface { // relay servers will be closed if there is no active connection. Periodically the manager will check if there is any // unused relay connection and close it. type Manager struct { - ctx context.Context - serverURLs []string - peerID string - tokenStore *relayAuth.TokenStore + ctx context.Context + peerID string + running bool + tokenStore *relayAuth.TokenStore + serverPicker *ServerPicker - relayClient *Client + relayClient *Client + // the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable + relayClientMu sync.Mutex reconnectGuard *Guard relayClients map[string]*RelayTrack @@ -76,48 +79,54 @@ type Manager struct { // NewManager creates a new manager instance. // The serverURL address can be empty. In this case, the manager will not serve. func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager { - return &Manager{ - ctx: ctx, - serverURLs: serverURLs, - peerID: peerID, - tokenStore: &relayAuth.TokenStore{}, + tokenStore := &relayAuth.TokenStore{} + + m := &Manager{ + ctx: ctx, + peerID: peerID, + tokenStore: tokenStore, + serverPicker: &ServerPicker{ + TokenStore: tokenStore, + PeerID: peerID, + }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), } + m.serverPicker.ServerURLs.Store(serverURLs) + m.reconnectGuard = NewGuard(m.serverPicker) + return m } -// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for -// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection. +// Serve starts the manager, attempting to establish a connection with the relay server. +// If the connection fails, it will keep trying to reconnect in the background. +// Additionally, it starts a cleanup loop to remove unused relay connections. +// The manager will automatically reconnect to the relay server in case of disconnection. func (m *Manager) Serve() error { - if m.relayClient != nil { + if m.running { return fmt.Errorf("manager already serving") } - log.Debugf("starting relay client manager with %v relay servers", m.serverURLs) + m.running = true + log.Debugf("starting relay client manager with %v relay servers", m.serverPicker.ServerURLs.Load()) - sp := ServerPicker{ - TokenStore: m.tokenStore, - PeerID: m.peerID, - } - - client, err := sp.PickServer(m.ctx, m.serverURLs) + client, err := m.serverPicker.PickServer(m.ctx) if err != nil { - return err + go m.reconnectGuard.StartReconnectTrys(m.ctx, nil) + } else { + m.storeClient(client) } - m.relayClient = client - m.reconnectGuard = NewGuard(m.ctx, m.relayClient) - m.relayClient.SetOnConnectedListener(m.onServerConnected) - m.relayClient.SetOnDisconnectListener(func() { - m.onServerDisconnected(client.connectionURL) - }) - m.startCleanupLoop() - return nil + go m.listenGuardEvent(m.ctx) + go m.startCleanupLoop() + return err } // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // established via the relay server. If the peer is on a different relay server, the manager will establish a new // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + if m.relayClient == nil { return nil, ErrRelayClientNotConnected } @@ -146,6 +155,9 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { // Ready returns true if the home Relay client is connected to the relay server. func (m *Manager) Ready() bool { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + if m.relayClient == nil { return false } @@ -159,6 +171,13 @@ func (m *Manager) SetOnReconnectedListener(f func()) { // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // closed. func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + + if m.relayClient == nil { + return ErrRelayClientNotConnected + } + foreign, err := m.isForeignServer(serverAddress) if err != nil { return err @@ -177,6 +196,9 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ // RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is // lost. This address will be sent to the target peer to choose the common relay server for the communication. func (m *Manager) RelayInstanceAddress() (string, error) { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + if m.relayClient == nil { return "", ErrRelayClientNotConnected } @@ -185,13 +207,18 @@ func (m *Manager) RelayInstanceAddress() (string, error) { // ServerURLs returns the addresses of the relay servers. func (m *Manager) ServerURLs() []string { - return m.serverURLs + return m.serverPicker.ServerURLs.Load().([]string) } // HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with // Relay service. func (m *Manager) HasRelayAddress() bool { - return len(m.serverURLs) > 0 + return len(m.serverPicker.ServerURLs.Load().([]string)) > 0 +} + +func (m *Manager) UpdateServerURLs(serverURLs []string) { + log.Infof("update relay server URLs: %v", serverURLs) + m.serverPicker.ServerURLs.Store(serverURLs) } // UpdateToken updates the token in the token store. @@ -245,9 +272,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { return nil, err } // if connection closed then delete the relay client from the list - relayClient.SetOnDisconnectListener(func() { - m.onServerDisconnected(serverAddress) - }) + relayClient.SetOnDisconnectListener(m.onServerDisconnected) rt.relayClient = relayClient rt.Unlock() @@ -265,14 +290,37 @@ func (m *Manager) onServerConnected() { go m.onReconnectedListenerFn() } +// onServerDisconnected start to reconnection for home server only func (m *Manager) onServerDisconnected(serverAddress string) { + m.relayClientMu.Lock() if serverAddress == m.relayClient.connectionURL { - go m.reconnectGuard.OnDisconnected() + go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient) } + m.relayClientMu.Unlock() m.notifyOnDisconnectListeners(serverAddress) } +func (m *Manager) listenGuardEvent(ctx context.Context) { + for { + select { + case rc := <-m.reconnectGuard.OnNewRelayClient: + m.storeClient(rc) + case <-ctx.Done(): + return + } + } +} + +func (m *Manager) storeClient(client *Client) { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + + m.relayClient = client + m.relayClient.SetOnConnectedListener(m.onServerConnected) + m.relayClient.SetOnDisconnectListener(m.onServerDisconnected) +} + func (m *Manager) isForeignServer(address string) (bool, error) { rAddr, err := m.relayClient.ServerInstanceURL() if err != nil { @@ -282,22 +330,16 @@ func (m *Manager) isForeignServer(address string) (bool, error) { } func (m *Manager) startCleanupLoop() { - if m.ctx.Err() != nil { - return - } - ticker := time.NewTicker(relayCleanupInterval) - go func() { - defer ticker.Stop() - for { - select { - case <-m.ctx.Done(): - return - case <-ticker.C: - m.cleanUpUnusedRelays() - } + defer ticker.Stop() + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + m.cleanUpUnusedRelays() } - }() + } } func (m *Manager) cleanUpUnusedRelays() { diff --git a/relay/client/picker.go b/relay/client/picker.go index 13b0547aa..eb5062dbb 100644 --- a/relay/client/picker.go +++ b/relay/client/picker.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync/atomic" "time" log "github.com/sirupsen/logrus" @@ -12,10 +13,13 @@ import ( ) const ( - connectionTimeout = 30 * time.Second maxConcurrentServers = 7 ) +var ( + connectionTimeout = 30 * time.Second +) + type connResult struct { RelayClient *Client Url string @@ -24,20 +28,22 @@ type connResult struct { type ServerPicker struct { TokenStore *auth.TokenStore + ServerURLs atomic.Value PeerID string } -func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) { +func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) defer cancel() - totalServers := len(urls) + totalServers := len(sp.ServerURLs.Load().([]string)) connResultChan := make(chan connResult, totalServers) successChan := make(chan connResult, 1) concurrentLimiter := make(chan struct{}, maxConcurrentServers) - for _, url := range urls { + log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string)) + for _, url := range sp.ServerURLs.Load().([]string) { // todo check if we have a successful connection so we do not need to connect to other servers concurrentLimiter <- struct{}{} go func(url string) { @@ -78,7 +84,7 @@ func (sp *ServerPicker) processConnResults(resultChan chan connResult, successCh for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ { cr := <-resultChan if cr.Err != nil { - log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) + log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) continue } log.Infof("connected to Relay server: %s", cr.Url) diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go index 4800e05ba..20a03e64d 100644 --- a/relay/client/picker_test.go +++ b/relay/client/picker_test.go @@ -7,16 +7,19 @@ import ( ) func TestServerPicker_UnavailableServers(t *testing.T) { + connectionTimeout = 5 * time.Second + sp := ServerPicker{ TokenStore: nil, PeerID: "test", } + sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"}) ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) defer cancel() go func() { - _, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"}) + _, err := sp.PickServer(ctx) if err == nil { t.Error(err) } From 05c4aa7c2cad19f3679bf2548f8dc296dd1e043b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 22 Nov 2024 18:50:47 +0100 Subject: [PATCH 31/39] [misc] Renew slack link (#2938) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a2d7f3897..e7925ae09 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@
- +
@@ -34,7 +34,7 @@
See Documentation
- Join our Slack channel + Join our Slack channel
From 56cecf849ea4f6092b0dc9b421126da399bffa95 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 22 Nov 2024 20:40:30 +0100 Subject: [PATCH 32/39] Import time package (#2940) --- relay/client/picker_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go index 20a03e64d..28167c5ce 100644 --- a/relay/client/picker_test.go +++ b/relay/client/picker_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" ) func TestServerPicker_UnavailableServers(t *testing.T) { From 940d0c48c69198803b4cd88125214c5cb666bf2b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:11:31 +0100 Subject: [PATCH 33/39] [client] Don't return error in userspace mode without firewall (#2924) --- client/firewall/uspfilter/uspfilter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index af5dc6733..fb726395b 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -239,7 +239,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { // SetLegacyManagement doesn't need to be implemented for this manager func (m *Manager) SetLegacyManagement(isLegacy bool) error { if m.nativeFirewall == nil { - return errRouteNotSupported + return nil } return m.nativeFirewall.SetLegacyManagement(isLegacy) } From 0ecd5f211850d50dcf7179adbf3ae0246705f0a3 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:11:56 +0100 Subject: [PATCH 34/39] [client] Test nftables for incompatible iptables rules (#2948) --- .../firewall/nftables/manager_linux_test.go | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 77f4f0306..33fdc4b3d 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -1,9 +1,11 @@ package nftables import ( + "bytes" "fmt" "net" "net/netip" + "os/exec" "testing" "time" @@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) { }) } } + +func runIptablesSave(t *testing.T) (string, string) { + t.Helper() + var stdout, stderr bytes.Buffer + cmd := exec.Command("iptables-save") + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + require.NoError(t, err, "iptables-save failed to run") + + return stdout.String(), stderr.String() +} + +func verifyIptablesOutput(t *testing.T, stdout, stderr string) { + t.Helper() + // Check for any incompatibility warnings + require.NotContains(t, + stderr, + "incompatible", + "iptables-save produced compatibility warning. Full stderr: %s", + stderr, + ) + + // Verify standard tables are present + expectedTables := []string{ + "*filter", + "*nat", + "*mangle", + } + + for _, table := range expectedTables { + require.Contains(t, + stdout, + table, + "iptables-save output missing expected table: %s\nFull stdout: %s", + table, + stdout, + ) + } +} + +func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + if _, err := exec.LookPath("iptables-save"); err != nil { + t.Skipf("iptables-save not available on this system: %v", err) + } + + // First ensure iptables-nft tables exist by running iptables-save + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + + manager, err := Create(ifaceMock) + require.NoError(t, err, "failed to create manager") + require.NoError(t, manager.Init(nil)) + + t.Cleanup(func() { + err := manager.Reset(nil) + require.NoError(t, err, "failed to reset manager state") + + // Verify iptables output after reset + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + }) + + ip := net.ParseIP("100.96.0.1") + _, err = manager.AddPeerFiltering( + ip, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []int{80}}, + fw.RuleDirectionIN, + fw.ActionAccept, + "", + "test rule", + ) + require.NoError(t, err, "failed to add peer filtering rule") + + _, err = manager.AddRouteFiltering( + []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, + netip.MustParsePrefix("10.1.0.0/24"), + fw.ProtocolTCP, + nil, + &fw.Port{Values: []int{443}}, + fw.ActionAccept, + ) + require.NoError(t, err, "failed to add route filtering rule") + + pair := fw.RouterPair{ + Source: netip.MustParsePrefix("192.168.1.0/24"), + Destination: netip.MustParsePrefix("10.0.0.0/24"), + Masquerade: true, + } + err = manager.AddNatRule(pair) + require.NoError(t, err, "failed to add NAT rule") + + stdout, stderr = runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) +} From f1625b32bdd6c1fe0abb73756835947b99d1a97f Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:12:16 +0100 Subject: [PATCH 35/39] [client] Set up sysctl and routing table name only if routing rules are available (#2933) --- .../routemanager/systemops/systemops_linux.go | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 71a0f26ae..ac4fd5c71 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager return r.setupRefCounter(initAddresses, stateManager) } - if err = addRoutingTableName(); err != nil { - log.Errorf("Error adding routing table name: %v", err) - } - - originalValues, err := sysctl.Setup(r.wgInterface) - if err != nil { - log.Errorf("Error setting up sysctl: %v", err) - sysctlFailed = true - } - originalSysctl = originalValues - defer func() { if err != nil { if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { @@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager } } + if err = addRoutingTableName(); err != nil { + log.Errorf("Error adding routing table name: %v", err) + } + + originalValues, err := sysctl.Setup(r.wgInterface) + if err != nil { + log.Errorf("Error setting up sysctl: %v", err) + sysctlFailed = true + } + originalSysctl = originalValues + return nil, nil, nil } From 9810386937edd1665109479f745e6faaf21db840 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:19:56 +0100 Subject: [PATCH 36/39] [client] Allow routing to fallback to exclusion routes if rules are not supported (#2909) --- client/internal/routemanager/systemops/systemops_linux.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index ac4fd5c71..1d629d6e9 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -450,7 +450,7 @@ func addRule(params ruleParams) error { rule.Invert = params.invert rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) { return fmt.Errorf("add routing rule: %w", err) } @@ -467,7 +467,7 @@ func removeRule(params ruleParams) error { rule.Priority = params.priority rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) { return fmt.Errorf("remove routing rule: %w", err) } From ca12bc6953b8ba0e8645b90dc2ef8d50c8c38a01 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 25 Nov 2024 18:26:24 +0300 Subject: [PATCH 37/39] [management] Refactor posture check to use store methods (#2874) --- management/server/account.go | 2 +- management/server/dns.go | 2 +- management/server/group.go | 19 +- .../server/http/posture_checks_handler.go | 3 +- .../http/posture_checks_handler_test.go | 6 +- management/server/mock_server/account_mock.go | 6 +- management/server/nameserver.go | 10 +- management/server/peer.go | 25 +- management/server/policy.go | 6 +- management/server/posture/checks.go | 6 - management/server/posture_checks.go | 337 +++++++++++------- management/server/posture_checks_test.go | 221 +++++++----- management/server/route.go | 10 +- management/server/sql_store.go | 51 ++- management/server/sql_store_test.go | 135 +++++++ management/server/status/error.go | 5 + management/server/store.go | 4 +- management/server/testdata/extended-store.sql | 1 + 18 files changed, 589 insertions(+), 260 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 0ab123655..9fb56c855 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -139,7 +139,7 @@ type AccountManager interface { HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager diff --git a/management/server/dns.go b/management/server/dns.go index 4551be5ab..e52be6016 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -145,7 +145,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) } - if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { + if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) { am.updateAccountPeers(ctx, accountID) } diff --git a/management/server/group.go b/management/server/group.go index a36213f04..7b307cf1a 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -566,8 +566,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI return false, nil } -// anyGroupHasPeers checks if any of the given groups in the account have peers. -func anyGroupHasPeers(account *Account, groupIDs []string) bool { +func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { if group, exists := account.Groups[groupID]; exists && group.HasPeers() { return true @@ -575,3 +574,19 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool { } return false } + +// anyGroupHasPeers checks if any of the given groups in the account have peers. +func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs) + if err != nil { + return false, err + } + + for _, group := range groups { + if group.HasPeers() { + return true, nil + } + } + + return false, nil +} diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 1d020e9bc..2c8204292 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http. return } - if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { + postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks) + if err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 02f0f0d83..f400cec81 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH } return p, nil }, - SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error { + SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { postureChecks.ID = "postureCheck" testPostureChecks[postureChecks.ID] = postureChecks if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint + return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint } - return nil + return postureChecks, nil }, DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error { _, ok := testPostureChecks[postureChecksID] diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index aa6a47b15..673ed33bb 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -96,7 +96,7 @@ type MockAccountManager struct { HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() server.ExternalCacheManager GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager @@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p } // SavePostureChecks mocks SavePostureChecks of the AccountManager interface -func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { +func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { if am.SavePostureChecksFunc != nil { return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks) } - return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") } // DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 957008714..9119a3dec 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -70,7 +70,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco return nil, err } - if anyGroupHasPeers(account, newNSGroup.Groups) { + if am.anyGroupHasPeers(account, newNSGroup.Groups) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) @@ -105,7 +105,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { + if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) @@ -135,7 +135,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return err } - if anyGroupHasPeers(account, nsGroup.Groups) { + if am.anyGroupHasPeers(account, nsGroup.Groups) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) @@ -279,9 +279,9 @@ func validateDomain(domain string) error { } // areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { +func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { if !newNSGroup.Enabled && !oldNSGroup.Enabled { return false } - return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups) + return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups) } diff --git a/management/server/peer.go b/management/server/peer.go index beb833dba..dcb47af3b 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -617,7 +617,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, err } - postureChecks := am.getPeerPostureChecks(account, newPeer) + postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID) + if err != nil { + return nil, nil, nil, err + } + customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) return newPeer, networkMap, postureChecks, nil @@ -702,7 +706,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac if err != nil { return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err) } - postureChecks = am.getPeerPostureChecks(account, peer) + + postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + if err != nil { + return nil, nil, nil, err + } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -876,7 +884,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is if err != nil { return nil, nil, nil, err } - postureChecks = am.getPeerPostureChecks(account, peer) + + postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + if err != nil { + return nil, nil, nil, err + } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -1030,7 +1042,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account defer wg.Done() defer func() { <-semaphore }() - postureChecks := am.getPeerPostureChecks(account, p) + postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err) + return + } + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) diff --git a/management/server/policy.go b/management/server/policy.go index 8a5733f01..c7872591d 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -405,7 +405,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - if anyGroupHasPeers(account, policy.ruleGroups()) { + if am.anyGroupHasPeers(account, policy.ruleGroups()) { am.updateAccountPeers(ctx, accountID) } @@ -469,7 +469,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli if !policyToSave.Enabled && !oldPolicy.Enabled { return false, nil } - updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups()) + updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups()) return updateAccountPeers, nil } @@ -477,7 +477,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli // Add the new policy to the account account.Policies = append(account.Policies, policyToSave) - return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil + return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil } func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index f2739dddf..b2f308d76 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -7,8 +7,6 @@ import ( "regexp" "github.com/hashicorp/go-version" - "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh } func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) { - if postureChecksID == "" { - postureChecksID = xid.New().String() - } - postureChecks := Checks{ ID: postureChecksID, Name: name, diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 096cff3f5..59e726c41 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,16 +2,14 @@ package server import ( "context" + "fmt" "slices" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" -) - -const ( - errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks" + "github.com/rs/xid" + "golang.org/x/exp/maps" ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { @@ -20,219 +18,284 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID return nil, err } - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) - } - - return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) -} - -func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + return nil, status.NewAdminPermissionError() } - if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint + return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) +} + +// SavePostureChecks saves a posture check. +func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err } - exists, uniqName := am.savePostureChecks(account, postureChecks) - - // we do not allow create new posture checks with non uniq name - if !exists && !uniqName { - return status.Errorf(status.PreconditionFailed, "Posture check name should be unique") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - action := activity.PostureCheckCreated - if exists { - action = activity.PostureCheckUpdated - account.Network.IncSerial() + if !user.HasAdminPower() { + return nil, status.NewAdminPermissionError() } - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err + var updateAccountPeers bool + var isUpdate = postureChecks.ID != "" + var action = activity.PostureCheckCreated + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil { + return err + } + + if isUpdate { + updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + action = activity.PostureCheckUpdated + } + + postureChecks.AccountID = accountID + return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks) + }) + if err != nil { + return nil, err } am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) - if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - return nil + return postureChecks, nil } +// DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + return status.NewAdminPermissionError() } - postureChecks, err := am.deletePostureChecks(account, postureChecksID) + var postureChecks *posture.Checks + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + if err != nil { + return err + } + + if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID) + }) if err != nil { return err } - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta()) return nil } +// ListPostureChecks returns a list of posture checks. func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if !user.HasAdminPower() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { - uniqName = true - for i, p := range account.PostureChecks { - if !exists && p.ID == postureChecks.ID { - account.PostureChecks[i] = postureChecks - exists = true - } - if p.Name == postureChecks.Name { - uniqName = false - } - } - if !exists { - account.PostureChecks = append(account.PostureChecks, postureChecks) - } - return -} - -func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) { - postureChecksIdx := -1 - for i, postureChecks := range account.PostureChecks { - if postureChecks.ID == postureChecksID { - postureChecksIdx = i - break - } - } - if postureChecksIdx < 0 { - return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID) - } - - // Check if posture check is linked to any policy - if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked { - return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name) - } - - postureChecks := account.PostureChecks[postureChecksIdx] - account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...) - - return postureChecks, nil -} - // getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks { - peerPostureChecks := make(map[string]posture.Checks) +func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) { + peerPostureChecks := make(map[string]*posture.Checks) - if len(account.PostureChecks) == 0 { + err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + if len(postureChecks) == 0 { + return nil + } + + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil { + return err + } + } + + return nil + }) + if err != nil { + return nil, err + } + + return maps.Values(peerPostureChecks), nil +} + +// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. +func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err + } + + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureCheckID) { + hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + } + } + + return false, nil +} + +// validatePostureChecks validates the posture checks. +func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error { + if err := postureChecks.Validate(); err != nil { + return status.Errorf(status.InvalidArgument, err.Error()) //nolint + } + + // If the posture check already has an ID, verify its existence in the store. + if postureChecks.ID != "" { + if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil { + return err + } return nil } - for _, policy := range account.Policies { - if !policy.Enabled { - continue - } + // For new posture checks, ensure no duplicates by name. + checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } - if isPeerInPolicySourceGroups(peer.ID, account, policy) { - addPolicyPostureChecks(account, policy, peerPostureChecks) + for _, check := range checks { + if check.Name == postureChecks.Name && check.ID != postureChecks.ID { + return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name) } } - postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks)) - for _, check := range peerPostureChecks { - checkCopy := check - postureChecksList = append(postureChecksList, &checkCopy) + postureChecks.ID = xid.New().String() + + return nil +} + +// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. +func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { + isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy) + if err != nil { + return err } - return postureChecksList + if !isInGroup { + return nil + } + + for _, sourcePostureCheckID := range policy.SourcePostureChecks { + postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID) + if err != nil { + return err + } + peerPostureChecks[sourcePostureCheckID] = postureCheck + } + + return nil } // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool { +func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue } for _, sourceGroup := range rule.Sources { - group, ok := account.Groups[sourceGroup] - if ok && slices.Contains(group.Peers, peerID) { - return true + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup) + if err != nil { + return false, fmt.Errorf("failed to check peer in policy source group: %w", err) + } + + if slices.Contains(group.Peers, peerID) { + return true, nil } } } - return false -} - -func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) { - for _, sourcePostureCheckID := range policy.SourcePostureChecks { - for _, postureCheck := range account.PostureChecks { - if postureCheck.ID == sourcePostureCheckID { - peerPostureChecks[sourcePostureCheckID] = *postureCheck - } - } - } -} - -func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) { - for _, policy := range account.Policies { - if slices.Contains(policy.SourcePostureChecks, postureChecksID) { - return true, policy - } - } return false, nil } -// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers. -func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool { - if !exists { - return false +// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. +func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return err } - isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID) - if !isLinked { - return false + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureChecksID) { + return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name) + } } - return anyGroupHasPeers(account, linkedPolicy.ruleGroups()) + + return nil } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index c63538b9d..3c5c5fc79 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -7,6 +7,7 @@ import ( "github.com/rs/xid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/group" @@ -16,7 +17,6 @@ import ( const ( adminUserID = "adminUserID" regularUserID = "regularUserID" - postureCheckID = "existing-id" postureCheckName = "Existing check" ) @@ -33,7 +33,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { t.Run("Generic posture check flow", func(t *testing.T) { // regular users can not create checks - err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) + _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) assert.Error(t, err) // regular users cannot list check @@ -41,8 +41,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // should be possible to create posture check with uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: postureCheckID, + postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -58,8 +57,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Len(t, checks, 1) // should not be possible to create posture check with non uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: "new-id", + _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ GeoLocationCheck: &posture.GeoLocationCheck{ @@ -74,23 +72,20 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // admins can update posture checks - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: postureCheckID, - Name: postureCheckName, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.27.0", - }, + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.27.0", }, - }) + } + _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck) assert.NoError(t, err) // users should not be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID) assert.Error(t, err) // admin should be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID) assert.NoError(t, err) checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) assert.NoError(t, err) @@ -150,9 +145,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - postureCheck := posture.Checks{ - ID: "postureCheck", - Name: "postureCheck", + postureCheckA := &posture.Checks{ + Name: "postureCheckA", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + {LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"}, + }, + }, + }, + } + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA) + require.NoError(t, err) + + postureCheckB := &posture.Checks{ + Name: "postureCheckB", AccountID: account.Id, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -169,7 +177,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -187,12 +195,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -215,7 +223,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } // Linking posture check to policy should trigger update account peers and send peer update @@ -238,7 +246,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture checks should update account peers and send peer update t.Run("updating linked to posture check with peers", func(t *testing.T) { - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, @@ -255,7 +263,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -293,7 +301,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID) + err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID) assert.NoError(t, err) select { @@ -303,7 +311,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update @@ -321,7 +329,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) assert.NoError(t, err) @@ -332,12 +340,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -367,7 +375,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) @@ -379,12 +387,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -409,7 +417,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) assert.NoError(t, err) @@ -420,7 +428,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ ProcessCheck: &posture.ProcessCheck{ Processes: []posture.Process{ { @@ -429,7 +437,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -440,80 +448,123 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }) } -func TestArePostureCheckChangesAffectingPeers(t *testing.T) { - account := &Account{ - Policies: []*Policy{ - { - ID: "policyA", - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - }, - }, - SourcePostureChecks: []string{"checkA"}, - }, - }, - Groups: map[string]*group.Group{ - "groupA": { - ID: "groupA", - Peers: []string{"peer1"}, - }, - "groupB": { - ID: "groupB", - Peers: []string{}, - }, - }, - PostureChecks: []*posture.Checks{ - { - ID: "checkA", - }, - { - ID: "checkB", - }, - }, +func TestArePostureCheckChangesAffectPeers(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "failed to create account manager") + + account, err := initTestPostureChecksAccount(manager) + require.NoError(t, err, "failed to init testing account") + + groupA := &group.Group{ + ID: "groupA", + AccountID: account.Id, + Peers: []string{"peer1"}, } + groupB := &group.Group{ + ID: "groupB", + AccountID: account.Id, + Peers: []string{}, + } + err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB}) + require.NoError(t, err, "failed to save groups") + + postureCheckA := &posture.Checks{ + Name: "checkA", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, + }, + } + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA) + require.NoError(t, err, "failed to save postureCheckA") + + postureCheckB := &posture.Checks{ + Name: "checkB", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, + }, + } + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB) + require.NoError(t, err, "failed to save postureCheckB") + + policy := &Policy{ + ID: "policyA", + AccountID: account.Id, + Rules: []*PolicyRule{ + { + ID: "ruleA", + PolicyID: "policyA", + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + }, + }, + SourcePostureChecks: []string{postureCheckA.ID}, + } + + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, false) + require.NoError(t, err, "failed to save policy") + t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkB", true) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID) + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check does not exist", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "unknown", false) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown") + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupB"} - account.Policies[0].Rules[0].Destinations = []string{"groupA"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupB"} + policy.Rules[0].Destinations = []string{"groupA"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupA"} - account.Policies[0].Rules[0].Destinations = []string{"groupB"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupA"} + policy.Rules[0].Destinations = []string{"groupB"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) - t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} - account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { + groupA.Peers = []string{} + err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA) + require.NoError(t, err, "failed to save groups") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.False(t, result) }) - t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { - account.Groups["groupA"].Peers = []string{} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { + policy.Rules[0].Sources = []string{"nonExistentGroup"} + policy.Rules[0].Destinations = []string{"nonExistentGroup"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.False(t, result) }) } diff --git a/management/server/route.go b/management/server/route.go index dcf2cb0d3..ecb562645 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -237,7 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } - if isRouteChangeAffectPeers(account, &newRoute) { + if am.isRouteChangeAffectPeers(account, &newRoute) { am.updateAccountPeers(ctx, accountID) } @@ -323,7 +323,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } - if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { + if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { am.updateAccountPeers(ctx, accountID) } @@ -355,7 +355,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - if isRouteChangeAffectPeers(account, routy) { + if am.isRouteChangeAffectPeers(account, routy) { am.updateAccountPeers(ctx, accountID) } @@ -651,6 +651,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { // isRouteChangeAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers -func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { - return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool { + return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 278f5443d..47c17bb92 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1305,12 +1305,57 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { - return getRecords[*posture.Checks](s.db, lockStrength, accountID) + var postureChecks []*posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture checks from store") + } + + return postureChecks, nil } // GetPostureChecksByID retrieves posture checks by their ID and account ID. -func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { - return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID) +func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) { + var postureCheck *posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewPostureChecksNotFoundError(postureChecksID) + } + log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture check from store") + } + + return postureCheck, nil +} + +// SavePostureChecks saves a posture checks to the database. +func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save posture checks to store") + } + + return nil +} + +// DeletePostureChecks deletes a posture checks from the database. +func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete posture checks from store") + } + + if result.RowsAffected == 0 { + return status.NewPostureChecksNotFoundError(postureChecksID) + } + + return nil } // GetAccountRoutes retrieves network routes for an account. diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 114da1ee6..de939e8d0 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/posture" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1564,3 +1565,137 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) { }) } } + +func TestSqlStore_GetPostureChecksByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + postureChecksID string + expectError bool + }{ + { + name: "retrieve existing posture checks", + postureChecksID: "csplshq7qv948l48f7t0", + expectError: false, + }, + { + name: "retrieve non-existing posture checks", + postureChecksID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty posture checks ID", + postureChecksID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, postureChecks) + } else { + require.NoError(t, err) + require.NotNil(t, postureChecks) + require.Equal(t, tt.postureChecksID, postureChecks.ID) + } + }) + } +} + +func TestSqlStore_SavePostureChecks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + postureChecks := &posture.Checks{ + ID: "posture-checks-id", + AccountID: accountID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.31.0", + }, + OSVersionCheck: &posture.OSVersionCheck{ + Ios: &posture.MinVersionCheck{ + MinVersion: "13.0.1", + }, + Linux: &posture.MinKernelVersionCheck{ + MinKernelVersion: "5.3.3-dev", + }, + }, + GeoLocationCheck: &posture.GeoLocationCheck{ + Locations: []posture.Location{ + { + CountryCode: "DE", + CityName: "Berlin", + }, + }, + Action: posture.CheckActionAllow, + }, + }, + } + err = store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureChecks) + require.NoError(t, err) + + savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, "posture-checks-id") + require.NoError(t, err) + require.Equal(t, savePostureChecks, postureChecks) +} + +func TestSqlStore_DeletePostureChecks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + postureChecksID string + expectError bool + }{ + { + name: "delete existing posture checks", + postureChecksID: "csplshq7qv948l48f7t0", + expectError: false, + }, + { + name: "delete non-existing posture checks", + postureChecksID: "non-existing-posture-checks-id", + expectError: true, + }, + { + name: "delete with empty posture checks ID", + postureChecksID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err = store.DeletePostureChecks(context.Background(), LockingStrengthUpdate, accountID, tt.postureChecksID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + } else { + require.NoError(t, err) + group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + require.Error(t, err) + require.Nil(t, group) + } + }) + } +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 8b6d0077b..44391e1f1 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -139,3 +139,8 @@ func NewGetAccountError(err error) error { func NewGroupNotFoundError(groupID string) error { return Errorf(NotFound, "group: %s not found", groupID) } + +// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks +func NewPostureChecksNotFoundError(postureChecksID string) error { + return Errorf(NotFound, "posture checks: %s not found", postureChecksID) +} diff --git a/management/server/store.go b/management/server/store.go index 71b0d457b..03b5821e7 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -84,7 +84,9 @@ type Store interface { GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) - GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) + GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) + SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error + DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index b522741e7..1646ff4da 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -34,4 +34,5 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003' INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); +INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO installations VALUES(1,''); From f118d81d3219b78b689206523e4159fcb495fa12 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 26 Nov 2024 12:46:05 +0300 Subject: [PATCH 38/39] [management] Refactor policy to use store methods (#2878) --- management/server/account.go | 2 +- management/server/account_test.go | 57 ++-- management/server/group_test.go | 5 +- management/server/http/policies_handler.go | 26 +- .../server/http/policies_handler_test.go | 4 +- management/server/mock_server/account_mock.go | 8 +- management/server/peer_test.go | 31 +- management/server/policy.go | 319 +++++++++++------- management/server/policy_test.go | 165 +++------ management/server/posture_checks_test.go | 43 +-- management/server/route_test.go | 3 +- management/server/setupkey_test.go | 5 +- management/server/sql_store.go | 82 ++++- management/server/sql_store_test.go | 158 +++++++++ management/server/status/error.go | 5 + management/server/store.go | 6 +- management/server/testdata/extended-store.sql | 1 + management/server/user_test.go | 5 +- 18 files changed, 576 insertions(+), 349 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 9fb56c855..fbe6fcc1a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -113,7 +113,7 @@ type AccountManager interface { GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error + SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 97e0d45f0..c8c2d5941 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1238,8 +1238,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { return } - policy := Policy{ - ID: "policy", + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1250,8 +1249,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1320,19 +1318,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - policy := Policy{ - Enabled: true, - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -1345,7 +1330,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } }() - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + }) + if err != nil { t.Errorf("delete default rule: %v", err) return } @@ -1366,7 +1363,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - policy := Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1377,9 +1374,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + }) + if err != nil { t.Errorf("save policy: %v", err) return } @@ -1421,7 +1417,12 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { require.NoError(t, err, "failed to save group") - policy := Policy{ + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1432,14 +1433,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { - t.Errorf("delete default rule: %v", err) - return - } - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + }) + if err != nil { t.Errorf("save policy: %v", err) return } diff --git a/management/server/group_test.go b/management/server/group_test.go index 59094a23e..ec017fc57 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -500,8 +500,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { }) // adding a group to policy - err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ - ID: "policy", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -512,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - }, false) + }) assert.NoError(t, err) // Saving a group linked to policy should update account peers and send peer update diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 73f3803b5..eff9092d4 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -6,10 +6,8 @@ import ( "strconv" "github.com/gorilla/mux" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -122,21 +120,22 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - isUpdate := policyID != "" - - if policyID == "" { - policyID = xid.New().String() - } - - policy := server.Policy{ + policy := &server.Policy{ ID: policyID, + AccountID: accountID, Name: req.Name, Enabled: req.Enabled, Description: req.Description, } for _, rule := range req.Rules { + var ruleID string + if rule.Id != nil { + ruleID = *rule.Id + } + pr := server.PolicyRule{ - ID: policyID, // TODO: when policy can contain multiple rules, need refactor + ID: ruleID, + PolicyID: policyID, Name: rule.Name, Destinations: rule.Destinations, Sources: rule.Sources, @@ -225,7 +224,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID policy.SourcePostureChecks = *req.SourcePostureChecks } - if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil { + policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy) + if err != nil { util.WriteError(r.Context(), err, w) return } @@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - resp := toPolicyResponse(allGroups, &policy) + resp := toPolicyResponse(allGroups, policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 228ebcbce..f8a897eb2 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } - return nil + return policy, nil }, GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 673ed33bb..46a4fbc1f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -49,7 +49,7 @@ type MockAccountManager struct { GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) @@ -386,11 +386,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) { if am.SavePolicyFunc != nil { - return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate) + return am.SavePolicyFunc(ctx, accountID, userID, policy) } - return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") } // DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 4e2dcb2c3..e410fa892 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -283,14 +283,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { var ( group1 nbgroup.Group group2 nbgroup.Group - policy Policy ) group1.ID = xid.New().String() group2.ID = xid.New().String() group1.Name = "src" group2.Name = "dst" - policy.ID = xid.New().String() group1.Peers = append(group1.Peers, peer1.ID) group2.Peers = append(group2.Peers, peer2.ID) @@ -305,18 +303,20 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - policy.Name = "test" - policy.Enabled = true - policy.Rules = []*PolicyRule{ - { - Enabled: true, - Sources: []string{group1.ID}, - Destinations: []string{group2.ID}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, + policy := &Policy{ + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{group1.ID}, + Destinations: []string{group2.ID}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -364,7 +364,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } policy.Enabled = false - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -1445,8 +1445,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { // Adding peer to group linked with policy should update account peers and send peer update t.Run("adding peer to group linked with policy", func(t *testing.T) { - err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ - ID: "policy", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1457,7 +1456,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - }, false) + }) require.NoError(t, err) done := make(chan struct{}) diff --git a/management/server/policy.go b/management/server/policy.go index c7872591d..2d3abc3f1 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,13 +3,13 @@ package server import ( "context" _ "embed" - "slices" "strconv" "strings" + "github.com/netbirdio/netbird/management/proto" + "github.com/rs/xid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -125,6 +125,7 @@ type PolicyRule struct { func (pm *PolicyRule) Copy() *PolicyRule { rule := &PolicyRule{ ID: pm.ID, + PolicyID: pm.PolicyID, Name: pm.Name, Description: pm.Description, Enabled: pm.Enabled, @@ -171,6 +172,7 @@ type Policy struct { func (p *Policy) Copy() *Policy { c := &Policy{ ID: p.ID, + AccountID: p.AccountID, Name: p.Name, Description: p.Description, Enabled: p.Enabled, @@ -343,157 +345,207 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return err + return nil, err } - updateAccountPeers, err := am.savePolicy(account, policy, isUpdate) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + var isUpdate = policy.ID != "" + var updateAccountPeers bool + var action = activity.PolicyAdded + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { + return err + } + + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + saveFunc := transaction.CreatePolicy + if isUpdate { + action = activity.PolicyUpdated + saveFunc = transaction.SavePolicy + } + + return saveFunc(ctx, LockingStrengthUpdate, policy) + }) if err != nil { - return err + return nil, err } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - action := activity.PolicyAdded - if isUpdate { - action = activity.PolicyUpdated - } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } + return policy, nil +} + +// DeletePolicy from the store +func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err + } + + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + + var policy *Policy + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID) + if err != nil { + return err + } + + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID) + }) + if err != nil { + return err + } + + am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) + + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + return nil } -// DeletePolicy from the store -func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - policy, err := am.deletePolicy(account, policyID) - if err != nil { - return err - } - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - - if am.anyGroupHasPeers(account, policy.ruleGroups()) { - am.updateAccountPeers(ctx, accountID) - } - - return nil -} - -// ListPolicies from the store +// ListPolicies from the store. func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { - policyIdx := -1 - for i, policy := range account.Policies { - if policy.ID == policyID { - policyIdx = i - break - } - } - if policyIdx < 0 { - return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID) - } - - policy := account.Policies[policyIdx] - account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...) - return policy, nil -} - -// savePolicy saves or updates a policy in the given account. -// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. -func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) { - for index, rule := range policyToSave.Rules { - rule.Sources = filterValidGroupIDs(account, rule.Sources) - rule.Destinations = filterValidGroupIDs(account, rule.Destinations) - policyToSave.Rules[index] = rule - } - - if policyToSave.SourcePostureChecks != nil { - policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks) - } - +// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. +func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) { if isUpdate { - policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) - if policyIdx < 0 { - return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + if err != nil { + return false, err } - oldPolicy := account.Policies[policyIdx] - // Update the existing policy - account.Policies[policyIdx] = policyToSave - - if !policyToSave.Enabled && !oldPolicy.Enabled { + if !policy.Enabled && !existingPolicy.Enabled { return false, nil } - updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups()) - return updateAccountPeers, nil - } + hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups()) + if err != nil { + return false, err + } - // Add the new policy to the account - account.Policies = append(account.Policies, policyToSave) - - return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil -} - -func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - - result[i] = &proto.FirewallRule{ - PeerIP: rule.PeerIP, - Direction: getProtoDirection(rule.Direction), - Action: getProtoAction(rule.Action), - Protocol: getProtoProtocol(rule.Protocol), - Port: rule.Port, + if hasPeers { + return true, nil } } - return result + + return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) +} + +// validatePolicy validates the policy and its rules. +func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error { + if policy.ID != "" { + _, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + if err != nil { + return err + } + } else { + policy.ID = xid.New().String() + policy.AccountID = accountID + } + + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups()) + if err != nil { + return err + } + + postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks) + if err != nil { + return err + } + + for i, rule := range policy.Rules { + ruleCopy := rule.Copy() + if ruleCopy.ID == "" { + ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor + ruleCopy.PolicyID = policy.ID + } + + ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources) + ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations) + policy.Rules[i] = ruleCopy + } + + if policy.SourcePostureChecks != nil { + policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) + } + + return nil } // getAllPeersFromGroups for given peer ID and list of groups @@ -574,27 +626,42 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { return nil } -// filterValidPostureChecks filters and returns the posture check IDs from the given list -// that are valid within the provided account. -func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { - result := make([]string, 0, len(postureChecksIds)) +// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. +func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string { + validIDs := make([]string, 0, len(postureChecksIds)) for _, id := range postureChecksIds { - for _, postureCheck := range account.PostureChecks { - if id == postureCheck.ID { - result = append(result, id) - continue - } + if _, exists := postureChecks[id]; exists { + validIDs = append(validIDs, id) } } - return result + + return validIDs } -// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. -func filterValidGroupIDs(account *Account, groupIDs []string) []string { - result := make([]string, 0, len(groupIDs)) - for _, groupID := range groupIDs { - if _, exists := account.Groups[groupID]; exists { - result = append(result, groupID) +// getValidGroupIDs filters and returns only the valid group IDs from the provided list. +func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string { + validIDs := make([]string, 0, len(groupIDs)) + for _, id := range groupIDs { + if _, exists := groups[id]; exists { + validIDs = append(validIDs, id) + } + } + + return validIDs +} + +// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. +func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + + result[i] = &proto.FirewallRule{ + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, } } return result diff --git a/management/server/policy_test.go b/management/server/policy_test.go index e7f0f9cd2..62d80f46e 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/rs/xid" "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" @@ -859,14 +858,23 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) + var policyWithGroupRulesNoPeers *Policy + var policyWithDestinationPeersOnly *Policy + var policyWithSourceAndDestinationPeers *Policy + // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-rule-groups-no-peers", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, @@ -874,15 +882,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -895,12 +895,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with source group containing peers, but destination group without peers should // update account's peers and send peer update t.Run("saving policy where source has peers but destination does not", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-has-peers-destination-none", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupB"}, @@ -909,15 +914,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -930,13 +927,18 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with destination group containing peers, but source group without peers should // update account's peers and send peer update t.Run("saving policy where destination has peers but source does not", func(t *testing.T) { - policy := Policy{ - ID: "policy-destination-has-peers-source-none", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), - Enabled: false, + Enabled: true, Sources: []string{"groupC"}, Destinations: []string{"groupD"}, Bidirectional: true, @@ -944,15 +946,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -965,12 +959,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("saving policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupD"}, @@ -978,15 +977,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -999,28 +990,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Disabling policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: false, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupD"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Enabled = false + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1033,29 +1010,15 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Updating disabled policy with destination and source groups containing peers should not update account's peers // or send peer update t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Description: "updated description", - Enabled: false, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Description = "updated description" + policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"} + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1068,28 +1031,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Enabling policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupD"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Enabled = true + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1101,15 +1050,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy should trigger account peers update and send peer update t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) { - policyID := "policy-source-destination-peers" - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithSourceAndDestinationPeers.ID, userID) assert.NoError(t, err) select { @@ -1123,14 +1070,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy with destination group containing peers, but source group without peers should // update account's peers and send peer update t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) { - policyID := "policy-destination-has-peers-source-none" done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithDestinationPeersOnly.ID, userID) assert.NoError(t, err) select { @@ -1142,14 +1088,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy with no peers in groups should not update account's peers and not send peer update t.Run("deleting policy with no peers in groups", func(t *testing.T) { - policyID := "policy-rule-groups-no-peers" done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithGroupRulesNoPeers.ID, userID) assert.NoError(t, err) select { diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 3c5c5fc79..93e5741cf 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -5,7 +5,6 @@ import ( "testing" "time" - "github.com/rs/xid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -210,12 +209,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - policy := Policy{ - ID: "policyA", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, @@ -234,7 +231,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) assert.NoError(t, err) select { @@ -282,8 +279,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }() policy.SourcePostureChecks = []string{} - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + _, err := manager.SavePolicy(context.Background(), account.Id, userID, policy) assert.NoError(t, err) select { @@ -316,12 +312,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { - policy = Policy{ - ID: "policyB", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, @@ -330,8 +324,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) done := make(chan struct{}) @@ -362,12 +355,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { t.Cleanup(func() { manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) }) - policy = Policy{ - ID: "policyB", + + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupA"}, @@ -376,9 +368,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - } - - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + }) assert.NoError(t, err) done := make(chan struct{}) @@ -405,8 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked client posture check to policy where source has peers but destination does not, // should trigger account peers update and send peer update t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { - policy = Policy{ - ID: "policyB", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -418,8 +407,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + }) assert.NoError(t, err) done := make(chan struct{}) @@ -490,12 +478,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { require.NoError(t, err, "failed to save postureCheckB") policy := &Policy{ - ID: "policyA", AccountID: account.Id, Rules: []*PolicyRule{ { - ID: "ruleA", - PolicyID: "policyA", Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, @@ -504,7 +489,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { SourcePostureChecks: []string{postureCheckA.ID}, } - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to save policy") t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { @@ -528,7 +513,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupB"} policy.Rules[0].Destinations = []string{"groupA"} - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) @@ -539,7 +524,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupA"} policy.Rules[0].Destinations = []string{"groupB"} - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) @@ -560,7 +545,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { policy.Rules[0].Sources = []string{"nonExistentGroup"} policy.Rules[0].Destinations = []string{"nonExistentGroup"} - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) diff --git a/management/server/route_test.go b/management/server/route_test.go index 5c848f68c..108f791e0 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1214,12 +1214,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { defaultRule := rules[0] newPolicy := defaultRule.Copy() - newPolicy.ID = xid.New().String() newPolicy.Name = "peer1 only" newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) + _, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) require.NoError(t, err) err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 7c8200706..614547c60 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -406,8 +406,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - policy := Policy{ - ID: "policy", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -419,7 +418,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 47c17bb92..9a24857d1 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1243,8 +1243,8 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren var groups []*nbgroup.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { - log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") + log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") } groupsMap := make(map[string]*nbgroup.Group) @@ -1295,12 +1295,67 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID) + var policies []*Policy + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get policies from store") + } + + return policies, nil } // GetPolicyByID retrieves a policy by its ID and account ID. -func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { - return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID) +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) { + var policy *Policy + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). + First(&policy, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewPolicyNotFoundError(policyID) + } + log.WithContext(ctx).Errorf("failed to get policy from store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get policy from store") + } + + return policy, nil +} + +func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error) + return status.Errorf(status.Internal, "failed to create policy in store") + } + + return nil +} + +// SavePolicy saves a policy to the database. +func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err) + return status.Errorf(status.Internal, "failed to save policy to store") + } + return nil +} + +func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) + return status.Errorf(status.Internal, "failed to delete policy from store") + } + + if result.RowsAffected == 0 { + return status.NewPolicyNotFoundError(policyID) + } + + return nil } // GetAccountPostureChecks retrieves posture checks for an account. @@ -1331,6 +1386,23 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin return postureCheck, nil } +// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID. +func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) { + var postureChecks []*posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store") + } + + postureChecksMap := make(map[string]*posture.Checks) + for _, postureCheck := range postureChecks { + postureChecksMap[postureCheck.ID] = postureCheck + } + + return postureChecksMap, nil +} + // SavePostureChecks saves a posture checks to the database. func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index de939e8d0..c05793fc6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1612,6 +1612,49 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) { } } +func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + postureCheckIDs []string + expectedCount int + }{ + { + name: "retrieve existing posture checks by existing IDs", + postureCheckIDs: []string{"csplshq7qv948l48f7t0", "cspnllq7qv95uq1r4k90"}, + expectedCount: 2, + }, + { + name: "empty posture check IDs list", + postureCheckIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing posture check IDs", + postureCheckIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing posture check IDs", + postureCheckIDs: []string{"cspnllq7qv95uq1r4k90", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + groups, err := store.GetPostureChecksByIDs(context.Background(), LockingStrengthShare, accountID, tt.postureCheckIDs) + require.NoError(t, err) + require.Len(t, groups, tt.expectedCount) + }) + } +} + func TestSqlStore_SavePostureChecks(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) @@ -1699,3 +1742,118 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) { }) } } + +func TestSqlStore_GetPolicyByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + policyID string + expectError bool + }{ + { + name: "retrieve existing policy", + policyID: "cs1tnh0hhcjnqoiuebf0", + expectError: false, + }, + { + name: "retrieve non-existing policy checks", + policyID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty policy ID", + policyID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, tt.policyID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, policy) + } else { + require.NoError(t, err) + require.NotNil(t, policy) + require.Equal(t, tt.policyID, policy.ID) + } + }) + } +} + +func TestSqlStore_CreatePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + policy := &Policy{ + ID: "policy-id", + AccountID: accountID, + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupC"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + err = store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err) + + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + require.NoError(t, err) + require.Equal(t, savePolicy, policy) + +} + +func TestSqlStore_SavePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + policyID := "cs1tnh0hhcjnqoiuebf0" + + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + require.NoError(t, err) + + policy.Enabled = false + policy.Description = "policy" + policy.Rules[0].Sources = []string{"group"} + policy.Rules[0].Ports = []string{"80", "443"} + err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err) + + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + require.NoError(t, err) + require.Equal(t, savePolicy, policy) +} + +func TestSqlStore_DeletePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + policyID := "cs1tnh0hhcjnqoiuebf0" + + err = store.DeletePolicy(context.Background(), LockingStrengthShare, accountID, policyID) + require.NoError(t, err) + + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + require.Error(t, err) + require.Nil(t, policy) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 44391e1f1..0fff53559 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -144,3 +144,8 @@ func NewGroupNotFoundError(groupID string) error { func NewPostureChecksNotFoundError(postureChecksID string) error { return Errorf(NotFound, "posture checks: %s not found", postureChecksID) } + +// NewPolicyNotFoundError creates a new Error with NotFound type for a missing policy +func NewPolicyNotFoundError(policyID string) error { + return Errorf(NotFound, "policy: %s not found", policyID) +} diff --git a/management/server/store.go b/management/server/store.go index 03b5821e7..ba61d552d 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -80,11 +80,15 @@ type Store interface { DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) - GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) + CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) + GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 1646ff4da..37db27316 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -35,4 +35,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-3465 INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); +INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); INSERT INTO installations VALUES(1,''); diff --git a/management/server/user_test.go b/management/server/user_test.go index d4f560a54..498017afa 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1279,8 +1279,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { }) require.NoError(t, err) - policy := Policy{ - ID: "policy", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1292,7 +1291,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) From 0e48a772ff8de32cb6710f8d2f8c36a6bd468551 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 26 Nov 2024 15:43:05 +0300 Subject: [PATCH 39/39] [management] Refactor DNS settings to use store methods (#2883) --- management/server/dns.go | 164 ++++++++++++++++++++-------- management/server/sql_store.go | 21 +++- management/server/sql_store_test.go | 63 +++++++++++ management/server/store.go | 1 + 4 files changed, 204 insertions(+), 45 deletions(-) diff --git a/management/server/dns.go b/management/server/dns.go index e52be6016..8df211b0b 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "slices" "strconv" "sync" @@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) @@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s // SaveDNSSettings validates a user role and updates the account's DNS settings func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - user, err := account.FindUser(userID) - if err != nil { - return err - } - - if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings") - } - if dnsSettingsToSave == nil { return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") } - if len(dnsSettingsToSave.DisabledManagementGroups) != 0 { - err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups) - if err != nil { - return err - } - } - - oldSettings := account.DNSSettings.Copy() - account.DNSSettings = dnsSettingsToSave.Copy() - - addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) - removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { return err } - for _, id := range addedGroups { - group := account.GetGroup(id) - meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - for _, id := range removedGroups { - group := account.GetGroup(id) - meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + if !user.HasAdminPower() { + return status.NewAdminPermissionError() } - if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) { + var updateAccountPeers bool + var eventsToStore []func() + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil { + return err + } + + oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID) + if err != nil { + return err + } + + addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) + removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) + + updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups) + if err != nil { + return err + } + + events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) + eventsToStore = append(eventsToStore, events...) + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave) + }) + if err != nil { + return err + } + + for _, storeEvent := range eventsToStore { + storeEvent() + } + + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } return nil } +// prepareDNSSettingsEvents prepares a list of event functions to be stored. +func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() { + var eventsToStore []func() + + modifiedGroups := slices.Concat(addedGroups, removedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + if err != nil { + log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) + return nil + } + + for _, groupID := range addedGroups { + group, ok := groups[groupID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID) + continue + } + + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + }) + + } + + for _, groupID := range removedGroups { + group, ok := groups[groupID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID) + continue + } + + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + }) + } + + return eventsToStore +} + +// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers. +func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) { + hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeers(ctx, transaction, accountID, removedGroups) +} + +// validateDNSSettings validates the DNS settings. +func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error { + if len(settings.DisabledManagementGroups) == 0 { + return nil + } + + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups) + if err != nil { + return err + } + + return validateGroups(settings.DisabledManagementGroups, groups) +} + // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { protoUpdate := &proto.DNSConfig{ diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 9a24857d1..f58ceb1ad 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1162,9 +1162,10 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki First(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "dns settings not found") + return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) + log.WithContext(ctx).Errorf("failed to get dns settings from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get dns settings from store") } return &accountDNSSettings.DNSSettings, nil } @@ -1537,3 +1538,19 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } return &record, nil } + +// SaveDNSSettings saves the DNS settings to the store. +func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings}) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save dns settings to store") + } + + if result.RowsAffected == 0 { + return status.NewAccountNotFoundError(accountID) + } + + return nil +} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index c05793fc6..df5294d73 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1857,3 +1857,66 @@ func TestSqlStore_DeletePolicy(t *testing.T) { require.Error(t, err) require.Nil(t, policy) } + +func TestSqlStore_GetDNSSettings(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectError bool + }{ + { + name: "retrieve existing account dns settings", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectError: false, + }, + { + name: "retrieve non-existing account dns settings", + accountID: "non-existing", + expectError: true, + }, + { + name: "retrieve dns settings with empty account ID", + accountID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, tt.accountID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, dnsSettings) + } else { + require.NoError(t, err) + require.NotNil(t, dnsSettings) + } + }) + } +} + +func TestSqlStore_SaveDNSSettings(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + + dnsSettings.DisabledManagementGroups = []string{"groupA", "groupB"} + err = store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, dnsSettings) + require.NoError(t, err) + + saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.Equal(t, saveDNSSettings, dnsSettings) +} diff --git a/management/server/store.go b/management/server/store.go index ba61d552d..cca014b52 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -59,6 +59,7 @@ type Store interface { SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error + SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)