From 39329e12a1f54cc1770a010279fa4902ccd51693 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 13 Nov 2024 13:46:00 +0100 Subject: [PATCH 001/159] [client] Improve state write timeout and abort work early on timeout (#2882) * Improve state write timeout and abort work early on timeout * Don't block on initial persist state --- client/firewall/iptables/manager_linux.go | 8 +++++--- client/firewall/nftables/manager_linux.go | 8 +++++--- client/internal/config.go | 6 +++--- client/internal/dns/server.go | 10 +++++++--- client/internal/statemanager/manager.go | 20 +++++--------------- client/internal/statemanager/path.go | 22 +++++----------------- management/server/file_store.go | 2 +- util/file.go | 23 ++++++++++++++++++----- util/file_test.go | 7 ++++--- 9 files changed, 53 insertions(+), 53 deletions(-) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index a59bd2c60..adb8f20ef 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -83,9 +83,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } // persist early to ensure cleanup of chains - if err := stateManager.PersistState(context.Background()); err != nil { - log.Errorf("failed to persist state: %v", err) - } + go func() { + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + }() return nil } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index ea8912f27..3f8fac249 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -99,9 +99,11 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { } // persist early - if err := stateManager.PersistState(context.Background()); err != nil { - log.Errorf("failed to persist state: %v", err) - } + go func() { + if err := stateManager.PersistState(context.Background()); err != nil { + log.Errorf("failed to persist state: %v", err) + } + }() return nil } diff --git a/client/internal/config.go b/client/internal/config.go index ee54c6380..ce87835cd 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -164,7 +164,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { if err != nil { return nil, err } - err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) + err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg) return cfg, err } @@ -185,7 +185,7 @@ func CreateInMemoryConfig(input ConfigInput) (*Config, error) { // WriteOutConfig write put the prepared config to the given path func WriteOutConfig(path string, config *Config) error { - return util.WriteJson(path, config) + return util.WriteJson(context.Background(), path, config) } // createNewConfig creates a new config generating a new Wireguard key and saving to file @@ -215,7 +215,7 @@ func update(input ConfigInput) (*Config, error) { } if updated { - if err := util.WriteJson(input.ConfigPath, config); err != nil { + if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil { return nil, err } } diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 929e1e60c..6c4dccae7 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -326,9 +326,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { // persist dns state right away ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) defer cancel() - if err := s.stateManager.PersistState(ctx); err != nil { - log.Errorf("Failed to persist dns state: %v", err) - } + + // don't block + go func() { + if err := s.stateManager.PersistState(ctx); err != nil { + log.Errorf("Failed to persist dns state: %v", err) + } + }() if s.searchDomainNotifier != nil { s.searchDomainNotifier.onNewSearchDomains(s.SearchDomains()) diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index a5a14f807..580ccdfc7 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -16,6 +16,7 @@ import ( "golang.org/x/exp/maps" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/util" ) // State interface defines the methods that all state types must implement @@ -178,25 +179,14 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() done := make(chan error, 1) + start := time.Now() go func() { - data, err := json.MarshalIndent(m.states, "", " ") - if err != nil { - done <- fmt.Errorf("marshal states: %w", err) - return - } - - // nolint:gosec - if err := os.WriteFile(m.filePath, data, 0640); err != nil { - done <- fmt.Errorf("write state file: %w", err) - return - } - - done <- nil + done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states) }() select { @@ -208,7 +198,7 @@ func (m *Manager) PersistState(ctx context.Context) error { } } - log.Debugf("persisted shutdown states: %v", maps.Keys(m.dirty)) + log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start)) clear(m.dirty) diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go index 96d6a9f12..6cfd79a12 100644 --- a/client/internal/statemanager/path.go +++ b/client/internal/statemanager/path.go @@ -4,32 +4,20 @@ import ( "os" "path/filepath" "runtime" - - log "github.com/sirupsen/logrus" ) // GetDefaultStatePath returns the path to the state file based on the operating system -// It returns an empty string if the path cannot be determined. It also creates the directory if it does not exist. +// It returns an empty string if the path cannot be determined. func GetDefaultStatePath() string { - var path string - switch runtime.GOOS { case "windows": - path = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") + return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") case "darwin", "linux": - path = "/var/lib/netbird/state.json" + return "/var/lib/netbird/state.json" case "freebsd", "openbsd", "netbsd", "dragonfly": - path = "/var/db/netbird/state.json" - // ios/android don't need state - default: - return "" + return "/var/db/netbird/state.json" } - dir := filepath.Dir(path) - if err := os.MkdirAll(dir, 0755); err != nil { - log.Errorf("Error creating directory %s: %v. Continuing without state support.", dir, err) - return "" - } + return "" - return path } diff --git a/management/server/file_store.go b/management/server/file_store.go index 561e133ce..f375fb990 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -223,7 +223,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { // It is recommended to call it with locking FileStore.mux func (s *FileStore) persist(ctx context.Context, file string) error { start := time.Now() - err := util.WriteJson(file, s) + err := util.WriteJson(context.Background(), file, s) if err != nil { return err } diff --git a/util/file.go b/util/file.go index ecaecd222..4641cc1b8 100644 --- a/util/file.go +++ b/util/file.go @@ -15,7 +15,7 @@ import ( ) // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory -func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { +func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err @@ -26,18 +26,18 @@ func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { return err } - return writeJson(file, obj, configDir, configFileName) + return writeJson(ctx, file, obj, configDir, configFileName) } // WriteJson writes JSON config object to a file creating parent directories if required // The output JSON is pretty-formatted -func WriteJson(file string, obj interface{}) error { +func WriteJson(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err } - return writeJson(file, obj, configDir, configFileName) + return writeJson(ctx, file, obj, configDir, configFileName) } // DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file @@ -79,7 +79,11 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { return nil } -func writeJson(file string, obj interface{}, configDir string, configFileName string) error { +func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error { + // Check context before expensive operations + if ctx.Err() != nil { + return ctx.Err() + } // make it pretty bs, err := json.MarshalIndent(obj, "", " ") @@ -87,6 +91,10 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st return err } + if ctx.Err() != nil { + return ctx.Err() + } + tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) if err != nil { return err @@ -111,6 +119,11 @@ func writeJson(file string, obj interface{}, configDir string, configFileName st return err } + // Check context again + if ctx.Err() != nil { + return ctx.Err() + } + err = os.Rename(tempFileName, file) if err != nil { return err diff --git a/util/file_test.go b/util/file_test.go index 566d8eda6..f8c9dfabb 100644 --- a/util/file_test.go +++ b/util/file_test.go @@ -1,6 +1,7 @@ package util import ( + "context" "crypto/md5" "encoding/hex" "io" @@ -39,7 +40,7 @@ func TestConfigJSON(t *testing.T) { t.Run(tt.name, func(t *testing.T) { tmpDir := t.TempDir() - err := WriteJson(tmpDir+"/testconfig.json", tt.config) + err := WriteJson(context.Background(), tmpDir+"/testconfig.json", tt.config) require.NoError(t, err) read, err := ReadJson(tmpDir+"/testconfig.json", &TestConfig{}) @@ -73,7 +74,7 @@ func TestCopyFileContents(t *testing.T) { src := tmpDir + "/copytest_src" dst := tmpDir + "/copytest_dst" - err := WriteJson(src, tt.srcContent) + err := WriteJson(context.Background(), src, tt.srcContent) require.NoError(t, err) err = CopyFileContents(src, dst) @@ -127,7 +128,7 @@ func TestHandleConfigFileWithoutFullPath(t *testing.T) { _ = os.Remove(cfgFile) }() - err := WriteJson(cfgFile, tt.config) + err := WriteJson(context.Background(), cfgFile, tt.config) require.NoError(t, err) read, err := ReadJson(cfgFile, &TestConfig{}) From ed047ec9dda048120edf4f074162a27136ac3cd6 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 13 Nov 2024 16:16:30 +0300 Subject: [PATCH 002/159] Add account locking and merge group deletion methods Signed-off-by: bcmmbaga --- management/server/group.go | 66 ++++++++++------------------------ management/server/sql_store.go | 2 +- 2 files changed, 20 insertions(+), 48 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index 57960e7f9..154a33b13 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -215,48 +215,9 @@ func difference(a, b []string) []string { // DeleteGroup object of the peers. func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) - if err != nil { - return err - } - - if user.AccountID != accountID { - return status.NewUserNotPartOfAccountError() - } - - if user.IsRegularUser() { - return status.NewAdminPermissionError() - } - - var group *nbgroup.Group - - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) - if err != nil { - return err - } - - if group.IsGroupAll() { - return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") - } - - if err = validateDeleteGroup(ctx, transaction, group, userID); err != nil { - return err - } - - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { - return err - } - - return transaction.DeleteGroup(ctx, LockingStrengthUpdate, accountID, groupID) - }) - if err != nil { - return err - } - - am.StoreEvent(ctx, userID, groupID, accountID, activity.GroupDeleted, group.EventMeta()) - - return nil + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + return am.DeleteGroups(ctx, accountID, userID, []string{groupID}) } // DeleteGroups deletes groups from an account. @@ -285,13 +246,14 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { for _, groupID := range groupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID) if err != nil { + allErrors = errors.Join(allErrors, err) continue } if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { - allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) + allErrors = errors.Join(allErrors, err) continue } @@ -318,12 +280,15 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var group *nbgroup.Group var updateAccountPeers bool var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -356,12 +321,15 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr // GroupDeletePeer removes peer from the group func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + var group *nbgroup.Group var updateAccountPeers bool var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -430,13 +398,17 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. if group.Issued == nbgroup.GroupIssuedIntegration { executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return status.Errorf(status.NotFound, "user not found") + return err } if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") } } + if group.IsGroupAll() { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 7c741d35c..0ebda6440 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1278,7 +1278,7 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) - return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete groups from store") } return nil From a4d905ffe77881b682a4798d5564b89860404a0a Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 13 Nov 2024 16:56:22 +0300 Subject: [PATCH 003/159] Fix tests Signed-off-by: bcmmbaga --- management/server/group_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/group_test.go b/management/server/group_test.go index 89184e819..59094a23e 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { { name: "delete non-existent group", groupIDs: []string{"non-existent-group"}, - expectedDeleted: []string{"non-existent-group"}, + expectedReasons: []string{"group: non-existent-group not found"}, }, { name: "delete multiple groups with mixed results", From b48afd92fdc2080cd0b4bf7bd4c046684b76338a Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 13 Nov 2024 15:02:51 +0100 Subject: [PATCH 004/159] [relay-server] Always close ws conn when work thread exit (#2879) Close ws conn when work thread exit --- relay/server/peer.go | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/relay/server/peer.go b/relay/server/peer.go index c909c35d5..f65fb786a 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -16,6 +16,8 @@ import ( const ( bufferSize = 8820 + + errCloseConn = "failed to close connection to peer: %s" ) // Peer represents a peer connection @@ -46,6 +48,12 @@ func NewPeer(metrics *metrics.Metrics, id []byte, conn net.Conn, store *Store) * // It manages the protocol (healthcheck, transport, close). Read the message and determine the message type and handle // the message accordingly. func (p *Peer) Work() { + defer func() { + if err := p.conn.Close(); err != nil && !errors.Is(err, net.ErrClosed) { + p.log.Errorf(errCloseConn, err) + } + }() + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -97,7 +105,7 @@ func (p *Peer) handleMsgType(ctx context.Context, msgType messages.MsgType, hc * case messages.MsgTypeClose: p.log.Infof("peer exited gracefully") if err := p.conn.Close(); err != nil { - log.Errorf("failed to close connection to peer: %s", err) + log.Errorf(errCloseConn, err) } default: p.log.Warnf("received unexpected message type: %s", msgType) @@ -121,9 +129,8 @@ func (p *Peer) CloseGracefully(ctx context.Context) { p.log.Errorf("failed to send close message to peer: %s", p.String()) } - err = p.conn.Close() - if err != nil { - p.log.Errorf("failed to close connection to peer: %s", err) + if err := p.conn.Close(); err != nil { + p.log.Errorf(errCloseConn, err) } } @@ -132,7 +139,7 @@ func (p *Peer) Close() { defer p.connMu.Unlock() if err := p.conn.Close(); err != nil { - p.log.Errorf("failed to close connection to peer: %s", err) + p.log.Errorf(errCloseConn, err) } } From 6886691213aa622b1f0f7440c83a3a4c4831cf3d Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 13 Nov 2024 15:21:33 +0100 Subject: [PATCH 005/159] Update route calculation tests (#2884) - Add two new test cases for p2p and relay routes with same latency - Add extra statuses generation --- client/internal/routemanager/client_test.go | 98 +++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/client/internal/routemanager/client_test.go b/client/internal/routemanager/client_test.go index 583156e4d..56fcf1613 100644 --- a/client/internal/routemanager/client_test.go +++ b/client/internal/routemanager/client_test.go @@ -1,6 +1,7 @@ package routemanager import ( + "fmt" "net/netip" "testing" "time" @@ -227,6 +228,64 @@ func TestGetBestrouteFromStatuses(t *testing.T) { currentRoute: "route1", expectedRouteID: "route1", }, + { + name: "relayed routes with latency 0 should maintain previous choice", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + connected: true, + relayed: true, + latency: 0 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: true, + latency: 0 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "route1", + expectedRouteID: "route1", + }, + { + name: "p2p routes with latency 0 should maintain previous choice", + statuses: map[route.ID]routerPeerStatus{ + "route1": { + connected: true, + relayed: false, + latency: 0 * time.Millisecond, + }, + "route2": { + connected: true, + relayed: false, + latency: 0 * time.Millisecond, + }, + }, + existingRoutes: map[route.ID]*route.Route{ + "route1": { + ID: "route1", + Metric: route.MaxMetric, + Peer: "peer1", + }, + "route2": { + ID: "route2", + Metric: route.MaxMetric, + Peer: "peer2", + }, + }, + currentRoute: "route1", + expectedRouteID: "route1", + }, { name: "current route with bad score should be changed to route with better score", statuses: map[route.ID]routerPeerStatus{ @@ -287,6 +346,45 @@ func TestGetBestrouteFromStatuses(t *testing.T) { }, } + // fill the test data with random routes + for _, tc := range testCases { + for i := 0; i < 50; i++ { + dummyRoute := &route.Route{ + ID: route.ID(fmt.Sprintf("dummy_p1_%d", i)), + Metric: route.MinMetric, + Peer: fmt.Sprintf("dummy_p1_%d", i), + } + tc.existingRoutes[dummyRoute.ID] = dummyRoute + } + for i := 0; i < 50; i++ { + dummyRoute := &route.Route{ + ID: route.ID(fmt.Sprintf("dummy_p2_%d", i)), + Metric: route.MinMetric, + Peer: fmt.Sprintf("dummy_p1_%d", i), + } + tc.existingRoutes[dummyRoute.ID] = dummyRoute + } + + for i := 0; i < 50; i++ { + id := route.ID(fmt.Sprintf("dummy_p1_%d", i)) + dummyStatus := routerPeerStatus{ + connected: false, + relayed: true, + latency: 0, + } + tc.statuses[id] = dummyStatus + } + for i := 0; i < 50; i++ { + id := route.ID(fmt.Sprintf("dummy_p2_%d", i)) + dummyStatus := routerPeerStatus{ + connected: false, + relayed: true, + latency: 0, + } + tc.statuses[id] = dummyStatus + } + } + for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { currentRoute := &route.Route{ From be78efbd429c0b7180efe36fe4a826436e5a675b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 14 Nov 2024 20:15:16 +0100 Subject: [PATCH 006/159] [client] Handle panic on nil wg interface (#2891) --- client/internal/engine.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 190d795cd..0f3a5d28a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -38,7 +38,6 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" - nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" @@ -171,7 +170,7 @@ type Engine struct { relayManager *relayClient.Manager stateManager *statemanager.Manager - srWatcher *guard.SRWatcher + srWatcher *guard.SRWatcher } // Peer is an instance of the Connection Peer @@ -641,6 +640,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { } func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { + if e.wgInterface == nil { + return errors.New("wireguard interface is not initialized") + } + if e.wgInterface.Address().String() != conf.Address { oldAddr := e.wgInterface.Address().String() log.Debugf("updating peer address from %s to %s", oldAddr, conf.Address) From 44e799c687ed1fd5e6a658aff3a06ac1594cec69 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 15 Nov 2024 11:16:16 +0100 Subject: [PATCH 007/159] [management] Fix limited peer view groups (#2894) --- management/server/group.go | 12 ++++-------- management/server/http/peers_handler.go | 20 ++++++++++++++++---- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/management/server/group.go b/management/server/group.go index b2ec88cc0..7b4f07948 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -6,11 +6,12 @@ import ( "fmt" "slices" - nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/route" "github.com/rs/xid" log "github.com/sirupsen/logrus" + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" @@ -27,17 +28,12 @@ func (e *GroupLinkError) Error() string { // CheckGroupPermissions validates if a user has the necessary permissions to view groups func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) - if err != nil { - return err - } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - if (!user.IsAdminOrServiceUser() && settings.RegularUsersViewBlocked) || user.AccountID != accountID { + if !user.IsAdminOrServiceUser() || user.AccountID != accountID { return status.Errorf(status.PermissionDenied, "groups are blocked for users") } diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index a5856a0e4..f5027cd77 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -184,14 +184,26 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - respBody := make([]*api.PeerBatch, 0, len(account.Peers)) - for _, peer := range account.Peers { + peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groupsMap := map[string]*nbgroup.Group{} + groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + for _, group := range groups { + groupsMap[group.ID] = group + } + + respBody := make([]*api.PeerBatch, 0, len(peers)) + for _, peer := range peers { peerToReturn, err := h.checkPeerStatus(peer) if err != nil { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } @@ -304,7 +316,7 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { - var groupsInfo []api.GroupMinimum + groupsInfo := []api.GroupMinimum{} groupsChecked := make(map[string]struct{}) for _, group := range groups { _, ok := groupsChecked[group.ID] From 4ef3890bf757f149da7bce3588f62bfbf85b9d8c Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 15 Nov 2024 17:48:00 +0300 Subject: [PATCH 008/159] Fix typo Signed-off-by: bcmmbaga --- management/server/dns.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/dns.go b/management/server/dns.go index be7caea4e..8df211b0b 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -161,7 +161,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID return nil } -// prepareGroupEvents prepares a list of event functions to be stored. +// prepareDNSSettingsEvents prepares a list of event functions to be stored. func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() { var eventsToStore []func() From 4aee3c9e33d6c733d38ac2b6e6287ea78e5ac591 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 15 Nov 2024 16:59:03 +0100 Subject: [PATCH 009/159] [client/management] add peer lock to peer meta update and fix isEqual func (#2840) --- client/internal/engine.go | 12 +++++ client/internal/engine_test.go | 93 ++++++++++++++++++++++++++++++++++ management/server/account.go | 5 +- management/server/peer.go | 3 ++ 4 files changed, 112 insertions(+), 1 deletion(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 0f3a5d28a..d4a3a561a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -11,6 +11,7 @@ import ( "reflect" "runtime" "slices" + "sort" "strings" "sync" "sync/atomic" @@ -1484,6 +1485,17 @@ func (e *Engine) stopDNSServer() { // isChecksEqual checks if two slices of checks are equal. func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { + for _, check := range checks { + sort.Slice(check.Files, func(i, j int) bool { + return check.Files[i] < check.Files[j] + }) + } + for _, oCheck := range oChecks { + sort.Slice(oCheck.Files, func(i, j int) bool { + return oCheck.Files[i] < oCheck.Files[j] + }) + } + return slices.EqualFunc(checks, oChecks, func(checks, oChecks *mgmProto.Checks) bool { return slices.Equal(checks.Files, oChecks.Files) }) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 0018af6df..b6c6186ea 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -1006,6 +1006,99 @@ func Test_ParseNATExternalIPMappings(t *testing.T) { } } +func Test_CheckFilesEqual(t *testing.T) { + testCases := []struct { + name string + inputChecks1 []*mgmtProto.Checks + inputChecks2 []*mgmtProto.Checks + expectedBool bool + }{ + { + name: "Equal Files In Equal Order Should Return True", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + expectedBool: true, + }, + { + name: "Equal Files In Reverse Order Should Return True", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile2", + "testfile1", + }, + }, + }, + expectedBool: true, + }, + { + name: "Unequal Files Should Return False", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile3", + }, + }, + }, + expectedBool: false, + }, + { + name: "Compared With Empty Should Return False", + inputChecks1: []*mgmtProto.Checks{ + { + Files: []string{ + "testfile1", + "testfile2", + }, + }, + }, + inputChecks2: []*mgmtProto.Checks{ + { + Files: []string{}, + }, + }, + expectedBool: false, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + result := isChecksEqual(testCase.inputChecks1, testCase.inputChecks2) + assert.Equal(t, testCase.expectedBool, result, "result should match expected bool") + }) + } +} + func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey string, i int, mgmtAddr string, signalAddr string) (*Engine, error) { key, err := wgtypes.GeneratePrivateKey() if err != nil { diff --git a/management/server/account.go b/management/server/account.go index bf6039229..4afadb4e9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2319,7 +2319,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account) if err != nil { - log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) + log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } return nil @@ -2335,6 +2335,9 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st unlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer unlock() + unlockPeer := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) + defer unlockPeer() + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return err diff --git a/management/server/peer.go b/management/server/peer.go index 9784650de..87b5b0e4e 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -168,6 +168,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context account.UpdatePeer(peer) + log.WithContext(ctx).Tracef("saving peer status for peer %s is connected: %t", peer.ID, connected) + err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus) if err != nil { return false, fmt.Errorf("failed to save peer status: %w", err) @@ -657,6 +659,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac updated := peer.UpdateMetaIfNew(sync.Meta) if updated { + log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) err = am.Store.SavePeer(ctx, account.Id, peer) if err != nil { return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err) From d9b691b8a56bb8bdd11f02364e0dbf59ae0580bc Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 15 Nov 2024 17:00:06 +0100 Subject: [PATCH 010/159] [management] Limit the setup-key update operation (#2841) --- management/server/http/api/openapi.yml | 25 --------------- management/server/http/api/types.gen.go | 15 --------- management/server/http/setupkeys_handler.go | 6 ---- management/server/setupkey.go | 12 ++++--- management/server/setupkey_test.go | 35 ++++++++++++++++++--- 5 files changed, 38 insertions(+), 55 deletions(-) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 9b4592ccf..bfb375277 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -521,19 +521,6 @@ components: SetupKeyRequest: type: object properties: - name: - description: Setup Key name - type: string - example: Default key - type: - description: Setup key type, one-off for single time usage and reusable - type: string - example: reusable - expires_in: - description: Expiration time in seconds, 0 will mean the key never expires - type: integer - minimum: 0 - example: 86400 revoked: description: Setup key revocation status type: boolean @@ -544,21 +531,9 @@ components: items: type: string example: "ch8i4ug6lnn4g9hqv7m0" - usage_limit: - description: A number of times this key can be used. The value of 0 indicates the unlimited usage. - type: integer - example: 0 - ephemeral: - description: Indicate that the peer will be ephemeral or not - type: boolean - example: true required: - - name - - type - - expires_in - revoked - auto_groups - - usage_limit CreateSetupKeyRequest: type: object properties: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index c1ef1ba21..f219c4574 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -1098,23 +1098,8 @@ type SetupKeyRequest struct { // AutoGroups List of group IDs to auto-assign to peers registered with this key AutoGroups []string `json:"auto_groups"` - // Ephemeral Indicate that the peer will be ephemeral or not - Ephemeral *bool `json:"ephemeral,omitempty"` - - // ExpiresIn Expiration time in seconds, 0 will mean the key never expires - ExpiresIn int `json:"expires_in"` - - // Name Setup Key name - Name string `json:"name"` - // Revoked Setup key revocation status Revoked bool `json:"revoked"` - - // Type Setup key type, one-off for single time usage and reusable - Type string `json:"type"` - - // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage. - UsageLimit int `json:"usage_limit"` } // User defines model for User. diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 31859f59b..9ba5977bb 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -137,11 +137,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request return } - if req.Name == "" { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key name field is invalid: %s", req.Name), w) - return - } - if req.AutoGroups == nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "setup key AutoGroups field is invalid"), w) return @@ -150,7 +145,6 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request newKey := &server.SetupKey{} newKey.AutoGroups = req.AutoGroups newKey.Revoked = req.Revoked - newKey.Name = req.Name newKey.Id = keyID newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 554c66ba4..960532bf9 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -12,9 +12,10 @@ import ( "unicode/utf8" "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" - log "github.com/sirupsen/logrus" ) const ( @@ -276,7 +277,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s // SaveSetupKey saves the provided SetupKey to the database overriding the existing one. // Due to the unique nature of a SetupKey certain properties must not be overwritten // (e.g. the key itself, creation date, ID, etc). -// These properties are overwritten: Name, AutoGroups, Revoked. The rest is copied from the existing key. +// These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key. func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { if keyToSave == nil { return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") @@ -312,9 +313,12 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return err } - // only auto groups, revoked status, and name can be updated for now + if oldKey.Revoked && !keyToSave.Revoked { + return status.Errorf(status.InvalidArgument, "can't un-revoke a revoked setup key") + } + + // only auto groups, revoked status (from false to true) can be updated newKey = oldKey.Copy() - newKey.Name = keyToSave.Name newKey.AutoGroups = keyToSave.AutoGroups newKey.Revoked = keyToSave.Revoked newKey.UpdatedAt = time.Now().UTC() diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 2ed8aef95..94ed022fa 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -56,11 +56,9 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } autoGroups := []string{"group_1", "group_2"} - newKeyName := "my-new-test-key" revoked := true newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ Id: key.Id, - Name: newKeyName, Revoked: revoked, AutoGroups: autoGroups, }, userID) @@ -68,7 +66,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - assertKey(t, newKey, newKeyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt, + assertKey(t, newKey, keyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt, key.Id, time.Now().UTC(), autoGroups, true) // check the corresponding events that should have been generated @@ -76,7 +74,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { assert.NotNil(t, ev) assert.Equal(t, account.Id, ev.AccountID) - assert.Equal(t, newKeyName, ev.Meta["name"]) + assert.Equal(t, keyName, ev.Meta["name"]) assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"])) assert.NotEmpty(t, ev.Meta["key"]) assert.Equal(t, userID, ev.InitiatorID) @@ -89,7 +87,6 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { autoGroups = append(autoGroups, groupAll.ID) _, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ Id: key.Id, - Name: newKeyName, Revoked: revoked, AutoGroups: autoGroups, }, userID) @@ -449,3 +446,31 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { } }) } + +func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + userID := "testingUser" + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + if err != nil { + t.Fatal(err) + } + + key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false) + assert.NoError(t, err) + + // revoke the key + updateKey := key.Copy() + updateKey.Revoked = true + _, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID) + assert.NoError(t, err) + + // re-activate revoked key + updateKey.Revoked = false + _, err = manager.SaveSetupKey(context.Background(), account.Id, updateKey, userID) + assert.Error(t, err, "should not allow to update revoked key") + +} From 51c1ec283cb9d9dacc9ec18ab6d98b64d954d362 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 15 Nov 2024 19:34:57 +0300 Subject: [PATCH 011/159] Add locks and remove log Signed-off-by: bcmmbaga --- management/server/posture_checks.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index d7b5a79a2..59e726c41 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -9,7 +9,6 @@ import ( "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/rs/xid" - log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" ) @@ -32,6 +31,9 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID // SavePostureChecks saves a posture check. func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err @@ -85,6 +87,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI // DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err @@ -267,7 +272,6 @@ func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountI for _, sourceGroup := range rule.Sources { group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup) if err != nil { - log.WithContext(ctx).Debugf("failed to check peer in policy source group: %v", err) return false, fmt.Errorf("failed to check peer in policy source group: %w", err) } From 12f442439a64a4b76e8b9c7e93c0aa5de74d213c Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 15 Nov 2024 20:09:32 +0300 Subject: [PATCH 012/159] [management] Refactor group to use store methods (#2867) * Refactor setup key handling to use store methods Signed-off-by: bcmmbaga * add lock to get account groups Signed-off-by: bcmmbaga * add check for regular user Signed-off-by: bcmmbaga * get only required groups for auto-group validation Signed-off-by: bcmmbaga * add account lock and return auto groups map on validation Signed-off-by: bcmmbaga * refactor account peers update Signed-off-by: bcmmbaga * Refactor groups to use store methods Signed-off-by: bcmmbaga * refactor GetGroupByID and add NewGroupNotFoundError Signed-off-by: bcmmbaga * fix tests Signed-off-by: bcmmbaga * Add AddPeer and RemovePeer methods to Group struct Signed-off-by: bcmmbaga * Preserve store engine in SqlStore transactions Signed-off-by: bcmmbaga * Run groups ops in transaction Signed-off-by: bcmmbaga * fix missing group removed from setup key activity Signed-off-by: bcmmbaga * fix merge Signed-off-by: bcmmbaga * fix merge Signed-off-by: bcmmbaga * fix sonar Signed-off-by: bcmmbaga * Change setup key log level to debug for missing group Signed-off-by: bcmmbaga * Retrieve modified peers once for group events Signed-off-by: bcmmbaga * Add tests Signed-off-by: bcmmbaga * Add account locking and merge group deletion methods Signed-off-by: bcmmbaga * Fix tests Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/account.go | 29 +- management/server/account_test.go | 12 +- management/server/dns.go | 2 +- management/server/group.go | 514 ++++++++++-------- management/server/group/group.go | 27 + management/server/group/group_test.go | 90 +++ management/server/group_test.go | 2 +- management/server/integrated_validator.go | 27 +- management/server/mock_server/account_mock.go | 9 - management/server/nameserver.go | 6 +- management/server/peer.go | 41 +- management/server/peer_test.go | 2 +- management/server/policy.go | 4 +- management/server/posture_checks.go | 2 +- management/server/route.go | 6 +- management/server/route_test.go | 2 +- management/server/setupkey.go | 6 +- management/server/sql_store.go | 120 +++- management/server/sql_store_test.go | 285 +++++++++- management/server/status/error.go | 11 +- management/server/store.go | 8 +- management/server/user.go | 8 +- 22 files changed, 878 insertions(+), 335 deletions(-) create mode 100644 management/server/group/group_test.go diff --git a/management/server/account.go b/management/server/account.go index 4afadb4e9..1bd8a99a9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -110,7 +110,6 @@ type AccountManager interface { SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error - ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error) GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) @@ -1435,7 +1434,7 @@ func isNil(i idp.Manager) bool { // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { - accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) if err != nil { return err } @@ -2083,7 +2082,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error saving groups: %w", err) } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } @@ -2101,7 +2100,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2114,7 +2113,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2127,14 +2126,19 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if settings.GroupsPropagationEnabled { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + removedGroupAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, removeOldGroups) if err != nil { - return status.NewGetAccountError(err) + return err } - if areGroupChangesAffectPeers(account, addNewGroups) || areGroupChangesAffectPeers(account, removeOldGroups) { + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, addNewGroups) + if err != nil { + return err + } + + if removedGroupAffectsPeers || newGroupsAffectsPeers { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } } @@ -2401,12 +2405,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) - updatedAccount, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) - return - } - am.updateAccountPeers(ctx, updatedAccount) + am.updateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { diff --git a/management/server/account_test.go b/management/server/account_test.go index fdf004a3b..97e0d45f0 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1413,11 +1413,13 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - group := group.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, - } + }) + + require.NoError(t, err, "failed to save group") policy := Policy{ Enabled: true, @@ -1460,7 +1462,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } - if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { + if err := manager.DeleteGroup(context.Background(), account.Id, userID, "groupA"); err != nil { t.Errorf("delete group: %v", err) return } @@ -2714,7 +2716,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 0) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) @@ -2734,7 +2736,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) diff --git a/management/server/dns.go b/management/server/dns.go index 256b8b125..4551be5ab 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -146,7 +146,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/group.go b/management/server/group.go index 7b4f07948..a36213f04 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -33,8 +33,12 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco return err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return status.Errorf(status.PermissionDenied, "groups are blocked for users") + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewAdminPermissionError() } return nil @@ -45,8 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - - return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) } // GetAllGroups returns all groups in an account @@ -54,13 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) + return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName) } // SaveGroup object of the peers @@ -73,79 +75,74 @@ func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userI // SaveGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error { - account, err := am.Store.GetAccount(ctx, accountID) +func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var eventsToStore []func() + var groupsToSave []*nbgroup.Group + var updateAccountPeers bool - for _, newGroup := range newGroups { - if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { - return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) - } - - if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := account.FindGroupByName(newGroup.Name) - if err != nil { - s, ok := status.FromError(err) - if !ok || s.ErrorType != status.NotFound { - return err - } + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + groupIDs := make([]string, 0, len(groups)) + for _, newGroup := range groups { + if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { + return err } - // Avoid duplicate groups only for the API issued groups. - // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. - if existingGroup != nil { - return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) - } + newGroup.AccountID = accountID + groupsToSave = append(groupsToSave, newGroup) + groupIDs = append(groupIDs, newGroup.ID) - newGroup.ID = xid.New().String() + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) } - for _, peerID := range newGroup.Peers { - if account.Peers[peerID] == nil { - return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) - } + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, groupIDs) + if err != nil { + return err } - oldGroup := account.Groups[newGroup.ID] - account.Groups[newGroup.ID] = newGroup + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } - events := am.prepareGroupEvents(ctx, userID, accountID, newGroup, oldGroup, account) - eventsToStore = append(eventsToStore, events...) - } - - newGroupIDs := make([]string, 0, len(newGroups)) - for _, newGroup := range newGroups { - newGroupIDs = append(newGroupIDs, newGroup.ID) - } - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave) + }) + if err != nil { return err } - if areGroupChangesAffectPeers(account, newGroupIDs) { - am.updateAccountPeers(ctx, account) - } - for _, storeEvent := range eventsToStore { storeEvent() } + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + return nil } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID string, accountID string, newGroup, oldGroup *nbgroup.Group, account *Account) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - if oldGroup != nil { + oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) + if err == nil && oldGroup != nil { addedPeers = difference(newGroup.Peers, oldGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers) } else { @@ -155,35 +152,42 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID }) } - for _, p := range addedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + modifiedPeers := slices.Concat(addedPeers, removedPeers) + peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers) + if err != nil { + log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) + return nil + } + + for _, peerID := range addedPeers { + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupAddedToPeer activity: peer not found in store", peerID) continue } - peerCopy := peer // copy to avoid closure issues + eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupAddedToPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), - "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), - }) + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupAddedToPeer, meta) }) } - for _, p := range removedPeers { - peer := account.Peers[p] - if peer == nil { - log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", p, accountID) + for _, peerID := range removedPeers { + peer, ok := peers[peerID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding peer: %s GroupRemovedFromPeer activity: peer not found in store", peerID) continue } - peerCopy := peer // copy to avoid closure issues + eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, userID, peerCopy.ID, accountID, activity.GroupRemovedFromPeer, - map[string]any{ - "group": newGroup.Name, "group_id": newGroup.ID, "peer_ip": peerCopy.IP.String(), - "peer_fqdn": peerCopy.FQDN(am.GetDNSDomain()), - }) + meta := map[string]any{ + "group": newGroup.Name, "group_id": newGroup.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } + am.StoreEvent(ctx, userID, peer.ID, accountID, activity.GroupRemovedFromPeer, meta) }) } @@ -206,42 +210,10 @@ func difference(a, b []string) []string { } // DeleteGroup object of the peers. -func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - - account, err := am.Store.GetAccount(ctx, accountId) - if err != nil { - return err - } - - group, ok := account.Groups[groupID] - if !ok { - return nil - } - - allGroup, err := account.GetGroupAll() - if err != nil { - return err - } - - if allGroup.ID == groupID { - return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") - } - - if err = validateDeleteGroup(account, group, userId); err != nil { - return err - } - delete(account.Groups, groupID) - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) - - return nil + return am.DeleteGroups(ctx, accountID, userID, []string{groupID}) } // DeleteGroups deletes groups from an account. @@ -250,93 +222,94 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use // // If an error occurs while deleting a group, the function skips it and continues deleting other groups. // Errors are collected and returned at the end. -func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error { - account, err := am.Store.GetAccount(ctx, accountId) +func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + var allErrors error + var groupIDsToDelete []string + var deletedGroups []*nbgroup.Group - deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs)) - for _, groupID := range groupIDs { - group, ok := account.Groups[groupID] - if !ok { - continue + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, groupID := range groupIDs { + group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID) + if err != nil { + allErrors = errors.Join(allErrors, err) + continue + } + + if err := validateDeleteGroup(ctx, transaction, group, userID); err != nil { + allErrors = errors.Join(allErrors, err) + continue + } + + groupIDsToDelete = append(groupIDsToDelete, groupID) + deletedGroups = append(deletedGroups, group) } - if err := validateDeleteGroup(account, group, userId); err != nil { - allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err)) - continue + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err } - delete(account.Groups, groupID) - deletedGroups = append(deletedGroups, group) - } - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete) + }) + if err != nil { return err } - for _, g := range deletedGroups { - am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) + for _, group := range deletedGroups { + am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta()) } return allErrors } -// ListGroups objects of the peers -func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - - groups := make([]*nbgroup.Group, 0, len(account.Groups)) - for _, item := range account.Groups { - groups = append(groups, item) - } - - return groups, nil -} - // GroupAddPeer appends peer to the group func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + var group *nbgroup.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.AddPeer(peerID); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + }) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } - - add := true - for _, itemID := range group.Peers { - if itemID == peerID { - add = false - break - } - } - if add { - group.Peers = append(group.Peers, peerID) - } - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - if areGroupChangesAffectPeers(account, []string{group.ID}) { - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) } return nil @@ -347,90 +320,162 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + var group *nbgroup.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.RemovePeer(peerID); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + }) if err != nil { return err } - group, ok := account.Groups[groupID] - if !ok { - return status.Errorf(status.NotFound, "group with ID %s not found", groupID) - } - - account.Network.IncSerial() - for i, itemID := range group.Peers { - if itemID == peerID { - group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) - if err := am.Store.SaveAccount(ctx, account); err != nil { - return err - } - } - } - - if areGroupChangesAffectPeers(account, []string{group.ID}) { - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) } return nil } -func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error { +// validateNewGroup validates the new group for existence and required fields. +func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error { + if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { + return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) + } + + if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { + existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) + if err != nil { + if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { + return err + } + } + + // Prevent duplicate groups for API-issued groups. + // Integration or JWT groups can be duplicated as they are coming from the IdP that we don't have control of. + if existingGroup != nil { + return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name) + } + + newGroup.ID = xid.New().String() + } + + for _, peerID := range newGroup.Peers { + _, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + if err != nil { + return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) + } + } + + return nil +} + +func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser := account.Users[userID] - if executingUser == nil { - return status.Errorf(status.NotFound, "user not found") + executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err } if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") } } - if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked { + if group.IsGroupAll() { + return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") + } + + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } - if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked { + if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"name server groups", linkedDns.Name} } - if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked { + if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"policy", linkedPolicy.Name} } - if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked { + if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"setup key", linkedSetupKey.Name} } - if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked { + if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"user", linkedUser.Id} } - if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) { + return checkGroupLinkedToSettings(ctx, transaction, group) +} + +// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. +func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error { + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) { return &GroupLinkError{"disabled DNS management groups", group.Name} } - if account.Settings.Extra != nil { - if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) { - return &GroupLinkError{"integrated validator", group.Name} - } + settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + if err != nil { + return err + } + + if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { + return &GroupLinkError{"integrated validator", group.Name} } return nil } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) { +func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) { + routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) + return false, nil + } + for _, r := range routes { if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { return true, r } } + return false, nil } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { +func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) + return false, nil + } + for _, policy := range policies { for _, rule := range policy.Rules { if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { @@ -442,7 +487,13 @@ func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) { } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) { +func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) + return false, nil + } + for _, dns := range nameServerGroups { for _, g := range dns.Groups { if g == groupID { @@ -450,11 +501,18 @@ func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, grou } } } + return false, nil } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) { +func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) { + setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) + return false, nil + } + for _, setupKey := range setupKeys { if slices.Contains(setupKey.AutoGroups, groupID) { return true, setupKey @@ -464,7 +522,13 @@ func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bo } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { +func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) { + users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) + return false, nil + } + for _, user := range users { if slices.Contains(user.AutoGroups, groupID) { return true, user @@ -473,6 +537,35 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { return false, nil } +// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. +func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { + return false, nil + } + + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err + } + + for _, groupID := range groupIDs { + if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) { + return true, nil + } + if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked { + return true, nil + } + if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked { + return true, nil + } + if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked { + return true, nil + } + } + + return false, nil +} + // anyGroupHasPeers checks if any of the given groups in the account have peers. func anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { @@ -482,22 +575,3 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool { } return false } - -func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool { - for _, groupID := range groupIDs { - if slices.Contains(account.DNSSettings.DisabledManagementGroups, groupID) { - return true - } - if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked { - return true - } - if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked { - return true - } - if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked { - return true - } - } - - return false -} diff --git a/management/server/group/group.go b/management/server/group/group.go index e98e5ecc4..24c60d3ce 100644 --- a/management/server/group/group.go +++ b/management/server/group/group.go @@ -54,3 +54,30 @@ func (g *Group) HasPeers() bool { func (g *Group) IsGroupAll() bool { return g.Name == "All" } + +// AddPeer adds peerID to Peers if not present, returning true if added. +func (g *Group) AddPeer(peerID string) bool { + if peerID == "" { + return false + } + + for _, itemID := range g.Peers { + if itemID == peerID { + return false + } + } + + g.Peers = append(g.Peers, peerID) + return true +} + +// RemovePeer removes peerID from Peers if present, returning true if removed. +func (g *Group) RemovePeer(peerID string) bool { + for i, itemID := range g.Peers { + if itemID == peerID { + g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) + return true + } + } + return false +} diff --git a/management/server/group/group_test.go b/management/server/group/group_test.go new file mode 100644 index 000000000..cb002f8d9 --- /dev/null +++ b/management/server/group/group_test.go @@ -0,0 +1,90 @@ +package group + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddPeer(t *testing.T) { + t.Run("add new peer to empty slice", func(t *testing.T) { + group := &Group{Peers: []string{}} + peerID := "peer1" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add new peer to nil slice", func(t *testing.T) { + group := &Group{Peers: nil} + peerID := "peer1" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add new peer to non-empty slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer3" + assert.True(t, group.AddPeer(peerID)) + assert.Contains(t, group.Peers, peerID) + }) + + t.Run("add duplicate peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer1" + assert.False(t, group.AddPeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("add empty peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "" + assert.False(t, group.AddPeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) +} + +func TestRemovePeer(t *testing.T) { + t.Run("remove existing peer from slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2", "peer3"}} + peerID := "peer2" + assert.True(t, group.RemovePeer(peerID)) + assert.NotContains(t, group.Peers, peerID) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("remove peer from empty slice", func(t *testing.T) { + group := &Group{Peers: []string{}} + peerID := "peer1" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 0, len(group.Peers)) + }) + + t.Run("remove peer from nil slice", func(t *testing.T) { + group := &Group{Peers: nil} + peerID := "peer1" + assert.False(t, group.RemovePeer(peerID)) + assert.Nil(t, group.Peers) + }) + + t.Run("remove non-existent peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "peer3" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) + + t.Run("remove peer from single-item slice", func(t *testing.T) { + group := &Group{Peers: []string{"peer1"}} + peerID := "peer1" + assert.True(t, group.RemovePeer(peerID)) + assert.Equal(t, 0, len(group.Peers)) + assert.NotContains(t, group.Peers, peerID) + }) + + t.Run("remove empty peer", func(t *testing.T) { + group := &Group{Peers: []string{"peer1", "peer2"}} + peerID := "" + assert.False(t, group.RemovePeer(peerID)) + assert.Equal(t, 2, len(group.Peers)) + }) +} diff --git a/management/server/group_test.go b/management/server/group_test.go index 89184e819..59094a23e 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -208,7 +208,7 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { { name: "delete non-existent group", groupIDs: []string{"non-existent-group"}, - expectedDeleted: []string{"non-existent-group"}, + expectedReasons: []string{"group: non-existent-group not found"}, }, { name: "delete multiple groups with mixed results", diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 99e6b204c..0c70b702a 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -52,25 +52,22 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con return am.Store.SaveAccount(ctx, a) } -func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) { - if len(groups) == 0 { +func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) { + if len(groupIDs) == 0 { return true, nil } - accountsGroups, err := am.ListGroups(ctx, accountId) - if err != nil { - return false, err - } - for _, group := range groups { - var found bool - for _, accountGroup := range accountsGroups { - if accountGroup.ID == group { - found = true - break + + err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for _, groupID := range groupIDs { + _, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + if err != nil { + return err } } - if !found { - return false, nil - } + return nil + }) + if err != nil { + return false, err } return true, nil diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index d7139bb2a..aa6a47b15 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -45,7 +45,6 @@ type MockAccountManager struct { SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error - ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error) GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error @@ -354,14 +353,6 @@ func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userI return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented") } -// ListGroups mock implementation of ListGroups from server.AccountManager interface -func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) { - if am.ListGroupsFunc != nil { - return am.ListGroupsFunc(ctx, accountID) - } - return nil, status.Errorf(codes.Unimplemented, "method ListGroups is not implemented") -} - // GroupAddPeer mock implementation of GroupAddPeer from server.AccountManager interface func (am *MockAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error { if am.GroupAddPeerFunc != nil { diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 5ebd263dc..957008714 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -71,7 +71,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco } if anyGroupHasPeers(account, newNSGroup.Groups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) @@ -106,7 +106,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun } if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) @@ -136,7 +136,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco } if anyGroupHasPeers(account, nsGroup.Groups) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) diff --git a/management/server/peer.go b/management/server/peer.go index 87b5b0e4e..1405dead8 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -133,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil @@ -271,7 +271,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if peerLabelUpdated || requiresPeerUpdates { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return peer, nil @@ -335,7 +335,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - updateAccountPeers := isPeerInActiveGroup(account, peerID) + updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID) + if err != nil { + return err + } err = am.deletePeers(ctx, account, []string{peerID}, userID) if err != nil { @@ -348,7 +351,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -555,7 +558,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed to add peer to account: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, accountID) + err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -598,10 +601,15 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s if err != nil { return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err) } - groupsToAdd = append(groupsToAdd, allGroup.ID) - if areGroupChangesAffectPeers(account, groupsToAdd) { - am.updateAccountPeers(ctx, account) + + newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd) + if err != nil { + return nil, nil, nil, err + } + + if newGroupsAffectsPeers { + am.updateAccountPeers(ctx, accountID) } approvedPeersMap, err := am.GetValidatedPeers(account) @@ -666,7 +674,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if sync.UpdateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } } @@ -685,7 +693,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if isStatusChanged { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } validPeersMap, err := am.GetValidatedPeers(account) @@ -816,7 +824,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } if updateRemotePeers || isStatusChanged { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) @@ -979,7 +987,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { +func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) { start := time.Now() defer func() { if am.metrics != nil { @@ -987,6 +995,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account } }() + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) + return + } peers := account.GetPeers() approvedPeersMap, err := am.GetValidatedPeers(account) @@ -1033,12 +1046,12 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} { // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(account *Account, peerID string) bool { +func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) { peerGroupIDs := make([]string, 0) for _, group := range account.Groups { if slices.Contains(group.Peers, peerID) { peerGroupIDs = append(peerGroupIDs, group.ID) } } - return areGroupChangesAffectPeers(account, peerGroupIDs) + return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 78885ea1b..4e2dcb2c3 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -877,7 +877,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - manager.updateAccountPeers(ctx, account) + manager.updateAccountPeers(ctx, account.Id) } duration := time.Since(start) diff --git a/management/server/policy.go b/management/server/policy.go index 43a925f88..8a5733f01 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -377,7 +377,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -406,7 +406,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) if anyGroupHasPeers(account, policy.ruleGroups()) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 2dccd8f59..096cff3f5 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -69,7 +69,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/route.go b/management/server/route.go index 1cf00b37c..dcf2cb0d3 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -238,7 +238,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } if isRouteChangeAffectPeers(account, &newRoute) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -324,7 +324,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) @@ -356,7 +356,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) if isRouteChangeAffectPeers(account, routy) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/route_test.go b/management/server/route_test.go index 4893e19b9..5c848f68c 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1091,7 +1091,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.ListGroups(context.Background(), account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) require.NoError(t, err) var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 960532bf9..cae0dfecb 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -453,14 +453,14 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran modifiedGroups := slices.Concat(addedGroups, removedGroups) groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) if err != nil { - log.WithContext(ctx).Errorf("issue getting groups for setup key events: %v", err) + log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) return nil } for _, g := range removedGroups { group, ok := groups[g] if !ok { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: %v", g, err) + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromSetupKey activity: group not found", g) continue } @@ -473,7 +473,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran for _, g := range addedGroups { group, ok := groups[g] if !ok { - log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: %v", g, err) + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToSetupKey activity: group not found", g) continue } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 9dd3e778d..0ebda6440 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -33,12 +33,13 @@ import ( ) const ( - storeSqliteFileName = "store.db" - idQueryCondition = "id = ?" - keyQueryCondition = "key = ?" - accountAndIDQueryCondition = "account_id = ? and id = ?" - accountIDCondition = "account_id = ?" - peerNotFoundFMT = "peer %s not found" + storeSqliteFileName = "store.db" + idQueryCondition = "id = ?" + keyQueryCondition = "key = ?" + accountAndIDQueryCondition = "account_id = ? and id = ?" + accountAndIDsQueryCondition = "account_id = ? AND id IN ?" + accountIDCondition = "account_id = ?" + peerNotFoundFMT = "peer %s not found" ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -555,9 +556,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } -func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { +func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { var users []*User - result := s.db.Find(&users, accountIDCondition, accountID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -857,7 +858,6 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewUserNotFoundError(userID) } - return status.NewGetUserFromStoreError() } user.LastLogin = lastLogin @@ -1045,7 +1045,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.Errorf(status.NotFound, "group not found for account") + return status.NewGroupNotFoundError(groupID) } return status.Errorf(status.Internal, "issue finding group: %s", result.Error) @@ -1079,10 +1079,45 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro return nil } -func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - result := s.db.Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) +// GetPeerByID retrieves a peer by its ID and account ID. +func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) { + var peer *nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&peer, accountAndIDQueryCondition, accountID, peerID) if result.Error != nil { - return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "peer not found") + } + log.WithContext(ctx).Errorf("failed to get peer from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peer from store") + } + + return peer, nil +} + +// GetPeersByIDs retrieves peers by their IDs and account ID. +func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get peers by ID's from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers by ID's from the store") + } + + peersMap := make(map[string]*nbpeer.Peer) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + return peersMap, nil +} + +func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) + return status.Errorf(status.Internal, "failed to increment network serial count in store") } return nil } @@ -1103,7 +1138,8 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor func (s *SqlStore) withTx(tx *gorm.DB) Store { return &SqlStore{ - db: tx, + db: tx, + storeEngine: s.storeEngine, } } @@ -1155,12 +1191,22 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength } // GetGroupByID retrieves a group by ID and account ID. -func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { - return getRecordByID[nbgroup.Group](s.db.Preload(clause.Associations), lockStrength, groupID, accountID) +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) { + var group *nbgroup.Group + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewGroupNotFoundError(groupID) + } + log.WithContext(ctx).Errorf("failed to get group from store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get group from store") + } + + return group, nil } // GetGroupByName retrieves a group by name and account ID. -func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) { var group nbgroup.Group // TODO: This fix is accepted for now, but if we need to handle this more frequently @@ -1172,12 +1218,13 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren query = query.Order("json_array_length(peers) DESC") } - result := query.First(&group, "name = ? and account_id = ?", groupName, accountID) + result := query.First(&group, "account_id = ? AND name = ?", accountID, groupName) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "group not found") + return nil, status.NewGroupNotFoundError(groupName) } - return nil, status.Errorf(status.Internal, "failed to get group from store: %s", result.Error) + log.WithContext(ctx).Errorf("failed to get group by name from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get group by name from store") } return &group, nil } @@ -1185,7 +1232,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren // GetGroupsByIDs retrieves groups by their IDs and account ID. func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { var groups []*nbgroup.Group - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, "account_id = ? AND id in ?", accountID, groupIDs) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") @@ -1203,11 +1250,40 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) if result.Error != nil { - return status.Errorf(status.Internal, "failed to save group to store: %v", result.Error) + log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save group to store") } return nil } +// DeleteGroup deletes a group from the database. +func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete group from store") + } + + if result.RowsAffected == 0 { + return status.NewGroupNotFoundError(groupID) + } + + return nil +} + +// DeleteGroups deletes groups from the database. +func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { + result := s.db.Clauses(clause.Locking{Strength: string(strength)}). + Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete groups from store") + } + + return nil +} + // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 3f3b2a453..114da1ee6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -14,11 +14,10 @@ import ( "time" "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" route2 "github.com/netbirdio/netbird/route" @@ -1181,7 +1180,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { t.Fatal("failed to save group") return err } - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID) if err != nil { t.Fatal("failed to get group") return err @@ -1201,7 +1200,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" account, err := store.GetAccount(context.Background(), accountID) require.NoError(t, err) - users, err := store.GetAccountUsers(context.Background(), accountID) + users, err := store.GetAccountUsers(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err) require.Len(t, users, len(account.Users)) } @@ -1260,9 +1259,9 @@ func TestSqlite_GetGroupByName(t *testing.T) { } accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, "All", accountID) + group, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All") require.NoError(t, err) - require.Equal(t, "All", group.Name) + require.True(t, group.IsGroupAll()) } func Test_DeleteSetupKeySuccessfully(t *testing.T) { @@ -1293,3 +1292,275 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { err = store.DeleteSetupKey(context.Background(), LockingStrengthUpdate, accountID, nonExistingKeyID) require.Error(t, err) } + +func TestSqlStore_GetGroupsByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectedCount int + }{ + { + name: "retrieve existing groups by existing IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectedCount: 2, + }, + { + name: "empty group IDs list", + groupIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing group IDs", + groupIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing group IDs", + groupIDs: []string{"cfefqs706sqkneg59g4g", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + groups, err := store.GetGroupsByIDs(context.Background(), LockingStrengthShare, accountID, tt.groupIDs) + require.NoError(t, err) + require.Len(t, groups, tt.expectedCount) + }) + } +} + +func TestSqlStore_SaveGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + group := &nbgroup.Group{ + ID: "group-id", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + } + err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + require.NoError(t, err) + + savedGroup, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, "group-id") + require.NoError(t, err) + require.Equal(t, savedGroup, group) +} + +func TestSqlStore_SaveGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + groups := []*nbgroup.Group{ + { + ID: "group-1", + AccountID: accountID, + Issued: "api", + Peers: []string{"peer1", "peer2"}, + }, + { + ID: "group-2", + AccountID: accountID, + Issued: "integration", + Peers: []string{"peer3", "peer4"}, + }, + } + err = store.SaveGroups(context.Background(), LockingStrengthUpdate, groups) + require.NoError(t, err) +} + +func TestSqlStore_DeleteGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupID string + expectError bool + }{ + { + name: "delete existing group", + groupID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "delete non-existing group", + groupID: "non-existing-group-id", + expectError: true, + }, + { + name: "delete with empty group ID", + groupID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroup(context.Background(), LockingStrengthUpdate, accountID, tt.groupID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + } else { + require.NoError(t, err) + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, tt.groupID) + require.Error(t, err) + require.Nil(t, group) + } + }) + } +} + +func TestSqlStore_DeleteGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + groupIDs []string + expectError bool + }{ + { + name: "delete multiple existing groups", + groupIDs: []string{"cfefqs706sqkneg59g4g", "cfefqs706sqkneg59g3g"}, + expectError: false, + }, + { + name: "delete non-existing groups", + groupIDs: []string{"non-existing-id-1", "non-existing-id-2"}, + expectError: false, + }, + { + name: "delete with empty group IDs list", + groupIDs: []string{}, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := store.DeleteGroups(context.Background(), LockingStrengthUpdate, accountID, tt.groupIDs) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + + for _, groupID := range tt.groupIDs { + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.Error(t, err) + require.Nil(t, group) + } + } + }) + } +} + +func TestSqlStore_GetPeerByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerID string + expectError bool + }{ + { + name: "retrieve existing peer", + peerID: "cfefqs706sqkneg59g4g", + expectError: false, + }, + { + name: "retrieve non-existing peer", + peerID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty peer ID", + peerID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, tt.peerID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, peer) + } else { + require.NoError(t, err) + require.NotNil(t, peer) + require.Equal(t, tt.peerID, peer.ID) + } + }) + } +} + +func TestSqlStore_GetPeersByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + peerIDs []string + expectedCount int + }{ + { + name: "retrieve existing peers by existing IDs", + peerIDs: []string{"cfefqs706sqkneg59g4g", "cfeg6sf06sqkneg59g50"}, + expectedCount: 2, + }, + { + name: "empty peer IDs list", + peerIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing peer IDs", + peerIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing peer IDs", + peerIDs: []string{"cfeg6sf06sqkneg59g50", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetPeersByIDs(context.Background(), LockingStrengthShare, accountID, tt.peerIDs) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} diff --git a/management/server/status/error.go b/management/server/status/error.go index f1f3f16e6..8b6d0077b 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -3,7 +3,6 @@ package status import ( "errors" "fmt" - "time" ) const ( @@ -126,11 +125,6 @@ func NewAdminPermissionError() error { return Errorf(PermissionDenied, "admin role required to perform this action") } -// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context -func NewStoreContextCanceledError(duration time.Duration) error { - return Errorf(Internal, "store access: context canceled after %v", duration) -} - // NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key func NewInvalidKeyIDError() error { return Errorf(InvalidArgument, "invalid key ID") @@ -140,3 +134,8 @@ func NewInvalidKeyIDError() error { func NewGetAccountError(err error) error { return Errorf(Internal, "error getting account: %s", err) } + +// NewGroupNotFoundError creates a new Error with NotFound type for a missing group +func NewGroupNotFoundError(groupID string) error { + return Errorf(NotFound, "group: %s not found", groupID) +} diff --git a/management/server/store.go b/management/server/store.go index c7c066629..71b0d457b 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -62,7 +62,7 @@ type Store interface { GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) + GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error @@ -76,6 +76,8 @@ type Store interface { GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error + DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error + DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) @@ -90,6 +92,8 @@ type Store interface { AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) + GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) + GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error @@ -108,7 +112,7 @@ type Store interface { GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) - IncrementNetworkSerial(ctx context.Context, accountId string) error + IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetInstallationID() string diff --git a/management/server/user.go b/management/server/user.go index 5e0d9d034..74062112a 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -494,7 +494,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil @@ -835,7 +835,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } for _, storeEvent := range eventsToStore { @@ -1132,7 +1132,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil } @@ -1240,7 +1240,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } for targetUserID, meta := range deletedUsersMeta { From a1c5287b7c7a6152cb3cff319dea278ba1cfd3c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=B0smail?= Date: Fri, 15 Nov 2024 20:21:27 +0300 Subject: [PATCH 013/159] Fix the Inactivity Expiration problem. (#2865) --- management/server/account.go | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 1bd8a99a9..95c93a22b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1186,20 +1186,25 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { - if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { - event := activity.AccountPeerInactivityExpirationEnabled - if !newSettings.PeerInactivityExpirationEnabled { - event = activity.AccountPeerInactivityExpirationDisabled - am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) - } else { + + if newSettings.PeerInactivityExpirationEnabled { + if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { + oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration + + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) am.checkAndSchedulePeerInactivityExpiration(ctx, account) } - am.StoreEvent(ctx, userID, accountID, accountID, event, nil) - } - - if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { - am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } else { + if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { + event := activity.AccountPeerInactivityExpirationEnabled + if !newSettings.PeerInactivityExpirationEnabled { + event = activity.AccountPeerInactivityExpirationDisabled + am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) + } else { + am.checkAndSchedulePeerInactivityExpiration(ctx, account) + } + am.StoreEvent(ctx, userID, accountID, accountID, event, nil) + } } return nil From 121dfda915cbaa17e8a16af1f96a71927fc0ece1 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 15 Nov 2024 20:05:26 +0100 Subject: [PATCH 014/159] [client] Fix state manager race conditions (#2890) --- client/internal/dns/server.go | 20 +++---- client/internal/engine.go | 2 +- .../routemanager/refcounter/refcounter.go | 58 +++++++++---------- .../internal/routemanager/systemops/state.go | 23 ++++---- .../systemops/systemops_generic.go | 20 +------ client/internal/statemanager/manager.go | 40 +++++++++---- util/file.go | 55 +++++++++++++----- 7 files changed, 118 insertions(+), 100 deletions(-) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 6c4dccae7..f0277319c 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -7,7 +7,6 @@ import ( "runtime" "strings" "sync" - "time" "github.com/miekg/dns" "github.com/mitchellh/hashstructure/v2" @@ -323,13 +322,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { log.Error(err) } - // persist dns state right away - ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) - defer cancel() - - // don't block go func() { - if err := s.stateManager.PersistState(ctx); err != nil { + // persist dns state right away + if err := s.stateManager.PersistState(s.ctx); err != nil { log.Errorf("Failed to persist dns state: %v", err) } }() @@ -537,12 +532,11 @@ func (s *DefaultServer) upstreamCallbacks( l.Errorf("Failed to apply nameserver deactivation on the host: %v", err) } - // persist dns state right away - ctx, cancel := context.WithTimeout(s.ctx, 3*time.Second) - defer cancel() - if err := s.stateManager.PersistState(ctx); err != nil { - l.Errorf("Failed to persist dns state: %v", err) - } + go func() { + if err := s.stateManager.PersistState(s.ctx); err != nil { + l.Errorf("Failed to persist dns state: %v", err) + } + }() if runtime.GOOS == "android" && nsGroup.Primary && len(s.hostsDNSHolder.get()) > 0 { s.addHostRootZone() diff --git a/client/internal/engine.go b/client/internal/engine.go index d4a3a561a..1c912220c 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -297,7 +297,7 @@ func (e *Engine) Stop() error { if err := e.stateManager.Stop(ctx); err != nil { return fmt.Errorf("failed to stop state manager: %w", err) } - if err := e.stateManager.PersistState(ctx); err != nil { + if err := e.stateManager.PersistState(context.Background()); err != nil { log.Errorf("failed to persist state: %v", err) } diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index 0e230ef40..f2f0a169d 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -47,10 +47,9 @@ type RemoveFunc[Key, O any] func(key Key, out O) error type Counter[Key comparable, I, O any] struct { // refCountMap keeps track of the reference Ref for keys refCountMap map[Key]Ref[O] - refCountMu sync.Mutex + mu sync.Mutex // idMap keeps track of the keys associated with an ID for removal idMap map[string][]Key - idMu sync.Mutex add AddFunc[Key, I, O] remove RemoveFunc[Key, O] } @@ -75,10 +74,8 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key func (rm *Counter[Key, I, O]) LoadData( existingCounter *Counter[Key, I, O], ) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() rm.refCountMap = existingCounter.refCountMap rm.idMap = existingCounter.idMap @@ -87,8 +84,8 @@ func (rm *Counter[Key, I, O]) LoadData( // Get retrieves the current reference count and associated data for a key. // If the key doesn't exist, it returns a zero value Ref and false. func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() ref, ok := rm.refCountMap[key] return ref, ok @@ -97,9 +94,13 @@ func (rm *Counter[Key, I, O]) Get(key Key) (Ref[O], bool) { // Increment increments the reference count for the given key. // If this is the first reference to the key, the AddFunc is called. func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() + return rm.increment(key, in) +} + +func (rm *Counter[Key, I, O]) increment(key Key, in I) (Ref[O], error) { ref := rm.refCountMap[key] logCallerF("Increasing ref count [%d -> %d] for key %v with In [%v] Out [%v]", ref.Count, ref.Count+1, key, in, ref.Out) @@ -126,10 +127,10 @@ func (rm *Counter[Key, I, O]) Increment(key Key, in I) (Ref[O], error) { // IncrementWithID increments the reference count for the given key and groups it under the given ID. // If this is the first reference to the key, the AddFunc is called. func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], error) { - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() - ref, err := rm.Increment(key, in) + ref, err := rm.increment(key, in) if err != nil { return ref, fmt.Errorf("with ID: %w", err) } @@ -141,9 +142,12 @@ func (rm *Counter[Key, I, O]) IncrementWithID(id string, key Key, in I) (Ref[O], // Decrement decrements the reference count for the given key. // If the reference count reaches 0, the RemoveFunc is called. func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() + return rm.decrement(key) +} +func (rm *Counter[Key, I, O]) decrement(key Key) (Ref[O], error) { ref, ok := rm.refCountMap[key] if !ok { logCallerF("No reference found for key %v", key) @@ -168,12 +172,12 @@ func (rm *Counter[Key, I, O]) Decrement(key Key) (Ref[O], error) { // DecrementWithID decrements the reference count for all keys associated with the given ID. // If the reference count reaches 0, the RemoveFunc is called. func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() var merr *multierror.Error for _, key := range rm.idMap[id] { - if _, err := rm.Decrement(key); err != nil { + if _, err := rm.decrement(key); err != nil { merr = multierror.Append(merr, err) } } @@ -184,10 +188,8 @@ func (rm *Counter[Key, I, O]) DecrementWithID(id string) error { // Flush removes all references and calls RemoveFunc for each key. func (rm *Counter[Key, I, O]) Flush() error { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() var merr *multierror.Error for key := range rm.refCountMap { @@ -206,10 +208,8 @@ func (rm *Counter[Key, I, O]) Flush() error { // Clear removes all references without calling RemoveFunc. func (rm *Counter[Key, I, O]) Clear() { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() clear(rm.refCountMap) clear(rm.idMap) @@ -217,10 +217,8 @@ func (rm *Counter[Key, I, O]) Clear() { // MarshalJSON implements the json.Marshaler interface for Counter. func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { - rm.refCountMu.Lock() - defer rm.refCountMu.Unlock() - rm.idMu.Lock() - defer rm.idMu.Unlock() + rm.mu.Lock() + defer rm.mu.Unlock() return json.Marshal(struct { RefCountMap map[Key]Ref[O] `json:"refCountMap"` diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 425908922..8e158711e 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -2,31 +2,28 @@ package systemops import ( "net/netip" - "sync" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" ) -type ShutdownState struct { - Counter *ExclusionCounter `json:"counter,omitempty"` - mu sync.RWMutex -} +type ShutdownState ExclusionCounter func (s *ShutdownState) Name() string { return "route_state" } func (s *ShutdownState) Cleanup() error { - s.mu.RLock() - defer s.mu.RUnlock() - - if s.Counter == nil { - return nil - } - sysops := NewSysOps(nil, nil) sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData(s.Counter) + sysops.refCounter.LoadData((*ExclusionCounter)(s)) return sysops.refCounter.Flush() } + +func (s *ShutdownState) MarshalJSON() ([]byte, error) { + return (*ExclusionCounter)(s).MarshalJSON() +} + +func (s *ShutdownState) UnmarshalJSON(data []byte) error { + return (*ExclusionCounter)(s).UnmarshalJSON(data) +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 4ff34aa51..f8b3ebbb8 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -62,7 +62,7 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return nexthop, err }, func(prefix netip.Prefix, nexthop Nexthop) error { - // remove from state even if we have trouble removing it from the route table + // update state even if we have trouble removing it from the route table // it could be already gone r.updateState(stateManager) @@ -75,12 +75,9 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return r.setupHooks(initAddresses) } +// updateState updates state on every change so it will be persisted regularly func (r *SysOps) updateState(stateManager *statemanager.Manager) { - state := getState(stateManager) - - state.Counter = r.refCounter - - if err := stateManager.UpdateState(state); err != nil { + if err := stateManager.UpdateState((*ShutdownState)(r.refCounter)); err != nil { log.Errorf("failed to update state: %v", err) } } @@ -532,14 +529,3 @@ func isVpnRoute(addr netip.Addr, vpnRoutes []netip.Prefix, localRoutes []netip.P // Return true if the longest matching prefix is from vpnRoutes return isVpn, longestPrefix } - -func getState(stateManager *statemanager.Manager) *ShutdownState { - var shutdownState *ShutdownState - if state := stateManager.GetState(shutdownState); state != nil { - shutdownState = state.(*ShutdownState) - } else { - shutdownState = &ShutdownState{} - } - - return shutdownState -} diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 580ccdfc7..da6dd022f 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -74,15 +74,15 @@ func (m *Manager) Stop(ctx context.Context) error { m.mu.Lock() defer m.mu.Unlock() - if m.cancel != nil { - m.cancel() + if m.cancel == nil { + return nil + } + m.cancel() - select { - case <-ctx.Done(): - return ctx.Err() - case <-m.done: - return nil - } + select { + case <-ctx.Done(): + return ctx.Err() + case <-m.done: } return nil @@ -179,14 +179,18 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } + bs, err := marshalWithPanicRecovery(m.states) + if err != nil { + return fmt.Errorf("marshal states: %w", err) + } + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() done := make(chan error, 1) - start := time.Now() go func() { - done <- util.WriteJsonWithRestrictedPermission(ctx, m.filePath, m.states) + done <- util.WriteBytesWithRestrictedPermission(ctx, m.filePath, bs) }() select { @@ -286,3 +290,19 @@ func (m *Manager) PerformCleanup() error { return nberrors.FormatErrorOrNil(merr) } + +func marshalWithPanicRecovery(v any) ([]byte, error) { + var bs []byte + var err error + + func() { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic during marshal: %v", r) + } + }() + bs, err = json.Marshal(v) + }() + + return bs, err +} diff --git a/util/file.go b/util/file.go index 4641cc1b8..f7de7ede2 100644 --- a/util/file.go +++ b/util/file.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -14,6 +15,19 @@ import ( log "github.com/sirupsen/logrus" ) +func WriteBytesWithRestrictedPermission(ctx context.Context, file string, bs []byte) error { + configDir, configFileName, err := prepareConfigFileDir(file) + if err != nil { + return fmt.Errorf("prepare config file dir: %w", err) + } + + if err = EnforcePermission(file); err != nil { + return fmt.Errorf("enforce permission: %w", err) + } + + return writeBytes(ctx, file, err, configDir, configFileName, bs) +} + // WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory func WriteJsonWithRestrictedPermission(ctx context.Context, file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) @@ -82,29 +96,44 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { func writeJson(ctx context.Context, file string, obj interface{}, configDir string, configFileName string) error { // Check context before expensive operations if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("write json start: %w", ctx.Err()) } // make it pretty bs, err := json.MarshalIndent(obj, "", " ") if err != nil { - return err + return fmt.Errorf("marshal: %w", err) } + return writeBytes(ctx, file, err, configDir, configFileName, bs) +} + +func writeBytes(ctx context.Context, file string, err error, configDir string, configFileName string, bs []byte) error { if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("write bytes start: %w", ctx.Err()) } tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) if err != nil { - return err + return fmt.Errorf("create temp: %w", err) } tempFileName := tempFile.Name() - // closing file ops as windows doesn't allow to move it - err = tempFile.Close() + + if deadline, ok := ctx.Deadline(); ok { + if err := tempFile.SetDeadline(deadline); err != nil && !errors.Is(err, os.ErrNoDeadline) { + log.Warnf("failed to set deadline: %v", err) + } + } + + _, err = tempFile.Write(bs) if err != nil { - return err + _ = tempFile.Close() + return fmt.Errorf("write: %w", err) + } + + if err = tempFile.Close(); err != nil { + return fmt.Errorf("close %s: %w", tempFileName, err) } defer func() { @@ -114,19 +143,13 @@ func writeJson(ctx context.Context, file string, obj interface{}, configDir stri } }() - err = os.WriteFile(tempFileName, bs, 0600) - if err != nil { - return err - } - // Check context again if ctx.Err() != nil { - return ctx.Err() + return fmt.Errorf("after temp file: %w", ctx.Err()) } - err = os.Rename(tempFileName, file) - if err != nil { - return err + if err = os.Rename(tempFileName, file); err != nil { + return fmt.Errorf("move %s to %s: %w", tempFileName, file, err) } return nil From 582bb587140789884c8905e7ae8c8c3e732077dc Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 15 Nov 2024 22:55:33 +0100 Subject: [PATCH 015/159] Move state updates outside the refcounter (#2897) --- .../systemops/systemops_generic.go | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index f8b3ebbb8..3038c3ec5 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -57,22 +57,14 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana return nexthop, refcounter.ErrIgnore } - r.updateState(stateManager) - return nexthop, err }, - func(prefix netip.Prefix, nexthop Nexthop) error { - // update state even if we have trouble removing it from the route table - // it could be already gone - r.updateState(stateManager) - - return r.removeFromRouteTable(prefix, nexthop) - }, + r.removeFromRouteTable, ) r.refCounter = refCounter - return r.setupHooks(initAddresses) + return r.setupHooks(initAddresses, stateManager) } // updateState updates state on every change so it will be persisted regularly @@ -333,7 +325,7 @@ func (r *SysOps) genericRemoveVPNRoute(prefix netip.Prefix, intf *net.Interface) return r.removeFromRouteTable(prefix, nextHop) } -func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (r *SysOps) setupHooks(initAddresses []net.IP, stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { beforeHook := func(connID nbnet.ConnectionID, ip net.IP) error { prefix, err := util.GetPrefixFromIP(ip) if err != nil { @@ -344,6 +336,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re return fmt.Errorf("adding route reference: %v", err) } + r.updateState(stateManager) + return nil } afterHook := func(connID nbnet.ConnectionID) error { @@ -351,6 +345,8 @@ func (r *SysOps) setupHooks(initAddresses []net.IP) (nbnet.AddHookFunc, nbnet.Re return fmt.Errorf("remove route reference: %w", err) } + r.updateState(stateManager) + return nil } From a7d5c522033beef9174898b5e43a907b282f7341 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 15 Nov 2024 22:59:49 +0100 Subject: [PATCH 016/159] Fix error state race on mgmt connection error (#2892) --- client/internal/connect.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index dff44f1d2..f76aa066b 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -157,7 +157,8 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold engineCtx, cancel := context.WithCancel(c.ctx) defer func() { - c.statusRecorder.MarkManagementDisconnected(state.err) + _, err := state.Status() + c.statusRecorder.MarkManagementDisconnected(err) c.statusRecorder.CleanLocalPeerState() cancel() }() From ec543f89fb819b4aae28850b370f0c06f05f7f96 Mon Sep 17 00:00:00 2001 From: Kursat Aktas Date: Sat, 16 Nov 2024 17:45:31 +0300 Subject: [PATCH 017/159] Introducing NetBird Guru on Gurubase.io (#2778) --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index 270c9ad87..a2d7f3897 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,10 @@
+ +
+ +

From 65a94f695f09063d505f49a6b2496f7a2e4d48b4 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 18 Nov 2024 12:55:02 +0100 Subject: [PATCH 018/159] use google domain for tests (#2902) --- client/internal/dns/server_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 21f1f1b7d..eab9f4ecb 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -782,7 +782,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) { Port: 53, }, }, - Domains: []string{"customdomain.com"}, + Domains: []string{"google.com"}, Primary: false, }, }, @@ -804,7 +804,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) { if ips[0] != zoneRecords[0].RData { t.Fatalf("invalid zone record: %v", err) } - _, err = resolver.LookupHost(context.Background(), "customdomain.com") + _, err = resolver.LookupHost(context.Background(), "google.com") if err != nil { t.Errorf("failed to resolve: %s", err) } From ec6438e643c3ef0b0388c05c2fad5beed99fb4fe Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 18 Nov 2024 17:12:13 +0300 Subject: [PATCH 019/159] Use update strength and simplify check Signed-off-by: bcmmbaga --- management/server/policy.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/management/server/policy.go b/management/server/policy.go index 6dcb96316..693ae2872 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -435,7 +435,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po var updateAccountPeers bool err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - policy, err = transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) + policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID) if err != nil { return err } @@ -502,8 +502,6 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account if hasPeers { return true, nil } - - return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) } return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) From 78fab877c07ee50aae95ed037218ec41f0bf2489 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 18 Nov 2024 15:31:53 +0100 Subject: [PATCH 020/159] [misc] Update signing pipeline version (#2900) --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 14e383a27..183cdb02c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,7 +9,7 @@ on: pull_request: env: - SIGN_PIPE_VER: "v0.0.16" + SIGN_PIPE_VER: "v0.0.17" GORELEASER_VER: "v2.3.2" PRODUCT_NAME: "NetBird" COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)" From df98c67ac8100ac75995fecfaa32e76691db36c0 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 18 Nov 2024 18:46:52 +0300 Subject: [PATCH 021/159] prevent changing ruleID when not empty Signed-off-by: bcmmbaga --- management/server/http/policies_handler.go | 7 ++++++- management/server/sql_store_test.go | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 8255e4896..ca256a183 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -128,8 +128,13 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID Description: req.Description, } for _, rule := range req.Rules { + ruleID := policyID // TODO: when policy can contain multiple rules, need refactor + if rule.Id != nil { + ruleID = *rule.Id + } + pr := server.PolicyRule{ - ID: policyID, // TODO: when policy can contain multiple rules, need refactor + ID: ruleID, PolicyID: policyID, Name: rule.Name, Destinations: rule.Destinations, diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 8931008d7..c05793fc6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1832,6 +1832,8 @@ func TestSqlStore_SavePolicy(t *testing.T) { policy.Enabled = false policy.Description = "policy" + policy.Rules[0].Sources = []string{"group"} + policy.Rules[0].Ports = []string{"80", "443"} err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) require.NoError(t, err) From b60e2c32614615d5f7c44d95794338ade30d9287 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 18 Nov 2024 22:48:38 +0300 Subject: [PATCH 022/159] prevent duplicate rules during updates Signed-off-by: bcmmbaga --- management/server/http/policies_handler.go | 2 +- management/server/policy.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index ca256a183..eff9092d4 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -128,7 +128,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID Description: req.Description, } for _, rule := range req.Rules { - ruleID := policyID // TODO: when policy can contain multiple rules, need refactor + var ruleID string if rule.Id != nil { ruleID = *rule.Id } diff --git a/management/server/policy.go b/management/server/policy.go index 693ae2872..2d3abc3f1 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -532,7 +532,7 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po for i, rule := range policy.Rules { ruleCopy := rule.Copy() if ruleCopy.ID == "" { - ruleCopy.ID = xid.New().String() + ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor ruleCopy.PolicyID = policy.ID } From 52ea2e84e9aa3c03fe43c5098ab50c94ff2e0818 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 19 Nov 2024 00:04:50 +0100 Subject: [PATCH 023/159] [management] Add transaction metrics and exclude getAccount time from peers update (#2904) --- management/server/peer.go | 11 ++++++----- management/server/sql_store.go | 11 ++++++++++- management/server/telemetry/store_metrics.go | 12 ++++++++++++ 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 1405dead8..8c45e45c9 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -988,6 +988,12 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) + return + } + start := time.Now() defer func() { if am.metrics != nil { @@ -995,11 +1001,6 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account } }() - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) - return - } peers := account.GetPeers() approvedPeersMap, err := am.GetValidatedPeers(account) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 0ebda6440..278f5443d 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1123,6 +1123,7 @@ func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength Lock } func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { + startTime := time.Now() tx := s.db.Begin() if tx.Error != nil { return tx.Error @@ -1133,7 +1134,15 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor tx.Rollback() return err } - return tx.Commit().Error + + err = tx.Commit().Error + + log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime)) + if s.metrics != nil { + s.metrics.StoreMetrics().CountTransactionDuration(time.Since(startTime)) + } + + return err } func (s *SqlStore) withTx(tx *gorm.DB) Store { diff --git a/management/server/telemetry/store_metrics.go b/management/server/telemetry/store_metrics.go index b038c3d36..bb3745b5a 100644 --- a/management/server/telemetry/store_metrics.go +++ b/management/server/telemetry/store_metrics.go @@ -13,6 +13,7 @@ type StoreMetrics struct { globalLockAcquisitionDurationMs metric.Int64Histogram persistenceDurationMicro metric.Int64Histogram persistenceDurationMs metric.Int64Histogram + transactionDurationMs metric.Int64Histogram ctx context.Context } @@ -40,11 +41,17 @@ func NewStoreMetrics(ctx context.Context, meter metric.Meter) (*StoreMetrics, er return nil, err } + transactionDurationMs, err := meter.Int64Histogram("management.store.transaction.duration.ms") + if err != nil { + return nil, err + } + return &StoreMetrics{ globalLockAcquisitionDurationMicro: globalLockAcquisitionDurationMicro, globalLockAcquisitionDurationMs: globalLockAcquisitionDurationMs, persistenceDurationMicro: persistenceDurationMicro, persistenceDurationMs: persistenceDurationMs, + transactionDurationMs: transactionDurationMs, ctx: ctx, }, nil } @@ -60,3 +67,8 @@ func (metrics *StoreMetrics) CountPersistenceDuration(duration time.Duration) { metrics.persistenceDurationMicro.Record(metrics.ctx, duration.Microseconds()) metrics.persistenceDurationMs.Record(metrics.ctx, duration.Milliseconds()) } + +// CountTransactionDuration counts the duration of a store persistence operation +func (metrics *StoreMetrics) CountTransactionDuration(duration time.Duration) { + metrics.transactionDurationMs.Record(metrics.ctx, duration.Milliseconds()) +} From eb5d0569ae0ce829a312e13ab3e9757b9cdf019f Mon Sep 17 00:00:00 2001 From: "Krzysztof Nazarewski (kdn)" Date: Tue, 19 Nov 2024 14:14:58 +0100 Subject: [PATCH 024/159] [client] Add NB_SKIP_SOCKET_MARK & fix crash instead of returing an error (#2899) * dialer: fix crash instead of returning error * add NB_SKIP_SOCKET_MARK --- .../routemanager/systemops/systemops_linux.go | 2 +- util/grpc/dialer.go | 9 +++++++-- util/net/dialer_nonios.go | 2 +- util/net/net_linux.go | 12 ++++++++++++ 4 files changed, 21 insertions(+), 4 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 0124fd95e..71a0f26ae 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -55,7 +55,7 @@ type ruleParams struct { // isLegacy determines whether to use the legacy routing setup func isLegacy() bool { - return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() + return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true" } // setIsLegacy sets the legacy routing setup diff --git a/util/grpc/dialer.go b/util/grpc/dialer.go index 57ab8fd55..4fbffe342 100644 --- a/util/grpc/dialer.go +++ b/util/grpc/dialer.go @@ -3,6 +3,9 @@ package grpc import ( "context" "crypto/tls" + "fmt" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "net" "os/user" "runtime" @@ -23,20 +26,22 @@ func WithCustomDialer() grpc.DialOption { if runtime.GOOS == "linux" { currentUser, err := user.Current() if err != nil { - log.Fatalf("failed to get current user: %v", err) + return nil, status.Errorf(codes.FailedPrecondition, "failed to get current user: %v", err) } // the custom dialer requires root permissions which are not required for use cases run as non-root if currentUser.Uid != "0" { + log.Debug("Not running as root, using standard dialer") dialer := &net.Dialer{} return dialer.DialContext(ctx, "tcp", addr) } } + log.Debug("Using nbnet.NewDialer()") conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { log.Errorf("Failed to dial: %s", err) - return nil, err + return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) } return conn, nil }) diff --git a/util/net/dialer_nonios.go b/util/net/dialer_nonios.go index 4032a75c0..34004a368 100644 --- a/util/net/dialer_nonios.go +++ b/util/net/dialer_nonios.go @@ -69,7 +69,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { - return nil, fmt.Errorf("dial: %w", err) + return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) } // Wrap the connection in Conn to handle Close with hooks diff --git a/util/net/net_linux.go b/util/net/net_linux.go index 954545eb5..98f49af8d 100644 --- a/util/net/net_linux.go +++ b/util/net/net_linux.go @@ -4,9 +4,14 @@ package net import ( "fmt" + "os" "syscall" + + log "github.com/sirupsen/logrus" ) +const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK" + // SetSocketMark sets the SO_MARK option on the given socket connection func SetSocketMark(conn syscall.Conn) error { sysconn, err := conn.SyscallConn() @@ -36,6 +41,13 @@ func SetRawSocketMark(conn syscall.RawConn) error { func SetSocketOpt(fd int) error { if CustomRoutingDisabled() { + log.Infof("Custom routing is disabled, skipping SO_MARK") + return nil + } + + // Check for the new environment variable + if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" { + log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK") return nil } From 5dd6a08ea6926cb1cb87a3301524b9031f4db61a Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 19 Nov 2024 17:25:49 +0100 Subject: [PATCH 025/159] link peer meta update back to account object (#2911) --- management/server/peer.go | 1 + 1 file changed, 1 insertion(+) diff --git a/management/server/peer.go b/management/server/peer.go index 8c45e45c9..901e4815d 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -667,6 +667,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac updated := peer.UpdateMetaIfNew(sync.Meta) if updated { + account.Peers[peer.ID] = peer log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) err = am.Store.SavePeer(ctx, account.Id, peer) if err != nil { From f66bbcc54c65f0856679fed1b50298e97c0ec7af Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 19 Nov 2024 18:13:26 +0100 Subject: [PATCH 026/159] [management] Add metric for peer meta update (#2913) --- management/server/peer.go | 2 ++ .../server/telemetry/accountmanager_metrics.go | 12 ++++++++++++ 2 files changed, 14 insertions(+) diff --git a/management/server/peer.go b/management/server/peer.go index 901e4815d..beb833dba 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -667,6 +667,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac updated := peer.UpdateMetaIfNew(sync.Meta) if updated { + am.metrics.AccountManagerMetrics().CountPeerMetUpdate() account.Peers[peer.ID] = peer log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) err = am.Store.SavePeer(ctx, account.Id, peer) @@ -801,6 +802,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) updated := peer.UpdateMetaIfNew(login.Meta) if updated { + am.metrics.AccountManagerMetrics().CountPeerMetUpdate() shouldStorePeer = true } diff --git a/management/server/telemetry/accountmanager_metrics.go b/management/server/telemetry/accountmanager_metrics.go index e4bb4e3c3..4a5a31e2d 100644 --- a/management/server/telemetry/accountmanager_metrics.go +++ b/management/server/telemetry/accountmanager_metrics.go @@ -13,6 +13,7 @@ type AccountManagerMetrics struct { updateAccountPeersDurationMs metric.Float64Histogram getPeerNetworkMapDurationMs metric.Float64Histogram networkMapObjectCount metric.Int64Histogram + peerMetaUpdateCount metric.Int64Counter } // NewAccountManagerMetrics creates an instance of AccountManagerMetrics @@ -44,11 +45,17 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account return nil, err } + peerMetaUpdateCount, err := meter.Int64Counter("management.account.peer.meta.update.counter", metric.WithUnit("1")) + if err != nil { + return nil, err + } + return &AccountManagerMetrics{ ctx: ctx, getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs, updateAccountPeersDurationMs: updateAccountPeersDurationMs, networkMapObjectCount: networkMapObjectCount, + peerMetaUpdateCount: peerMetaUpdateCount, }, nil } @@ -67,3 +74,8 @@ func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration ti func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) { metrics.networkMapObjectCount.Record(metrics.ctx, count) } + +// CountPeerMetUpdate counts the number of peer meta updates +func (metrics *AccountManagerMetrics) CountPeerMetUpdate() { + metrics.peerMetaUpdateCount.Add(metrics.ctx, 1) +} From aa575d6f445e74f34f8353a4c413adc209c56f4b Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 21 Nov 2024 15:10:34 +0100 Subject: [PATCH 027/159] [management] Add activity events to group propagation flow (#2916) --- management/server/account.go | 38 ++++++++++- management/server/activity/codes.go | 6 ++ management/server/user.go | 97 ++++++++++++++++++++++------- 3 files changed, 116 insertions(+), 25 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 95c93a22b..0ab123655 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -965,7 +965,9 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro } // UserGroupsAddToPeers adds groups to all peers of user -func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { +func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + userPeers := make(map[string]struct{}) for pid, peer := range a.Peers { if peer.UserID == userID { @@ -979,6 +981,8 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { continue } + oldPeers := group.Peers + groupPeers := make(map[string]struct{}) for _, pid := range group.Peers { groupPeers[pid] = struct{}{} @@ -992,16 +996,25 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) { for pid := range groupPeers { group.Peers = append(group.Peers, pid) } + + groupUpdates[gid] = difference(group.Peers, oldPeers) } + + return groupUpdates } // UserGroupsRemoveFromPeers removes groups from all peers of user -func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { +func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + for _, gid := range groups { group, ok := a.Groups[gid] if !ok || group.Name == "All" { continue } + + oldPeers := group.Peers + update := make([]string, 0, len(group.Peers)) for _, pid := range group.Peers { peer, ok := a.Peers[pid] @@ -1013,7 +1026,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) { } } group.Peers = update + groupUpdates[gid] = difference(oldPeers, group.Peers) } + + return groupUpdates } // BuildManager creates a new DefaultAccountManager with a provided Store @@ -1175,6 +1191,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, err } + err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID) + if err != nil { + return nil, fmt.Errorf("groups propagation failed: %w", err) + } + updatedAccount := account.UpdateSettings(newSettings) err = am.Store.SaveAccount(ctx, account) @@ -1185,6 +1206,19 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return updatedAccount, nil } +func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error { + if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled { + if newSettings.GroupsPropagationEnabled { + am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil) + // Todo: retroactively add user groups to all peers + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil) + } + } + + return nil +} + func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 603260dbc..4c57d65fb 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -148,6 +148,9 @@ const ( AccountPeerInactivityExpirationDurationUpdated Activity = 67 SetupKeyDeleted Activity = 68 + + UserGroupPropagationEnabled Activity = 69 + UserGroupPropagationDisabled Activity = 70 ) var activityMap = map[Activity]Code{ @@ -222,6 +225,9 @@ var activityMap = map[Activity]Code{ AccountPeerInactivityExpirationDisabled: {"Account peer inactivity expiration disabled", "account.peer.inactivity.expiration.disable"}, AccountPeerInactivityExpirationDurationUpdated: {"Account peer inactivity expiration duration updated", "account.peer.inactivity.expiration.update"}, SetupKeyDeleted: {"Setup key deleted", "setupkey.delete"}, + + UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"}, + UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"}, } // StringCode returns a string code of the activity diff --git a/management/server/user.go b/management/server/user.go index 74062112a..edb5e6fd3 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -805,15 +805,20 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, expiredPeers = append(expiredPeers, blockedPeers...) } + peerGroupsAdded := make(map[string][]string) + peerGroupsRemoved := make(map[string][]string) if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) // need force update all auto groups in any case they will not be duplicated - account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) - account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) + peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) + peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) } - events := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole) - eventsToStore = append(eventsToStore, events...) + userUpdateEvents := am.prepareUserUpdateEvents(ctx, initiatorUser.Id, oldUser, newUser, account, transferredOwnerRole) + eventsToStore = append(eventsToStore, userUpdateEvents...) + + userGroupsEvents := am.prepareUserGroupsEvents(ctx, initiatorUser.Id, oldUser, newUser, account, peerGroupsAdded, peerGroupsRemoved) + eventsToStore = append(eventsToStore, userGroupsEvents...) updatedUserInfo, err := getUserInfo(ctx, am, newUser, account) if err != nil { @@ -872,32 +877,78 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in }) } + return eventsToStore +} + +func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() { + var eventsToStore []func() if newUser.AutoGroups != nil { removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups) addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) - for _, g := range removedGroups { - group := account.GetGroup(g) - if group != nil { - eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser, - map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) - }) - } else { - log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) - } - } - for _, g := range addedGroups { - group := account.GetGroup(g) - if group != nil { - eventsToStore = append(eventsToStore, func() { - am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser, - map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) - }) - } + removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved) + eventsToStore = append(eventsToStore, removedEvents...) + + addedEvents := am.handleGroupAddedToUser(ctx, initiatorUserID, oldUser, newUser, account, addedGroups, peerGroupsAdded) + eventsToStore = append(eventsToStore, addedEvents...) + } + return eventsToStore +} + +func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() { + var eventsToStore []func() + for _, g := range addedGroups { + group := account.GetGroup(g) + if group != nil { + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupAddedToUser, + map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) + }) } } + for groupID, peerIDs := range peerGroupsAdded { + group := account.GetGroup(groupID) + for _, peerID := range peerIDs { + peer := account.GetPeer(peerID) + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } + am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupAddedToPeer, meta) + }) + } + } + return eventsToStore +} +func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() { + var eventsToStore []func() + for _, g := range removedGroups { + group := account.GetGroup(g) + if group != nil { + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, initiatorUserID, oldUser.Id, account.Id, activity.GroupRemovedFromUser, + map[string]any{"group": group.Name, "group_id": group.ID, "is_service_user": newUser.IsServiceUser, "user_name": newUser.ServiceUserName}) + }) + + } else { + log.WithContext(ctx).Errorf("group %s not found while saving user activity event of account %s", g, account.Id) + } + } + for groupID, peerIDs := range peerGroupsRemoved { + group := account.GetGroup(groupID) + for _, peerID := range peerIDs { + peer := account.GetPeer(peerID) + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "peer_ip": peer.IP.String(), "peer_fqdn": peer.FQDN(am.GetDNSDomain()), + } + am.StoreEvent(ctx, activity.SystemInitiator, peer.ID, account.Id, activity.GroupRemovedFromPeer, meta) + }) + } + } return eventsToStore } From 82746d93ee93c5f1cbcd3d36f28d9dc89377e6b4 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 21 Nov 2024 17:15:07 +0300 Subject: [PATCH 028/159] Use UTC time in test Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 39c42bfa6..ad1a39cda 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -392,7 +392,7 @@ func TestSqlite_SavePeer(t *testing.T) { IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().Local()}, + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } ctx := context.Background() err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer) @@ -422,7 +422,7 @@ func TestSqlite_SavePeer(t *testing.T) { assert.Equal(t, updatedPeer.Status.Connected, actual.Status.Connected) assert.Equal(t, updatedPeer.Status.LoginExpired, actual.Status.LoginExpired) assert.Equal(t, updatedPeer.Status.RequiresApproval, actual.Status.RequiresApproval) - assert.WithinDurationf(t, updatedPeer.Status.LastSeen, actual.Status.LastSeen, time.Millisecond, "LastSeen should be equal") + assert.WithinDurationf(t, updatedPeer.Status.LastSeen, actual.Status.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") } func TestSqlite_SavePeerStatus(t *testing.T) { @@ -434,7 +434,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { require.NoError(t, err) // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().Local()} + newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()} err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) assert.Error(t, err) parsedErr, ok := status.FromError(err) @@ -448,7 +448,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().Local()}, + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } err = store.SaveAccount(context.Background(), account) @@ -464,7 +464,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { assert.Equal(t, newStatus.Connected, actual.Connected) assert.Equal(t, newStatus.LoginExpired, actual.LoginExpired) assert.Equal(t, newStatus.RequiresApproval, actual.RequiresApproval) - assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen, time.Millisecond, "LastSeen should be equal") + assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") newStatus.Connected = true @@ -478,7 +478,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { assert.Equal(t, newStatus.Connected, actual.Connected) assert.Equal(t, newStatus.LoginExpired, actual.LoginExpired) assert.Equal(t, newStatus.RequiresApproval, actual.RequiresApproval) - assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen, time.Millisecond, "LastSeen should be equal") + assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") } func TestSqlite_SavePeerLocation(t *testing.T) { From 1bbabf70b057c2384a077742b9e6760161e153aa Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 21 Nov 2024 16:53:37 +0100 Subject: [PATCH 029/159] [client] Fix allow netbird rule verdict (#2925) * Fix allow netbird rule verdict * Fix chain name --- client/firewall/nftables/manager_linux.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 3f8fac249..8e1aa0d80 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -199,7 +199,7 @@ func (m *Manager) AllowNetbird() error { var chain *nftables.Chain for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameForward { + if c.Table.Name == tableNameFilter && c.Name == chainNameInput { chain = c break } @@ -276,7 +276,7 @@ func (m *Manager) resetNetbirdInputRules() error { func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { for _, c := range chains { - if c.Table.Name == "filter" && c.Name == "INPUT" { + if c.Table.Name == tableNameFilter && c.Name == chainNameInput { rules, err := m.rConn.GetRules(c.Table, c) if err != nil { log.Errorf("get rules for chain %q: %v", c.Name, err) @@ -351,7 +351,9 @@ func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { Register: 1, Data: ifname(m.wgIface.Name()), }, - &expr.Verdict{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, }, UserData: []byte(allowNetbirdInputRuleID), } From 9db1932664557da94ff64bfc03f5acdcf30667ea Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 22 Nov 2024 10:15:51 +0100 Subject: [PATCH 030/159] [management] Fix getSetupKey call (#2927) --- management/server/http/api/openapi.yml | 30 +++++++-- management/server/http/api/types.gen.go | 89 ++++++++++++++++++++++++- management/server/setupkey.go | 2 +- management/server/setupkey_test.go | 43 ++++++++---- 4 files changed, 144 insertions(+), 20 deletions(-) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index bfb375277..2e084f6e4 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -439,17 +439,13 @@ components: example: 5 required: - accessible_peers_count - SetupKey: + SetupKeyBase: type: object properties: id: description: Setup Key ID type: string example: 2531583362 - key: - description: Setup Key value - type: string - example: A616097E-FCF0-48FA-9354-CA4A61142761 name: description: Setup key name identifier type: string @@ -518,6 +514,28 @@ components: - updated_at - usage_limit - ephemeral + SetupKeyClear: + allOf: + - $ref: '#/components/schemas/SetupKeyBase' + - type: object + properties: + key: + description: Setup Key as plain text + type: string + example: A616097E-FCF0-48FA-9354-CA4A61142761 + required: + - key + SetupKey: + allOf: + - $ref: '#/components/schemas/SetupKeyBase' + - type: object + properties: + key: + description: Setup Key as secret + type: string + example: A6160**** + required: + - key SetupKeyRequest: type: object properties: @@ -1918,7 +1936,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/SetupKey' + $ref: '#/components/schemas/SetupKeyClear' '400': "$ref": "#/components/responses/bad_request" '401': diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index f219c4574..321395d25 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -1062,7 +1062,94 @@ type SetupKey struct { // Id Setup Key ID Id string `json:"id"` - // Key Setup Key value + // Key Setup Key as secret + Key string `json:"key"` + + // LastUsed Setup key last usage date + LastUsed time.Time `json:"last_used"` + + // Name Setup key name identifier + Name string `json:"name"` + + // Revoked Setup key revocation status + Revoked bool `json:"revoked"` + + // State Setup key status, "valid", "overused","expired" or "revoked" + State string `json:"state"` + + // Type Setup key type, one-off for single time usage and reusable + Type string `json:"type"` + + // UpdatedAt Setup key last update date + UpdatedAt time.Time `json:"updated_at"` + + // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage. + UsageLimit int `json:"usage_limit"` + + // UsedTimes Usage count of setup key + UsedTimes int `json:"used_times"` + + // Valid Setup key validity status + Valid bool `json:"valid"` +} + +// SetupKeyBase defines model for SetupKeyBase. +type SetupKeyBase struct { + // AutoGroups List of group IDs to auto-assign to peers registered with this key + AutoGroups []string `json:"auto_groups"` + + // Ephemeral Indicate that the peer will be ephemeral or not + Ephemeral bool `json:"ephemeral"` + + // Expires Setup Key expiration date + Expires time.Time `json:"expires"` + + // Id Setup Key ID + Id string `json:"id"` + + // LastUsed Setup key last usage date + LastUsed time.Time `json:"last_used"` + + // Name Setup key name identifier + Name string `json:"name"` + + // Revoked Setup key revocation status + Revoked bool `json:"revoked"` + + // State Setup key status, "valid", "overused","expired" or "revoked" + State string `json:"state"` + + // Type Setup key type, one-off for single time usage and reusable + Type string `json:"type"` + + // UpdatedAt Setup key last update date + UpdatedAt time.Time `json:"updated_at"` + + // UsageLimit A number of times this key can be used. The value of 0 indicates the unlimited usage. + UsageLimit int `json:"usage_limit"` + + // UsedTimes Usage count of setup key + UsedTimes int `json:"used_times"` + + // Valid Setup key validity status + Valid bool `json:"valid"` +} + +// SetupKeyClear defines model for SetupKeyClear. +type SetupKeyClear struct { + // AutoGroups List of group IDs to auto-assign to peers registered with this key + AutoGroups []string `json:"auto_groups"` + + // Ephemeral Indicate that the peer will be ephemeral or not + Ephemeral bool `json:"ephemeral"` + + // Expires Setup Key expiration date + Expires time.Time `json:"expires"` + + // Id Setup Key ID + Id string `json:"id"` + + // Key Setup Key as plain text Key string `json:"key"` // LastUsed Setup key last usage date diff --git a/management/server/setupkey.go b/management/server/setupkey.go index cae0dfecb..ef431d3ad 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -379,7 +379,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, status.NewAdminPermissionError() } - setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) if err != nil { return nil, err } diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 94ed022fa..7c8200706 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -210,22 +210,41 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ - ID: "group_1", - Name: "group_name_1", - Peers: []string{}, - }) + plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ - ID: "group_2", - Name: "group_name_2", - Peers: []string{}, - }) - if err != nil { - t.Fatal(err) + type testCase struct { + name string + keyId string + expectedFailure bool + } + + testCase1 := testCase{ + name: "Should get existing Setup Key", + keyId: plainKey.Id, + expectedFailure: false, + } + testCase2 := testCase{ + name: "Should fail to get non-existent Setup Key", + keyId: "some key", + expectedFailure: true, + } + + for _, tCase := range []testCase{testCase1, testCase2} { + t.Run(tCase.name, func(t *testing.T) { + key, err := manager.GetSetupKey(context.Background(), account.Id, userID, tCase.keyId) + + if tCase.expectedFailure { + if err == nil { + t.Fatal("expected to fail") + } + return + } + + assert.NotEqual(t, plainKey.Key, key.Key) + }) } } From 2a5cb1649402d42f588d374f6a775c62e92a5522 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 22 Nov 2024 18:12:34 +0100 Subject: [PATCH 031/159] [relay] Refactor initial Relay connection (#2800) Can support firewalls with restricted WS rules allow to run engine without Relay servers keep up to date Relay address changes --- client/internal/connect.go | 3 +- client/internal/engine.go | 10 ++- client/internal/peer/status.go | 20 ++--- client/internal/peer/worker_ice.go | 14 +-- relay/client/client.go | 6 +- relay/client/client_test.go | 2 +- relay/client/guard.go | 91 +++++++++++++++---- relay/client/manager.go | 140 +++++++++++++++++++---------- relay/client/picker.go | 16 ++-- relay/client/picker_test.go | 5 +- 10 files changed, 211 insertions(+), 96 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index f76aa066b..8c2ad4aa1 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -232,6 +232,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold relayURLs, token := parseRelayInfo(loginResp) relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String()) + c.statusRecorder.SetRelayMgr(relayManager) if len(relayURLs) > 0 { if token != nil { if err := relayManager.UpdateToken(token); err != nil { @@ -242,9 +243,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold log.Infof("connecting to the Relay service(s): %s", strings.Join(relayURLs, ", ")) if err = relayManager.Serve(); err != nil { log.Error(err) - return wrapErr(err) } - c.statusRecorder.SetRelayMgr(relayManager) } peerConfig := loginResp.GetPeerConfig() diff --git a/client/internal/engine.go b/client/internal/engine.go index 1c912220c..dc4499e17 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -538,6 +538,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { relayMsg := wCfg.GetRelay() if relayMsg != nil { + // when we receive token we expect valid address list too c := &auth.Token{ Payload: relayMsg.GetTokenPayload(), Signature: relayMsg.GetTokenSignature(), @@ -546,9 +547,16 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { log.Errorf("failed to update relay token: %v", err) return fmt.Errorf("update relay token: %w", err) } + + e.relayManager.UpdateServerURLs(relayMsg.Urls) + + // Just in case the agent started with an MGM server where the relay was disabled but was later enabled. + // We can ignore all errors because the guard will manage the reconnection retries. + _ = e.relayManager.Serve() + } else { + e.relayManager.UpdateServerURLs(nil) } - // todo update relay address in the relay manager // todo update signal } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 0444dc60b..74e2ee82c 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -676,25 +676,23 @@ func (d *Status) GetRelayStates() []relay.ProbeResult { // extend the list of stun, turn servers with relay address relayStates := slices.Clone(d.relayStates) - var relayState relay.ProbeResult - // if the server connection is not established then we will use the general address // in case of connection we will use the instance specific address instanceAddr, err := d.relayMgr.RelayInstanceAddress() if err != nil { // TODO add their status - if errors.Is(err, relayClient.ErrRelayClientNotConnected) { - for _, r := range d.relayMgr.ServerURLs() { - relayStates = append(relayStates, relay.ProbeResult{ - URI: r, - }) - } - return relayStates + for _, r := range d.relayMgr.ServerURLs() { + relayStates = append(relayStates, relay.ProbeResult{ + URI: r, + Err: err, + }) } - relayState.Err = err + return relayStates } - relayState.URI = instanceAddr + relayState := relay.ProbeResult{ + URI: instanceAddr, + } return append(relayStates, relayState) } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 4c67cb781..7ce4797c3 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -46,8 +46,6 @@ type WorkerICE struct { hasRelayOnLocally bool conn WorkerICECallbacks - selectedPriority ConnPriority - agent *ice.Agent muxAgent sync.Mutex @@ -95,10 +93,8 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { var preferredCandidateTypes []ice.CandidateType if w.hasRelayOnLocally && remoteOfferAnswer.RelaySrvAddress != "" { - w.selectedPriority = connPriorityICEP2P preferredCandidateTypes = icemaker.CandidateTypesP2P() } else { - w.selectedPriority = connPriorityICETurn preferredCandidateTypes = icemaker.CandidateTypes() } @@ -159,7 +155,7 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { RelayedOnLocal: isRelayCandidate(pair.Local), } w.log.Debugf("on ICE conn read to use ready") - go w.conn.OnConnReady(w.selectedPriority, ci) + go w.conn.OnConnReady(selectedPriority(pair), ci) } // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. @@ -394,3 +390,11 @@ func isRelayed(pair *ice.CandidatePair) bool { } return false } + +func selectedPriority(pair *ice.CandidatePair) ConnPriority { + if isRelayed(pair) { + return connPriorityICETurn + } else { + return connPriorityICEP2P + } +} diff --git a/relay/client/client.go b/relay/client/client.go index 154c1787f..db5252f50 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -140,7 +140,7 @@ type Client struct { instanceURL *RelayAddr muInstanceURL sync.Mutex - onDisconnectListener func() + onDisconnectListener func(string) onConnectedListener func() listenerMutex sync.Mutex } @@ -233,7 +233,7 @@ func (c *Client) ServerInstanceURL() (string, error) { } // SetOnDisconnectListener sets a function that will be called when the connection to the relay server is closed. -func (c *Client) SetOnDisconnectListener(fn func()) { +func (c *Client) SetOnDisconnectListener(fn func(string)) { c.listenerMutex.Lock() defer c.listenerMutex.Unlock() c.onDisconnectListener = fn @@ -554,7 +554,7 @@ func (c *Client) notifyDisconnected() { if c.onDisconnectListener == nil { return } - go c.onDisconnectListener() + go c.onDisconnectListener(c.connectionURL) } func (c *Client) notifyConnected() { diff --git a/relay/client/client_test.go b/relay/client/client_test.go index ef28203e9..7ddfba4c6 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -551,7 +551,7 @@ func TestCloseByServer(t *testing.T) { } disconnected := make(chan struct{}) - relayClient.SetOnDisconnectListener(func() { + relayClient.SetOnDisconnectListener(func(_ string) { log.Infof("client disconnected") close(disconnected) }) diff --git a/relay/client/guard.go b/relay/client/guard.go index d6b6b0da5..b971363a8 100644 --- a/relay/client/guard.go +++ b/relay/client/guard.go @@ -4,65 +4,120 @@ import ( "context" "time" + "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" ) var ( - reconnectingTimeout = 5 * time.Second + reconnectingTimeout = 60 * time.Second ) // Guard manage the reconnection tries to the Relay server in case of disconnection event. type Guard struct { - ctx context.Context - relayClient *Client + // OnNewRelayClient is a channel that is used to notify the relay client about a new relay client instance. + OnNewRelayClient chan *Client + serverPicker *ServerPicker } // NewGuard creates a new guard for the relay client. -func NewGuard(context context.Context, relayClient *Client) *Guard { +func NewGuard(sp *ServerPicker) *Guard { g := &Guard{ - ctx: context, - relayClient: relayClient, + OnNewRelayClient: make(chan *Client, 1), + serverPicker: sp, } return g } -// OnDisconnected is called when the relay client is disconnected from the relay server. It will trigger the reconnection +// StartReconnectTrys is called when the relay client is disconnected from the relay server. +// It attempts to reconnect to the relay server. The function first tries a quick reconnect +// to the same server that was used before, if the server URL is still valid. If the quick +// reconnect fails, it starts a ticker to periodically attempt server picking until it +// succeeds or the context is done. +// +// Parameters: +// - ctx: The context to control the lifecycle of the reconnection attempts. +// - relayClient: The relay client instance that was disconnected. // todo prevent multiple reconnection instances. In the current usage it should not happen, but it is better to prevent -func (g *Guard) OnDisconnected() { - if g.quickReconnect() { +func (g *Guard) StartReconnectTrys(ctx context.Context, relayClient *Client) { + if relayClient == nil { + goto RETRY + } + if g.isServerURLStillValid(relayClient) && g.quickReconnect(ctx, relayClient) { return } - ticker := time.NewTicker(reconnectingTimeout) +RETRY: + ticker := exponentTicker(ctx) defer ticker.Stop() for { select { case <-ticker.C: - err := g.relayClient.Connect() - if err != nil { - log.Errorf("failed to reconnect to relay server: %s", err) + if err := g.retry(ctx); err != nil { + log.Errorf("failed to pick new Relay server: %s", err) continue } return - case <-g.ctx.Done(): + case <-ctx.Done(): return } } } -func (g *Guard) quickReconnect() bool { - ctx, cancel := context.WithTimeout(g.ctx, 1500*time.Millisecond) +func (g *Guard) retry(ctx context.Context) error { + log.Infof("try to pick up a new Relay server") + relayClient, err := g.serverPicker.PickServer(ctx) + if err != nil { + return err + } + + // prevent to work with a deprecated Relay client instance + g.drainRelayClientChan() + + g.OnNewRelayClient <- relayClient + return nil +} + +func (g *Guard) quickReconnect(parentCtx context.Context, rc *Client) bool { + ctx, cancel := context.WithTimeout(parentCtx, 1500*time.Millisecond) defer cancel() <-ctx.Done() - if g.ctx.Err() != nil { + if parentCtx.Err() != nil { return false } + log.Infof("try to reconnect to Relay server: %s", rc.connectionURL) - if err := g.relayClient.Connect(); err != nil { + if err := rc.Connect(); err != nil { log.Errorf("failed to reconnect to relay server: %s", err) return false } return true } + +func (g *Guard) drainRelayClientChan() { + select { + case <-g.OnNewRelayClient: + default: + } +} + +func (g *Guard) isServerURLStillValid(rc *Client) bool { + for _, url := range g.serverPicker.ServerURLs.Load().([]string) { + if url == rc.connectionURL { + return true + } + } + return false +} + +func exponentTicker(ctx context.Context) *backoff.Ticker { + bo := backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 2 * time.Second, + Multiplier: 2, + MaxInterval: reconnectingTimeout, + Clock: backoff.SystemClock, + }, ctx) + + return backoff.NewTicker(bo) +} diff --git a/relay/client/manager.go b/relay/client/manager.go index b14a7701b..d847bb879 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -57,12 +57,15 @@ type ManagerService interface { // relay servers will be closed if there is no active connection. Periodically the manager will check if there is any // unused relay connection and close it. type Manager struct { - ctx context.Context - serverURLs []string - peerID string - tokenStore *relayAuth.TokenStore + ctx context.Context + peerID string + running bool + tokenStore *relayAuth.TokenStore + serverPicker *ServerPicker - relayClient *Client + relayClient *Client + // the guard logic can overwrite the relayClient variable, this mutex protect the usage of the variable + relayClientMu sync.Mutex reconnectGuard *Guard relayClients map[string]*RelayTrack @@ -76,48 +79,54 @@ type Manager struct { // NewManager creates a new manager instance. // The serverURL address can be empty. In this case, the manager will not serve. func NewManager(ctx context.Context, serverURLs []string, peerID string) *Manager { - return &Manager{ - ctx: ctx, - serverURLs: serverURLs, - peerID: peerID, - tokenStore: &relayAuth.TokenStore{}, + tokenStore := &relayAuth.TokenStore{} + + m := &Manager{ + ctx: ctx, + peerID: peerID, + tokenStore: tokenStore, + serverPicker: &ServerPicker{ + TokenStore: tokenStore, + PeerID: peerID, + }, relayClients: make(map[string]*RelayTrack), onDisconnectedListeners: make(map[string]*list.List), } + m.serverPicker.ServerURLs.Store(serverURLs) + m.reconnectGuard = NewGuard(m.serverPicker) + return m } -// Serve starts the manager. It will establish a connection to the relay server and start the relay cleanup loop for -// the unused relay connections. The manager will automatically reconnect to the relay server in case of disconnection. +// Serve starts the manager, attempting to establish a connection with the relay server. +// If the connection fails, it will keep trying to reconnect in the background. +// Additionally, it starts a cleanup loop to remove unused relay connections. +// The manager will automatically reconnect to the relay server in case of disconnection. func (m *Manager) Serve() error { - if m.relayClient != nil { + if m.running { return fmt.Errorf("manager already serving") } - log.Debugf("starting relay client manager with %v relay servers", m.serverURLs) + m.running = true + log.Debugf("starting relay client manager with %v relay servers", m.serverPicker.ServerURLs.Load()) - sp := ServerPicker{ - TokenStore: m.tokenStore, - PeerID: m.peerID, - } - - client, err := sp.PickServer(m.ctx, m.serverURLs) + client, err := m.serverPicker.PickServer(m.ctx) if err != nil { - return err + go m.reconnectGuard.StartReconnectTrys(m.ctx, nil) + } else { + m.storeClient(client) } - m.relayClient = client - m.reconnectGuard = NewGuard(m.ctx, m.relayClient) - m.relayClient.SetOnConnectedListener(m.onServerConnected) - m.relayClient.SetOnDisconnectListener(func() { - m.onServerDisconnected(client.connectionURL) - }) - m.startCleanupLoop() - return nil + go m.listenGuardEvent(m.ctx) + go m.startCleanupLoop() + return err } // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be // established via the relay server. If the peer is on a different relay server, the manager will establish a new // connection to the relay server. It returns back with a net.Conn what represent the remote peer connection. func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + if m.relayClient == nil { return nil, ErrRelayClientNotConnected } @@ -146,6 +155,9 @@ func (m *Manager) OpenConn(serverAddress, peerKey string) (net.Conn, error) { // Ready returns true if the home Relay client is connected to the relay server. func (m *Manager) Ready() bool { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + if m.relayClient == nil { return false } @@ -159,6 +171,13 @@ func (m *Manager) SetOnReconnectedListener(f func()) { // AddCloseListener adds a listener to the given server instance address. The listener will be called if the connection // closed. func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServerCloseListener) error { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + + if m.relayClient == nil { + return ErrRelayClientNotConnected + } + foreign, err := m.isForeignServer(serverAddress) if err != nil { return err @@ -177,6 +196,9 @@ func (m *Manager) AddCloseListener(serverAddress string, onClosedListener OnServ // RelayInstanceAddress returns the address of the permanent relay server. It could change if the network connection is // lost. This address will be sent to the target peer to choose the common relay server for the communication. func (m *Manager) RelayInstanceAddress() (string, error) { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + if m.relayClient == nil { return "", ErrRelayClientNotConnected } @@ -185,13 +207,18 @@ func (m *Manager) RelayInstanceAddress() (string, error) { // ServerURLs returns the addresses of the relay servers. func (m *Manager) ServerURLs() []string { - return m.serverURLs + return m.serverPicker.ServerURLs.Load().([]string) } // HasRelayAddress returns true if the manager is serving. With this method can check if the peer can communicate with // Relay service. func (m *Manager) HasRelayAddress() bool { - return len(m.serverURLs) > 0 + return len(m.serverPicker.ServerURLs.Load().([]string)) > 0 +} + +func (m *Manager) UpdateServerURLs(serverURLs []string) { + log.Infof("update relay server URLs: %v", serverURLs) + m.serverPicker.ServerURLs.Store(serverURLs) } // UpdateToken updates the token in the token store. @@ -245,9 +272,7 @@ func (m *Manager) openConnVia(serverAddress, peerKey string) (net.Conn, error) { return nil, err } // if connection closed then delete the relay client from the list - relayClient.SetOnDisconnectListener(func() { - m.onServerDisconnected(serverAddress) - }) + relayClient.SetOnDisconnectListener(m.onServerDisconnected) rt.relayClient = relayClient rt.Unlock() @@ -265,14 +290,37 @@ func (m *Manager) onServerConnected() { go m.onReconnectedListenerFn() } +// onServerDisconnected start to reconnection for home server only func (m *Manager) onServerDisconnected(serverAddress string) { + m.relayClientMu.Lock() if serverAddress == m.relayClient.connectionURL { - go m.reconnectGuard.OnDisconnected() + go m.reconnectGuard.StartReconnectTrys(m.ctx, m.relayClient) } + m.relayClientMu.Unlock() m.notifyOnDisconnectListeners(serverAddress) } +func (m *Manager) listenGuardEvent(ctx context.Context) { + for { + select { + case rc := <-m.reconnectGuard.OnNewRelayClient: + m.storeClient(rc) + case <-ctx.Done(): + return + } + } +} + +func (m *Manager) storeClient(client *Client) { + m.relayClientMu.Lock() + defer m.relayClientMu.Unlock() + + m.relayClient = client + m.relayClient.SetOnConnectedListener(m.onServerConnected) + m.relayClient.SetOnDisconnectListener(m.onServerDisconnected) +} + func (m *Manager) isForeignServer(address string) (bool, error) { rAddr, err := m.relayClient.ServerInstanceURL() if err != nil { @@ -282,22 +330,16 @@ func (m *Manager) isForeignServer(address string) (bool, error) { } func (m *Manager) startCleanupLoop() { - if m.ctx.Err() != nil { - return - } - ticker := time.NewTicker(relayCleanupInterval) - go func() { - defer ticker.Stop() - for { - select { - case <-m.ctx.Done(): - return - case <-ticker.C: - m.cleanUpUnusedRelays() - } + defer ticker.Stop() + for { + select { + case <-m.ctx.Done(): + return + case <-ticker.C: + m.cleanUpUnusedRelays() } - }() + } } func (m *Manager) cleanUpUnusedRelays() { diff --git a/relay/client/picker.go b/relay/client/picker.go index 13b0547aa..eb5062dbb 100644 --- a/relay/client/picker.go +++ b/relay/client/picker.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sync/atomic" "time" log "github.com/sirupsen/logrus" @@ -12,10 +13,13 @@ import ( ) const ( - connectionTimeout = 30 * time.Second maxConcurrentServers = 7 ) +var ( + connectionTimeout = 30 * time.Second +) + type connResult struct { RelayClient *Client Url string @@ -24,20 +28,22 @@ type connResult struct { type ServerPicker struct { TokenStore *auth.TokenStore + ServerURLs atomic.Value PeerID string } -func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) { +func (sp *ServerPicker) PickServer(parentCtx context.Context) (*Client, error) { ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) defer cancel() - totalServers := len(urls) + totalServers := len(sp.ServerURLs.Load().([]string)) connResultChan := make(chan connResult, totalServers) successChan := make(chan connResult, 1) concurrentLimiter := make(chan struct{}, maxConcurrentServers) - for _, url := range urls { + log.Debugf("pick server from list: %v", sp.ServerURLs.Load().([]string)) + for _, url := range sp.ServerURLs.Load().([]string) { // todo check if we have a successful connection so we do not need to connect to other servers concurrentLimiter <- struct{}{} go func(url string) { @@ -78,7 +84,7 @@ func (sp *ServerPicker) processConnResults(resultChan chan connResult, successCh for numOfResults := 0; numOfResults < cap(resultChan); numOfResults++ { cr := <-resultChan if cr.Err != nil { - log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) + log.Tracef("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) continue } log.Infof("connected to Relay server: %s", cr.Url) diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go index 4800e05ba..20a03e64d 100644 --- a/relay/client/picker_test.go +++ b/relay/client/picker_test.go @@ -7,16 +7,19 @@ import ( ) func TestServerPicker_UnavailableServers(t *testing.T) { + connectionTimeout = 5 * time.Second + sp := ServerPicker{ TokenStore: nil, PeerID: "test", } + sp.ServerURLs.Store([]string{"rel://dummy1", "rel://dummy2"}) ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1) defer cancel() go func() { - _, err := sp.PickServer(ctx, []string{"rel://dummy1", "rel://dummy2"}) + _, err := sp.PickServer(ctx) if err == nil { t.Error(err) } From 05c4aa7c2cad19f3679bf2548f8dc296dd1e043b Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 22 Nov 2024 18:50:47 +0100 Subject: [PATCH 032/159] [misc] Renew slack link (#2938) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a2d7f3897..e7925ae09 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,7 @@
- +
@@ -34,7 +34,7 @@
See Documentation
- Join our Slack channel + Join our Slack channel
From 56cecf849ea4f6092b0dc9b421126da399bffa95 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 22 Nov 2024 20:40:30 +0100 Subject: [PATCH 033/159] Import time package (#2940) --- relay/client/picker_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/relay/client/picker_test.go b/relay/client/picker_test.go index 20a03e64d..28167c5ce 100644 --- a/relay/client/picker_test.go +++ b/relay/client/picker_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "testing" + "time" ) func TestServerPicker_UnavailableServers(t *testing.T) { From 940d0c48c69198803b4cd88125214c5cb666bf2b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:11:31 +0100 Subject: [PATCH 034/159] [client] Don't return error in userspace mode without firewall (#2924) --- client/firewall/uspfilter/uspfilter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index af5dc6733..fb726395b 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -239,7 +239,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { // SetLegacyManagement doesn't need to be implemented for this manager func (m *Manager) SetLegacyManagement(isLegacy bool) error { if m.nativeFirewall == nil { - return errRouteNotSupported + return nil } return m.nativeFirewall.SetLegacyManagement(isLegacy) } From 0ecd5f211850d50dcf7179adbf3ae0246705f0a3 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:11:56 +0100 Subject: [PATCH 035/159] [client] Test nftables for incompatible iptables rules (#2948) --- .../firewall/nftables/manager_linux_test.go | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 77f4f0306..33fdc4b3d 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -1,9 +1,11 @@ package nftables import ( + "bytes" "fmt" "net" "net/netip" + "os/exec" "testing" "time" @@ -225,3 +227,105 @@ func TestNFtablesCreatePerformance(t *testing.T) { }) } } + +func runIptablesSave(t *testing.T) (string, string) { + t.Helper() + var stdout, stderr bytes.Buffer + cmd := exec.Command("iptables-save") + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + require.NoError(t, err, "iptables-save failed to run") + + return stdout.String(), stderr.String() +} + +func verifyIptablesOutput(t *testing.T, stdout, stderr string) { + t.Helper() + // Check for any incompatibility warnings + require.NotContains(t, + stderr, + "incompatible", + "iptables-save produced compatibility warning. Full stderr: %s", + stderr, + ) + + // Verify standard tables are present + expectedTables := []string{ + "*filter", + "*nat", + "*mangle", + } + + for _, table := range expectedTables { + require.Contains(t, + stdout, + table, + "iptables-save output missing expected table: %s\nFull stdout: %s", + table, + stdout, + ) + } +} + +func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { + if check() != NFTABLES { + t.Skip("nftables not supported on this system") + } + + if _, err := exec.LookPath("iptables-save"); err != nil { + t.Skipf("iptables-save not available on this system: %v", err) + } + + // First ensure iptables-nft tables exist by running iptables-save + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + + manager, err := Create(ifaceMock) + require.NoError(t, err, "failed to create manager") + require.NoError(t, manager.Init(nil)) + + t.Cleanup(func() { + err := manager.Reset(nil) + require.NoError(t, err, "failed to reset manager state") + + // Verify iptables output after reset + stdout, stderr := runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) + }) + + ip := net.ParseIP("100.96.0.1") + _, err = manager.AddPeerFiltering( + ip, + fw.ProtocolTCP, + nil, + &fw.Port{Values: []int{80}}, + fw.RuleDirectionIN, + fw.ActionAccept, + "", + "test rule", + ) + require.NoError(t, err, "failed to add peer filtering rule") + + _, err = manager.AddRouteFiltering( + []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, + netip.MustParsePrefix("10.1.0.0/24"), + fw.ProtocolTCP, + nil, + &fw.Port{Values: []int{443}}, + fw.ActionAccept, + ) + require.NoError(t, err, "failed to add route filtering rule") + + pair := fw.RouterPair{ + Source: netip.MustParsePrefix("192.168.1.0/24"), + Destination: netip.MustParsePrefix("10.0.0.0/24"), + Masquerade: true, + } + err = manager.AddNatRule(pair) + require.NoError(t, err, "failed to add NAT rule") + + stdout, stderr = runIptablesSave(t) + verifyIptablesOutput(t, stdout, stderr) +} From f1625b32bdd6c1fe0abb73756835947b99d1a97f Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:12:16 +0100 Subject: [PATCH 036/159] [client] Set up sysctl and routing table name only if routing rules are available (#2933) --- .../routemanager/systemops/systemops_linux.go | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 71a0f26ae..ac4fd5c71 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -92,17 +92,6 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager return r.setupRefCounter(initAddresses, stateManager) } - if err = addRoutingTableName(); err != nil { - log.Errorf("Error adding routing table name: %v", err) - } - - originalValues, err := sysctl.Setup(r.wgInterface) - if err != nil { - log.Errorf("Error setting up sysctl: %v", err) - sysctlFailed = true - } - originalSysctl = originalValues - defer func() { if err != nil { if cleanErr := r.CleanupRouting(stateManager); cleanErr != nil { @@ -123,6 +112,17 @@ func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager } } + if err = addRoutingTableName(); err != nil { + log.Errorf("Error adding routing table name: %v", err) + } + + originalValues, err := sysctl.Setup(r.wgInterface) + if err != nil { + log.Errorf("Error setting up sysctl: %v", err) + sysctlFailed = true + } + originalSysctl = originalValues + return nil, nil, nil } From 9810386937edd1665109479f745e6faaf21db840 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:19:56 +0100 Subject: [PATCH 037/159] [client] Allow routing to fallback to exclusion routes if rules are not supported (#2909) --- client/internal/routemanager/systemops/systemops_linux.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index ac4fd5c71..1d629d6e9 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -450,7 +450,7 @@ func addRule(params ruleParams) error { rule.Invert = params.invert rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) { return fmt.Errorf("add routing rule: %w", err) } @@ -467,7 +467,7 @@ func removeRule(params ruleParams) error { rule.Priority = params.priority rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) { return fmt.Errorf("remove routing rule: %w", err) } From ca12bc6953b8ba0e8645b90dc2ef8d50c8c38a01 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 25 Nov 2024 18:26:24 +0300 Subject: [PATCH 038/159] [management] Refactor posture check to use store methods (#2874) --- management/server/account.go | 2 +- management/server/dns.go | 2 +- management/server/group.go | 19 +- .../server/http/posture_checks_handler.go | 3 +- .../http/posture_checks_handler_test.go | 6 +- management/server/mock_server/account_mock.go | 6 +- management/server/nameserver.go | 10 +- management/server/peer.go | 25 +- management/server/policy.go | 6 +- management/server/posture/checks.go | 6 - management/server/posture_checks.go | 337 +++++++++++------- management/server/posture_checks_test.go | 221 +++++++----- management/server/route.go | 10 +- management/server/sql_store.go | 51 ++- management/server/sql_store_test.go | 135 +++++++ management/server/status/error.go | 5 + management/server/store.go | 4 +- management/server/testdata/extended-store.sql | 1 + 18 files changed, 589 insertions(+), 260 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 0ab123655..9fb56c855 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -139,7 +139,7 @@ type AccountManager interface { HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager diff --git a/management/server/dns.go b/management/server/dns.go index 4551be5ab..e52be6016 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -145,7 +145,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) } - if anyGroupHasPeers(account, addedGroups) || anyGroupHasPeers(account, removedGroups) { + if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) { am.updateAccountPeers(ctx, accountID) } diff --git a/management/server/group.go b/management/server/group.go index a36213f04..7b307cf1a 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -566,8 +566,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI return false, nil } -// anyGroupHasPeers checks if any of the given groups in the account have peers. -func anyGroupHasPeers(account *Account, groupIDs []string) bool { +func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool { for _, groupID := range groupIDs { if group, exists := account.Groups[groupID]; exists && group.HasPeers() { return true @@ -575,3 +574,19 @@ func anyGroupHasPeers(account *Account, groupIDs []string) bool { } return false } + +// anyGroupHasPeers checks if any of the given groups in the account have peers. +func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs) + if err != nil { + return false, err + } + + for _, group := range groups { + if group.HasPeers() { + return true, nil + } + } + + return false, nil +} diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 1d020e9bc..2c8204292 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -169,7 +169,8 @@ func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http. return } - if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { + postureChecks, err = p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks) + if err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 02f0f0d83..f400cec81 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -40,15 +40,15 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH } return p, nil }, - SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) error { + SavePostureChecksFunc: func(_ context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { postureChecks.ID = "postureCheck" testPostureChecks[postureChecks.ID] = postureChecks if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint + return nil, status.Errorf(status.InvalidArgument, err.Error()) //nolint } - return nil + return postureChecks, nil }, DeletePostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) error { _, ok := testPostureChecks[postureChecksID] diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index aa6a47b15..673ed33bb 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -96,7 +96,7 @@ type MockAccountManager struct { HasConnectedChannelFunc func(peerID string) bool GetExternalCacheManagerFunc func() server.ExternalCacheManager GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManagerFunc func() idp.Manager @@ -730,11 +730,11 @@ func (am *MockAccountManager) GetPostureChecks(ctx context.Context, accountID, p } // SavePostureChecks mocks SavePostureChecks of the AccountManager interface -func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { +func (am *MockAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { if am.SavePostureChecksFunc != nil { return am.SavePostureChecksFunc(ctx, accountID, userID, postureChecks) } - return status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method SavePostureChecks is not implemented") } // DeletePostureChecks mocks DeletePostureChecks of the AccountManager interface diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 957008714..9119a3dec 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -70,7 +70,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco return nil, err } - if anyGroupHasPeers(account, newNSGroup.Groups) { + if am.anyGroupHasPeers(account, newNSGroup.Groups) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) @@ -105,7 +105,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - if areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { + if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) @@ -135,7 +135,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return err } - if anyGroupHasPeers(account, nsGroup.Groups) { + if am.anyGroupHasPeers(account, nsGroup.Groups) { am.updateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) @@ -279,9 +279,9 @@ func validateDomain(domain string) error { } // areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { +func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { if !newNSGroup.Enabled && !oldNSGroup.Enabled { return false } - return anyGroupHasPeers(account, newNSGroup.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups) + return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups) } diff --git a/management/server/peer.go b/management/server/peer.go index beb833dba..dcb47af3b 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -617,7 +617,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, err } - postureChecks := am.getPeerPostureChecks(account, newPeer) + postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID) + if err != nil { + return nil, nil, nil, err + } + customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) return newPeer, networkMap, postureChecks, nil @@ -702,7 +706,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac if err != nil { return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err) } - postureChecks = am.getPeerPostureChecks(account, peer) + + postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + if err != nil { + return nil, nil, nil, err + } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -876,7 +884,11 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is if err != nil { return nil, nil, nil, err } - postureChecks = am.getPeerPostureChecks(account, peer) + + postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + if err != nil { + return nil, nil, nil, err + } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -1030,7 +1042,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account defer wg.Done() defer func() { <-semaphore }() - postureChecks := am.getPeerPostureChecks(account, p) + postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err) + return + } + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) diff --git a/management/server/policy.go b/management/server/policy.go index 8a5733f01..c7872591d 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -405,7 +405,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - if anyGroupHasPeers(account, policy.ruleGroups()) { + if am.anyGroupHasPeers(account, policy.ruleGroups()) { am.updateAccountPeers(ctx, accountID) } @@ -469,7 +469,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli if !policyToSave.Enabled && !oldPolicy.Enabled { return false, nil } - updateAccountPeers := anyGroupHasPeers(account, oldPolicy.ruleGroups()) || anyGroupHasPeers(account, policyToSave.ruleGroups()) + updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups()) return updateAccountPeers, nil } @@ -477,7 +477,7 @@ func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Poli // Add the new policy to the account account.Policies = append(account.Policies, policyToSave) - return anyGroupHasPeers(account, policyToSave.ruleGroups()), nil + return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil } func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { diff --git a/management/server/posture/checks.go b/management/server/posture/checks.go index f2739dddf..b2f308d76 100644 --- a/management/server/posture/checks.go +++ b/management/server/posture/checks.go @@ -7,8 +7,6 @@ import ( "regexp" "github.com/hashicorp/go-version" - "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" @@ -172,10 +170,6 @@ func NewChecksFromAPIPostureCheckUpdate(source api.PostureCheckUpdate, postureCh } func buildPostureCheck(postureChecksID string, name string, description string, checks api.Checks) (*Checks, error) { - if postureChecksID == "" { - postureChecksID = xid.New().String() - } - postureChecks := Checks{ ID: postureChecksID, Name: name, diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 096cff3f5..59e726c41 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,16 +2,14 @@ package server import ( "context" + "fmt" "slices" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" -) - -const ( - errMsgPostureAdminOnly = "only users with admin power are allowed to view posture checks" + "github.com/rs/xid" + "golang.org/x/exp/maps" ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { @@ -20,219 +18,284 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID return nil, err } - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) - } - - return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) -} - -func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + return nil, status.NewAdminPermissionError() } - if err := postureChecks.Validate(); err != nil { - return status.Errorf(status.InvalidArgument, err.Error()) //nolint + return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) +} + +// SavePostureChecks saves a posture check. +func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) (*posture.Checks, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err } - exists, uniqName := am.savePostureChecks(account, postureChecks) - - // we do not allow create new posture checks with non uniq name - if !exists && !uniqName { - return status.Errorf(status.PreconditionFailed, "Posture check name should be unique") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - action := activity.PostureCheckCreated - if exists { - action = activity.PostureCheckUpdated - account.Network.IncSerial() + if !user.HasAdminPower() { + return nil, status.NewAdminPermissionError() } - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err + var updateAccountPeers bool + var isUpdate = postureChecks.ID != "" + var action = activity.PostureCheckCreated + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil { + return err + } + + if isUpdate { + updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + action = activity.PostureCheckUpdated + } + + postureChecks.AccountID = accountID + return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks) + }) + if err != nil { + return nil, err } am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) - if arePostureCheckChangesAffectingPeers(account, postureChecks.ID, exists) { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - return nil + return postureChecks, nil } +// DeletePostureChecks deletes a posture check by ID. func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - user, err := account.FindUser(userID) - if err != nil { - return err + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + return status.NewAdminPermissionError() } - postureChecks, err := am.deletePostureChecks(account, postureChecksID) + var postureChecks *posture.Checks + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + if err != nil { + return err + } + + if err = isPostureCheckLinkedToPolicy(ctx, transaction, postureChecksID, accountID); err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID) + }) if err != nil { return err } - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - am.StoreEvent(ctx, userID, postureChecks.ID, accountID, activity.PostureCheckDeleted, postureChecks.EventMeta()) return nil } +// ListPostureChecks returns a list of posture checks. func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.HasAdminPower() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if !user.HasAdminPower() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { - uniqName = true - for i, p := range account.PostureChecks { - if !exists && p.ID == postureChecks.ID { - account.PostureChecks[i] = postureChecks - exists = true - } - if p.Name == postureChecks.Name { - uniqName = false - } - } - if !exists { - account.PostureChecks = append(account.PostureChecks, postureChecks) - } - return -} - -func (am *DefaultAccountManager) deletePostureChecks(account *Account, postureChecksID string) (*posture.Checks, error) { - postureChecksIdx := -1 - for i, postureChecks := range account.PostureChecks { - if postureChecks.ID == postureChecksID { - postureChecksIdx = i - break - } - } - if postureChecksIdx < 0 { - return nil, status.Errorf(status.NotFound, "posture checks with ID %s doesn't exist", postureChecksID) - } - - // Check if posture check is linked to any policy - if isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureChecksID); isLinked { - return nil, status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", linkedPolicy.Name) - } - - postureChecks := account.PostureChecks[postureChecksIdx] - account.PostureChecks = append(account.PostureChecks[:postureChecksIdx], account.PostureChecks[postureChecksIdx+1:]...) - - return postureChecks, nil -} - // getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peer *nbpeer.Peer) []*posture.Checks { - peerPostureChecks := make(map[string]posture.Checks) +func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) { + peerPostureChecks := make(map[string]*posture.Checks) - if len(account.PostureChecks) == 0 { + err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + if len(postureChecks) == 0 { + return nil + } + + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil { + return err + } + } + + return nil + }) + if err != nil { + return nil, err + } + + return maps.Values(peerPostureChecks), nil +} + +// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. +func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return false, err + } + + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureCheckID) { + hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + } + } + + return false, nil +} + +// validatePostureChecks validates the posture checks. +func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error { + if err := postureChecks.Validate(); err != nil { + return status.Errorf(status.InvalidArgument, err.Error()) //nolint + } + + // If the posture check already has an ID, verify its existence in the store. + if postureChecks.ID != "" { + if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil { + return err + } return nil } - for _, policy := range account.Policies { - if !policy.Enabled { - continue - } + // For new posture checks, ensure no duplicates by name. + checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + if err != nil { + return err + } - if isPeerInPolicySourceGroups(peer.ID, account, policy) { - addPolicyPostureChecks(account, policy, peerPostureChecks) + for _, check := range checks { + if check.Name == postureChecks.Name && check.ID != postureChecks.ID { + return status.Errorf(status.InvalidArgument, "posture checks with name %s already exists", postureChecks.Name) } } - postureChecksList := make([]*posture.Checks, 0, len(peerPostureChecks)) - for _, check := range peerPostureChecks { - checkCopy := check - postureChecksList = append(postureChecksList, &checkCopy) + postureChecks.ID = xid.New().String() + + return nil +} + +// addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. +func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { + isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy) + if err != nil { + return err } - return postureChecksList + if !isInGroup { + return nil + } + + for _, sourcePostureCheckID := range policy.SourcePostureChecks { + postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID) + if err != nil { + return err + } + peerPostureChecks[sourcePostureCheckID] = postureCheck + } + + return nil } // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(peerID string, account *Account, policy *Policy) bool { +func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue } for _, sourceGroup := range rule.Sources { - group, ok := account.Groups[sourceGroup] - if ok && slices.Contains(group.Peers, peerID) { - return true + group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup) + if err != nil { + return false, fmt.Errorf("failed to check peer in policy source group: %w", err) + } + + if slices.Contains(group.Peers, peerID) { + return true, nil } } } - return false -} - -func addPolicyPostureChecks(account *Account, policy *Policy, peerPostureChecks map[string]posture.Checks) { - for _, sourcePostureCheckID := range policy.SourcePostureChecks { - for _, postureCheck := range account.PostureChecks { - if postureCheck.ID == sourcePostureCheckID { - peerPostureChecks[sourcePostureCheckID] = *postureCheck - } - } - } -} - -func isPostureCheckLinkedToPolicy(account *Account, postureChecksID string) (bool, *Policy) { - for _, policy := range account.Policies { - if slices.Contains(policy.SourcePostureChecks, postureChecksID) { - return true, policy - } - } return false, nil } -// arePostureCheckChangesAffectingPeers checks if the changes in posture checks are affecting peers. -func arePostureCheckChangesAffectingPeers(account *Account, postureCheckID string, exists bool) bool { - if !exists { - return false +// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. +func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return err } - isLinked, linkedPolicy := isPostureCheckLinkedToPolicy(account, postureCheckID) - if !isLinked { - return false + for _, policy := range policies { + if slices.Contains(policy.SourcePostureChecks, postureChecksID) { + return status.Errorf(status.PreconditionFailed, "posture checks have been linked to policy: %s", policy.Name) + } } - return anyGroupHasPeers(account, linkedPolicy.ruleGroups()) + + return nil } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index c63538b9d..3c5c5fc79 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -7,6 +7,7 @@ import ( "github.com/rs/xid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/group" @@ -16,7 +17,6 @@ import ( const ( adminUserID = "adminUserID" regularUserID = "regularUserID" - postureCheckID = "existing-id" postureCheckName = "Existing check" ) @@ -33,7 +33,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { t.Run("Generic posture check flow", func(t *testing.T) { // regular users can not create checks - err := am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) + _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) assert.Error(t, err) // regular users cannot list check @@ -41,8 +41,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // should be possible to create posture check with uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: postureCheckID, + postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -58,8 +57,7 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Len(t, checks, 1) // should not be possible to create posture check with non uniq name - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: "new-id", + _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ GeoLocationCheck: &posture.GeoLocationCheck{ @@ -74,23 +72,20 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.Error(t, err) // admins can update posture checks - err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ - ID: postureCheckID, - Name: postureCheckName, - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.27.0", - }, + postureCheck.Checks = posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.27.0", }, - }) + } + _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck) assert.NoError(t, err) // users should not be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, regularUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID) assert.Error(t, err) // admin should be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheckID, adminUserID) + err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID) assert.NoError(t, err) checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) assert.NoError(t, err) @@ -150,9 +145,22 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - postureCheck := posture.Checks{ - ID: "postureCheck", - Name: "postureCheck", + postureCheckA := &posture.Checks{ + Name: "postureCheckA", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + ProcessCheck: &posture.ProcessCheck{ + Processes: []posture.Process{ + {LinuxPath: "/usr/bin/netbird", MacPath: "/usr/local/bin/netbird"}, + }, + }, + }, + } + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckA) + require.NoError(t, err) + + postureCheckB := &posture.Checks{ + Name: "postureCheckB", AccountID: account.Id, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -169,7 +177,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -187,12 +195,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -215,7 +223,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } // Linking posture check to policy should trigger update account peers and send peer update @@ -238,7 +246,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture checks should update account peers and send peer update t.Run("updating linked to posture check with peers", func(t *testing.T) { - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, @@ -255,7 +263,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -293,7 +301,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.DeletePostureChecks(context.Background(), account.Id, "postureCheck", userID) + err := manager.DeletePostureChecks(context.Background(), account.Id, postureCheckA.ID, userID) assert.NoError(t, err) select { @@ -303,7 +311,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - err = manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update @@ -321,7 +329,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) assert.NoError(t, err) @@ -332,12 +340,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -367,7 +375,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) @@ -379,12 +387,12 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ MinVersion: "0.29.0", }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -409,7 +417,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - SourcePostureChecks: []string{postureCheck.ID}, + SourcePostureChecks: []string{postureCheckB.ID}, } err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) assert.NoError(t, err) @@ -420,7 +428,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - postureCheck.Checks = posture.ChecksDefinition{ + postureCheckB.Checks = posture.ChecksDefinition{ ProcessCheck: &posture.ProcessCheck{ Processes: []posture.Process{ { @@ -429,7 +437,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, } - err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck) + _, err = manager.SavePostureChecks(context.Background(), account.Id, userID, postureCheckB) assert.NoError(t, err) select { @@ -440,80 +448,123 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }) } -func TestArePostureCheckChangesAffectingPeers(t *testing.T) { - account := &Account{ - Policies: []*Policy{ - { - ID: "policyA", - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - }, - }, - SourcePostureChecks: []string{"checkA"}, - }, - }, - Groups: map[string]*group.Group{ - "groupA": { - ID: "groupA", - Peers: []string{"peer1"}, - }, - "groupB": { - ID: "groupB", - Peers: []string{}, - }, - }, - PostureChecks: []*posture.Checks{ - { - ID: "checkA", - }, - { - ID: "checkB", - }, - }, +func TestArePostureCheckChangesAffectPeers(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err, "failed to create account manager") + + account, err := initTestPostureChecksAccount(manager) + require.NoError(t, err, "failed to init testing account") + + groupA := &group.Group{ + ID: "groupA", + AccountID: account.Id, + Peers: []string{"peer1"}, } + groupB := &group.Group{ + ID: "groupB", + AccountID: account.Id, + Peers: []string{}, + } + err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB}) + require.NoError(t, err, "failed to save groups") + + postureCheckA := &posture.Checks{ + Name: "checkA", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, + }, + } + postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA) + require.NoError(t, err, "failed to save postureCheckA") + + postureCheckB := &posture.Checks{ + Name: "checkB", + AccountID: account.Id, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, + }, + } + postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB) + require.NoError(t, err, "failed to save postureCheckB") + + policy := &Policy{ + ID: "policyA", + AccountID: account.Id, + Rules: []*PolicyRule{ + { + ID: "ruleA", + PolicyID: "policyA", + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + }, + }, + SourcePostureChecks: []string{postureCheckA.ID}, + } + + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, false) + require.NoError(t, err, "failed to save policy") + t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "checkB", true) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID) + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check does not exist", func(t *testing.T) { - result := arePostureCheckChangesAffectingPeers(account, "unknown", false) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown") + require.NoError(t, err) assert.False(t, result) }) t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupB"} - account.Policies[0].Rules[0].Destinations = []string{"groupA"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupB"} + policy.Rules[0].Destinations = []string{"groupA"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"groupA"} - account.Policies[0].Rules[0].Destinations = []string{"groupB"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + policy.Rules[0].Sources = []string{"groupA"} + policy.Rules[0].Destinations = []string{"groupB"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.True(t, result) }) - t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { - account.Policies[0].Rules[0].Sources = []string{"nonExistentGroup"} - account.Policies[0].Rules[0].Destinations = []string{"nonExistentGroup"} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { + groupA.Peers = []string{} + err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA) + require.NoError(t, err, "failed to save groups") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.False(t, result) }) - t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { - account.Groups["groupA"].Peers = []string{} - result := arePostureCheckChangesAffectingPeers(account, "checkA", true) + t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { + policy.Rules[0].Sources = []string{"nonExistentGroup"} + policy.Rules[0].Destinations = []string{"nonExistentGroup"} + err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + require.NoError(t, err, "failed to update policy") + + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + require.NoError(t, err) assert.False(t, result) }) } diff --git a/management/server/route.go b/management/server/route.go index dcf2cb0d3..ecb562645 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -237,7 +237,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, err } - if isRouteChangeAffectPeers(account, &newRoute) { + if am.isRouteChangeAffectPeers(account, &newRoute) { am.updateAccountPeers(ctx, accountID) } @@ -323,7 +323,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } - if isRouteChangeAffectPeers(account, oldRoute) || isRouteChangeAffectPeers(account, routeToSave) { + if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { am.updateAccountPeers(ctx, accountID) } @@ -355,7 +355,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) - if isRouteChangeAffectPeers(account, routy) { + if am.isRouteChangeAffectPeers(account, routy) { am.updateAccountPeers(ctx, accountID) } @@ -651,6 +651,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { // isRouteChangeAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers -func isRouteChangeAffectPeers(account *Account, route *route.Route) bool { - return anyGroupHasPeers(account, route.Groups) || anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" +func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool { + return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 278f5443d..47c17bb92 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1305,12 +1305,57 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { - return getRecords[*posture.Checks](s.db, lockStrength, accountID) + var postureChecks []*posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get posture checks from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture checks from store") + } + + return postureChecks, nil } // GetPostureChecksByID retrieves posture checks by their ID and account ID. -func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { - return getRecordByID[posture.Checks](s.db, lockStrength, postureCheckID, accountID) +func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) { + var postureCheck *posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&postureCheck, accountAndIDQueryCondition, accountID, postureChecksID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewPostureChecksNotFoundError(postureChecksID) + } + log.WithContext(ctx).Errorf("failed to get posture check from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture check from store") + } + + return postureCheck, nil +} + +// SavePostureChecks saves a posture checks to the database. +func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save posture checks to store: %s", result.Error) + return status.Errorf(status.Internal, "failed to save posture checks to store") + } + + return nil +} + +// DeletePostureChecks deletes a posture checks from the database. +func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete posture checks from store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete posture checks from store") + } + + if result.RowsAffected == 0 { + return status.NewPostureChecksNotFoundError(postureChecksID) + } + + return nil } // GetAccountRoutes retrieves network routes for an account. diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 114da1ee6..de939e8d0 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" nbdns "github.com/netbirdio/netbird/dns" nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/posture" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1564,3 +1565,137 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) { }) } } + +func TestSqlStore_GetPostureChecksByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + postureChecksID string + expectError bool + }{ + { + name: "retrieve existing posture checks", + postureChecksID: "csplshq7qv948l48f7t0", + expectError: false, + }, + { + name: "retrieve non-existing posture checks", + postureChecksID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty posture checks ID", + postureChecksID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + postureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, postureChecks) + } else { + require.NoError(t, err) + require.NotNil(t, postureChecks) + require.Equal(t, tt.postureChecksID, postureChecks.ID) + } + }) + } +} + +func TestSqlStore_SavePostureChecks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + postureChecks := &posture.Checks{ + ID: "posture-checks-id", + AccountID: accountID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.31.0", + }, + OSVersionCheck: &posture.OSVersionCheck{ + Ios: &posture.MinVersionCheck{ + MinVersion: "13.0.1", + }, + Linux: &posture.MinKernelVersionCheck{ + MinKernelVersion: "5.3.3-dev", + }, + }, + GeoLocationCheck: &posture.GeoLocationCheck{ + Locations: []posture.Location{ + { + CountryCode: "DE", + CityName: "Berlin", + }, + }, + Action: posture.CheckActionAllow, + }, + }, + } + err = store.SavePostureChecks(context.Background(), LockingStrengthUpdate, postureChecks) + require.NoError(t, err) + + savePostureChecks, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, "posture-checks-id") + require.NoError(t, err) + require.Equal(t, savePostureChecks, postureChecks) +} + +func TestSqlStore_DeletePostureChecks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + postureChecksID string + expectError bool + }{ + { + name: "delete existing posture checks", + postureChecksID: "csplshq7qv948l48f7t0", + expectError: false, + }, + { + name: "delete non-existing posture checks", + postureChecksID: "non-existing-posture-checks-id", + expectError: true, + }, + { + name: "delete with empty posture checks ID", + postureChecksID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err = store.DeletePostureChecks(context.Background(), LockingStrengthUpdate, accountID, tt.postureChecksID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + } else { + require.NoError(t, err) + group, err := store.GetPostureChecksByID(context.Background(), LockingStrengthShare, accountID, tt.postureChecksID) + require.Error(t, err) + require.Nil(t, group) + } + }) + } +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 8b6d0077b..44391e1f1 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -139,3 +139,8 @@ func NewGetAccountError(err error) error { func NewGroupNotFoundError(groupID string) error { return Errorf(NotFound, "group: %s not found", groupID) } + +// NewPostureChecksNotFoundError creates a new Error with NotFound type for a missing posture checks +func NewPostureChecksNotFoundError(postureChecksID string) error { + return Errorf(NotFound, "posture checks: %s not found", postureChecksID) +} diff --git a/management/server/store.go b/management/server/store.go index 71b0d457b..03b5821e7 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -84,7 +84,9 @@ type Store interface { GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) - GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) + GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) + SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error + DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index b522741e7..1646ff4da 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -34,4 +34,5 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003' INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); +INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO installations VALUES(1,''); From f118d81d3219b78b689206523e4159fcb495fa12 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 26 Nov 2024 12:46:05 +0300 Subject: [PATCH 039/159] [management] Refactor policy to use store methods (#2878) --- management/server/account.go | 2 +- management/server/account_test.go | 57 ++-- management/server/group_test.go | 5 +- management/server/http/policies_handler.go | 26 +- .../server/http/policies_handler_test.go | 4 +- management/server/mock_server/account_mock.go | 8 +- management/server/peer_test.go | 31 +- management/server/policy.go | 319 +++++++++++------- management/server/policy_test.go | 165 +++------ management/server/posture_checks_test.go | 43 +-- management/server/route_test.go | 3 +- management/server/setupkey_test.go | 5 +- management/server/sql_store.go | 82 ++++- management/server/sql_store_test.go | 158 +++++++++ management/server/status/error.go | 5 + management/server/store.go | 6 +- management/server/testdata/extended-store.sql | 1 + management/server/user_test.go | 5 +- 18 files changed, 576 insertions(+), 349 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 9fb56c855..fbe6fcc1a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -113,7 +113,7 @@ type AccountManager interface { GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error + SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 97e0d45f0..c8c2d5941 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1238,8 +1238,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { return } - policy := Policy{ - ID: "policy", + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1250,8 +1249,7 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1320,19 +1318,6 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - policy := Policy{ - Enabled: true, - Rules: []*PolicyRule{ - { - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -1345,7 +1330,19 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } }() - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupA"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + }) + if err != nil { t.Errorf("delete default rule: %v", err) return } @@ -1366,7 +1363,7 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - policy := Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1377,9 +1374,8 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + }) + if err != nil { t.Errorf("save policy: %v", err) return } @@ -1421,7 +1417,12 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { require.NoError(t, err, "failed to save group") - policy := Policy{ + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1432,14 +1433,8 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { - t.Errorf("delete default rule: %v", err) - return - } - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false); err != nil { + }) + if err != nil { t.Errorf("save policy: %v", err) return } diff --git a/management/server/group_test.go b/management/server/group_test.go index 59094a23e..ec017fc57 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -500,8 +500,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { }) // adding a group to policy - err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ - ID: "policy", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -512,7 +511,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - }, false) + }) assert.NoError(t, err) // Saving a group linked to policy should update account peers and send peer update diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 73f3803b5..eff9092d4 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -6,10 +6,8 @@ import ( "strconv" "github.com/gorilla/mux" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/rs/xid" - "github.com/netbirdio/netbird/management/server" + nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -122,21 +120,22 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - isUpdate := policyID != "" - - if policyID == "" { - policyID = xid.New().String() - } - - policy := server.Policy{ + policy := &server.Policy{ ID: policyID, + AccountID: accountID, Name: req.Name, Enabled: req.Enabled, Description: req.Description, } for _, rule := range req.Rules { + var ruleID string + if rule.Id != nil { + ruleID = *rule.Id + } + pr := server.PolicyRule{ - ID: policyID, // TODO: when policy can contain multiple rules, need refactor + ID: ruleID, + PolicyID: policyID, Name: rule.Name, Destinations: rule.Destinations, Sources: rule.Sources, @@ -225,7 +224,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID policy.SourcePostureChecks = *req.SourcePostureChecks } - if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil { + policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy) + if err != nil { util.WriteError(r.Context(), err, w) return } @@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID return } - resp := toPolicyResponse(allGroups, &policy) + resp := toPolicyResponse(allGroups, policy) if len(resp.Rules) == 0 { util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) return diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 228ebcbce..f8a897eb2 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } - return nil + return policy, nil }, GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 673ed33bb..46a4fbc1f 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -49,7 +49,7 @@ type MockAccountManager struct { GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) @@ -386,11 +386,11 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy, isUpdate bool) error { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) { if am.SavePolicyFunc != nil { - return am.SavePolicyFunc(ctx, accountID, userID, policy, isUpdate) + return am.SavePolicyFunc(ctx, accountID, userID, policy) } - return status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") + return nil, status.Errorf(codes.Unimplemented, "method SavePolicy is not implemented") } // DeletePolicy mock implementation of DeletePolicy from server.AccountManager interface diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 4e2dcb2c3..e410fa892 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -283,14 +283,12 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { var ( group1 nbgroup.Group group2 nbgroup.Group - policy Policy ) group1.ID = xid.New().String() group2.ID = xid.New().String() group1.Name = "src" group2.Name = "dst" - policy.ID = xid.New().String() group1.Peers = append(group1.Peers, peer1.ID) group2.Peers = append(group2.Peers, peer2.ID) @@ -305,18 +303,20 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - policy.Name = "test" - policy.Enabled = true - policy.Rules = []*PolicyRule{ - { - Enabled: true, - Sources: []string{group1.ID}, - Destinations: []string{group2.ID}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, + policy := &Policy{ + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{group1.ID}, + Destinations: []string{group2.ID}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -364,7 +364,7 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { } policy.Enabled = false - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) if err != nil { t.Errorf("expecting rule to be added, got failure %v", err) return @@ -1445,8 +1445,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { // Adding peer to group linked with policy should update account peers and send peer update t.Run("adding peer to group linked with policy", func(t *testing.T) { - err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ - ID: "policy", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1457,7 +1456,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - }, false) + }) require.NoError(t, err) done := make(chan struct{}) diff --git a/management/server/policy.go b/management/server/policy.go index c7872591d..2d3abc3f1 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,13 +3,13 @@ package server import ( "context" _ "embed" - "slices" "strconv" "strings" + "github.com/netbirdio/netbird/management/proto" + "github.com/rs/xid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -125,6 +125,7 @@ type PolicyRule struct { func (pm *PolicyRule) Copy() *PolicyRule { rule := &PolicyRule{ ID: pm.ID, + PolicyID: pm.PolicyID, Name: pm.Name, Description: pm.Description, Enabled: pm.Enabled, @@ -171,6 +172,7 @@ type Policy struct { func (p *Policy) Copy() *Policy { c := &Policy{ ID: p.ID, + AccountID: p.AccountID, Name: p.Name, Description: p.Description, Enabled: p.Enabled, @@ -343,157 +345,207 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy, isUpdate bool) error { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { - return err + return nil, err } - updateAccountPeers, err := am.savePolicy(account, policy, isUpdate) + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + var isUpdate = policy.ID != "" + var updateAccountPeers bool + var action = activity.PolicyAdded + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { + return err + } + + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + saveFunc := transaction.CreatePolicy + if isUpdate { + action = activity.PolicyUpdated + saveFunc = transaction.SavePolicy + } + + return saveFunc(ctx, LockingStrengthUpdate, policy) + }) if err != nil { - return err + return nil, err } - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - action := activity.PolicyAdded - if isUpdate { - action = activity.PolicyUpdated - } am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } + return policy, nil +} + +// DeletePolicy from the store +func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err + } + + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return status.NewAdminPermissionError() + } + + var policy *Policy + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID) + if err != nil { + return err + } + + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID) + }) + if err != nil { + return err + } + + am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) + + if updateAccountPeers { + am.updateAccountPeers(ctx, accountID) + } + return nil } -// DeletePolicy from the store -func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - policy, err := am.deletePolicy(account, policyID) - if err != nil { - return err - } - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta()) - - if am.anyGroupHasPeers(account, policy.ruleGroups()) { - am.updateAccountPeers(ctx, accountID) - } - - return nil -} - -// ListPolicies from the store +// ListPolicies from the store. func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) } -func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { - policyIdx := -1 - for i, policy := range account.Policies { - if policy.ID == policyID { - policyIdx = i - break - } - } - if policyIdx < 0 { - return nil, status.Errorf(status.NotFound, "rule with ID %s doesn't exist", policyID) - } - - policy := account.Policies[policyIdx] - account.Policies = append(account.Policies[:policyIdx], account.Policies[policyIdx+1:]...) - return policy, nil -} - -// savePolicy saves or updates a policy in the given account. -// If isUpdate is true, the function updates the existing policy; otherwise, it adds a new policy. -func (am *DefaultAccountManager) savePolicy(account *Account, policyToSave *Policy, isUpdate bool) (bool, error) { - for index, rule := range policyToSave.Rules { - rule.Sources = filterValidGroupIDs(account, rule.Sources) - rule.Destinations = filterValidGroupIDs(account, rule.Destinations) - policyToSave.Rules[index] = rule - } - - if policyToSave.SourcePostureChecks != nil { - policyToSave.SourcePostureChecks = filterValidPostureChecks(account, policyToSave.SourcePostureChecks) - } - +// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. +func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) { if isUpdate { - policyIdx := slices.IndexFunc(account.Policies, func(policy *Policy) bool { return policy.ID == policyToSave.ID }) - if policyIdx < 0 { - return false, status.Errorf(status.NotFound, "couldn't find policy id %s", policyToSave.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + if err != nil { + return false, err } - oldPolicy := account.Policies[policyIdx] - // Update the existing policy - account.Policies[policyIdx] = policyToSave - - if !policyToSave.Enabled && !oldPolicy.Enabled { + if !policy.Enabled && !existingPolicy.Enabled { return false, nil } - updateAccountPeers := am.anyGroupHasPeers(account, oldPolicy.ruleGroups()) || am.anyGroupHasPeers(account, policyToSave.ruleGroups()) - return updateAccountPeers, nil - } + hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups()) + if err != nil { + return false, err + } - // Add the new policy to the account - account.Policies = append(account.Policies, policyToSave) - - return am.anyGroupHasPeers(account, policyToSave.ruleGroups()), nil -} - -func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { - result := make([]*proto.FirewallRule, len(rules)) - for i := range rules { - rule := rules[i] - - result[i] = &proto.FirewallRule{ - PeerIP: rule.PeerIP, - Direction: getProtoDirection(rule.Direction), - Action: getProtoAction(rule.Action), - Protocol: getProtoProtocol(rule.Protocol), - Port: rule.Port, + if hasPeers { + return true, nil } } - return result + + return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) +} + +// validatePolicy validates the policy and its rules. +func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error { + if policy.ID != "" { + _, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + if err != nil { + return err + } + } else { + policy.ID = xid.New().String() + policy.AccountID = accountID + } + + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups()) + if err != nil { + return err + } + + postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks) + if err != nil { + return err + } + + for i, rule := range policy.Rules { + ruleCopy := rule.Copy() + if ruleCopy.ID == "" { + ruleCopy.ID = policy.ID // TODO: when policy can contain multiple rules, need refactor + ruleCopy.PolicyID = policy.ID + } + + ruleCopy.Sources = getValidGroupIDs(groups, ruleCopy.Sources) + ruleCopy.Destinations = getValidGroupIDs(groups, ruleCopy.Destinations) + policy.Rules[i] = ruleCopy + } + + if policy.SourcePostureChecks != nil { + policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) + } + + return nil } // getAllPeersFromGroups for given peer ID and list of groups @@ -574,27 +626,42 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { return nil } -// filterValidPostureChecks filters and returns the posture check IDs from the given list -// that are valid within the provided account. -func filterValidPostureChecks(account *Account, postureChecksIds []string) []string { - result := make([]string, 0, len(postureChecksIds)) +// getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. +func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string { + validIDs := make([]string, 0, len(postureChecksIds)) for _, id := range postureChecksIds { - for _, postureCheck := range account.PostureChecks { - if id == postureCheck.ID { - result = append(result, id) - continue - } + if _, exists := postureChecks[id]; exists { + validIDs = append(validIDs, id) } } - return result + + return validIDs } -// filterValidGroupIDs filters a list of group IDs and returns only the ones present in the account's group map. -func filterValidGroupIDs(account *Account, groupIDs []string) []string { - result := make([]string, 0, len(groupIDs)) - for _, groupID := range groupIDs { - if _, exists := account.Groups[groupID]; exists { - result = append(result, groupID) +// getValidGroupIDs filters and returns only the valid group IDs from the provided list. +func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string { + validIDs := make([]string, 0, len(groupIDs)) + for _, id := range groupIDs { + if _, exists := groups[id]; exists { + validIDs = append(validIDs, id) + } + } + + return validIDs +} + +// toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. +func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { + result := make([]*proto.FirewallRule, len(rules)) + for i := range rules { + rule := rules[i] + + result[i] = &proto.FirewallRule{ + PeerIP: rule.PeerIP, + Direction: getProtoDirection(rule.Direction), + Action: getProtoAction(rule.Action), + Protocol: getProtoProtocol(rule.Protocol), + Port: rule.Port, } } return result diff --git a/management/server/policy_test.go b/management/server/policy_test.go index e7f0f9cd2..62d80f46e 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - "github.com/rs/xid" "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" @@ -859,14 +858,23 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) + var policyWithGroupRulesNoPeers *Policy + var policyWithDestinationPeersOnly *Policy + var policyWithSourceAndDestinationPeers *Policy + // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-rule-groups-no-peers", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, @@ -874,15 +882,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldNotReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -895,12 +895,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with source group containing peers, but destination group without peers should // update account's peers and send peer update t.Run("saving policy where source has peers but destination does not", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-has-peers-destination-none", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupB"}, @@ -909,15 +914,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -930,13 +927,18 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with destination group containing peers, but source group without peers should // update account's peers and send peer update t.Run("saving policy where destination has peers but source does not", func(t *testing.T) { - policy := Policy{ - ID: "policy-destination-has-peers-source-none", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), - Enabled: false, + Enabled: true, Sources: []string{"groupC"}, Destinations: []string{"groupD"}, Bidirectional: true, @@ -944,15 +946,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -965,12 +959,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Saving policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("saving policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: true, + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + AccountID: account.Id, + Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupD"}, @@ -978,15 +977,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { Action: PolicyTrafficActionAccept, }, }, - } - - done := make(chan struct{}) - go func() { - peerShouldReceiveUpdate(t, updMsg) - close(done) - }() - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) select { @@ -999,28 +990,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Disabling policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("disabling policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: false, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupD"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Enabled = false + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1033,29 +1010,15 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Updating disabled policy with destination and source groups containing peers should not update account's peers // or send peer update t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Description: "updated description", - Enabled: false, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupA"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Description = "updated description" + policyWithSourceAndDestinationPeers.Rules[0].Destinations = []string{"groupA"} + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1068,28 +1031,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Enabling policy with destination and source groups containing peers should update account's peers // and send peer update t.Run("enabling policy with source and destination groups with peers", func(t *testing.T) { - policy := Policy{ - ID: "policy-source-destination-peers", - Enabled: true, - Rules: []*PolicyRule{ - { - ID: xid.New().String(), - Enabled: true, - Sources: []string{"groupA"}, - Destinations: []string{"groupD"}, - Bidirectional: true, - Action: PolicyTrafficActionAccept, - }, - }, - } - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + policyWithSourceAndDestinationPeers.Enabled = true + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, policyWithSourceAndDestinationPeers) assert.NoError(t, err) select { @@ -1101,15 +1050,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy should trigger account peers update and send peer update t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) { - policyID := "policy-source-destination-peers" - done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithSourceAndDestinationPeers.ID, userID) assert.NoError(t, err) select { @@ -1123,14 +1070,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy with destination group containing peers, but source group without peers should // update account's peers and send peer update t.Run("deleting policy where destination has peers but source does not", func(t *testing.T) { - policyID := "policy-destination-has-peers-source-none" done := make(chan struct{}) go func() { peerShouldReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithDestinationPeersOnly.ID, userID) assert.NoError(t, err) select { @@ -1142,14 +1088,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { // Deleting policy with no peers in groups should not update account's peers and not send peer update t.Run("deleting policy with no peers in groups", func(t *testing.T) { - policyID := "policy-rule-groups-no-peers" done := make(chan struct{}) go func() { peerShouldNotReceiveUpdate(t, updMsg) close(done) }() - err := manager.DeletePolicy(context.Background(), account.Id, policyID, userID) + err := manager.DeletePolicy(context.Background(), account.Id, policyWithGroupRulesNoPeers.ID, userID) assert.NoError(t, err) select { diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 3c5c5fc79..93e5741cf 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -5,7 +5,6 @@ import ( "testing" "time" - "github.com/rs/xid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -210,12 +209,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - policy := Policy{ - ID: "policyA", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, @@ -234,7 +231,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) assert.NoError(t, err) select { @@ -282,8 +279,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }() policy.SourcePostureChecks = []string{} - - err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + _, err := manager.SavePolicy(context.Background(), account.Id, userID, policy) assert.NoError(t, err) select { @@ -316,12 +312,10 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { - policy = Policy{ - ID: "policyB", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, @@ -330,8 +324,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + }) assert.NoError(t, err) done := make(chan struct{}) @@ -362,12 +355,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { t.Cleanup(func() { manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) }) - policy = Policy{ - ID: "policyB", + + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { - ID: xid.New().String(), Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupA"}, @@ -376,9 +368,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - } - - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + }) assert.NoError(t, err) done := make(chan struct{}) @@ -405,8 +395,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked client posture check to policy where source has peers but destination does not, // should trigger account peers update and send peer update t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { - policy = Policy{ - ID: "policyB", + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -418,8 +407,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { }, }, SourcePostureChecks: []string{postureCheckB.ID}, - } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true) + }) assert.NoError(t, err) done := make(chan struct{}) @@ -490,12 +478,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { require.NoError(t, err, "failed to save postureCheckB") policy := &Policy{ - ID: "policyA", AccountID: account.Id, Rules: []*PolicyRule{ { - ID: "ruleA", - PolicyID: "policyA", Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, @@ -504,7 +489,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { SourcePostureChecks: []string{postureCheckA.ID}, } - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, false) + policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to save policy") t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { @@ -528,7 +513,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupB"} policy.Rules[0].Destinations = []string{"groupA"} - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) @@ -539,7 +524,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupA"} policy.Rules[0].Destinations = []string{"groupB"} - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) @@ -560,7 +545,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { policy.Rules[0].Sources = []string{"nonExistentGroup"} policy.Rules[0].Destinations = []string{"nonExistentGroup"} - err = manager.SavePolicy(context.Background(), account.Id, userID, policy, true) + _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) require.NoError(t, err, "failed to update policy") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) diff --git a/management/server/route_test.go b/management/server/route_test.go index 5c848f68c..108f791e0 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1214,12 +1214,11 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { defaultRule := rules[0] newPolicy := defaultRule.Copy() - newPolicy.ID = xid.New().String() newPolicy.Name = "peer1 only" newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy, false) + _, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) require.NoError(t, err) err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 7c8200706..614547c60 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -406,8 +406,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { }) assert.NoError(t, err) - policy := Policy{ - ID: "policy", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -419,7 +418,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 47c17bb92..9a24857d1 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1243,8 +1243,8 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren var groups []*nbgroup.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { - log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") + log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") } groupsMap := make(map[string]*nbgroup.Group) @@ -1295,12 +1295,67 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - return getRecords[*Policy](s.db.Preload(clause.Associations), lockStrength, accountID) + var policies []*Policy + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get policies from store") + } + + return policies, nil } // GetPolicyByID retrieves a policy by its ID and account ID. -func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { - return getRecordByID[Policy](s.db.Preload(clause.Associations), lockStrength, policyID, accountID) +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) { + var policy *Policy + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). + First(&policy, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewPolicyNotFoundError(policyID) + } + log.WithContext(ctx).Errorf("failed to get policy from store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get policy from store") + } + + return policy, nil +} + +func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error) + return status.Errorf(status.Internal, "failed to create policy in store") + } + + return nil +} + +// SavePolicy saves a policy to the database. +func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err) + return status.Errorf(status.Internal, "failed to save policy to store") + } + return nil +} + +func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) + return status.Errorf(status.Internal, "failed to delete policy from store") + } + + if result.RowsAffected == 0 { + return status.NewPolicyNotFoundError(policyID) + } + + return nil } // GetAccountPostureChecks retrieves posture checks for an account. @@ -1331,6 +1386,23 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin return postureCheck, nil } +// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID. +func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) { + var postureChecks []*posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store") + } + + postureChecksMap := make(map[string]*posture.Checks) + for _, postureCheck := range postureChecks { + postureChecksMap[postureCheck.ID] = postureCheck + } + + return postureChecksMap, nil +} + // SavePostureChecks saves a posture checks to the database. func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index de939e8d0..c05793fc6 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1612,6 +1612,49 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) { } } +func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + postureCheckIDs []string + expectedCount int + }{ + { + name: "retrieve existing posture checks by existing IDs", + postureCheckIDs: []string{"csplshq7qv948l48f7t0", "cspnllq7qv95uq1r4k90"}, + expectedCount: 2, + }, + { + name: "empty posture check IDs list", + postureCheckIDs: []string{}, + expectedCount: 0, + }, + { + name: "non-existing posture check IDs", + postureCheckIDs: []string{"nonexistent1", "nonexistent2"}, + expectedCount: 0, + }, + { + name: "mixed existing and non-existing posture check IDs", + postureCheckIDs: []string{"cspnllq7qv95uq1r4k90", "nonexistent"}, + expectedCount: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + groups, err := store.GetPostureChecksByIDs(context.Background(), LockingStrengthShare, accountID, tt.postureCheckIDs) + require.NoError(t, err) + require.Len(t, groups, tt.expectedCount) + }) + } +} + func TestSqlStore_SavePostureChecks(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) @@ -1699,3 +1742,118 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) { }) } } + +func TestSqlStore_GetPolicyByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + policyID string + expectError bool + }{ + { + name: "retrieve existing policy", + policyID: "cs1tnh0hhcjnqoiuebf0", + expectError: false, + }, + { + name: "retrieve non-existing policy checks", + policyID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty policy ID", + policyID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, tt.policyID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, policy) + } else { + require.NoError(t, err) + require.NotNil(t, policy) + require.Equal(t, tt.policyID, policy.ID) + } + }) + } +} + +func TestSqlStore_CreatePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + policy := &Policy{ + ID: "policy-id", + AccountID: accountID, + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"groupA"}, + Destinations: []string{"groupC"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + err = store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err) + + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + require.NoError(t, err) + require.Equal(t, savePolicy, policy) + +} + +func TestSqlStore_SavePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + policyID := "cs1tnh0hhcjnqoiuebf0" + + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + require.NoError(t, err) + + policy.Enabled = false + policy.Description = "policy" + policy.Rules[0].Sources = []string{"group"} + policy.Rules[0].Ports = []string{"80", "443"} + err = store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + require.NoError(t, err) + + savePolicy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policy.ID) + require.NoError(t, err) + require.Equal(t, savePolicy, policy) +} + +func TestSqlStore_DeletePolicy(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + policyID := "cs1tnh0hhcjnqoiuebf0" + + err = store.DeletePolicy(context.Background(), LockingStrengthShare, accountID, policyID) + require.NoError(t, err) + + policy, err := store.GetPolicyByID(context.Background(), LockingStrengthShare, accountID, policyID) + require.Error(t, err) + require.Nil(t, policy) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 44391e1f1..0fff53559 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -144,3 +144,8 @@ func NewGroupNotFoundError(groupID string) error { func NewPostureChecksNotFoundError(postureChecksID string) error { return Errorf(NotFound, "posture checks: %s not found", postureChecksID) } + +// NewPolicyNotFoundError creates a new Error with NotFound type for a missing policy +func NewPolicyNotFoundError(policyID string) error { + return Errorf(NotFound, "policy: %s not found", policyID) +} diff --git a/management/server/store.go b/management/server/store.go index 03b5821e7..ba61d552d 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -80,11 +80,15 @@ type Store interface { DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) - GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) + CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) + GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 1646ff4da..37db27316 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -35,4 +35,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-3465 INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); +INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); INSERT INTO installations VALUES(1,''); diff --git a/management/server/user_test.go b/management/server/user_test.go index d4f560a54..498017afa 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -1279,8 +1279,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { }) require.NoError(t, err) - policy := Policy{ - ID: "policy", + policy := &Policy{ Enabled: true, Rules: []*PolicyRule{ { @@ -1292,7 +1291,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { }, }, } - err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, false) + _, err = manager.SavePolicy(context.Background(), account.Id, userID, policy) require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) From 0e48a772ff8de32cb6710f8d2f8c36a6bd468551 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 26 Nov 2024 15:43:05 +0300 Subject: [PATCH 040/159] [management] Refactor DNS settings to use store methods (#2883) --- management/server/dns.go | 164 ++++++++++++++++++++-------- management/server/sql_store.go | 21 +++- management/server/sql_store_test.go | 63 +++++++++++ management/server/store.go | 1 + 4 files changed, 204 insertions(+), 45 deletions(-) diff --git a/management/server/dns.go b/management/server/dns.go index e52be6016..8df211b0b 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "slices" "strconv" "sync" @@ -85,8 +86,12 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view DNS settings") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) @@ -94,64 +99,137 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s // SaveDNSSettings validates a user role and updates the account's DNS settings func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return err - } - - user, err := account.FindUser(userID) - if err != nil { - return err - } - - if !user.HasAdminPower() { - return status.Errorf(status.PermissionDenied, "only users with admin power are allowed to update DNS settings") - } - if dnsSettingsToSave == nil { return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") } - if len(dnsSettingsToSave.DisabledManagementGroups) != 0 { - err = validateGroups(dnsSettingsToSave.DisabledManagementGroups, account.Groups) - if err != nil { - return err - } - } - - oldSettings := account.DNSSettings.Copy() - account.DNSSettings = dnsSettingsToSave.Copy() - - addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) - removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { return err } - for _, id := range addedGroups { - group := account.GetGroup(id) - meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - for _, id := range removedGroups { - group := account.GetGroup(id) - meta := map[string]any{"group": group.Name, "group_id": group.ID} - am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + if !user.HasAdminPower() { + return status.NewAdminPermissionError() } - if am.anyGroupHasPeers(account, addedGroups) || am.anyGroupHasPeers(account, removedGroups) { + var updateAccountPeers bool + var eventsToStore []func() + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil { + return err + } + + oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID) + if err != nil { + return err + } + + addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) + removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) + + updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups) + if err != nil { + return err + } + + events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) + eventsToStore = append(eventsToStore, events...) + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave) + }) + if err != nil { + return err + } + + for _, storeEvent := range eventsToStore { + storeEvent() + } + + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } return nil } +// prepareDNSSettingsEvents prepares a list of event functions to be stored. +func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() { + var eventsToStore []func() + + modifiedGroups := slices.Concat(addedGroups, removedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + if err != nil { + log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) + return nil + } + + for _, groupID := range addedGroups { + group, ok := groups[groupID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupAddedToDisabledManagementGroups activity", groupID) + continue + } + + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupAddedToDisabledManagementGroups, meta) + }) + + } + + for _, groupID := range removedGroups { + group, ok := groups[groupID] + if !ok { + log.WithContext(ctx).Debugf("skipped adding group: %s GroupRemovedFromDisabledManagementGroups activity", groupID) + continue + } + + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{"group": group.Name, "group_id": group.ID} + am.StoreEvent(ctx, userID, accountID, accountID, activity.GroupRemovedFromDisabledManagementGroups, meta) + }) + } + + return eventsToStore +} + +// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers. +func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) { + hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeers(ctx, transaction, accountID, removedGroups) +} + +// validateDNSSettings validates the DNS settings. +func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error { + if len(settings.DisabledManagementGroups) == 0 { + return nil + } + + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups) + if err != nil { + return err + } + + return validateGroups(settings.DisabledManagementGroups, groups) +} + // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig { protoUpdate := &proto.DNSConfig{ diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 9a24857d1..f58ceb1ad 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1162,9 +1162,10 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki First(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "dns settings not found") + return nil, status.NewAccountNotFoundError(accountID) } - return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error) + log.WithContext(ctx).Errorf("failed to get dns settings from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get dns settings from store") } return &accountDNSSettings.DNSSettings, nil } @@ -1537,3 +1538,19 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } return &record, nil } + +// SaveDNSSettings saves the DNS settings to the store. +func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings}) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save dns settings to store") + } + + if result.RowsAffected == 0 { + return status.NewAccountNotFoundError(accountID) + } + + return nil +} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index c05793fc6..df5294d73 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1857,3 +1857,66 @@ func TestSqlStore_DeletePolicy(t *testing.T) { require.Error(t, err) require.Nil(t, policy) } + +func TestSqlStore_GetDNSSettings(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectError bool + }{ + { + name: "retrieve existing account dns settings", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectError: false, + }, + { + name: "retrieve non-existing account dns settings", + accountID: "non-existing", + expectError: true, + }, + { + name: "retrieve dns settings with empty account ID", + accountID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, tt.accountID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, dnsSettings) + } else { + require.NoError(t, err) + require.NotNil(t, dnsSettings) + } + }) + } +} + +func TestSqlStore_SaveDNSSettings(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + dnsSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + + dnsSettings.DisabledManagementGroups = []string{"groupA", "groupB"} + err = store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, dnsSettings) + require.NoError(t, err) + + saveDNSSettings, err := store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.Equal(t, saveDNSSettings, dnsSettings) +} diff --git a/management/server/store.go b/management/server/store.go index ba61d552d..cca014b52 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -59,6 +59,7 @@ type Store interface { SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error + SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) From 9683da54b06846f5d09dafc39e917f68d48f9067 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 26 Nov 2024 19:39:04 +0300 Subject: [PATCH 041/159] [management] Refactor nameserver groups to use store methods (#2888) --- management/server/nameserver.go | 213 +++++++++++------- management/server/sql_store.go | 49 +++- management/server/sql_store_test.go | 131 +++++++++++ management/server/status/error.go | 5 + management/server/store.go | 2 + management/server/testdata/extended-store.sql | 1 + 6 files changed, 319 insertions(+), 82 deletions(-) diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 9119a3dec..e7a5387a1 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -24,26 +24,34 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() + } + + return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID) } // CreateNameServerGroup creates and saves a new nameserver group func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainEnabled bool) (*nbdns.NameServerGroup, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + newNSGroup := &nbdns.NameServerGroup{ ID: xid.New().String(), + AccountID: accountID, Name: name, Description: description, NameServers: nameServerList, @@ -54,26 +62,33 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco SearchDomainsEnabled: searchDomainEnabled, } - err = validateNameServerGroup(false, newNSGroup, account) + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil { + return err + } + + updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup) + }) if err != nil { return nil, err } - if account.NameServerGroups == nil { - account.NameServerGroups = make(map[string]*nbdns.NameServerGroup) - } + am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) - account.NameServerGroups[newNSGroup.ID] = newNSGroup - - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return nil, err - } - - if am.anyGroupHasPeers(account, newNSGroup.Groups) { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) return newNSGroup.Copy(), nil } @@ -87,58 +102,95 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - err = validateNameServerGroup(true, nsGroupToSave, account) + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID) + if err != nil { + return err + } + nsGroupToSave.AccountID = accountID + + if err = validateNameServerGroup(ctx, transaction, accountID, nsGroupToSave); err != nil { + return err + } + + updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave) + }) if err != nil { return err } - oldNSGroup := account.NameServerGroups[nsGroupToSave.ID] - account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave + am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { - return err - } - - if am.areNameServerGroupChangesAffectPeers(account, nsGroupToSave, oldNSGroup) { + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) return nil } // DeleteNameServerGroup deletes nameserver group with nsGroupID func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, accountID, nsGroupID, userID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return err } - nsGroup := account.NameServerGroups[nsGroupID] - if nsGroup == nil { - return status.Errorf(status.NotFound, "nameserver group %s wasn't found", nsGroupID) + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() } - delete(account.NameServerGroups, nsGroupID) - account.Network.IncSerial() - if err = am.Store.SaveAccount(ctx, account); err != nil { + var nsGroup *nbdns.NameServerGroup + var updateAccountPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID) + if err != nil { + return err + } + + updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID) + }) + if err != nil { return err } - if am.anyGroupHasPeers(account, nsGroup.Groups) { + am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) + + if updateAccountPeers { am.updateAccountPeers(ctx, accountID) } - am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) return nil } @@ -150,44 +202,62 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou return nil, err } - if !user.IsAdminOrServiceUser() || user.AccountID != accountID { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) } -func validateNameServerGroup(existingGroup bool, nameserverGroup *nbdns.NameServerGroup, account *Account) error { - nsGroupID := "" - if existingGroup { - nsGroupID = nameserverGroup.ID - _, found := account.NameServerGroups[nsGroupID] - if !found { - return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupID) - } - } - +func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) if err != nil { return err } - err = validateNSGroupName(nameserverGroup.Name, nsGroupID, account.NameServerGroups) - if err != nil { - return err - } - err = validateNSList(nameserverGroup.NameServers) if err != nil { return err } - err = validateGroups(nameserverGroup.Groups, account.Groups) + nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) if err != nil { return err } - return nil + err = validateNSGroupName(nameserverGroup.Name, nameserverGroup.ID, nsServerGroups) + if err != nil { + return err + } + + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups) + if err != nil { + return err + } + + return validateGroups(nameserverGroup.Groups, groups) +} + +// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. +func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { + if !newNSGroup.Enabled && !oldNSGroup.Enabled { + return false, nil + } + + hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + + return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) } func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { @@ -213,14 +283,14 @@ func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bo return nil } -func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.NameServerGroup) error { +func validateNSGroupName(name, nsGroupID string, groups []*nbdns.NameServerGroup) error { if utf8.RuneCountInString(name) > nbdns.MaxGroupNameChar || name == "" { return status.Errorf(status.InvalidArgument, "nameserver group name should be between 1 and %d", nbdns.MaxGroupNameChar) } - for _, nsGroup := range nsGroupMap { + for _, nsGroup := range groups { if name == nsGroup.Name && nsGroup.ID != nsGroupID { - return status.Errorf(status.InvalidArgument, "a nameserver group with name %s already exist", name) + return status.Errorf(status.InvalidArgument, "nameserver group with name %s already exist", name) } } @@ -228,8 +298,8 @@ func validateNSGroupName(name, nsGroupID string, nsGroupMap map[string]*nbdns.Na } func validateNSList(list []nbdns.NameServer) error { - nsListLenght := len(list) - if nsListLenght == 0 || nsListLenght > 3 { + nsListLength := len(list) + if nsListLength == 0 || nsListLength > 3 { return status.Errorf(status.InvalidArgument, "the list of nameservers should be 1 or 3, got %d", len(list)) } return nil @@ -244,14 +314,7 @@ func validateGroups(list []string, groups map[string]*nbgroup.Group) error { if id == "" { return status.Errorf(status.InvalidArgument, "group ID should not be empty string") } - found := false - for groupID := range groups { - if id == groupID { - found = true - break - } - } - if !found { + if _, found := groups[id]; !found { return status.Errorf(status.InvalidArgument, "group id %s not found", id) } } @@ -277,11 +340,3 @@ func validateDomain(domain string) error { return nil } - -// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func (am *DefaultAccountManager) areNameServerGroupChangesAffectPeers(account *Account, newNSGroup, oldNSGroup *nbdns.NameServerGroup) bool { - if !newNSGroup.Enabled && !oldNSGroup.Enabled { - return false - } - return am.anyGroupHasPeers(account, newNSGroup.Groups) || am.anyGroupHasPeers(account, oldNSGroup.Groups) -} diff --git a/management/server/sql_store.go b/management/server/sql_store.go index f58ceb1ad..1fd8ae2aa 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1498,12 +1498,55 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren // GetAccountNameServerGroups retrieves name server groups for an account. func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { - return getRecords[*nbdns.NameServerGroup](s.db, lockStrength, accountID) + var nsGroups []*nbdns.NameServerGroup + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get name server groups from store") + } + + return nsGroups, nil } // GetNameServerGroupByID retrieves a name server group by its ID and account ID. -func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) { - return getRecordByID[nbdns.NameServerGroup](s.db, lockStrength, nsGroupID, accountID) +func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { + var nsGroup *nbdns.NameServerGroup + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewNameServerGroupNotFoundError(nsGroupID) + } + log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get name server group from store") + } + + return nsGroup, nil +} + +// SaveNameServerGroup saves a name server group to the database. +func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err) + return status.Errorf(status.Internal, "failed to save name server group to store") + } + return nil +} + +// DeleteNameServerGroup deletes a name server group from the database. +func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nsGroupID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete name server group from store") + } + + if result.RowsAffected == 0 { + return status.NewNameServerGroupNotFoundError(nsGroupID) + } + + return nil } // getRecords retrieves records from the database based on the account ID. diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index df5294d73..6064b019f 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1920,3 +1920,134 @@ func TestSqlStore_SaveDNSSettings(t *testing.T) { require.NoError(t, err) require.Equal(t, saveDNSSettings, dnsSettings) } + +func TestSqlStore_GetAccountNameServerGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve name server groups by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountNameServerGroups(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } + +} + +func TestSqlStore_GetNameServerByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + nsGroupID string + expectError bool + }{ + { + name: "retrieve existing nameserver group", + nsGroupID: "csqdelq7qv97ncu7d9t0", + expectError: false, + }, + { + name: "retrieve non-existing nameserver group", + nsGroupID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty nameserver group ID", + nsGroupID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, tt.nsGroupID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, nsGroup) + } else { + require.NoError(t, err) + require.NotNil(t, nsGroup) + require.Equal(t, tt.nsGroupID, nsGroup.ID) + } + }) + } +} + +func TestSqlStore_SaveNameServerGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + nsGroup := &nbdns.NameServerGroup{ + ID: "ns-group-id", + AccountID: accountID, + Name: "NS Group", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: 1, + Port: 53, + }, + }, + Groups: []string{"groupA"}, + Primary: true, + Enabled: true, + SearchDomainsEnabled: false, + } + + err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nsGroup) + require.NoError(t, err) + + saveNSGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroup.ID) + require.NoError(t, err) + require.Equal(t, saveNSGroup, nsGroup) +} + +func TestSqlStore_DeleteNameServerGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + nsGroupID := "csqdelq7qv97ncu7d9t0" + + err = store.DeleteNameServerGroup(context.Background(), LockingStrengthShare, accountID, nsGroupID) + require.NoError(t, err) + + nsGroup, err := store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, nsGroupID) + require.Error(t, err) + require.Nil(t, nsGroup) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 0fff53559..59f436f5b 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -149,3 +149,8 @@ func NewPostureChecksNotFoundError(postureChecksID string) error { func NewPolicyNotFoundError(policyID string) error { return Errorf(NotFound, "policy: %s not found", policyID) } + +// NewNameServerGroupNotFoundError creates a new Error with NotFound type for a missing name server group +func NewNameServerGroupNotFoundError(nsGroupID string) error { + return Errorf(NotFound, "nameserver group: %s not found", nsGroupID) +} diff --git a/management/server/store.go b/management/server/store.go index cca014b52..b16ad8a1a 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -117,6 +117,8 @@ type Store interface { GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) + SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error + DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 37db27316..455111439 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -36,4 +36,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-3465 INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup2','api','[]',0,''); INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); +INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0); INSERT INTO installations VALUES(1,''); From 92036900337fce22e5afce7a86f2781369ea6fc8 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 26 Nov 2024 23:34:27 +0100 Subject: [PATCH 042/159] [client] Code cleaning in net pkg and fix exit node feature on Android(#2932) Code cleaning around the util/net package. The goal was to write a more understandable source code but modify nothing on the logic. Protect the WireGuard UDP listeners with marks. The implementation can support the VPN permission revocation events in thread safe way. It will be important if we start to support the running time route and DNS update features. - uniformize the file name convention: [struct_name] _ [functions] _ [os].go - code cleaning in net_linux.go - move env variables to env.go file --- client/iface/bind/control_android.go | 12 ++++ .../routemanager/systemops/systemops_linux.go | 2 +- go.mod | 2 +- go.sum | 4 +- util/net/conn.go | 31 ++++++++ util/net/dial.go | 58 +++++++++++++++ util/net/{dialer_ios.go => dial_ios.go} | 0 util/net/dialer_android.go | 25 ------- util/net/{dialer_nonios.go => dialer_dial.go} | 70 ------------------- util/net/dialer_init_android.go | 5 ++ .../{dialer_linux.go => dialer_init_linux.go} | 2 +- ...er_nonlinux.go => dialer_init_nonlinux.go} | 1 + util/net/env.go | 29 ++++++++ util/net/listen.go | 37 ++++++++++ util/net/{listener_ios.go => listen_ios.go} | 0 util/net/listener_android.go | 26 ------- util/net/listener_init_android.go | 6 ++ ...stener_linux.go => listener_init_linux.go} | 2 +- ..._nonlinux.go => listener_init_nonlinux.go} | 1 + ...{listener_nonios.go => listener_listen.go} | 25 ------- util/net/net.go | 12 ---- util/net/net_linux.go | 50 ++++++++----- util/net/protectsocket_android.go | 34 ++++++++- 23 files changed, 249 insertions(+), 185 deletions(-) create mode 100644 client/iface/bind/control_android.go create mode 100644 util/net/conn.go create mode 100644 util/net/dial.go rename util/net/{dialer_ios.go => dial_ios.go} (100%) delete mode 100644 util/net/dialer_android.go rename util/net/{dialer_nonios.go => dialer_dial.go} (63%) create mode 100644 util/net/dialer_init_android.go rename util/net/{dialer_linux.go => dialer_init_linux.go} (88%) rename util/net/{dialer_nonlinux.go => dialer_init_nonlinux.go} (58%) create mode 100644 util/net/env.go create mode 100644 util/net/listen.go rename util/net/{listener_ios.go => listen_ios.go} (100%) delete mode 100644 util/net/listener_android.go create mode 100644 util/net/listener_init_android.go rename util/net/{listener_linux.go => listener_init_linux.go} (89%) rename util/net/{listener_nonlinux.go => listener_init_nonlinux.go} (61%) rename util/net/{listener_nonios.go => listener_listen.go} (84%) diff --git a/client/iface/bind/control_android.go b/client/iface/bind/control_android.go new file mode 100644 index 000000000..b8a865e39 --- /dev/null +++ b/client/iface/bind/control_android.go @@ -0,0 +1,12 @@ +package bind + +import ( + wireguard "golang.zx2c4.com/wireguard/conn" + + nbnet "github.com/netbirdio/netbird/util/net" +) + +func init() { + // ControlFns is not thread safe and should only be modified during init. + *wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket) +} diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 1d629d6e9..455e3407e 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -55,7 +55,7 @@ type ruleParams struct { // isLegacy determines whether to use the legacy routing setup func isLegacy() bool { - return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || os.Getenv(nbnet.EnvSkipSocketMark) == "true" + return os.Getenv("NB_USE_LEGACY_ROUTING") == "true" || nbnet.CustomRoutingDisabled() || nbnet.SkipSocketMark() } // setIsLegacy sets the legacy routing setup diff --git a/go.mod b/go.mod index 0a16753ea..e8c655422 100644 --- a/go.mod +++ b/go.mod @@ -236,7 +236,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 -replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 +replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 diff --git a/go.sum b/go.sum index a4d7ea7f9..47975d4ea 100644 --- a/go.sum +++ b/go.sum @@ -527,8 +527,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= -github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73 h1:jayg97LH/jJlvpIHVxueTfa+tfQ+FY8fy2sIhCwkz0g= -github.com/netbirdio/wireguard-go v0.0.0-20241107152827-57d8513b5f73/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY= +github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= diff --git a/util/net/conn.go b/util/net/conn.go new file mode 100644 index 000000000..26693f841 --- /dev/null +++ b/util/net/conn.go @@ -0,0 +1,31 @@ +//go:build !ios + +package net + +import ( + "net" + + log "github.com/sirupsen/logrus" +) + +// Conn wraps a net.Conn to override the Close method +type Conn struct { + net.Conn + ID ConnectionID +} + +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection +func (c *Conn) Close() error { + err := c.Conn.Close() + + dialerCloseHooksMutex.RLock() + defer dialerCloseHooksMutex.RUnlock() + + for _, hook := range dialerCloseHooks { + if err := hook(c.ID, &c.Conn); err != nil { + log.Errorf("Error executing dialer close hook: %v", err) + } + } + + return err +} diff --git a/util/net/dial.go b/util/net/dial.go new file mode 100644 index 000000000..595311492 --- /dev/null +++ b/util/net/dial.go @@ -0,0 +1,58 @@ +//go:build !ios + +package net + +import ( + "fmt" + "net" + + log "github.com/sirupsen/logrus" +) + +func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { + if CustomRoutingDisabled() { + return net.DialUDP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) + } + + udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) + } + + return udpConn, nil +} + +func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { + if CustomRoutingDisabled() { + return net.DialTCP(network, laddr, raddr) + } + + dialer := NewDialer() + dialer.LocalAddr = laddr + + conn, err := dialer.Dial(network, raddr.String()) + if err != nil { + return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) + } + + tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) + if !ok { + if err := conn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) + } + + return tcpConn, nil +} diff --git a/util/net/dialer_ios.go b/util/net/dial_ios.go similarity index 100% rename from util/net/dialer_ios.go rename to util/net/dial_ios.go diff --git a/util/net/dialer_android.go b/util/net/dialer_android.go deleted file mode 100644 index 4cbded536..000000000 --- a/util/net/dialer_android.go +++ /dev/null @@ -1,25 +0,0 @@ -package net - -import ( - "syscall" - - log "github.com/sirupsen/logrus" -) - -func (d *Dialer) init() { - d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { - err := c.Control(func(fd uintptr) { - androidProtectSocketLock.Lock() - f := androidProtectSocket - androidProtectSocketLock.Unlock() - if f == nil { - return - } - ok := f(int32(fd)) - if !ok { - log.Errorf("failed to protect socket: %d", fd) - } - }) - return err - } -} diff --git a/util/net/dialer_nonios.go b/util/net/dialer_dial.go similarity index 63% rename from util/net/dialer_nonios.go rename to util/net/dialer_dial.go index 34004a368..1659b6220 100644 --- a/util/net/dialer_nonios.go +++ b/util/net/dialer_dial.go @@ -81,28 +81,6 @@ func (d *Dialer) Dial(network, address string) (net.Conn, error) { return d.DialContext(context.Background(), network, address) } -// Conn wraps a net.Conn to override the Close method -type Conn struct { - net.Conn - ID ConnectionID -} - -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -func (c *Conn) Close() error { - err := c.Conn.Close() - - dialerCloseHooksMutex.RLock() - defer dialerCloseHooksMutex.RUnlock() - - for _, hook := range dialerCloseHooks { - if err := hook(c.ID, &c.Conn); err != nil { - log.Errorf("Error executing dialer close hook: %v", err) - } - } - - return err -} - func callDialerHooks(ctx context.Context, connID ConnectionID, address string, resolver *net.Resolver) error { host, _, err := net.SplitHostPort(address) if err != nil { @@ -127,51 +105,3 @@ func callDialerHooks(ctx context.Context, connID ConnectionID, address string, r return result.ErrorOrNil() } - -func DialUDP(network string, laddr, raddr *net.UDPAddr) (*net.UDPConn, error) { - if CustomRoutingDisabled() { - return net.DialUDP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing UDP %s: %w", raddr.String(), err) - } - - udpConn, ok := conn.(*Conn).Conn.(*net.UDPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDP connection, got different type: %T", conn) - } - - return udpConn, nil -} - -func DialTCP(network string, laddr, raddr *net.TCPAddr) (*net.TCPConn, error) { - if CustomRoutingDisabled() { - return net.DialTCP(network, laddr, raddr) - } - - dialer := NewDialer() - dialer.LocalAddr = laddr - - conn, err := dialer.Dial(network, raddr.String()) - if err != nil { - return nil, fmt.Errorf("dialing TCP %s: %w", raddr.String(), err) - } - - tcpConn, ok := conn.(*Conn).Conn.(*net.TCPConn) - if !ok { - if err := conn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected TCP connection, got different type: %T", conn) - } - - return tcpConn, nil -} diff --git a/util/net/dialer_init_android.go b/util/net/dialer_init_android.go new file mode 100644 index 000000000..63b903348 --- /dev/null +++ b/util/net/dialer_init_android.go @@ -0,0 +1,5 @@ +package net + +func (d *Dialer) init() { + d.Dialer.Control = ControlProtectSocket +} diff --git a/util/net/dialer_linux.go b/util/net/dialer_init_linux.go similarity index 88% rename from util/net/dialer_linux.go rename to util/net/dialer_init_linux.go index aed5c59a3..d801e6080 100644 --- a/util/net/dialer_linux.go +++ b/util/net/dialer_init_linux.go @@ -7,6 +7,6 @@ import "syscall" // init configures the net.Dialer Control function to set the fwmark on the socket func (d *Dialer) init() { d.Dialer.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) + return setRawSocketMark(c) } } diff --git a/util/net/dialer_nonlinux.go b/util/net/dialer_init_nonlinux.go similarity index 58% rename from util/net/dialer_nonlinux.go rename to util/net/dialer_init_nonlinux.go index c838441bd..8c57ebbaa 100644 --- a/util/net/dialer_nonlinux.go +++ b/util/net/dialer_init_nonlinux.go @@ -3,4 +3,5 @@ package net func (d *Dialer) init() { + // implemented on Linux and Android only } diff --git a/util/net/env.go b/util/net/env.go new file mode 100644 index 000000000..099da39b7 --- /dev/null +++ b/util/net/env.go @@ -0,0 +1,29 @@ +package net + +import ( + "os" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/iface/netstack" +) + +const ( + envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" + envSkipSocketMark = "NB_SKIP_SOCKET_MARK" +) + +func CustomRoutingDisabled() bool { + if netstack.IsEnabled() { + return true + } + return os.Getenv(envDisableCustomRouting) == "true" +} + +func SkipSocketMark() bool { + if skipSocketMark := os.Getenv(envSkipSocketMark); skipSocketMark == "true" { + log.Infof("%s is set to true, skipping SO_MARK", envSkipSocketMark) + return true + } + return false +} diff --git a/util/net/listen.go b/util/net/listen.go new file mode 100644 index 000000000..3ae8a9435 --- /dev/null +++ b/util/net/listen.go @@ -0,0 +1,37 @@ +//go:build !ios + +package net + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/pion/transport/v3" + log "github.com/sirupsen/logrus" +) + +// ListenUDP listens on the network address and returns a transport.UDPConn +// which includes support for write and close hooks. +func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { + if CustomRoutingDisabled() { + return net.ListenUDP(network, laddr) + } + + conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) + if err != nil { + return nil, fmt.Errorf("listen UDP: %w", err) + } + + packetConn := conn.(*PacketConn) + udpConn, ok := packetConn.PacketConn.(*net.UDPConn) + if !ok { + if err := packetConn.Close(); err != nil { + log.Errorf("Failed to close connection: %v", err) + } + return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) + } + + return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil +} diff --git a/util/net/listener_ios.go b/util/net/listen_ios.go similarity index 100% rename from util/net/listener_ios.go rename to util/net/listen_ios.go diff --git a/util/net/listener_android.go b/util/net/listener_android.go deleted file mode 100644 index d4167ad53..000000000 --- a/util/net/listener_android.go +++ /dev/null @@ -1,26 +0,0 @@ -package net - -import ( - "syscall" - - log "github.com/sirupsen/logrus" -) - -// init configures the net.ListenerConfig Control function to set the fwmark on the socket -func (l *ListenerConfig) init() { - l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { - err := c.Control(func(fd uintptr) { - androidProtectSocketLock.Lock() - f := androidProtectSocket - androidProtectSocketLock.Unlock() - if f == nil { - return - } - ok := f(int32(fd)) - if !ok { - log.Errorf("failed to protect listener socket: %d", fd) - } - }) - return err - } -} diff --git a/util/net/listener_init_android.go b/util/net/listener_init_android.go new file mode 100644 index 000000000..f7bfa1dab --- /dev/null +++ b/util/net/listener_init_android.go @@ -0,0 +1,6 @@ +package net + +// init configures the net.ListenerConfig Control function to set the fwmark on the socket +func (l *ListenerConfig) init() { + l.ListenConfig.Control = ControlProtectSocket +} diff --git a/util/net/listener_linux.go b/util/net/listener_init_linux.go similarity index 89% rename from util/net/listener_linux.go rename to util/net/listener_init_linux.go index 8d332160a..e32d5d894 100644 --- a/util/net/listener_linux.go +++ b/util/net/listener_init_linux.go @@ -9,6 +9,6 @@ import ( // init configures the net.ListenerConfig Control function to set the fwmark on the socket func (l *ListenerConfig) init() { l.ListenConfig.Control = func(_, _ string, c syscall.RawConn) error { - return SetRawSocketMark(c) + return setRawSocketMark(c) } } diff --git a/util/net/listener_nonlinux.go b/util/net/listener_init_nonlinux.go similarity index 61% rename from util/net/listener_nonlinux.go rename to util/net/listener_init_nonlinux.go index 14a6be49d..80f6f7f1a 100644 --- a/util/net/listener_nonlinux.go +++ b/util/net/listener_init_nonlinux.go @@ -3,4 +3,5 @@ package net func (l *ListenerConfig) init() { + // implemented on Linux and Android only } diff --git a/util/net/listener_nonios.go b/util/net/listener_listen.go similarity index 84% rename from util/net/listener_nonios.go rename to util/net/listener_listen.go index ae4be3494..efffba40e 100644 --- a/util/net/listener_nonios.go +++ b/util/net/listener_listen.go @@ -8,7 +8,6 @@ import ( "net" "sync" - "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" ) @@ -146,27 +145,3 @@ func closeConn(id ConnectionID, conn net.PacketConn) error { return err } - -// ListenUDP listens on the network address and returns a transport.UDPConn -// which includes support for write and close hooks. -func ListenUDP(network string, laddr *net.UDPAddr) (transport.UDPConn, error) { - if CustomRoutingDisabled() { - return net.ListenUDP(network, laddr) - } - - conn, err := NewListener().ListenPacket(context.Background(), network, laddr.String()) - if err != nil { - return nil, fmt.Errorf("listen UDP: %w", err) - } - - packetConn := conn.(*PacketConn) - udpConn, ok := packetConn.PacketConn.(*net.UDPConn) - if !ok { - if err := packetConn.Close(); err != nil { - log.Errorf("Failed to close connection: %v", err) - } - return nil, fmt.Errorf("expected UDPConn, got different type: %T", udpConn) - } - - return &UDPConn{UDPConn: udpConn, ID: packetConn.ID, seenAddrs: &sync.Map{}}, nil -} diff --git a/util/net/net.go b/util/net/net.go index 5448eb85a..403aa87e7 100644 --- a/util/net/net.go +++ b/util/net/net.go @@ -2,9 +2,6 @@ package net import ( "net" - "os" - - "github.com/netbirdio/netbird/client/iface/netstack" "github.com/google/uuid" ) @@ -16,8 +13,6 @@ const ( PreroutingFwmarkRedirected = 0x1BD01 PreroutingFwmarkMasquerade = 0x1BD11 PreroutingFwmarkMasqueradeReturn = 0x1BD12 - - envDisableCustomRouting = "NB_DISABLE_CUSTOM_ROUTING" ) // ConnectionID provides a globally unique identifier for network connections. @@ -31,10 +26,3 @@ type RemoveHookFunc func(connID ConnectionID) error func GenerateConnID() ConnectionID { return ConnectionID(uuid.NewString()) } - -func CustomRoutingDisabled() bool { - if netstack.IsEnabled() { - return true - } - return os.Getenv(envDisableCustomRouting) == "true" -} diff --git a/util/net/net_linux.go b/util/net/net_linux.go index 98f49af8d..fc486ebd4 100644 --- a/util/net/net_linux.go +++ b/util/net/net_linux.go @@ -4,29 +4,42 @@ package net import ( "fmt" - "os" "syscall" log "github.com/sirupsen/logrus" ) -const EnvSkipSocketMark = "NB_SKIP_SOCKET_MARK" - // SetSocketMark sets the SO_MARK option on the given socket connection func SetSocketMark(conn syscall.Conn) error { + if isSocketMarkDisabled() { + return nil + } + sysconn, err := conn.SyscallConn() if err != nil { return fmt.Errorf("get raw conn: %w", err) } - return SetRawSocketMark(sysconn) + return setRawSocketMark(sysconn) } -func SetRawSocketMark(conn syscall.RawConn) error { +// SetSocketOpt sets the SO_MARK option on the given file descriptor +func SetSocketOpt(fd int) error { + if isSocketMarkDisabled() { + return nil + } + + return setSocketOptInt(fd) +} + +func setRawSocketMark(conn syscall.RawConn) error { var setErr error err := conn.Control(func(fd uintptr) { - setErr = SetSocketOpt(int(fd)) + if isSocketMarkDisabled() { + return + } + setErr = setSocketOptInt(int(fd)) }) if err != nil { return fmt.Errorf("control: %w", err) @@ -39,17 +52,18 @@ func SetRawSocketMark(conn syscall.RawConn) error { return nil } -func SetSocketOpt(fd int) error { - if CustomRoutingDisabled() { - log.Infof("Custom routing is disabled, skipping SO_MARK") - return nil - } - - // Check for the new environment variable - if skipSocketMark := os.Getenv(EnvSkipSocketMark); skipSocketMark == "true" { - log.Info("NB_SKIP_SOCKET_MARK is set to true, skipping SO_MARK") - return nil - } - +func setSocketOptInt(fd int) error { return syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_MARK, NetbirdFwmark) } + +func isSocketMarkDisabled() bool { + if CustomRoutingDisabled() { + log.Infof("Custom routing is disabled, skipping SO_MARK") + return true + } + + if SkipSocketMark() { + return true + } + return false +} diff --git a/util/net/protectsocket_android.go b/util/net/protectsocket_android.go index 64fb45aa4..febed8a1e 100644 --- a/util/net/protectsocket_android.go +++ b/util/net/protectsocket_android.go @@ -1,14 +1,42 @@ package net -import "sync" +import ( + "fmt" + "sync" + "syscall" +) var ( androidProtectSocketLock sync.Mutex androidProtectSocket func(fd int32) bool ) -func SetAndroidProtectSocketFn(f func(fd int32) bool) { +func SetAndroidProtectSocketFn(fn func(fd int32) bool) { androidProtectSocketLock.Lock() - androidProtectSocket = f + androidProtectSocket = fn androidProtectSocketLock.Unlock() } + +// ControlProtectSocket is a Control function that sets the fwmark on the socket +func ControlProtectSocket(_, _ string, c syscall.RawConn) error { + var aErr error + err := c.Control(func(fd uintptr) { + androidProtectSocketLock.Lock() + defer androidProtectSocketLock.Unlock() + + if androidProtectSocket == nil { + aErr = fmt.Errorf("socket protection function not set") + return + } + + if !androidProtectSocket(int32(fd)) { + aErr = fmt.Errorf("failed to protect socket via Android") + } + }) + + if err != nil { + return err + } + + return aErr +} From fde9f2ffdafda92802753afeef2ac029a932302b Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 28 Nov 2024 12:18:02 +0300 Subject: [PATCH 043/159] Add store locks and prevent fetching setup keys peers when retrieving user peers with empty userID Signed-off-by: bcmmbaga --- management/server/peer.go | 6 +++--- management/server/sql_store.go | 24 ++++++++++++++++-------- management/server/store.go | 6 +++--- 3 files changed, 22 insertions(+), 14 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index ff5bc23d5..586f6d919 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -558,21 +558,21 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) - err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID) + err = transaction.AddPeerToAllGroup(ctx, LockingStrengthUpdate, accountID, newPeer.ID) if err != nil { return fmt.Errorf("failed adding peer to All group: %w", err) } if len(groupsToAdd) > 0 { for _, g := range groupsToAdd { - err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g) + err = transaction.AddPeerToGroup(ctx, LockingStrengthUpdate, accountID, newPeer.ID, g) if err != nil { return err } } } - err = transaction.AddPeerToAccount(ctx, newPeer) + err = transaction.AddPeerToAccount(ctx, LockingStrengthUpdate, newPeer) if err != nil { return fmt.Errorf("failed to add peer to account: %w", err) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 84c7ab8a9..1280cc888 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1030,9 +1030,10 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string return nil } -func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { +func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { var group nbgroup.Group - result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&group, "account_id = ? AND name = ?", accountID, "All") if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.Errorf(status.NotFound, "group 'All' not found for account") @@ -1048,16 +1049,17 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer group.Peers = append(group.Peers, peerID) - if err := s.db.Save(&group).Error; err != nil { + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { return status.Errorf(status.Internal, "issue updating group 'All': %s", err) } return nil } -func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { +func (s *SqlStore) AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error { var group nbgroup.Group - result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountAndIDQueryCondition, accountId, groupID). + First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewGroupNotFoundError(groupID) @@ -1074,7 +1076,7 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId group.Peers = append(group.Peers, peerId) - if err := s.db.Save(&group).Error; err != nil { + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&group).Error; err != nil { return status.Errorf(status.Internal, "issue updating group: %s", err) } @@ -1096,6 +1098,12 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer + + // Exclude peers added via setup keys, as they are not user-specific and have an empty user_id. + if userID == "" { + return peers, nil + } + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Find(&peers, "account_id = ? AND user_id = ?", accountID, userID) if err := result.Error; err != nil { @@ -1106,8 +1114,8 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt return peers, nil } -func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - if err := s.db.Create(peer).Error; err != nil { +func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error { + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } diff --git a/management/server/store.go b/management/server/store.go index 5b48de378..852ca6911 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -95,9 +95,9 @@ type Store interface { DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) - AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error - AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error - AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error + AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error + AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error + AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) From a22d5041e3287182252ce7b6d8ab6dc76a796c01 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 28 Nov 2024 12:21:15 +0300 Subject: [PATCH 044/159] Add missing tests Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 313 +++++++++++++++++- .../server/testdata/store_policy_migrate.sql | 1 + 2 files changed, 310 insertions(+), 4 deletions(-) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index c19eb1117..6ee161c79 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1005,7 +1005,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { AccountID: existingAccountID, IP: net.IP{1, 1, 1, 1}, } - err = store.AddPeerToAccount(context.Background(), peer1) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) require.NoError(t, err) takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) @@ -1018,7 +1018,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { AccountID: existingAccountID, IP: net.IP{2, 2, 2, 2}, } - err = store.AddPeerToAccount(context.Background(), peer2) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) require.NoError(t, err) takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) @@ -1050,7 +1050,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { AccountID: existingAccountID, DNSLabel: "peer1.domain.test", } - err = store.AddPeerToAccount(context.Background(), peer1) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer1) require.NoError(t, err) labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) @@ -1062,7 +1062,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { AccountID: existingAccountID, DNSLabel: "peer2.domain.test", } - err = store.AddPeerToAccount(context.Background(), peer2) + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer2) require.NoError(t, err) labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) @@ -2045,3 +2045,308 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) { require.Error(t, err) require.Nil(t, nsGroup) } + +func TestSqlStore_AddPeerToGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerID := "cfefqs706sqkneg59g4g" + groupID := "cfefqs706sqkneg59g4h" + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 0, "group should have 0 peers") + + err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, groupID) + require.NoError(t, err, "failed to add peer to group") + + group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 1, "group should have 1 peers") + require.Contains(t, group.Peers, peerID) +} + +func TestSqlStore_AddPeerToAllGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + groupID := "cfefqs706sqkneg59g3g" + + peer := &nbpeer.Peer{ + ID: "peer1", + AccountID: accountID, + DNSLabel: "peer1.domain.test", + } + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 2, "group should have 2 peers") + require.NotContains(t, group.Peers, peer.ID) + + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) + require.NoError(t, err, "failed to add peer to account") + + err = store.AddPeerToAllGroup(context.Background(), LockingStrengthUpdate, accountID, peer.ID) + require.NoError(t, err, "failed to add peer to all group") + + group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err, "failed to get group") + require.Len(t, group.Peers, 3, "group should have peers") + require.Contains(t, group.Peers, peer.ID) +} + +func TestSqlStore_AddPeerToAccount(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + peer := &nbpeer.Peer{ + ID: "peer1", + AccountID: accountID, + Key: "key", + IP: net.IP{1, 1, 1, 1}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "hostname", + GoOS: "linux", + Kernel: "Linux", + Core: "21.04", + Platform: "x86_64", + OS: "Ubuntu", + WtVersion: "development", + UIVersion: "development", + }, + Name: "peer.test", + DNSLabel: "peer", + Status: &nbpeer.PeerStatus{ + LastSeen: time.Now().UTC(), + Connected: true, + LoginExpired: false, + RequiresApproval: false, + }, + SSHKey: "ssh-key", + SSHEnabled: false, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + LastLogin: time.Now().UTC(), + CreatedAt: time.Now().UTC(), + Ephemeral: true, + } + err = store.AddPeerToAccount(context.Background(), LockingStrengthUpdate, peer) + require.NoError(t, err, "failed to add peer to account") + + storedPeer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peer.ID) + require.NoError(t, err, "failed to get peer") + + assert.Equal(t, peer.ID, storedPeer.ID) + assert.Equal(t, peer.AccountID, storedPeer.AccountID) + assert.Equal(t, peer.Key, storedPeer.Key) + assert.Equal(t, peer.IP.String(), storedPeer.IP.String()) + assert.Equal(t, peer.Meta, storedPeer.Meta) + assert.Equal(t, peer.Name, storedPeer.Name) + assert.Equal(t, peer.DNSLabel, storedPeer.DNSLabel) + assert.Equal(t, peer.SSHKey, storedPeer.SSHKey) + assert.Equal(t, peer.SSHEnabled, storedPeer.SSHEnabled) + assert.Equal(t, peer.LoginExpirationEnabled, storedPeer.LoginExpirationEnabled) + assert.Equal(t, peer.InactivityExpirationEnabled, storedPeer.InactivityExpirationEnabled) + assert.WithinDurationf(t, peer.LastLogin, storedPeer.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal") + assert.WithinDurationf(t, peer.CreatedAt, storedPeer.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal") + assert.Equal(t, peer.Ephemeral, storedPeer.Ephemeral) + assert.Equal(t, peer.Status.Connected, storedPeer.Status.Connected) + assert.Equal(t, peer.Status.LoginExpired, storedPeer.Status.LoginExpired) + assert.Equal(t, peer.Status.RequiresApproval, storedPeer.Status.RequiresApproval) + assert.WithinDurationf(t, peer.Status.LastSeen, storedPeer.Status.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") +} + +func TestSqlStore_GetAccountPeers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "should retrieve peers for an existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 4, + }, + { + name: "should return no peers for a non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "should return no peers for an empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } + +} + +func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "should retrieve peers with expiration for an existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "should return no peers with expiration for a non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "should return no peers with expiration for a empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "should retrieve peers with inactivity for an existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + { + name: "should return no peers with inactivity for a non-existing account ID", + accountID: "nonexistent", + expectedCount: 0, + }, + { + name: "should return no peers with inactivity for an empty account ID", + accountID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetAccountPeersWithInactivity(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetAllEphemeralPeers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/storev1.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + peers, err := store.GetAllEphemeralPeers(context.Background(), LockingStrengthShare) + require.NoError(t, err) + require.Len(t, peers, 1) + require.True(t, peers[0].Ephemeral) +} + +func TestSqlStore_GetUserPeers(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + userID string + expectedCount int + }{ + { + name: "should retrieve peers for existing account ID and user ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "f4f6d672-63fb-11ec-90d6-0242ac120003", + expectedCount: 1, + }, + { + name: "should return no peers for non-existing account ID with existing user ID", + accountID: "nonexistent", + userID: "f4f6d672-63fb-11ec-90d6-0242ac120003", + expectedCount: 0, + }, + { + name: "should return no peers for non-existing user ID with existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "nonexistent_user", + expectedCount: 0, + }, + { + name: "should retrieve peers for another valid account ID and user ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "edafee4e-63fb-11ec-90d6-0242ac120003", + expectedCount: 2, + }, + { + name: "should return no peers for existing account ID with empty user ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + userID: "", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + peers, err := store.GetUserPeers(context.Background(), LockingStrengthShare, tt.accountID, tt.userID) + require.NoError(t, err) + require.Len(t, peers, tt.expectedCount) + }) + } +} + +func TestSqlStore_DeletePeer(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerID := "csrnkiq7qv9d8aitqd50" + + err = store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID) + require.NoError(t, err) + + peer, err := store.GetPeerByID(context.Background(), LockingStrengthShare, accountID, peerID) + require.Error(t, err) + require.Nil(t, peer) +} diff --git a/management/server/testdata/store_policy_migrate.sql b/management/server/testdata/store_policy_migrate.sql index a9360e9d6..15917f391 100644 --- a/management/server/testdata/store_policy_migrate.sql +++ b/management/server/testdata/store_policy_migrate.sql @@ -32,4 +32,5 @@ INSERT INTO peers VALUES('cfeg6sf06sqkneg59g50','bf1c8084-ba50-4ce7-9439-3465300 INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfefqs706sqkneg59g4g","cfeg6sf06sqkneg59g50"]',0,''); +INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4h','bf1c8084-ba50-4ce7-9439-34653001fc3b','groupA','api','',0,''); INSERT INTO installations VALUES(1,''); From cde0e51c720883fa308c050dabfd060557c992eb Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 28 Nov 2024 12:30:38 +0300 Subject: [PATCH 045/159] Refactor test names and remove duplicate TestPostgresql_SavePeerStatus Signed-off-by: bcmmbaga --- management/server/sql_store_test.go | 47 ++--------------------------- 1 file changed, 3 insertions(+), 44 deletions(-) diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 6ee161c79..eddd628bd 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -377,7 +377,7 @@ func TestSqlite_GetAccount(t *testing.T) { require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") } -func TestSqlite_SavePeer(t *testing.T) { +func TestSqlStore_SavePeer(t *testing.T) { store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -425,7 +425,7 @@ func TestSqlite_SavePeer(t *testing.T) { assert.WithinDurationf(t, updatedPeer.Status.LastSeen, actual.Status.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") } -func TestSqlite_SavePeerStatus(t *testing.T) { +func TestSqlStore_SavePeerStatus(t *testing.T) { store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -481,7 +481,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") } -func TestSqlite_SavePeerLocation(t *testing.T) { +func TestSqlStore_SavePeerLocation(t *testing.T) { store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -887,47 +887,6 @@ func TestPostgresql_DeleteAccount(t *testing.T) { } -func TestPostgresql_SavePeerStatus(t *testing.T) { - if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { - t.Skip("skip CI tests on darwin and windows") - } - - t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) - t.Cleanup(cleanUp) - assert.NoError(t, err) - - account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b") - require.NoError(t, err) - - // save status of non-existing peer - newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()} - err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "non-existing-peer", newStatus) - assert.Error(t, err) - - // save new status of existing peer - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}, - } - - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) - - err = store.SavePeerStatus(context.Background(), LockingStrengthUpdate, account.Id, "testpeer", newStatus) - require.NoError(t, err) - - account, err = store.GetAccount(context.Background(), account.Id) - require.NoError(t, err) - - actual := account.Peers["testpeer"].Status - assert.Equal(t, newStatus.Connected, actual.Connected) -} - func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" { t.Skip("skip CI tests on darwin and windows") From 00c3b671821ec9be52d412948b3fc48dae4b341c Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 28 Nov 2024 11:13:01 +0100 Subject: [PATCH 046/159] [management] refactor to use account object instead of separate db calls for peer update (#2957) --- management/server/peer.go | 8 ++-- management/server/peer_test.go | 33 ++++++++++------ management/server/posture_checks.go | 61 +++++++++++------------------ 3 files changed, 48 insertions(+), 54 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index dcb47af3b..d45bb1a50 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -617,7 +617,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, err } - postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID) + postureChecks, err := am.getPeerPostureChecks(account, newPeer.ID) if err != nil { return nil, nil, nil, err } @@ -707,7 +707,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err) } - postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + postureChecks, err = am.getPeerPostureChecks(account, peer.ID) if err != nil { return nil, nil, nil, err } @@ -885,7 +885,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + postureChecks, err = am.getPeerPostureChecks(account, peer.ID) if err != nil { return nil, nil, nil, err } @@ -1042,7 +1042,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account defer wg.Done() defer func() { <-semaphore }() - postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID) + postureChecks, err := am.getPeerPostureChecks(account, p.ID) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err) return diff --git a/management/server/peer_test.go b/management/server/peer_test.go index e410fa892..e5eaa20d6 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -833,19 +833,20 @@ func BenchmarkGetPeers(b *testing.B) { }) } } - func BenchmarkUpdateAccountPeers(b *testing.B) { benchCases := []struct { - name string - peers int - groups int + name string + peers int + groups int + minMsPerOp float64 + maxMsPerOp float64 }{ - {"Small", 50, 5}, - {"Medium", 500, 10}, - {"Large", 5000, 20}, - {"Small single", 50, 1}, - {"Medium single", 500, 1}, - {"Large 5", 5000, 5}, + {"Small", 50, 5, 90, 120}, + {"Medium", 500, 100, 110, 140}, + {"Large", 5000, 200, 800, 1300}, + {"Small single", 50, 10, 90, 120}, + {"Medium single", 500, 10, 110, 170}, + {"Large 5", 5000, 15, 1300, 1800}, } log.SetOutput(io.Discard) @@ -881,8 +882,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } duration := time.Since(start) - b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op") - b.ReportMetric(0, "ns/op") + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + if msPerOp < bc.minMsPerOp { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp) + } + + if msPerOp > bc.maxMsPerOp { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp) + } }) } } diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 59e726c41..0467efedb 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,14 +2,16 @@ package server import ( "context" + "errors" "fmt" "slices" + "github.com/rs/xid" + "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" - "github.com/rs/xid" - "golang.org/x/exp/maps" ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { @@ -149,38 +151,21 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI } // getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) { +func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) { peerPostureChecks := make(map[string]*posture.Checks) - err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) - if err != nil { - return err + if len(account.PostureChecks) == 0 { + return nil, nil + } + + for _, policy := range account.Policies { + if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { + continue } - if len(postureChecks) == 0 { - return nil + if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil { + return nil, err } - - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) - if err != nil { - return err - } - - for _, policy := range policies { - if !policy.Enabled { - continue - } - - if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil { - return err - } - } - - return nil - }) - if err != nil { - return nil, err } return maps.Values(peerPostureChecks), nil @@ -241,8 +226,8 @@ func validatePostureChecks(ctx context.Context, transaction Store, accountID str } // addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. -func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { - isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy) +func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { + isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) if err != nil { return err } @@ -252,9 +237,9 @@ func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, p } for _, sourcePostureCheckID := range policy.SourcePostureChecks { - postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID) - if err != nil { - return err + postureCheck := account.getPostureChecks(sourcePostureCheckID) + if postureCheck == nil { + return errors.New("failed to add policy posture checks: posture checks not found") } peerPostureChecks[sourcePostureCheckID] = postureCheck } @@ -263,16 +248,16 @@ func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, p } // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) { +func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue } for _, sourceGroup := range rule.Sources { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup) - if err != nil { - return false, fmt.Errorf("failed to check peer in policy source group: %w", err) + group := account.GetGroup(sourceGroup) + if group == nil { + return false, fmt.Errorf("failed to check peer in policy source group: group not found") } if slices.Contains(group.Peers, peerID) { From f87bc601c60562a7a0e05627c60d4320860f5b7a Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 28 Nov 2024 14:03:08 +0300 Subject: [PATCH 047/159] Add account locks and remove redundant ephemeral check Signed-off-by: bcmmbaga --- management/server/account.go | 8 ++++++-- management/server/ephemeral.go | 8 ++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index f42569542..d25fe5b79 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1244,10 +1244,11 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context. return nil } - - func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + expiredPeers, err := am.getExpiredPeers(ctx, accountID) if err != nil { return 0, false @@ -1279,6 +1280,9 @@ func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context // peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { return func() (time.Duration, bool) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + inactivePeers, err := am.getInactivePeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID) diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 6e245ec5a..111d5e3fc 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -127,15 +127,11 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { } t := newDeadLine() - count := 0 for _, p := range peers { - if p.Ephemeral { - count++ - e.addPeer(p.AccountID, p.ID, t) - } + e.addPeer(p.AccountID, p.ID, t) } - log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count) + log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers)) } func (e *EphemeralManager) cleanup(ctx context.Context) { From 1ba6eb62a607fe3da9bf09abd79fe2cd1e3ec142 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 28 Nov 2024 15:01:44 +0300 Subject: [PATCH 048/159] Retrieve all groups for peers and restrict groups for regular users Signed-off-by: bcmmbaga --- management/server/http/peers_handler.go | 47 ++++++++++++++----------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 235e744b3..4d0bdec2d 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -62,12 +62,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, accountID, peerID, userID st } dnsDomain := h.accountManager.GetDNSDomain() - peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) - if err != nil { - util.WriteError(ctx, err, w) - return - } - groupsInfo := toGroupsInfo(peerGroups) + groups, _ := h.accountManager.GetAllGroups(ctx, accountID, userID) + groupsInfo := toGroupsInfo(groups, peerID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -116,7 +112,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, accountID, userID, peerID util.WriteError(ctx, err, w) return } - groupMinimumInfo := toGroupsInfo(peerGroups) + groupMinimumInfo := toGroupsInfo(peerGroups, peer.ID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -187,6 +183,8 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() + groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + respBody := make([]*api.PeerBatch, 0, len(peers)) for _, peer := range peers { peerToReturn, err := h.checkPeerStatus(peer) @@ -195,12 +193,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), accountID, peer.ID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - groupMinimumInfo := toGroupsInfo(peerGroups) + groupMinimumInfo := toGroupsInfo(groups, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } @@ -312,14 +305,28 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toGroupsInfo(groups []*nbgroup.Group) []api.GroupMinimum { - groupsInfo := make([]api.GroupMinimum, 0, len(groups)) +func toGroupsInfo(groups []*nbgroup.Group, peerID string) []api.GroupMinimum { + groupsInfo := []api.GroupMinimum{} + groupsChecked := make(map[string]struct{}) + for _, group := range groups { - groupsInfo = append(groupsInfo, api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - }) + _, ok := groupsChecked[group.ID] + if ok { + continue + } + + groupsChecked[group.ID] = struct{}{} + for _, pk := range group.Peers { + if pk == peerID { + info := api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + } + groupsInfo = append(groupsInfo, info) + break + } + } } return groupsInfo } From d66140fc82af4a55c522df851876832c931807e9 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 28 Nov 2024 15:08:42 +0300 Subject: [PATCH 049/159] Fix merge Signed-off-by: bcmmbaga --- management/server/peer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/peer.go b/management/server/peer.go index 0547cbbbc..9360ce29f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -880,7 +880,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - postureChecks, err = am.getPeerPostureChecks(account, peer.ID) + postureChecks, err := am.getPeerPostureChecks(account, peer.ID) if err != nil { return nil, nil, nil, err } From 89cf8a55e2e11f0d75949523c064a8be558506c4 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 28 Nov 2024 14:59:53 +0100 Subject: [PATCH 050/159] [management] Add performance test for login and sync calls (#2960) --- management/server/account_test.go | 195 +++++++++++++++++++++++++++++- 1 file changed, 192 insertions(+), 3 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index c8c2d5941..dbabfd366 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -6,13 +6,17 @@ import ( b64 "encoding/base64" "encoding/json" "fmt" + "io" "net" + "os" "reflect" + "strconv" "sync" "testing" "time" "github.com/golang-jwt/jwt" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -1038,7 +1042,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { } b.Run("public without account ID", func(b *testing.B) { - //b.ResetTimer() + // b.ResetTimer() for i := 0; i < b.N; i++ { _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), publicClaims) if err != nil { @@ -1048,7 +1052,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { }) b.Run("private without account ID", func(b *testing.B) { - //b.ResetTimer() + // b.ResetTimer() for i := 0; i < b.N; i++ { _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims) if err != nil { @@ -1059,7 +1063,7 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { b.Run("private with account ID", func(b *testing.B) { claims.AccountId = id - //b.ResetTimer() + // b.ResetTimer() for i := 0; i < b.N; i++ { _, err := am.getAccountIDWithAuthorizationClaims(context.Background(), claims) if err != nil { @@ -2982,3 +2986,188 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) t.Error("Timed out waiting for update message") } } + +func BenchmarkSyncAndMarkPeer(b *testing.B) { + benchCases := []struct { + name string + peers int + groups int + minMsPerOp float64 + maxMsPerOp float64 + }{ + {"Small", 50, 5, 1, 3}, + {"Medium", 500, 100, 7, 13}, + {"Large", 5000, 200, 65, 80}, + {"Small single", 50, 10, 1, 3}, + {"Medium single", 500, 10, 7, 13}, + {"Large 5", 5000, 15, 65, 80}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + if err != nil { + b.Fatalf("Failed to setup test account manager: %v", err) + } + ctx := context.Background() + account, err := manager.Store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { + peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + } + manager.peersUpdateManager.peerChannels = peerChannels + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) + assert.NoError(b, err) + } + + duration := time.Since(start) + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + if msPerOp < bc.minMsPerOp { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp) + } + + if msPerOp > bc.maxMsPerOp { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp) + } + }) + } +} + +func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { + benchCases := []struct { + name string + peers int + groups int + minMsPerOp float64 + maxMsPerOp float64 + }{ + {"Small", 50, 5, 102, 110}, + {"Medium", 500, 100, 105, 140}, + {"Large", 5000, 200, 160, 200}, + {"Small single", 50, 10, 102, 110}, + {"Medium single", 500, 10, 105, 140}, + {"Large 5", 5000, 15, 160, 200}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + if err != nil { + b.Fatalf("Failed to setup test account manager: %v", err) + } + ctx := context.Background() + account, err := manager.Store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { + peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + } + manager.peersUpdateManager.peerChannels = peerChannels + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + _, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{ + WireGuardPubKey: account.Peers["peer-1"].Key, + SSHKey: "someKey", + Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, + UserID: "regular_user", + SetupKey: "", + ConnectionIP: net.IP{1, 1, 1, 1}, + }) + assert.NoError(b, err) + } + + duration := time.Since(start) + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + if msPerOp < bc.minMsPerOp { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp) + } + + if msPerOp > bc.maxMsPerOp { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp) + } + }) + } +} + +func BenchmarkLoginPeer_NewPeer(b *testing.B) { + benchCases := []struct { + name string + peers int + groups int + minMsPerOp float64 + maxMsPerOp float64 + }{ + {"Small", 50, 5, 107, 120}, + {"Medium", 500, 100, 105, 140}, + {"Large", 5000, 200, 180, 220}, + {"Small single", 50, 10, 107, 120}, + {"Medium single", 500, 10, 105, 140}, + {"Large 5", 5000, 15, 180, 220}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + for _, bc := range benchCases { + b.Run(bc.name, func(b *testing.B) { + manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups) + if err != nil { + b.Fatalf("Failed to setup test account manager: %v", err) + } + ctx := context.Background() + account, err := manager.Store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + peerChannels := make(map[string]chan *UpdateMessage) + for peerID := range account.Peers { + peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + } + manager.peersUpdateManager.peerChannels = peerChannels + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + _, _, _, err := manager.LoginPeer(context.Background(), PeerLogin{ + WireGuardPubKey: "some-new-key" + strconv.Itoa(i), + SSHKey: "someKey", + Meta: nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, + UserID: "regular_user", + SetupKey: "", + ConnectionIP: net.IP{1, 1, 1, 1}, + }) + assert.NoError(b, err) + } + + duration := time.Since(start) + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + if msPerOp < bc.minMsPerOp { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp) + } + + if msPerOp > bc.maxMsPerOp { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp) + } + }) + } +} From c6641be94b3c120fd49513fa19c4aeabe273a3f5 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 28 Nov 2024 19:22:01 +0100 Subject: [PATCH 051/159] [tests] Enable benchmark tests on github actions (#2961) --- .github/workflows/golang-test-linux.yml | 41 ++++++++ management/server/account_test.go | 120 +++++++++++++++--------- management/server/peer_test.go | 40 +++++--- 3 files changed, 141 insertions(+), 60 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index ef6672002..36dcb791f 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -52,6 +52,47 @@ jobs: - name: Test run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... + benchmark: + strategy: + fail-fast: false + matrix: + arch: [ '386','amd64' ] + store: [ 'sqlite', 'postgres' ] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Install modules + run: go mod tidy + + - name: check git status + run: git --no-pager diff --exit-code + + - name: Test + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... + test_client_on_docker: runs-on: ubuntu-20.04 steps: diff --git a/management/server/account_test.go b/management/server/account_test.go index dbabfd366..4ff812607 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2989,18 +2989,21 @@ func peerShouldReceiveUpdate(t *testing.T, updateMessage <-chan *UpdateMessage) func BenchmarkSyncAndMarkPeer(b *testing.B) { benchCases := []struct { - name string - peers int - groups int - minMsPerOp float64 - maxMsPerOp float64 + name string + peers int + groups int + // We need different expectations for CI/CD and local runs because of the different performance characteristics + minMsPerOpLocal float64 + maxMsPerOpLocal float64 + minMsPerOpCICD float64 + maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 1, 3}, - {"Medium", 500, 100, 7, 13}, - {"Large", 5000, 200, 65, 80}, - {"Small single", 50, 10, 1, 3}, - {"Medium single", 500, 10, 7, 13}, - {"Large 5", 5000, 15, 65, 80}, + {"Small", 50, 5, 1, 3, 4, 10}, + {"Medium", 500, 100, 7, 13, 10, 60}, + {"Large", 5000, 200, 65, 80, 60, 170}, + {"Small single", 50, 10, 1, 3, 4, 60}, + {"Medium single", 500, 10, 7, 13, 10, 26}, + {"Large 5", 5000, 15, 65, 80, 60, 170}, } log.SetOutput(io.Discard) @@ -3033,12 +3036,19 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - if msPerOp < bc.minMsPerOp { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp) + minExpected := bc.minMsPerOpLocal + maxExpected := bc.maxMsPerOpLocal + if os.Getenv("CI") == "true" { + minExpected = bc.minMsPerOpCICD + maxExpected = bc.maxMsPerOpCICD } - if msPerOp > bc.maxMsPerOp { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp) + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) + } + + if msPerOp > maxExpected { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } @@ -3046,18 +3056,21 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { benchCases := []struct { - name string - peers int - groups int - minMsPerOp float64 - maxMsPerOp float64 + name string + peers int + groups int + // We need different expectations for CI/CD and local runs because of the different performance characteristics + minMsPerOpLocal float64 + maxMsPerOpLocal float64 + minMsPerOpCICD float64 + maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 102, 110}, - {"Medium", 500, 100, 105, 140}, - {"Large", 5000, 200, 160, 200}, - {"Small single", 50, 10, 102, 110}, - {"Medium single", 500, 10, 105, 140}, - {"Large 5", 5000, 15, 160, 200}, + {"Small", 50, 5, 102, 110, 102, 120}, + {"Medium", 500, 100, 105, 140, 105, 170}, + {"Large", 5000, 200, 160, 200, 160, 270}, + {"Small single", 50, 10, 102, 110, 102, 120}, + {"Medium single", 500, 10, 105, 140, 105, 170}, + {"Large 5", 5000, 15, 160, 200, 160, 270}, } log.SetOutput(io.Discard) @@ -3097,12 +3110,19 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - if msPerOp < bc.minMsPerOp { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp) + minExpected := bc.minMsPerOpLocal + maxExpected := bc.maxMsPerOpLocal + if os.Getenv("CI") == "true" { + minExpected = bc.minMsPerOpCICD + maxExpected = bc.maxMsPerOpCICD } - if msPerOp > bc.maxMsPerOp { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp) + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) + } + + if msPerOp > maxExpected { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } @@ -3110,18 +3130,21 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { func BenchmarkLoginPeer_NewPeer(b *testing.B) { benchCases := []struct { - name string - peers int - groups int - minMsPerOp float64 - maxMsPerOp float64 + name string + peers int + groups int + // We need different expectations for CI/CD and local runs because of the different performance characteristics + minMsPerOpLocal float64 + maxMsPerOpLocal float64 + minMsPerOpCICD float64 + maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 107, 120}, - {"Medium", 500, 100, 105, 140}, - {"Large", 5000, 200, 180, 220}, - {"Small single", 50, 10, 107, 120}, - {"Medium single", 500, 10, 105, 140}, - {"Large 5", 5000, 15, 180, 220}, + {"Small", 50, 5, 107, 120, 107, 140}, + {"Medium", 500, 100, 105, 140, 105, 170}, + {"Large", 5000, 200, 180, 220, 180, 320}, + {"Small single", 50, 10, 107, 120, 105, 140}, + {"Medium single", 500, 10, 105, 140, 105, 170}, + {"Large 5", 5000, 15, 180, 220, 180, 320}, } log.SetOutput(io.Discard) @@ -3161,12 +3184,19 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - if msPerOp < bc.minMsPerOp { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp) + minExpected := bc.minMsPerOpLocal + maxExpected := bc.maxMsPerOpLocal + if os.Getenv("CI") == "true" { + minExpected = bc.minMsPerOpCICD + maxExpected = bc.maxMsPerOpCICD } - if msPerOp > bc.maxMsPerOp { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp) + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) + } + + if msPerOp > maxExpected { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index e5eaa20d6..b15315f98 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -835,18 +835,21 @@ func BenchmarkGetPeers(b *testing.B) { } func BenchmarkUpdateAccountPeers(b *testing.B) { benchCases := []struct { - name string - peers int - groups int - minMsPerOp float64 - maxMsPerOp float64 + name string + peers int + groups int + // We need different expectations for CI/CD and local runs because of the different performance characteristics + minMsPerOpLocal float64 + maxMsPerOpLocal float64 + minMsPerOpCICD float64 + maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 90, 120}, - {"Medium", 500, 100, 110, 140}, - {"Large", 5000, 200, 800, 1300}, - {"Small single", 50, 10, 90, 120}, - {"Medium single", 500, 10, 110, 170}, - {"Large 5", 5000, 15, 1300, 1800}, + {"Small", 50, 5, 90, 120, 90, 120}, + {"Medium", 500, 100, 110, 140, 120, 200}, + {"Large", 5000, 200, 800, 1300, 2500, 3600}, + {"Small single", 50, 10, 90, 120, 90, 120}, + {"Medium single", 500, 10, 110, 170, 120, 200}, + {"Large 5", 5000, 15, 1300, 1800, 5000, 6000}, } log.SetOutput(io.Discard) @@ -885,12 +888,19 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 b.ReportMetric(msPerOp, "ms/op") - if msPerOp < bc.minMsPerOp { - b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp) + minExpected := bc.minMsPerOpLocal + maxExpected := bc.maxMsPerOpLocal + if os.Getenv("CI") == "true" { + minExpected = bc.minMsPerOpCICD + maxExpected = bc.maxMsPerOpCICD } - if msPerOp > bc.maxMsPerOp { - b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp) + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) + } + + if msPerOp > maxExpected { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) } From 8efad1d170d0c956652cb86fa9a6923a5e1e0f5e Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 29 Nov 2024 10:06:40 +0100 Subject: [PATCH 052/159] Add guide when signing key is not found (#2942) Some users face issues with their IdP due to signing key not being refreshed With this change we advise users to configure key refresh Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * removing leftover --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- management/server/jwtclaims/jwtValidator.go | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index d5c1e7c9e..b91616fa5 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -77,6 +77,8 @@ type JWTValidator struct { options Options } +var keyNotFound = errors.New("unable to find appropriate key") + // NewJWTValidator constructor func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) { keys, err := getPemKeys(ctx, keysLocation) @@ -124,12 +126,18 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, } publicKey, err := getPublicKey(ctx, token, keys) - if err != nil { - log.WithContext(ctx).Errorf("getPublicKey error: %s", err) - return nil, err + if err == nil { + return publicKey, nil } - return publicKey, nil + msg := fmt.Sprintf("getPublicKey error: %s", err) + if errors.Is(err, keyNotFound) && !idpSignkeyRefreshEnabled { + msg = fmt.Sprintf("getPublicKey error: %s. You can enable key refresh by setting HttpServerConfig.IdpSignKeyRefreshEnabled to true in your management.json file and restart the service", err) + } + + log.WithContext(ctx).Error(msg) + + return nil, err }, EnableAuthOnOptions: false, } @@ -229,7 +237,7 @@ func getPublicKey(ctx context.Context, token *jwt.Token, jwks *Jwks) (interface{ log.WithContext(ctx).Debugf("Key Type: %s not yet supported, please raise ticket!", jwks.Keys[k].Kty) } - return nil, errors.New("unable to find appropriate key") + return nil, keyNotFound } func getPublicKeyFromECDSA(jwk JSONWebKey) (publicKey *ecdsa.PublicKey, err error) { @@ -310,4 +318,3 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int { return 0 } - From f9723c9266e17aeafcc7d05d6b0e63f605bbea2f Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 29 Nov 2024 17:50:35 +0100 Subject: [PATCH 053/159] [client] Account different policiy rules for routes firewall rules (#2939) * Account different policies rules for routes firewall rules This change ensures that route firewall rules will consider source group peers in the rules generation for access control policies. This fixes the behavior where multiple policies with different levels of access was being applied to all peers in a distribution group * split function * avoid unnecessary allocation Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com> --------- Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com> --- management/server/route.go | 85 ++++++++++++++++---- management/server/route_test.go | 132 +++++++++++++++++++++++++++++++- 2 files changed, 199 insertions(+), 18 deletions(-) diff --git a/management/server/route.go b/management/server/route.go index ecb562645..23bea87e3 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -417,27 +417,84 @@ func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, continue } - policies := getAllRoutePoliciesFromGroups(a, route.AccessControlGroups) - for _, policy := range policies { - if !policy.Enabled { - continue - } + distributionPeers := a.getDistributionGroupsPeers(route) - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - distributionGroupPeers, _ := a.getAllPeersFromGroups(ctx, route.Groups, peerID, nil, validatedPeersMap) - rules := generateRouteFirewallRules(ctx, route, rule, distributionGroupPeers, firewallRuleDirectionIN) - routesFirewallRules = append(routesFirewallRules, rules...) - } + for _, accessGroup := range route.AccessControlGroups { + policies := getAllRoutePoliciesFromGroups(a, []string{accessGroup}) + rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) } } return routesFirewallRules } +func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, firewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + for _, id := range rule.Sources { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := validatedPeersMap[pID] + if distPeer && valid { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peer := a.Peers[pID] + if peer == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peer) + } + return distributionGroupPeers +} + +func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range route.Groups { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + func getDefaultPermit(route *route.Route) []*RouteFirewallRule { var rules []*RouteFirewallRule diff --git a/management/server/route_test.go b/management/server/route_test.go index 108f791e0..8bf9a3aeb 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/netip" + "sort" "testing" "time" @@ -1486,6 +1487,8 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { peerBIp = "100.65.80.39" peerCIp = "100.65.254.139" peerHIp = "100.65.29.55" + peerJIp = "100.65.29.65" + peerKIp = "100.65.29.66" ) account := &Account{ @@ -1541,6 +1544,16 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { IP: net.ParseIP(peerHIp), Status: &nbpeer.PeerStatus{}, }, + "peerJ": { + ID: "peerJ", + IP: net.ParseIP(peerJIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerK": { + ID: "peerK", + IP: net.ParseIP(peerKIp), + Status: &nbpeer.PeerStatus{}, + }, }, Groups: map[string]*nbgroup.Group{ "routingPeer1": { @@ -1567,6 +1580,11 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Name: "Route2", Peers: []string{}, }, + "route4": { + ID: "route4", + Name: "route4", + Peers: []string{}, + }, "finance": { ID: "finance", Name: "Finance", @@ -1584,6 +1602,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { "peerB", }, }, + "qa": { + ID: "qa", + Name: "QA", + Peers: []string{ + "peerJ", + "peerK", + }, + }, + "restrictQA": { + ID: "restrictQA", + Name: "restrictQA", + Peers: []string{ + "peerJ", + }, + }, + "unrestrictedQA": { + ID: "unrestrictedQA", + Name: "unrestrictedQA", + Peers: []string{ + "peerK", + }, + }, "contractors": { ID: "contractors", Name: "Contractors", @@ -1631,6 +1671,19 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Groups: []string{"contractors"}, AccessControlGroups: []string{}, }, + "route4": { + ID: "route4", + Network: netip.MustParsePrefix("192.168.10.0/16"), + NetID: "route4", + NetworkType: route.IPv4Network, + PeerGroups: []string{"routingPeer1"}, + Description: "Route4", + Masquerade: false, + Metric: 9999, + Enabled: true, + Groups: []string{"qa"}, + AccessControlGroups: []string{"route4"}, + }, }, Policies: []*Policy{ { @@ -1685,6 +1738,49 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, }, }, + { + ID: "RuleRoute4", + Name: "RuleRoute4", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute4", + Name: "RuleRoute4", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolTCP, + Action: PolicyTrafficActionAccept, + Ports: []string{"80"}, + Sources: []string{ + "restrictQA", + }, + Destinations: []string{ + "route4", + }, + }, + }, + }, + { + ID: "RuleRoute5", + Name: "RuleRoute5", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "RuleRoute5", + Name: "RuleRoute5", + Bidirectional: true, + Enabled: true, + Protocol: PolicyRuleProtocolALL, + Action: PolicyTrafficActionAccept, + Sources: []string{ + "unrestrictedQA", + }, + Destinations: []string{ + "route4", + }, + }, + }, + }, }, } @@ -1709,7 +1805,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { t.Run("check peer routes firewall rules", func(t *testing.T) { routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) - assert.Len(t, routesFirewallRules, 2) + assert.Len(t, routesFirewallRules, 4) expectedRoutesFirewallRules := []*RouteFirewallRule{ { @@ -1735,12 +1831,32 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Port: 320, }, } - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + additionalFirewallRule := []*RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerJIp), + }, + Action: "accept", + Destination: "192.168.10.0/16", + Protocol: "tcp", + Port: 80, + }, + { + SourceRanges: []string{ + fmt.Sprintf(AllowedIPsFormat, peerKIp), + }, + Action: "accept", + Destination: "192.168.10.0/16", + Protocol: "all", + }, + } + + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...))) // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) assert.Len(t, routesFirewallRules, 2) - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerE is a single routing peer for route 2 and route 3 routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) @@ -1769,7 +1885,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { IsDynamic: true, }, } - assert.ElementsMatch(t, routesFirewallRules, expectedRoutesFirewallRules) + assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerC is part of route1 distribution groups but should not receive the routes firewall rules routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) @@ -1778,6 +1894,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { } +// orderList is a helper function to sort a list of strings +func orderRuleSourceRanges(ruleList []*RouteFirewallRule) []*RouteFirewallRule { + for _, rule := range ruleList { + sort.Strings(rule.SourceRanges) + } + return ruleList +} + func TestRouteAccountPeersUpdate(t *testing.T) { manager, err := createRouterManager(t) require.NoError(t, err, "failed to create account manager") From e52d352a48aac55caada063ef44bd7e79e173bdb Mon Sep 17 00:00:00 2001 From: v1rusnl <18641204+v1rusnl@users.noreply.github.com> Date: Sat, 30 Nov 2024 10:26:31 +0100 Subject: [PATCH 054/159] Update Caddyfile and Docker Compose to support HTTP3 (#2822) --- infrastructure_files/getting-started-with-zitadel.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 0b2b65142..7793d1fda 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -530,7 +530,7 @@ renderCaddyfile() { { debug servers :80,:443 { - protocols h1 h2c + protocols h1 h2c h3 } } @@ -788,6 +788,7 @@ services: networks: [ netbird ] ports: - '443:443' + - '443:443/udp' - '80:80' - '8080:8080' volumes: From e4a5fb3e91bfbe202d3f13f66599f6ebfd9fe77e Mon Sep 17 00:00:00 2001 From: victorserbu2709 Date: Sat, 30 Nov 2024 11:34:52 +0200 Subject: [PATCH 055/159] Unspecified address: default NetworkTypeUDP4+NetworkTypeUDP6 (#2804) --- client/iface/bind/udp_mux.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/client/iface/bind/udp_mux.go b/client/iface/bind/udp_mux.go index 12f7a8129..00a91f0ec 100644 --- a/client/iface/bind/udp_mux.go +++ b/client/iface/bind/udp_mux.go @@ -162,12 +162,13 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault { params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead") var networks []ice.NetworkType switch { - case addr.IP.To4() != nil: - networks = []ice.NetworkType{ice.NetworkTypeUDP4} case addr.IP.To16() != nil: networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6} + case addr.IP.To4() != nil: + networks = []ice.NetworkType{ice.NetworkTypeUDP4} + default: params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr()) } From ecb44ff3065156ba0d548b4c4fe04c927b0d947e Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Sun, 1 Dec 2024 19:22:52 +0100 Subject: [PATCH 056/159] [client] Add pprof build tag (#2964) * Add pprof build tag * Change env handling --- client/cmd/pprof.go | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 client/cmd/pprof.go diff --git a/client/cmd/pprof.go b/client/cmd/pprof.go new file mode 100644 index 000000000..37efd35f0 --- /dev/null +++ b/client/cmd/pprof.go @@ -0,0 +1,33 @@ +//go:build pprof +// +build pprof + +package cmd + +import ( + "net/http" + _ "net/http/pprof" + "os" + + log "github.com/sirupsen/logrus" +) + +func init() { + addr := pprofAddr() + go pprof(addr) +} + +func pprofAddr() string { + listenAddr := os.Getenv("NB_PPROF_ADDR") + if listenAddr == "" { + return "localhost:6969" + } + + return listenAddr +} + +func pprof(listenAddr string) { + log.Infof("listening pprof on: %s\n", listenAddr) + if err := http.ListenAndServe(listenAddr, nil); err != nil { + log.Fatalf("Failed to start pprof: %v", err) + } +} From 5142dc52c11ad9841fe8ae229ed585d27f976503 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 2 Dec 2024 17:55:02 +0100 Subject: [PATCH 057/159] [client] Persist route selection (#2810) --- client/firewall/iptables/rulestore_linux.go | 10 ++ client/internal/engine.go | 13 +- client/internal/engine_test.go | 5 +- client/internal/routemanager/manager.go | 38 ++++- client/internal/routemanager/manager_test.go | 4 +- client/internal/routemanager/mock.go | 2 +- .../routemanager/refcounter/refcounter.go | 13 ++ client/internal/routemanager/state.go | 19 +++ .../internal/routeselector/routeselector.go | 67 +++++++++ client/internal/statemanager/manager.go | 142 ++++++++++++++---- 10 files changed, 273 insertions(+), 40 deletions(-) create mode 100644 client/internal/routemanager/state.go diff --git a/client/firewall/iptables/rulestore_linux.go b/client/firewall/iptables/rulestore_linux.go index bfd08bee2..004c512a4 100644 --- a/client/firewall/iptables/rulestore_linux.go +++ b/client/firewall/iptables/rulestore_linux.go @@ -37,6 +37,11 @@ func (s *ipList) UnmarshalJSON(data []byte) error { return err } s.ips = temp.IPs + + if temp.IPs == nil { + temp.IPs = make(map[string]struct{}) + } + return nil } @@ -89,5 +94,10 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error { return err } s.ipsets = temp.IPSets + + if temp.IPSets == nil { + temp.IPSets = make(map[string]*ipList) + } + return nil } diff --git a/client/internal/engine.go b/client/internal/engine.go index dc4499e17..920c295cd 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -349,8 +349,17 @@ func (e *Engine) Start() error { } e.dnsServer = dnsServer - e.routeManager = routemanager.NewManager(e.ctx, e.config.WgPrivateKey.PublicKey().String(), e.config.DNSRouteInterval, e.wgInterface, e.statusRecorder, e.relayManager, initialRoutes) - beforePeerHook, afterPeerHook, err := e.routeManager.Init(e.stateManager) + e.routeManager = routemanager.NewManager( + e.ctx, + e.config.WgPrivateKey.PublicKey().String(), + e.config.DNSRouteInterval, + e.wgInterface, + e.statusRecorder, + e.relayManager, + initialRoutes, + e.stateManager, + ) + beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err != nil { log.Errorf("Failed to initialize route manager: %s", err) } else { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index b6c6186ea..b58c1f7e9 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -245,12 +245,15 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { nil) wgIface := &iface.MockWGIface{ + NameFunc: func() string { return "utun102" }, RemovePeerFunc: func(peerKey string) error { return nil }, } engine.wgInterface = wgIface - engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil) + engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil) + _, _, err = engine.routeManager.Init() + require.NoError(t, err) engine.dnsServer = &dns.MockServer{ UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil }, } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 0a1c7dc56..f1c4ae5ef 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -32,7 +32,7 @@ import ( // Manager is a route manager interface type Manager interface { - Init(*statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) + Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector @@ -59,6 +59,7 @@ type DefaultManager struct { routeRefCounter *refcounter.RouteRefCounter allowedIPsRefCounter *refcounter.AllowedIPsRefCounter dnsRouteInterval time.Duration + stateManager *statemanager.Manager } func NewManager( @@ -69,6 +70,7 @@ func NewManager( statusRecorder *peer.Status, relayMgr *relayClient.Manager, initialRoutes []*route.Route, + stateManager *statemanager.Manager, ) *DefaultManager { mCTX, cancel := context.WithCancel(ctx) notifier := notifier.NewNotifier() @@ -80,12 +82,12 @@ func NewManager( dnsRouteInterval: dnsRouteInterval, clientNetworks: make(map[route.HAUniqueID]*clientNetwork), relayMgr: relayMgr, - routeSelector: routeselector.NewRouteSelector(), sysOps: sysOps, statusRecorder: statusRecorder, wgInterface: wgInterface, pubKey: pubKey, notifier: notifier, + stateManager: stateManager, } dm.routeRefCounter = refcounter.New( @@ -121,7 +123,7 @@ func NewManager( } // Init sets up the routing -func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { +func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { if nbnet.CustomRoutingDisabled() { return nil, nil, nil } @@ -137,14 +139,38 @@ func (m *DefaultManager) Init(stateManager *statemanager.Manager) (nbnet.AddHook ips := resolveURLsToIPs(initialAddresses) - beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, stateManager) + beforePeerHook, afterPeerHook, err := m.sysOps.SetupRouting(ips, m.stateManager) if err != nil { return nil, nil, fmt.Errorf("setup routing: %w", err) } + + m.routeSelector = m.initSelector() + log.Info("Routing setup complete") return beforePeerHook, afterPeerHook, nil } +func (m *DefaultManager) initSelector() *routeselector.RouteSelector { + var state *SelectorState + m.stateManager.RegisterState(state) + + // restore selector state if it exists + if err := m.stateManager.LoadState(state); err != nil { + log.Warnf("failed to load state: %v", err) + return routeselector.NewRouteSelector() + } + + if state := m.stateManager.GetState(state); state != nil { + if selector, ok := state.(*SelectorState); ok { + return (*routeselector.RouteSelector)(selector) + } + + log.Warnf("failed to convert state with type %T to SelectorState", state) + } + + return routeselector.NewRouteSelector() +} + func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { var err error m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) @@ -252,6 +278,10 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { go clientNetworkWatcher.peersStateAndUpdateWatcher() clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) } + + if err := m.stateManager.UpdateState((*SelectorState)(m.routeSelector)); err != nil { + log.Errorf("failed to update state: %v", err) + } } // stopObsoleteClients stops the client network watcher for the networks that are not in the new list diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index e669bc44a..07dac21b8 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -424,9 +424,9 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() - routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil) + routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil) - _, _, err = routeManager.Init(nil) + _, _, err = routeManager.Init() require.NoError(t, err, "should init route manager") defer routeManager.Stop(nil) diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 503185f03..556a62351 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -21,7 +21,7 @@ type MockManager struct { StopFunc func(manager *statemanager.Manager) } -func (m *MockManager) Init(*statemanager.Manager) (net.AddHookFunc, net.RemoveHookFunc, error) { +func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { return nil, nil, nil } diff --git a/client/internal/routemanager/refcounter/refcounter.go b/client/internal/routemanager/refcounter/refcounter.go index f2f0a169d..27a724f50 100644 --- a/client/internal/routemanager/refcounter/refcounter.go +++ b/client/internal/routemanager/refcounter/refcounter.go @@ -71,11 +71,14 @@ func New[Key comparable, I, O any](add AddFunc[Key, I, O], remove RemoveFunc[Key } // LoadData loads the data from the existing counter +// The passed counter should not be used any longer after calling this function. func (rm *Counter[Key, I, O]) LoadData( existingCounter *Counter[Key, I, O], ) { rm.mu.Lock() defer rm.mu.Unlock() + existingCounter.mu.Lock() + defer existingCounter.mu.Unlock() rm.refCountMap = existingCounter.refCountMap rm.idMap = existingCounter.idMap @@ -231,6 +234,9 @@ func (rm *Counter[Key, I, O]) MarshalJSON() ([]byte, error) { // UnmarshalJSON implements the json.Unmarshaler interface for Counter. func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { + rm.mu.Lock() + defer rm.mu.Unlock() + var temp struct { RefCountMap map[Key]Ref[O] `json:"refCountMap"` IDMap map[string][]Key `json:"idMap"` @@ -241,6 +247,13 @@ func (rm *Counter[Key, I, O]) UnmarshalJSON(data []byte) error { rm.refCountMap = temp.RefCountMap rm.idMap = temp.IDMap + if temp.RefCountMap == nil { + temp.RefCountMap = map[Key]Ref[O]{} + } + if temp.IDMap == nil { + temp.IDMap = map[string][]Key{} + } + return nil } diff --git a/client/internal/routemanager/state.go b/client/internal/routemanager/state.go new file mode 100644 index 000000000..a45c32b50 --- /dev/null +++ b/client/internal/routemanager/state.go @@ -0,0 +1,19 @@ +package routemanager + +import ( + "github.com/netbirdio/netbird/client/internal/routeselector" +) + +type SelectorState routeselector.RouteSelector + +func (s *SelectorState) Name() string { + return "routeselector_state" +} + +func (s *SelectorState) MarshalJSON() ([]byte, error) { + return (*routeselector.RouteSelector)(s).MarshalJSON() +} + +func (s *SelectorState) UnmarshalJSON(data []byte) error { + return (*routeselector.RouteSelector)(s).UnmarshalJSON(data) +} diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index 00128a27b..2874604fd 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -1,8 +1,10 @@ package routeselector import ( + "encoding/json" "fmt" "slices" + "sync" "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" @@ -12,6 +14,7 @@ import ( ) type RouteSelector struct { + mu sync.RWMutex selectedRoutes map[route.NetID]struct{} selectAll bool } @@ -26,6 +29,9 @@ func NewRouteSelector() *RouteSelector { // SelectRoutes updates the selected routes based on the provided route IDs. func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, allRoutes []route.NetID) error { + rs.mu.Lock() + defer rs.mu.Unlock() + if !appendRoute { rs.selectedRoutes = map[route.NetID]struct{}{} } @@ -46,6 +52,9 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al // SelectAllRoutes sets the selector to select all routes. func (rs *RouteSelector) SelectAllRoutes() { + rs.mu.Lock() + defer rs.mu.Unlock() + rs.selectAll = true rs.selectedRoutes = map[route.NetID]struct{}{} } @@ -53,6 +62,9 @@ func (rs *RouteSelector) SelectAllRoutes() { // DeselectRoutes removes specific routes from the selection. // If the selector is in "select all" mode, it will transition to "select specific" mode. func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route.NetID) error { + rs.mu.Lock() + defer rs.mu.Unlock() + if rs.selectAll { rs.selectAll = false rs.selectedRoutes = map[route.NetID]struct{}{} @@ -76,12 +88,18 @@ func (rs *RouteSelector) DeselectRoutes(routes []route.NetID, allRoutes []route. // DeselectAllRoutes deselects all routes, effectively disabling route selection. func (rs *RouteSelector) DeselectAllRoutes() { + rs.mu.Lock() + defer rs.mu.Unlock() + rs.selectAll = false rs.selectedRoutes = map[route.NetID]struct{}{} } // IsSelected checks if a specific route is selected. func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { + rs.mu.RLock() + defer rs.mu.RUnlock() + if rs.selectAll { return true } @@ -91,6 +109,9 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { // FilterSelected removes unselected routes from the provided map. func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { + rs.mu.RLock() + defer rs.mu.RUnlock() + if rs.selectAll { return maps.Clone(routes) } @@ -103,3 +124,49 @@ func (rs *RouteSelector) FilterSelected(routes route.HAMap) route.HAMap { } return filtered } + +// MarshalJSON implements the json.Marshaler interface +func (rs *RouteSelector) MarshalJSON() ([]byte, error) { + rs.mu.RLock() + defer rs.mu.RUnlock() + + return json.Marshal(struct { + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` + SelectAll bool `json:"select_all"` + }{ + SelectAll: rs.selectAll, + SelectedRoutes: rs.selectedRoutes, + }) +} + +// UnmarshalJSON implements the json.Unmarshaler interface +// If the JSON is empty or null, it will initialize like a NewRouteSelector. +func (rs *RouteSelector) UnmarshalJSON(data []byte) error { + rs.mu.Lock() + defer rs.mu.Unlock() + + // Check for null or empty JSON + if len(data) == 0 || string(data) == "null" { + rs.selectedRoutes = map[route.NetID]struct{}{} + rs.selectAll = true + return nil + } + + var temp struct { + SelectedRoutes map[route.NetID]struct{} `json:"selected_routes"` + SelectAll bool `json:"select_all"` + } + + if err := json.Unmarshal(data, &temp); err != nil { + return err + } + + rs.selectedRoutes = temp.SelectedRoutes + rs.selectAll = temp.SelectAll + + if rs.selectedRoutes == nil { + rs.selectedRoutes = map[route.NetID]struct{}{} + } + + return nil +} diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index da6dd022f..aae73b79f 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -22,9 +22,28 @@ import ( // State interface defines the methods that all state types must implement type State interface { Name() string +} + +// CleanableState interface extends State with cleanup capability +type CleanableState interface { + State Cleanup() error } +// RawState wraps raw JSON data for unregistered states +type RawState struct { + data json.RawMessage +} + +func (r *RawState) Name() string { + return "" // This is a placeholder implementation +} + +// MarshalJSON implements json.Marshaler to preserve the original JSON +func (r *RawState) MarshalJSON() ([]byte, error) { + return r.data, nil +} + // Manager handles the persistence and management of various states type Manager struct { mu sync.Mutex @@ -209,15 +228,15 @@ func (m *Manager) PersistState(ctx context.Context) error { return nil } -// loadState loads the existing state from the state file -func (m *Manager) loadState() error { +// loadStateFile reads and unmarshals the state file into a map of raw JSON messages +func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) { data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { log.Debug("state file does not exist") - return nil + return nil, nil // nolint:nilnil } - return fmt.Errorf("read state file: %w", err) + return nil, fmt.Errorf("read state file: %w", err) } var rawStates map[string]json.RawMessage @@ -228,37 +247,69 @@ func (m *Manager) loadState() error { } else { log.Info("State file deleted") } - return fmt.Errorf("unmarshal states: %w", err) + return nil, fmt.Errorf("unmarshal states: %w", err) } - var merr *multierror.Error + return rawStates, nil +} - for name, rawState := range rawStates { - stateType, ok := m.stateTypes[name] - if !ok { - merr = multierror.Append(merr, fmt.Errorf("unknown state type: %s", name)) - continue - } +// loadSingleRawState unmarshals a raw state into a concrete state object +func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) { + stateType, ok := m.stateTypes[name] + if !ok { + return nil, fmt.Errorf("state %s not registered", name) + } - if string(rawState) == "null" { - continue - } + if string(rawState) == "null" { + return nil, nil //nolint:nilnil + } - statePtr := reflect.New(stateType).Interface().(State) - if err := json.Unmarshal(rawState, statePtr); err != nil { - merr = multierror.Append(merr, fmt.Errorf("unmarshal state %s: %w", name, err)) - continue - } + statePtr := reflect.New(stateType).Interface().(State) + if err := json.Unmarshal(rawState, statePtr); err != nil { + return nil, fmt.Errorf("unmarshal state %s: %w", name, err) + } - m.states[name] = statePtr + return statePtr, nil +} + +// LoadState loads a specific state from the state file +func (m *Manager) LoadState(state State) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + rawStates, err := m.loadStateFile() + if err != nil { + return err + } + if rawStates == nil { + return nil + } + + name := state.Name() + rawState, exists := rawStates[name] + if !exists { + return nil + } + + loadedState, err := m.loadSingleRawState(name, rawState) + if err != nil { + return err + } + + m.states[name] = loadedState + if loadedState != nil { log.Debugf("loaded state: %s", name) } - return nberrors.FormatErrorOrNil(merr) + return nil } -// PerformCleanup retrieves all states from the state file for the registered states and calls Cleanup on them. -// If the cleanup is successful, the state is marked for deletion. +// PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it. +// Unregistered states are preserved in their original state. func (m *Manager) PerformCleanup() error { if m == nil { return nil @@ -267,22 +318,53 @@ func (m *Manager) PerformCleanup() error { m.mu.Lock() defer m.mu.Unlock() - if err := m.loadState(); err != nil { + // Load raw states from file + rawStates, err := m.loadStateFile() + if err != nil { log.Warnf("Failed to load state during cleanup: %v", err) + return err + } + if rawStates == nil { + return nil } var merr *multierror.Error - for name, state := range m.states { - if state == nil { - // If no state was found in the state file, we don't mark the state dirty nor return an error + + // Process each state in the file + for name, rawState := range rawStates { + // For unregistered states, preserve the raw JSON + if _, registered := m.stateTypes[name]; !registered { + m.states[name] = &RawState{data: rawState} continue } + // Load the registered state + loadedState, err := m.loadSingleRawState(name, rawState) + if err != nil { + merr = multierror.Append(merr, err) + continue + } + + if loadedState == nil { + continue + } + + // Check if state supports cleanup + cleanableState, isCleanable := loadedState.(CleanableState) + if !isCleanable { + // If it doesn't support cleanup, keep it as-is + m.states[name] = loadedState + continue + } + + // Perform cleanup for cleanable states log.Infof("client was not shut down properly, cleaning up %s", name) - if err := state.Cleanup(); err != nil { + if err := cleanableState.Cleanup(); err != nil { merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err)) + // On cleanup error, preserve the state + m.states[name] = loadedState } else { - // mark for deletion on cleanup success + // Successfully cleaned up - mark for deletion m.states[name] = nil m.dirty[name] = struct{}{} } From c7e7ad5030cd695f1ba1ad2c819b2a8a551a90d8 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 2 Dec 2024 18:04:02 +0100 Subject: [PATCH 058/159] [client] Add state file to debug bundle (#2969) --- client/anonymize/anonymize.go | 4 +- client/anonymize/anonymize_test.go | 8 + client/server/debug.go | 178 +++++++++++++++++--- client/server/debug_test.go | 258 +++++++++++++++++++++++++++++ 4 files changed, 425 insertions(+), 23 deletions(-) create mode 100644 client/server/debug_test.go diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index 7ebe0442d..ad31682f2 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -152,9 +152,9 @@ func (a *Anonymizer) AnonymizeString(str string) string { return str } -// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes. +// AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes. func (a *Anonymizer) AnonymizeSchemeURI(text string) string { - re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`) + re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`) return re.ReplaceAllStringFunc(text, a.AnonymizeURI) } diff --git a/client/anonymize/anonymize_test.go b/client/anonymize/anonymize_test.go index e660749ec..605788ab5 100644 --- a/client/anonymize/anonymize_test.go +++ b/client/anonymize/anonymize_test.go @@ -140,8 +140,16 @@ func TestAnonymizeSchemeURI(t *testing.T) { expect string }{ {"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`}, + {"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`}, {"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`}, + {"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`}, + {"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`}, + {"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`}, {"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`}, + {"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`}, + {"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`}, + {"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`}, + {"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`}, } for _, tc := range tests { diff --git a/client/server/debug.go b/client/server/debug.go index 5ed43293b..1bad907ba 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -5,9 +5,13 @@ package server import ( "archive/zip" "bufio" + "bytes" "context" + "encoding/json" + "errors" "fmt" "io" + "io/fs" "net" "net/netip" "os" @@ -20,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" + "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/proto" ) @@ -31,6 +36,7 @@ client.log: Most recent, anonymized log file of the NetBird client. routes.txt: Anonymized system routes, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided. config.txt: Anonymized configuration information of the NetBird client. +state.json: Anonymized client state dump containing netbird states. Anonymization Process @@ -50,8 +56,22 @@ Domains All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle. Reoccuring domain names are replaced with the same anonymized domain. +State File +The state.json file contains anonymized internal state information of the NetBird client, including: +- DNS settings and configuration +- Firewall rules +- Exclusion routes +- Route selection +- Other internal states that may be present + +The state file follows the same anonymization rules as other files: +- IP addresses (both individual and CIDR ranges) are anonymized while preserving their structure +- Domain names are consistently anonymized +- Technical identifiers and non-sensitive data remain unchanged + Routes For anonymized routes, the IP addresses are replaced as described above. The prefix length remains unchanged. Note that for prefixes, the anonymized IP might not be a network address, but the prefix length is still correct. + Network Interfaces The interfaces.txt file contains information about network interfaces, including: - Interface name @@ -132,6 +152,10 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques } } + if err := s.addStateFile(req, anonymizer, archive); err != nil { + log.Errorf("Failed to add state file to debug bundle: %v", err) + } + if err := s.addLogfile(req, anonymizer, archive); err != nil { return fmt.Errorf("add log file: %w", err) } @@ -248,6 +272,44 @@ func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonym return nil } +func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + path := statemanager.GetDefaultStatePath() + if path == "" { + return nil + } + + data, err := os.ReadFile(path) + if err != nil { + if errors.Is(err, fs.ErrNotExist) { + return nil + } + return fmt.Errorf("read state file: %w", err) + } + + if req.GetAnonymize() { + var rawStates map[string]json.RawMessage + if err := json.Unmarshal(data, &rawStates); err != nil { + return fmt.Errorf("unmarshal states: %w", err) + } + + if err := anonymizeStateFile(&rawStates, anonymizer); err != nil { + return fmt.Errorf("anonymize state file: %w", err) + } + + bs, err := json.MarshalIndent(rawStates, "", " ") + if err != nil { + return fmt.Errorf("marshal states: %w", err) + } + data = bs + } + + if err := addFileToZip(archive, bytes.NewReader(data), "state.json"); err != nil { + return fmt.Errorf("add state file to zip: %w", err) + } + + return nil +} + func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) { logFile, err := os.Open(s.logFile) if err != nil { @@ -264,7 +326,7 @@ func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize var writer *io.PipeWriter logReader, writer = io.Pipe() - go s.anonymize(logFile, writer, anonymizer) + go anonymizeLog(logFile, writer, anonymizer) } else { logReader = logFile } @@ -275,26 +337,6 @@ func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize return nil } -func (s *Server) anonymize(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) { - defer func() { - // always nil - _ = writer.Close() - }() - - scanner := bufio.NewScanner(reader) - for scanner.Scan() { - line := anonymizer.AnonymizeString(scanner.Text()) - if _, err := writer.Write([]byte(line + "\n")); err != nil { - writer.CloseWithError(fmt.Errorf("anonymize write: %w", err)) - return - } - } - if err := scanner.Err(); err != nil { - writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err)) - return - } -} - // GetLogLevel gets the current logging level for the server. func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) { level := ParseLogLevel(log.GetLevel().String()) @@ -458,6 +500,26 @@ func formatInterfaces(interfaces []net.Interface, anonymize bool, anonymizer *an return builder.String() } +func anonymizeLog(reader io.Reader, writer *io.PipeWriter, anonymizer *anonymize.Anonymizer) { + defer func() { + // always nil + _ = writer.Close() + }() + + scanner := bufio.NewScanner(reader) + for scanner.Scan() { + line := anonymizer.AnonymizeString(scanner.Text()) + if _, err := writer.Write([]byte(line + "\n")); err != nil { + writer.CloseWithError(fmt.Errorf("anonymize write: %w", err)) + return + } + } + if err := scanner.Err(); err != nil { + writer.CloseWithError(fmt.Errorf("anonymize scan: %w", err)) + return + } +} + func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []string { anonymizedIPs := make([]string, len(ips)) for i, ip := range ips { @@ -484,3 +546,77 @@ func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []s } return anonymizedIPs } + +func anonymizeStateFile(rawStates *map[string]json.RawMessage, anonymizer *anonymize.Anonymizer) error { + for name, rawState := range *rawStates { + if string(rawState) == "null" { + continue + } + + var state map[string]any + if err := json.Unmarshal(rawState, &state); err != nil { + return fmt.Errorf("unmarshal state %s: %w", name, err) + } + + state = anonymizeValue(state, anonymizer).(map[string]any) + + bs, err := json.Marshal(state) + if err != nil { + return fmt.Errorf("marshal state %s: %w", name, err) + } + + (*rawStates)[name] = bs + } + + return nil +} + +func anonymizeValue(value any, anonymizer *anonymize.Anonymizer) any { + switch v := value.(type) { + case string: + return anonymizeString(v, anonymizer) + case map[string]any: + return anonymizeMap(v, anonymizer) + case []any: + return anonymizeSlice(v, anonymizer) + } + return value +} + +func anonymizeString(v string, anonymizer *anonymize.Anonymizer) string { + if prefix, err := netip.ParsePrefix(v); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + return fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } + if ip, err := netip.ParseAddr(v); err == nil { + return anonymizer.AnonymizeIP(ip).String() + } + return anonymizer.AnonymizeString(v) +} + +func anonymizeMap(v map[string]any, anonymizer *anonymize.Anonymizer) map[string]any { + result := make(map[string]any, len(v)) + for key, val := range v { + newKey := anonymizeMapKey(key, anonymizer) + result[newKey] = anonymizeValue(val, anonymizer) + } + return result +} + +func anonymizeMapKey(key string, anonymizer *anonymize.Anonymizer) string { + if prefix, err := netip.ParsePrefix(key); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + return fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } + if ip, err := netip.ParseAddr(key); err == nil { + return anonymizer.AnonymizeIP(ip).String() + } + return key +} + +func anonymizeSlice(v []any, anonymizer *anonymize.Anonymizer) []any { + for i, val := range v { + v[i] = anonymizeValue(val, anonymizer) + } + return v +} diff --git a/client/server/debug_test.go b/client/server/debug_test.go new file mode 100644 index 000000000..303e5e661 --- /dev/null +++ b/client/server/debug_test.go @@ -0,0 +1,258 @@ +package server + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/anonymize" +) + +func TestAnonymizeStateFile(t *testing.T) { + testState := map[string]json.RawMessage{ + "null_state": json.RawMessage("null"), + "test_state": mustMarshal(map[string]any{ + // Test simple fields + "public_ip": "203.0.113.1", + "private_ip": "192.168.1.1", + "protected_ip": "100.64.0.1", + "well_known_ip": "8.8.8.8", + "ipv6_addr": "2001:db8::1", + "private_ipv6": "fd00::1", + "domain": "test.example.com", + "uri": "stun:stun.example.com:3478", + "uri_with_ip": "turn:203.0.113.1:3478", + "netbird_domain": "device.netbird.cloud", + + // Test CIDR ranges + "public_cidr": "203.0.113.0/24", + "private_cidr": "192.168.0.0/16", + "protected_cidr": "100.64.0.0/10", + "ipv6_cidr": "2001:db8::/32", + "private_ipv6_cidr": "fd00::/8", + + // Test nested structures + "nested": map[string]any{ + "ip": "203.0.113.2", + "domain": "nested.example.com", + "more_nest": map[string]any{ + "ip": "203.0.113.3", + "domain": "deep.example.com", + }, + }, + + // Test arrays + "string_array": []any{ + "203.0.113.4", + "test1.example.com", + "test2.example.com", + }, + "object_array": []any{ + map[string]any{ + "ip": "203.0.113.5", + "domain": "array1.example.com", + }, + map[string]any{ + "ip": "203.0.113.6", + "domain": "array2.example.com", + }, + }, + + // Test multiple occurrences of same value + "duplicate_ip": "203.0.113.1", // Same as public_ip + "duplicate_domain": "test.example.com", // Same as domain + + // Test URIs with various schemes + "stun_uri": "stun:stun.example.com:3478", + "turns_uri": "turns:turns.example.com:5349", + "http_uri": "http://web.example.com:80", + "https_uri": "https://secure.example.com:443", + + // Test strings that might look like IPs but aren't + "not_ip": "300.300.300.300", + "partial_ip": "192.168", + "ip_like_string": "1234.5678", + + // Test mixed content strings + "mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80", + + // Test empty and special values + "empty_string": "", + "null_value": nil, + "numeric_value": 42, + "boolean_value": true, + }), + "route_state": mustMarshal(map[string]any{ + "routes": []any{ + map[string]any{ + "network": "203.0.113.0/24", + "gateway": "203.0.113.1", + "domains": []any{ + "route1.example.com", + "route2.example.com", + }, + }, + map[string]any{ + "network": "2001:db8::/32", + "gateway": "2001:db8::1", + "domains": []any{ + "route3.example.com", + "route4.example.com", + }, + }, + }, + // Test map with IP/CIDR keys + "refCountMap": map[string]any{ + "203.0.113.1/32": map[string]any{ + "Count": 1, + "Out": map[string]any{ + "IP": "192.168.0.1", + "Intf": map[string]any{ + "Name": "eth0", + "Index": 1, + }, + }, + }, + "2001:db8::1/128": map[string]any{ + "Count": 1, + "Out": map[string]any{ + "IP": "fe80::1", + "Intf": map[string]any{ + "Name": "eth0", + "Index": 1, + }, + }, + }, + "10.0.0.1/32": map[string]any{ // private IP should remain unchanged + "Count": 1, + "Out": map[string]any{ + "IP": "192.168.0.1", + }, + }, + }, + }), + } + + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + err := anonymizeStateFile(&testState, anonymizer) + require.NoError(t, err) + + // Helper function to unmarshal and get nested values + var state map[string]any + err = json.Unmarshal(testState["test_state"], &state) + require.NoError(t, err) + + // Test null state remains unchanged + require.Equal(t, "null", string(testState["null_state"])) + + // Basic assertions + assert.NotEqual(t, "203.0.113.1", state["public_ip"]) + assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged + assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged + assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged + assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"]) + assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged + assert.NotEqual(t, "test.example.com", state["domain"]) + assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain")) + assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged + + // CIDR ranges + assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"]) + assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved + assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged + assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged + assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"]) + assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved + + // Nested structures + nested := state["nested"].(map[string]any) + assert.NotEqual(t, "203.0.113.2", nested["ip"]) + assert.NotEqual(t, "nested.example.com", nested["domain"]) + moreNest := nested["more_nest"].(map[string]any) + assert.NotEqual(t, "203.0.113.3", moreNest["ip"]) + assert.NotEqual(t, "deep.example.com", moreNest["domain"]) + + // Arrays + strArray := state["string_array"].([]any) + assert.NotEqual(t, "203.0.113.4", strArray[0]) + assert.NotEqual(t, "test1.example.com", strArray[1]) + assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain")) + + objArray := state["object_array"].([]any) + firstObj := objArray[0].(map[string]any) + assert.NotEqual(t, "203.0.113.5", firstObj["ip"]) + assert.NotEqual(t, "array1.example.com", firstObj["domain"]) + + // Duplicate values should be anonymized consistently + assert.Equal(t, state["public_ip"], state["duplicate_ip"]) + assert.Equal(t, state["domain"], state["duplicate_domain"]) + + // URIs + assert.NotContains(t, state["stun_uri"], "stun.example.com") + assert.NotContains(t, state["turns_uri"], "turns.example.com") + assert.NotContains(t, state["http_uri"], "web.example.com") + assert.NotContains(t, state["https_uri"], "secure.example.com") + + // Non-IP strings should remain unchanged + assert.Equal(t, "300.300.300.300", state["not_ip"]) + assert.Equal(t, "192.168", state["partial_ip"]) + assert.Equal(t, "1234.5678", state["ip_like_string"]) + + // Mixed content should have IPs and domains replaced + mixedContent := state["mixed_content"].(string) + assert.NotContains(t, mixedContent, "203.0.113.1") + assert.NotContains(t, mixedContent, "test.example.com") + assert.Contains(t, mixedContent, "Server at ") + assert.Contains(t, mixedContent, " on port 80") + + // Special values should remain unchanged + assert.Equal(t, "", state["empty_string"]) + assert.Nil(t, state["null_value"]) + assert.Equal(t, float64(42), state["numeric_value"]) + assert.Equal(t, true, state["boolean_value"]) + + // Check route state + var routeState map[string]any + err = json.Unmarshal(testState["route_state"], &routeState) + require.NoError(t, err) + + routes := routeState["routes"].([]any) + route1 := routes[0].(map[string]any) + assert.NotEqual(t, "203.0.113.0/24", route1["network"]) + assert.Contains(t, route1["network"], "/24") + assert.NotEqual(t, "203.0.113.1", route1["gateway"]) + domains := route1["domains"].([]any) + assert.True(t, strings.HasSuffix(domains[0].(string), ".domain")) + assert.True(t, strings.HasSuffix(domains[1].(string), ".domain")) + + // Check map keys are anonymized + refCountMap := routeState["refCountMap"].(map[string]any) + hasPublicIPKey := false + hasIPv6Key := false + hasPrivateIPKey := false + for key := range refCountMap { + if strings.Contains(key, "203.0.113.1") { + hasPublicIPKey = true + } + if strings.Contains(key, "2001:db8::1") { + hasIPv6Key = true + } + if key == "10.0.0.1/32" { + hasPrivateIPKey = true + } + } + assert.False(t, hasPublicIPKey, "public IP in key should be anonymized") + assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized") + assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged") +} + +func mustMarshal(v any) json.RawMessage { + data, err := json.Marshal(v) + if err != nil { + panic(err) + } + return data +} From dffce78a8c8e78161e8fe1f8d7f01ba232c0d8e7 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 2 Dec 2024 20:19:34 +0100 Subject: [PATCH 059/159] [client] Fix debug bundle state anonymization test (#2976) --- client/server/debug_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/client/server/debug_test.go b/client/server/debug_test.go index 303e5e661..1515036bd 100644 --- a/client/server/debug_test.go +++ b/client/server/debug_test.go @@ -137,6 +137,13 @@ func TestAnonymizeStateFile(t *testing.T) { } anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + + // Pre-seed the domains we need to verify in the test assertions + anonymizer.AnonymizeDomain("test.example.com") + anonymizer.AnonymizeDomain("nested.example.com") + anonymizer.AnonymizeDomain("deep.example.com") + anonymizer.AnonymizeDomain("array1.example.com") + err := anonymizeStateFile(&testState, anonymizer) require.NoError(t, err) From a0bf0bdcc077e54294c6e31a24efcd9047f3f97f Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Tue, 3 Dec 2024 10:13:27 +0100 Subject: [PATCH 060/159] Pass IP instead of net to Rosenpass (#2975) --- client/internal/peer/conn.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 81c456db7..a8de2fccb 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -83,7 +83,6 @@ type Conn struct { signaler *Signaler relayManager *relayClient.Manager allowedIP net.IP - allowedNet string handshaker *Handshaker onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) @@ -111,7 +110,7 @@ type Conn struct { // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { - allowedIP, allowedNet, err := net.ParseCIDR(config.WgConfig.AllowedIps) + allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps) if err != nil { log.Errorf("failed to parse allowedIPS: %v", err) return nil, err @@ -129,7 +128,6 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu signaler: signaler, relayManager: relayManager, allowedIP: allowedIP, - allowedNet: allowedNet.String(), statusRelay: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(), } @@ -594,7 +592,7 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd } if conn.onConnected != nil { - conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedNet, remoteRosenpassAddr) + conn.onConnected(conn.config.Key, remoteRosenpassPubKey, conn.allowedIP.String(), remoteRosenpassAddr) } } From a4826cfb5fb5e0510f485644583bb14da1aa2ae8 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 3 Dec 2024 10:22:04 +0100 Subject: [PATCH 061/159] [client] Get static system info once (#2965) Get static system info once for Windows, Darwin, and Linux nodes This should improve startup and peer authentication times --- client/system/info.go | 8 ++++ client/system/info_darwin.go | 20 +++++----- client/system/info_linux.go | 61 +++++++++++++---------------- client/system/info_windows.go | 53 +++++++++++++------------ client/system/static_info.go | 46 ++++++++++++++++++++++ client/system/sysinfo_linux_test.go | 5 ++- 6 files changed, 123 insertions(+), 70 deletions(-) create mode 100644 client/system/static_info.go diff --git a/client/system/info.go b/client/system/info.go index 2af2e637b..200d835df 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -61,6 +61,14 @@ type Info struct { Files []File // for posture checks } +// StaticInfo is an object that contains machine information that does not change +type StaticInfo struct { + SystemSerialNumber string + SystemProductName string + SystemManufacturer string + Environment Environment +} + // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context func extractUserAgent(ctx context.Context) string { md, hasMeta := metadata.FromOutgoingContext(ctx) diff --git a/client/system/info_darwin.go b/client/system/info_darwin.go index 6f4ed173b..13b0a446b 100644 --- a/client/system/info_darwin.go +++ b/client/system/info_darwin.go @@ -10,13 +10,12 @@ import ( "os/exec" "runtime" "strings" + "time" "golang.org/x/sys/unix" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/version" ) @@ -41,11 +40,10 @@ func GetInfo(ctx context.Context) *Info { log.Warnf("failed to discover network addresses: %s", err) } - serialNum, prodName, manufacturer := sysInfo() - - env := Environment{ - Cloud: detect_cloud.Detect(ctx), - Platform: detect_platform.Detect(ctx), + start := time.Now() + si := updateStaticInfo() + if time.Since(start) > 1*time.Second { + log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ @@ -57,10 +55,10 @@ func GetInfo(ctx context.Context) *Info { CPUs: runtime.NumCPU(), KernelVersion: release, NetworkAddresses: addrs, - SystemSerialNumber: serialNum, - SystemProductName: prodName, - SystemManufacturer: manufacturer, - Environment: env, + SystemSerialNumber: si.SystemSerialNumber, + SystemProductName: si.SystemProductName, + SystemManufacturer: si.SystemManufacturer, + Environment: si.Environment, } systemHostname, _ := os.Hostname() diff --git a/client/system/info_linux.go b/client/system/info_linux.go index b6a142bce..bfc77be19 100644 --- a/client/system/info_linux.go +++ b/client/system/info_linux.go @@ -1,5 +1,4 @@ //go:build !android -// +build !android package system @@ -16,30 +15,13 @@ import ( log "github.com/sirupsen/logrus" "github.com/zcalusic/sysinfo" - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/version" ) -type SysInfoGetter interface { - GetSysInfo() SysInfo -} - -type SysInfoWrapper struct { - si sysinfo.SysInfo -} - -func (s SysInfoWrapper) GetSysInfo() SysInfo { - s.si.GetSysInfo() - return SysInfo{ - ChassisSerial: s.si.Chassis.Serial, - ProductSerial: s.si.Product.Serial, - BoardSerial: s.si.Board.Serial, - ProductName: s.si.Product.Name, - BoardName: s.si.Board.Name, - ProductVendor: s.si.Product.Vendor, - } -} +var ( + // it is override in tests + getSystemInfo = defaultSysInfoImplementation +) // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { @@ -65,12 +47,10 @@ func GetInfo(ctx context.Context) *Info { log.Warnf("failed to discover network addresses: %s", err) } - si := SysInfoWrapper{} - serialNum, prodName, manufacturer := sysInfo(si.GetSysInfo()) - - env := Environment{ - Cloud: detect_cloud.Detect(ctx), - Platform: detect_platform.Detect(ctx), + start := time.Now() + si := updateStaticInfo() + if time.Since(start) > 1*time.Second { + log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ @@ -85,10 +65,10 @@ func GetInfo(ctx context.Context) *Info { UIVersion: extractUserAgent(ctx), KernelVersion: osInfo[1], NetworkAddresses: addrs, - SystemSerialNumber: serialNum, - SystemProductName: prodName, - SystemManufacturer: manufacturer, - Environment: env, + SystemSerialNumber: si.SystemSerialNumber, + SystemProductName: si.SystemProductName, + SystemManufacturer: si.SystemManufacturer, + Environment: si.Environment, } return gio @@ -108,9 +88,9 @@ func _getInfo() string { return out.String() } -func sysInfo(si SysInfo) (string, string, string) { +func sysInfo() (string, string, string) { isascii := regexp.MustCompile("^[[:ascii:]]+$") - + si := getSystemInfo() serials := []string{si.ChassisSerial, si.ProductSerial} serial := "" @@ -141,3 +121,16 @@ func sysInfo(si SysInfo) (string, string, string) { } return serial, name, manufacturer } + +func defaultSysInfoImplementation() SysInfo { + si := sysinfo.SysInfo{} + si.GetSysInfo() + return SysInfo{ + ChassisSerial: si.Chassis.Serial, + ProductSerial: si.Product.Serial, + BoardSerial: si.Board.Serial, + ProductName: si.Product.Name, + BoardName: si.Board.Name, + ProductVendor: si.Product.Vendor, + } +} diff --git a/client/system/info_windows.go b/client/system/info_windows.go index 68631fe16..28bd3d300 100644 --- a/client/system/info_windows.go +++ b/client/system/info_windows.go @@ -6,13 +6,12 @@ import ( "os" "runtime" "strings" + "time" log "github.com/sirupsen/logrus" "github.com/yusufpapurcu/wmi" "golang.org/x/sys/windows/registry" - "github.com/netbirdio/netbird/client/system/detect_cloud" - "github.com/netbirdio/netbird/client/system/detect_platform" "github.com/netbirdio/netbird/version" ) @@ -42,24 +41,10 @@ func GetInfo(ctx context.Context) *Info { log.Warnf("failed to discover network addresses: %s", err) } - serialNum, err := sysNumber() - if err != nil { - log.Warnf("failed to get system serial number: %s", err) - } - - prodName, err := sysProductName() - if err != nil { - log.Warnf("failed to get system product name: %s", err) - } - - manufacturer, err := sysManufacturer() - if err != nil { - log.Warnf("failed to get system manufacturer: %s", err) - } - - env := Environment{ - Cloud: detect_cloud.Detect(ctx), - Platform: detect_platform.Detect(ctx), + start := time.Now() + si := updateStaticInfo() + if time.Since(start) > 1*time.Second { + log.Warnf("updateStaticInfo took %s", time.Since(start)) } gio := &Info{ @@ -71,10 +56,10 @@ func GetInfo(ctx context.Context) *Info { CPUs: runtime.NumCPU(), KernelVersion: buildVersion, NetworkAddresses: addrs, - SystemSerialNumber: serialNum, - SystemProductName: prodName, - SystemManufacturer: manufacturer, - Environment: env, + SystemSerialNumber: si.SystemSerialNumber, + SystemProductName: si.SystemProductName, + SystemManufacturer: si.SystemManufacturer, + Environment: si.Environment, } systemHostname, _ := os.Hostname() @@ -85,6 +70,26 @@ func GetInfo(ctx context.Context) *Info { return gio } +func sysInfo() (serialNumber string, productName string, manufacturer string) { + var err error + serialNumber, err = sysNumber() + if err != nil { + log.Warnf("failed to get system serial number: %s", err) + } + + productName, err = sysProductName() + if err != nil { + log.Warnf("failed to get system product name: %s", err) + } + + manufacturer, err = sysManufacturer() + if err != nil { + log.Warnf("failed to get system manufacturer: %s", err) + } + + return serialNumber, productName, manufacturer +} + func getOSNameAndVersion() (string, string) { var dst []Win32_OperatingSystem query := wmi.CreateQuery(&dst, "") diff --git a/client/system/static_info.go b/client/system/static_info.go new file mode 100644 index 000000000..fabe65a68 --- /dev/null +++ b/client/system/static_info.go @@ -0,0 +1,46 @@ +//go:build (linux && !android) || windows || (darwin && !ios) + +package system + +import ( + "context" + "sync" + "time" + + "github.com/netbirdio/netbird/client/system/detect_cloud" + "github.com/netbirdio/netbird/client/system/detect_platform" +) + +var ( + staticInfo StaticInfo + once sync.Once +) + +func init() { + go func() { + _ = updateStaticInfo() + }() +} + +func updateStaticInfo() StaticInfo { + once.Do(func() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wg := sync.WaitGroup{} + wg.Add(3) + go func() { + staticInfo.SystemSerialNumber, staticInfo.SystemProductName, staticInfo.SystemManufacturer = sysInfo() + wg.Done() + }() + go func() { + staticInfo.Environment.Cloud = detect_cloud.Detect(ctx) + wg.Done() + }() + go func() { + staticInfo.Environment.Platform = detect_platform.Detect(ctx) + wg.Done() + }() + wg.Wait() + }) + return staticInfo +} diff --git a/client/system/sysinfo_linux_test.go b/client/system/sysinfo_linux_test.go index f6a0b7058..ae89bfcf9 100644 --- a/client/system/sysinfo_linux_test.go +++ b/client/system/sysinfo_linux_test.go @@ -183,7 +183,10 @@ func Test_sysInfo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotSerialNum, gotProdName, gotManufacturer := sysInfo(tt.sysInfo) + getSystemInfo = func() SysInfo { + return tt.sysInfo + } + gotSerialNum, gotProdName, gotManufacturer := sysInfo() if gotSerialNum != tt.wantSerialNum { t.Errorf("sysInfo() gotSerialNum = %v, want %v", gotSerialNum, tt.wantSerialNum) } From 6285e0d23ef95d98abc815315c67a24371f7a2a4 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 3 Dec 2024 12:43:17 +0100 Subject: [PATCH 062/159] [client] Add netbird.err and netbird.out to debug bundle (#2971) --- client/server/debug.go | 45 +++++++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/client/server/debug.go b/client/server/debug.go index 1bad907ba..d87fc9b6a 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -15,6 +15,7 @@ import ( "net" "net/netip" "os" + "path/filepath" "sort" "strings" "time" @@ -32,7 +33,9 @@ const readmeContent = `Netbird debug bundle This debug bundle contains the following files: status.txt: Anonymized status information of the NetBird client. -client.log: Most recent, anonymized log file of the NetBird client. +client.log: Most recent, anonymized client log file of the NetBird client. +netbird.err: Most recent, anonymized stderr log file of the NetBird client. +netbird.out: Most recent, anonymized stdout log file of the NetBird client. routes.txt: Anonymized system routes, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided. config.txt: Anonymized configuration information of the NetBird client. @@ -92,6 +95,12 @@ The config.txt file contains anonymized configuration information of the NetBird Other non-sensitive configuration options are included without anonymization. ` +const ( + clientLogFile = "client.log" + errorLogFile = "netbird.err" + stdoutLogFile = "netbird.out" +) + // DebugBundle creates a debug bundle and returns the location. func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (resp *proto.DebugBundleResponse, err error) { s.mutex.Lock() @@ -310,14 +319,35 @@ func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymi return nil } -func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) (err error) { - logFile, err := os.Open(s.logFile) +func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + logDir := filepath.Dir(s.logFile) + + if err := s.addSingleLogfile(s.logFile, clientLogFile, req, anonymizer, archive); err != nil { + return fmt.Errorf("add client log file to zip: %w", err) + } + + errLogPath := filepath.Join(logDir, errorLogFile) + if err := s.addSingleLogfile(errLogPath, errorLogFile, req, anonymizer, archive); err != nil { + log.Warnf("Failed to add %s to zip: %v", errorLogFile, err) + } + + stdoutLogPath := filepath.Join(logDir, stdoutLogFile) + if err := s.addSingleLogfile(stdoutLogPath, stdoutLogFile, req, anonymizer, archive); err != nil { + log.Warnf("Failed to add %s to zip: %v", stdoutLogFile, err) + } + + return nil +} + +// addSingleLogfile adds a single log file to the archive +func (s *Server) addSingleLogfile(logPath, targetName string, req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + logFile, err := os.Open(logPath) if err != nil { - return fmt.Errorf("open log file: %w", err) + return fmt.Errorf("open log file %s: %w", targetName, err) } defer func() { if err := logFile.Close(); err != nil { - log.Errorf("Failed to close original log file: %v", err) + log.Errorf("Failed to close log file %s: %v", targetName, err) } }() @@ -330,8 +360,9 @@ func (s *Server) addLogfile(req *proto.DebugBundleRequest, anonymizer *anonymize } else { logReader = logFile } - if err := addFileToZip(archive, logReader, "client.log"); err != nil { - return fmt.Errorf("add log file to zip: %w", err) + + if err := addFileToZip(archive, logReader, targetName); err != nil { + return fmt.Errorf("add %s to zip: %w", targetName, err) } return nil From 7dacd9cb235cc7c7501cd444bc0f76633900177b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Nohlg=C3=A5rd?= Date: Tue, 3 Dec 2024 13:49:02 +0100 Subject: [PATCH 063/159] [management] Add missing parentheses on iphone hostname generation condition (#2977) --- management/server/peer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/peer.go b/management/server/peer.go index d45bb1a50..a301544d5 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -477,7 +477,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s setupKeyName = sk.Name } - if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { + if (strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad") && userID != "" { if am.idpManager != nil { userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) if err == nil && userdata != nil { From 17c20b45ce9a029f7180bbef49eb5cd3ada000b4 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:50:12 +0100 Subject: [PATCH 064/159] [client] Add network map to debug bundle (#2966) --- client/anonymize/anonymize.go | 38 ++-- client/anonymize/anonymize_test.go | 18 ++ client/cmd/debug.go | 52 +++++ client/cmd/root.go | 1 + client/internal/connect.go | 20 +- client/internal/engine.go | 66 +++++- client/proto/daemon.pb.go | 352 ++++++++++++++++++++--------- client/proto/daemon.proto | 14 +- client/proto/daemon_grpc.pb.go | 38 ++++ client/server/debug.go | 289 ++++++++++++++++++++++- client/server/debug_test.go | 165 ++++++++++++++ client/server/server.go | 3 + go.mod | 4 +- go.sum | 8 +- 14 files changed, 915 insertions(+), 153 deletions(-) diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index ad31682f2..9a6d97207 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -12,6 +12,8 @@ import ( "strings" ) +const anonTLD = ".domain" + type Anonymizer struct { ipAnonymizer map[netip.Addr]netip.Addr domainAnonymizer map[string]string @@ -83,29 +85,39 @@ func (a *Anonymizer) AnonymizeIPString(ip string) string { } func (a *Anonymizer) AnonymizeDomain(domain string) string { - if strings.HasSuffix(domain, "netbird.io") || - strings.HasSuffix(domain, "netbird.selfhosted") || - strings.HasSuffix(domain, "netbird.cloud") || - strings.HasSuffix(domain, "netbird.stage") || - strings.HasSuffix(domain, ".domain") { + baseDomain := domain + hasDot := strings.HasSuffix(domain, ".") + if hasDot { + baseDomain = domain[:len(domain)-1] + } + + if strings.HasSuffix(baseDomain, "netbird.io") || + strings.HasSuffix(baseDomain, "netbird.selfhosted") || + strings.HasSuffix(baseDomain, "netbird.cloud") || + strings.HasSuffix(baseDomain, "netbird.stage") || + strings.HasSuffix(baseDomain, anonTLD) { return domain } - parts := strings.Split(domain, ".") + parts := strings.Split(baseDomain, ".") if len(parts) < 2 { return domain } - baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1] + baseForLookup := parts[len(parts)-2] + "." + parts[len(parts)-1] - anonymized, ok := a.domainAnonymizer[baseDomain] + anonymized, ok := a.domainAnonymizer[baseForLookup] if !ok { - anonymizedBase := "anon-" + generateRandomString(5) + ".domain" - a.domainAnonymizer[baseDomain] = anonymizedBase + anonymizedBase := "anon-" + generateRandomString(5) + anonTLD + a.domainAnonymizer[baseForLookup] = anonymizedBase anonymized = anonymizedBase } - return strings.Replace(domain, baseDomain, anonymized, 1) + result := strings.Replace(baseDomain, baseForLookup, anonymized, 1) + if hasDot { + result += "." + } + return result } func (a *Anonymizer) AnonymizeURI(uri string) string { @@ -168,10 +180,10 @@ func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string { parts := strings.Split(match, `"`) if len(parts) >= 2 { domain := parts[1] - if strings.HasSuffix(domain, ".domain") { + if strings.HasSuffix(domain, anonTLD) { return match } - randomDomain := generateRandomString(10) + ".domain" + randomDomain := generateRandomString(10) + anonTLD return strings.Replace(match, domain, randomDomain, 1) } return match diff --git a/client/anonymize/anonymize_test.go b/client/anonymize/anonymize_test.go index 605788ab5..a3aae1ee9 100644 --- a/client/anonymize/anonymize_test.go +++ b/client/anonymize/anonymize_test.go @@ -67,18 +67,36 @@ func TestAnonymizeDomain(t *testing.T) { `^anon-[a-zA-Z0-9]+\.domain$`, true, }, + { + "Domain with Trailing Dot", + "example.com.", + `^anon-[a-zA-Z0-9]+\.domain.$`, + true, + }, { "Subdomain", "sub.example.com", `^sub\.anon-[a-zA-Z0-9]+\.domain$`, true, }, + { + "Subdomain with Trailing Dot", + "sub.example.com.", + `^sub\.anon-[a-zA-Z0-9]+\.domain.$`, + true, + }, { "Protected Domain", "netbird.io", `^netbird\.io$`, false, }, + { + "Protected Domain with Trailing Dot", + "netbird.io.", + `^netbird\.io.$`, + false, + }, } for _, tc := range tests { diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 9abd2039d..c7ab87b47 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -3,6 +3,7 @@ package cmd import ( "context" "fmt" + "strings" "time" log "github.com/sirupsen/logrus" @@ -61,6 +62,15 @@ var forCmd = &cobra.Command{ RunE: runForDuration, } +var persistenceCmd = &cobra.Command{ + Use: "persistence [on|off]", + Short: "Set network map memory persistence", + Long: `Configure whether the latest network map should persist in memory. When enabled, the last known network map will be kept in memory.`, + Example: " netbird debug persistence on", + Args: cobra.ExactArgs(1), + RunE: setNetworkMapPersistence, +} + func debugBundle(cmd *cobra.Command, _ []string) error { conn, err := getClient(cmd) if err != nil { @@ -171,6 +181,13 @@ func runForDuration(cmd *cobra.Command, args []string) error { time.Sleep(1 * time.Second) + // Enable network map persistence before bringing the service up + if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{ + Enabled: true, + }); err != nil { + return fmt.Errorf("failed to enable network map persistence: %v", status.Convert(err).Message()) + } + if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil { return fmt.Errorf("failed to up: %v", status.Convert(err).Message()) } @@ -200,6 +217,13 @@ func runForDuration(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message()) } + // Disable network map persistence after creating the debug bundle + if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{ + Enabled: false, + }); err != nil { + return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message()) + } + if stateWasDown { if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil { return fmt.Errorf("failed to down: %v", status.Convert(err).Message()) @@ -219,6 +243,34 @@ func runForDuration(cmd *cobra.Command, args []string) error { return nil } +func setNetworkMapPersistence(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + log.Errorf(errCloseConnection, err) + } + }() + + persistence := strings.ToLower(args[0]) + if persistence != "on" && persistence != "off" { + return fmt.Errorf("invalid persistence value: %s. Use 'on' or 'off'", args[0]) + } + + client := proto.NewDaemonServiceClient(conn) + _, err = client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{ + Enabled: persistence == "on", + }) + if err != nil { + return fmt.Errorf("failed to set network map persistence: %v", status.Convert(err).Message()) + } + + cmd.Printf("Network map persistence set to: %s\n", persistence) + return nil +} + func getStatusOutput(cmd *cobra.Command) string { var statusOutputString string statusResp, err := getStatus(cmd.Context()) diff --git a/client/cmd/root.go b/client/cmd/root.go index 8dae6e273..3f2d04ef3 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -155,6 +155,7 @@ func init() { debugCmd.AddCommand(logCmd) logCmd.AddCommand(logLevelCmd) debugCmd.AddCommand(forCmd) + debugCmd.AddCommand(persistenceCmd) upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil, `Sets external IPs maps between local addresses and interfaces.`+ diff --git a/client/internal/connect.go b/client/internal/connect.go index 8c2ad4aa1..21f6d3ec8 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -40,6 +40,8 @@ type ConnectClient struct { statusRecorder *peer.Status engine *Engine engineMutex sync.Mutex + + persistNetworkMap bool } func NewConnectClient( @@ -258,7 +260,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, probes *ProbeHold c.engineMutex.Lock() c.engine = NewEngineWithProbes(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, probes, checks) - + c.engine.SetNetworkMapPersistence(c.persistNetworkMap) c.engineMutex.Unlock() if err := c.engine.Start(); err != nil { @@ -362,6 +364,22 @@ func (c *ConnectClient) isContextCancelled() bool { } } +// SetNetworkMapPersistence enables or disables network map persistence. +// When enabled, the last received network map will be stored and can be retrieved +// through the Engine's getLatestNetworkMap method. When disabled, any stored +// network map will be cleared. This functionality is primarily used for debugging +// and should not be enabled during normal operation. +func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) { + c.engineMutex.Lock() + c.persistNetworkMap = enabled + c.engineMutex.Unlock() + + engine := c.Engine() + if engine != nil { + engine.SetNetworkMapPersistence(enabled) + } +} + // createEngineConfig converts configuration received from Management Service to EngineConfig func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.PeerConfig) (*EngineConfig, error) { nm := false diff --git a/client/internal/engine.go b/client/internal/engine.go index 920c295cd..fc9620d80 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -21,6 +21,7 @@ import ( "github.com/pion/stun/v2" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "google.golang.org/protobuf/proto" "github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall/manager" @@ -172,6 +173,10 @@ type Engine struct { relayManager *relayClient.Manager stateManager *statemanager.Manager srWatcher *guard.SRWatcher + + // Network map persistence + persistNetworkMap bool + latestNetworkMap *mgmProto.NetworkMap } // Peer is an instance of the Connection Peer @@ -573,13 +578,22 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { return err } - if update.GetNetworkMap() != nil { - // only apply new changes and ignore old ones - err := e.updateNetworkMap(update.GetNetworkMap()) - if err != nil { - return err - } + nm := update.GetNetworkMap() + if nm == nil { + return nil } + + // Store network map if persistence is enabled + if e.persistNetworkMap { + e.latestNetworkMap = nm + log.Debugf("network map persisted with serial %d", nm.GetSerial()) + } + + // only apply new changes and ignore old ones + if err := e.updateNetworkMap(nm); err != nil { + return err + } + return nil } @@ -1500,6 +1514,46 @@ func (e *Engine) stopDNSServer() { e.statusRecorder.UpdateDNSStates(nsGroupStates) } +// SetNetworkMapPersistence enables or disables network map persistence +func (e *Engine) SetNetworkMapPersistence(enabled bool) { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + + if enabled == e.persistNetworkMap { + return + } + e.persistNetworkMap = enabled + log.Debugf("Network map persistence is set to %t", enabled) + + if !enabled { + e.latestNetworkMap = nil + } +} + +// GetLatestNetworkMap returns the stored network map if persistence is enabled +func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + + if !e.persistNetworkMap { + return nil, errors.New("network map persistence is disabled") + } + + if e.latestNetworkMap == nil { + //nolint:nilnil + return nil, nil + } + + // Create a deep copy to avoid external modifications + nm, ok := proto.Clone(e.latestNetworkMap).(*mgmProto.NetworkMap) + if !ok { + + return nil, fmt.Errorf("failed to clone network map") + } + + return nm, nil +} + // isChecksEqual checks if two slices of checks are equal. func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { for _, check := range checks { diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index b942d8b6e..0eae4e0d6 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.21.12 +// protoc v4.23.4 // source: daemon.proto package proto @@ -2103,6 +2103,91 @@ func (*SetLogLevelResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{30} } +type SetNetworkMapPersistenceRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Enabled bool `protobuf:"varint,1,opt,name=enabled,proto3" json:"enabled,omitempty"` +} + +func (x *SetNetworkMapPersistenceRequest) Reset() { + *x = SetNetworkMapPersistenceRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[31] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SetNetworkMapPersistenceRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetNetworkMapPersistenceRequest) ProtoMessage() {} + +func (x *SetNetworkMapPersistenceRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[31] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetNetworkMapPersistenceRequest.ProtoReflect.Descriptor instead. +func (*SetNetworkMapPersistenceRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{31} +} + +func (x *SetNetworkMapPersistenceRequest) GetEnabled() bool { + if x != nil { + return x.Enabled + } + return false +} + +type SetNetworkMapPersistenceResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *SetNetworkMapPersistenceResponse) Reset() { + *x = SetNetworkMapPersistenceResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[32] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SetNetworkMapPersistenceResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetNetworkMapPersistenceResponse) ProtoMessage() {} + +func (x *SetNetworkMapPersistenceResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[32] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetNetworkMapPersistenceResponse.ProtoReflect.Descriptor instead. +func (*SetNetworkMapPersistenceResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{32} +} + var File_daemon_proto protoreflect.FileDescriptor var file_daemon_proto_rawDesc = []byte{ @@ -2399,66 +2484,79 @@ var file_daemon_proto_rawDesc = []byte{ 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, - 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, - 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, - 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, - 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, - 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, - 0x43, 0x45, 0x10, 0x07, 0x32, 0xb8, 0x06, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, - 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, - 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, - 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, - 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, - 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, - 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, - 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, - 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, - 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, + 0x22, 0x3b, 0x0a, 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, + 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, + 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, + 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, + 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, + 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, + 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, + 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, + 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, + 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0xa9, 0x07, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, + 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, + 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, + 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, + 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, + 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, + 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, + 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, + 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, + 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, - 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, - 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, - 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, - 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, - 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, - 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, - 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, - 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, - 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x33, + 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, + 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, + 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, + 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, + 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, + 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, + 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, + 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, + 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, + 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, + 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, } var ( @@ -2474,50 +2572,52 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 32) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 34) var file_daemon_proto_goTypes = []interface{}{ - (LogLevel)(0), // 0: daemon.LogLevel - (*LoginRequest)(nil), // 1: daemon.LoginRequest - (*LoginResponse)(nil), // 2: daemon.LoginResponse - (*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest - (*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse - (*UpRequest)(nil), // 5: daemon.UpRequest - (*UpResponse)(nil), // 6: daemon.UpResponse - (*StatusRequest)(nil), // 7: daemon.StatusRequest - (*StatusResponse)(nil), // 8: daemon.StatusResponse - (*DownRequest)(nil), // 9: daemon.DownRequest - (*DownResponse)(nil), // 10: daemon.DownResponse - (*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest - (*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse - (*PeerState)(nil), // 13: daemon.PeerState - (*LocalPeerState)(nil), // 14: daemon.LocalPeerState - (*SignalState)(nil), // 15: daemon.SignalState - (*ManagementState)(nil), // 16: daemon.ManagementState - (*RelayState)(nil), // 17: daemon.RelayState - (*NSGroupState)(nil), // 18: daemon.NSGroupState - (*FullStatus)(nil), // 19: daemon.FullStatus - (*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest - (*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse - (*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest - (*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse - (*IPList)(nil), // 24: daemon.IPList - (*Route)(nil), // 25: daemon.Route - (*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest - (*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse - (*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest - (*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse - (*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest - (*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse - nil, // 32: daemon.Route.ResolvedIPsEntry - (*durationpb.Duration)(nil), // 33: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 34: google.protobuf.Timestamp + (LogLevel)(0), // 0: daemon.LogLevel + (*LoginRequest)(nil), // 1: daemon.LoginRequest + (*LoginResponse)(nil), // 2: daemon.LoginResponse + (*WaitSSOLoginRequest)(nil), // 3: daemon.WaitSSOLoginRequest + (*WaitSSOLoginResponse)(nil), // 4: daemon.WaitSSOLoginResponse + (*UpRequest)(nil), // 5: daemon.UpRequest + (*UpResponse)(nil), // 6: daemon.UpResponse + (*StatusRequest)(nil), // 7: daemon.StatusRequest + (*StatusResponse)(nil), // 8: daemon.StatusResponse + (*DownRequest)(nil), // 9: daemon.DownRequest + (*DownResponse)(nil), // 10: daemon.DownResponse + (*GetConfigRequest)(nil), // 11: daemon.GetConfigRequest + (*GetConfigResponse)(nil), // 12: daemon.GetConfigResponse + (*PeerState)(nil), // 13: daemon.PeerState + (*LocalPeerState)(nil), // 14: daemon.LocalPeerState + (*SignalState)(nil), // 15: daemon.SignalState + (*ManagementState)(nil), // 16: daemon.ManagementState + (*RelayState)(nil), // 17: daemon.RelayState + (*NSGroupState)(nil), // 18: daemon.NSGroupState + (*FullStatus)(nil), // 19: daemon.FullStatus + (*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest + (*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse + (*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest + (*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse + (*IPList)(nil), // 24: daemon.IPList + (*Route)(nil), // 25: daemon.Route + (*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest + (*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse + (*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest + (*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse + (*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest + (*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse + (*SetNetworkMapPersistenceRequest)(nil), // 32: daemon.SetNetworkMapPersistenceRequest + (*SetNetworkMapPersistenceResponse)(nil), // 33: daemon.SetNetworkMapPersistenceResponse + nil, // 34: daemon.Route.ResolvedIPsEntry + (*durationpb.Duration)(nil), // 35: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 36: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 33, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 35, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 34, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 34, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 33, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 36, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 36, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 35, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration 16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState 14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState @@ -2525,7 +2625,7 @@ var file_daemon_proto_depIdxs = []int32{ 17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState 18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 25, // 11: daemon.ListRoutesResponse.routes:type_name -> daemon.Route - 32, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry + 34, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry 0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel 24, // 15: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList @@ -2541,20 +2641,22 @@ var file_daemon_proto_depIdxs = []int32{ 26, // 25: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest 28, // 26: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest 30, // 27: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 2, // 28: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 4, // 29: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 6, // 30: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 8, // 31: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 10, // 32: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 12, // 33: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 21, // 34: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse - 23, // 35: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse - 23, // 36: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse - 27, // 37: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 29, // 38: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 31, // 39: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 28, // [28:40] is the sub-list for method output_type - 16, // [16:28] is the sub-list for method input_type + 32, // 28: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest + 2, // 29: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 4, // 30: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 6, // 31: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 8, // 32: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 10, // 33: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 12, // 34: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 21, // 35: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse + 23, // 36: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse + 23, // 37: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse + 27, // 38: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 29, // 39: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 31, // 40: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 33, // 41: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse + 29, // [29:42] is the sub-list for method output_type + 16, // [16:29] is the sub-list for method input_type 16, // [16:16] is the sub-list for extension type_name 16, // [16:16] is the sub-list for extension extendee 0, // [0:16] is the sub-list for field type_name @@ -2938,6 +3040,30 @@ func file_daemon_proto_init() { return nil } } + file_daemon_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SetNetworkMapPersistenceRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SetNetworkMapPersistenceResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{} type x struct{} @@ -2946,7 +3072,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_daemon_proto_rawDesc, NumEnums: 1, - NumMessages: 32, + NumMessages: 34, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 384bc0e62..04aa66882 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -45,7 +45,11 @@ service DaemonService { // SetLogLevel sets the log level of the daemon rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {} -}; + + // SetNetworkMapPersistence enables or disables network map persistence + rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {} +} + message LoginRequest { // setupKey wiretrustee setup key. @@ -293,4 +297,10 @@ message SetLogLevelRequest { } message SetLogLevelResponse { -} \ No newline at end of file +} + +message SetNetworkMapPersistenceRequest { + bool enabled = 1; +} + +message SetNetworkMapPersistenceResponse {} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index e0bc117e5..31645763e 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -43,6 +43,8 @@ type DaemonServiceClient interface { GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) // SetLogLevel sets the log level of the daemon SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) + // SetNetworkMapPersistence enables or disables network map persistence + SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) } type daemonServiceClient struct { @@ -161,6 +163,15 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe return out, nil } +func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) { + out := new(SetNetworkMapPersistenceResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetNetworkMapPersistence", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -190,6 +201,8 @@ type DaemonServiceServer interface { GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error) // SetLogLevel sets the log level of the daemon SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) + // SetNetworkMapPersistence enables or disables network map persistence + SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -233,6 +246,9 @@ func (UnimplementedDaemonServiceServer) GetLogLevel(context.Context, *GetLogLeve func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented") } +func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -462,6 +478,24 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SetNetworkMapPersistenceRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/SetNetworkMapPersistence", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).SetNetworkMapPersistence(ctx, req.(*SetNetworkMapPersistenceRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -517,6 +551,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "SetLogLevel", Handler: _DaemonService_SetLogLevel_Handler, }, + { + MethodName: "SetNetworkMapPersistence", + Handler: _DaemonService_SetNetworkMapPersistence_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "daemon.proto", diff --git a/client/server/debug.go b/client/server/debug.go index d87fc9b6a..c12fd99db 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -21,12 +21,14 @@ import ( "time" log "github.com/sirupsen/logrus" + "google.golang.org/protobuf/encoding/protojson" "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/proto" + mgmProto "github.com/netbirdio/netbird/management/proto" ) const readmeContent = `Netbird debug bundle @@ -39,6 +41,7 @@ netbird.out: Most recent, anonymized stdout log file of the NetBird client. routes.txt: Anonymized system routes, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided. config.txt: Anonymized configuration information of the NetBird client. +network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules. state.json: Anonymized client state dump containing netbird states. @@ -59,6 +62,16 @@ Domains All domain names (except for the netbird domains) are replaced with randomly generated strings ending in ".domain". Anonymized domains are consistent across all files in the bundle. Reoccuring domain names are replaced with the same anonymized domain. +Network Map +The network_map.json file contains the following anonymized information: +- Peer configurations (addresses, FQDNs, DNS settings) +- Remote and offline peer information (allowed IPs, FQDNs) +- Routes (network ranges, associated domains) +- DNS configuration (nameservers, domains, custom zones) +- Firewall rules (peer IPs, source/destination ranges) + +SSH keys in the network map are replaced with a placeholder value. All IP addresses and domains in the network map follow the same anonymization rules as described above. + State File The state.json file contains anonymized internal state information of the NetBird client, including: - DNS settings and configuration @@ -148,19 +161,23 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques seedFromStatus(anonymizer, &status) if err := s.addConfig(req, anonymizer, archive); err != nil { - return fmt.Errorf("add config: %w", err) + log.Errorf("Failed to add config to debug bundle: %v", err) } if req.GetSystemInfo() { if err := s.addRoutes(req, anonymizer, archive); err != nil { - return fmt.Errorf("add routes: %w", err) + log.Errorf("Failed to add routes to debug bundle: %v", err) } if err := s.addInterfaces(req, anonymizer, archive); err != nil { - return fmt.Errorf("add interfaces: %w", err) + log.Errorf("Failed to add interfaces to debug bundle: %v", err) } } + if err := s.addNetworkMap(req, anonymizer, archive); err != nil { + return fmt.Errorf("add network map: %w", err) + } + if err := s.addStateFile(req, anonymizer, archive); err != nil { log.Errorf("Failed to add state file to debug bundle: %v", err) } @@ -253,15 +270,16 @@ func (s *Server) addCommonConfigFields(configContent *strings.Builder) { } func (s *Server) addRoutes(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { - if routes, err := systemops.GetRoutesFromTable(); err != nil { - log.Errorf("Failed to get routes: %v", err) - } else { - // TODO: get routes including nexthop - routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer) - routesReader := strings.NewReader(routesContent) - if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil { - return fmt.Errorf("add routes file to zip: %w", err) - } + routes, err := systemops.GetRoutesFromTable() + if err != nil { + return fmt.Errorf("get routes: %w", err) + } + + // TODO: get routes including nexthop + routesContent := formatRoutes(routes, req.GetAnonymize(), anonymizer) + routesReader := strings.NewReader(routesContent) + if err := addFileToZip(archive, routesReader, "routes.txt"); err != nil { + return fmt.Errorf("add routes file to zip: %w", err) } return nil } @@ -281,6 +299,39 @@ func (s *Server) addInterfaces(req *proto.DebugBundleRequest, anonymizer *anonym return nil } +func (s *Server) addNetworkMap(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + networkMap, err := s.getLatestNetworkMap() + if err != nil { + // Skip if network map is not available, but log it + log.Debugf("skipping empty network map in debug bundle: %v", err) + return nil + } + + if req.GetAnonymize() { + if err := anonymizeNetworkMap(networkMap, anonymizer); err != nil { + return fmt.Errorf("anonymize network map: %w", err) + } + } + + options := protojson.MarshalOptions{ + EmitUnpopulated: true, + UseProtoNames: true, + Indent: " ", + AllowPartial: true, + } + + jsonBytes, err := options.Marshal(networkMap) + if err != nil { + return fmt.Errorf("generate json: %w", err) + } + + if err := addFileToZip(archive, bytes.NewReader(jsonBytes), "network_map.json"); err != nil { + return fmt.Errorf("add network map to zip: %w", err) + } + + return nil +} + func (s *Server) addStateFile(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { path := statemanager.GetDefaultStatePath() if path == "" { @@ -368,14 +419,43 @@ func (s *Server) addSingleLogfile(logPath, targetName string, req *proto.DebugBu return nil } +// getLatestNetworkMap returns the latest network map from the engine if network map persistence is enabled +func (s *Server) getLatestNetworkMap() (*mgmProto.NetworkMap, error) { + if s.connectClient == nil { + return nil, errors.New("connect client is not initialized") + } + + engine := s.connectClient.Engine() + if engine == nil { + return nil, errors.New("engine is not initialized") + } + + networkMap, err := engine.GetLatestNetworkMap() + if err != nil { + return nil, fmt.Errorf("get latest network map: %w", err) + } + + if networkMap == nil { + return nil, errors.New("network map is not available") + } + + return networkMap, nil +} + // GetLogLevel gets the current logging level for the server. func (s *Server) GetLogLevel(_ context.Context, _ *proto.GetLogLevelRequest) (*proto.GetLogLevelResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + level := ParseLogLevel(log.GetLevel().String()) return &proto.GetLogLevelResponse{Level: level}, nil } // SetLogLevel sets the logging level for the server. func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) (*proto.SetLogLevelResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + level, err := log.ParseLevel(req.Level.String()) if err != nil { return nil, fmt.Errorf("invalid log level: %w", err) @@ -386,6 +466,20 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) ( return &proto.SetLogLevelResponse{}, nil } +// SetNetworkMapPersistence sets the network map persistence for the server. +func (s *Server) SetNetworkMapPersistence(_ context.Context, req *proto.SetNetworkMapPersistenceRequest) (*proto.SetNetworkMapPersistenceResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + enabled := req.GetEnabled() + s.persistNetworkMap = enabled + if s.connectClient != nil { + s.connectClient.SetNetworkMapPersistence(enabled) + } + + return &proto.SetNetworkMapPersistenceResponse{}, nil +} + func addFileToZip(archive *zip.Writer, reader io.Reader, filename string) error { header := &zip.FileHeader{ Name: filename, @@ -578,6 +672,177 @@ func anonymizeNATExternalIPs(ips []string, anonymizer *anonymize.Anonymizer) []s return anonymizedIPs } +func anonymizeNetworkMap(networkMap *mgmProto.NetworkMap, anonymizer *anonymize.Anonymizer) error { + if networkMap.PeerConfig != nil { + anonymizePeerConfig(networkMap.PeerConfig, anonymizer) + } + + for _, peer := range networkMap.RemotePeers { + anonymizeRemotePeer(peer, anonymizer) + } + + for _, peer := range networkMap.OfflinePeers { + anonymizeRemotePeer(peer, anonymizer) + } + + for _, r := range networkMap.Routes { + anonymizeRoute(r, anonymizer) + } + + if networkMap.DNSConfig != nil { + anonymizeDNSConfig(networkMap.DNSConfig, anonymizer) + } + + for _, rule := range networkMap.FirewallRules { + anonymizeFirewallRule(rule, anonymizer) + } + + for _, rule := range networkMap.RoutesFirewallRules { + anonymizeRouteFirewallRule(rule, anonymizer) + } + + return nil +} + +func anonymizePeerConfig(config *mgmProto.PeerConfig, anonymizer *anonymize.Anonymizer) { + if config == nil { + return + } + + if addr, err := netip.ParseAddr(config.Address); err == nil { + config.Address = anonymizer.AnonymizeIP(addr).String() + } + + if config.SshConfig != nil && len(config.SshConfig.SshPubKey) > 0 { + config.SshConfig.SshPubKey = []byte("ssh-placeholder-key") + } + + config.Dns = anonymizer.AnonymizeString(config.Dns) + config.Fqdn = anonymizer.AnonymizeDomain(config.Fqdn) +} + +func anonymizeRemotePeer(peer *mgmProto.RemotePeerConfig, anonymizer *anonymize.Anonymizer) { + if peer == nil { + return + } + + for i, ip := range peer.AllowedIps { + // Try to parse as prefix first (CIDR) + if prefix, err := netip.ParsePrefix(ip); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + peer.AllowedIps[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } else if addr, err := netip.ParseAddr(ip); err == nil { + peer.AllowedIps[i] = anonymizer.AnonymizeIP(addr).String() + } + } + + peer.Fqdn = anonymizer.AnonymizeDomain(peer.Fqdn) + + if peer.SshConfig != nil && len(peer.SshConfig.SshPubKey) > 0 { + peer.SshConfig.SshPubKey = []byte("ssh-placeholder-key") + } +} + +func anonymizeRoute(route *mgmProto.Route, anonymizer *anonymize.Anonymizer) { + if route == nil { + return + } + + if prefix, err := netip.ParsePrefix(route.Network); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + route.Network = fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } + + for i, domain := range route.Domains { + route.Domains[i] = anonymizer.AnonymizeDomain(domain) + } + + route.NetID = anonymizer.AnonymizeString(route.NetID) +} + +func anonymizeDNSConfig(config *mgmProto.DNSConfig, anonymizer *anonymize.Anonymizer) { + if config == nil { + return + } + + anonymizeNameServerGroups(config.NameServerGroups, anonymizer) + anonymizeCustomZones(config.CustomZones, anonymizer) +} + +func anonymizeNameServerGroups(groups []*mgmProto.NameServerGroup, anonymizer *anonymize.Anonymizer) { + for _, group := range groups { + anonymizeServers(group.NameServers, anonymizer) + anonymizeDomains(group.Domains, anonymizer) + } +} + +func anonymizeServers(servers []*mgmProto.NameServer, anonymizer *anonymize.Anonymizer) { + for _, server := range servers { + if addr, err := netip.ParseAddr(server.IP); err == nil { + server.IP = anonymizer.AnonymizeIP(addr).String() + } + } +} + +func anonymizeDomains(domains []string, anonymizer *anonymize.Anonymizer) { + for i, domain := range domains { + domains[i] = anonymizer.AnonymizeDomain(domain) + } +} + +func anonymizeCustomZones(zones []*mgmProto.CustomZone, anonymizer *anonymize.Anonymizer) { + for _, zone := range zones { + zone.Domain = anonymizer.AnonymizeDomain(zone.Domain) + anonymizeRecords(zone.Records, anonymizer) + } +} + +func anonymizeRecords(records []*mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) { + for _, record := range records { + record.Name = anonymizer.AnonymizeDomain(record.Name) + anonymizeRData(record, anonymizer) + } +} + +func anonymizeRData(record *mgmProto.SimpleRecord, anonymizer *anonymize.Anonymizer) { + switch record.Type { + case 1, 28: // A or AAAA record + if addr, err := netip.ParseAddr(record.RData); err == nil { + record.RData = anonymizer.AnonymizeIP(addr).String() + } + default: + record.RData = anonymizer.AnonymizeString(record.RData) + } +} + +func anonymizeFirewallRule(rule *mgmProto.FirewallRule, anonymizer *anonymize.Anonymizer) { + if rule == nil { + return + } + + if addr, err := netip.ParseAddr(rule.PeerIP); err == nil { + rule.PeerIP = anonymizer.AnonymizeIP(addr).String() + } +} + +func anonymizeRouteFirewallRule(rule *mgmProto.RouteFirewallRule, anonymizer *anonymize.Anonymizer) { + if rule == nil { + return + } + + for i, sourceRange := range rule.SourceRanges { + if prefix, err := netip.ParsePrefix(sourceRange); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + rule.SourceRanges[i] = fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } + } + + if prefix, err := netip.ParsePrefix(rule.Destination); err == nil { + anonIP := anonymizer.AnonymizeIP(prefix.Addr()) + rule.Destination = fmt.Sprintf("%s/%d", anonIP, prefix.Bits()) + } +} + func anonymizeStateFile(rawStates *map[string]json.RawMessage, anonymizer *anonymize.Anonymizer) error { for name, rawState := range *rawStates { if string(rawState) == "null" { diff --git a/client/server/debug_test.go b/client/server/debug_test.go index 1515036bd..c8f7bae5d 100644 --- a/client/server/debug_test.go +++ b/client/server/debug_test.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "net" "strings" "testing" @@ -9,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/anonymize" + mgmProto "github.com/netbirdio/netbird/management/proto" ) func TestAnonymizeStateFile(t *testing.T) { @@ -263,3 +265,166 @@ func mustMarshal(v any) json.RawMessage { } return data } + +func TestAnonymizeNetworkMap(t *testing.T) { + networkMap := &mgmProto.NetworkMap{ + PeerConfig: &mgmProto.PeerConfig{ + Address: "203.0.113.5", + Dns: "1.2.3.4", + Fqdn: "peer1.corp.example.com", + SshConfig: &mgmProto.SSHConfig{ + SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."), + }, + }, + RemotePeers: []*mgmProto.RemotePeerConfig{ + { + AllowedIps: []string{ + "203.0.113.1/32", + "2001:db8:1234::1/128", + "192.168.1.1/32", + "100.64.0.1/32", + "10.0.0.1/32", + }, + Fqdn: "peer2.corp.example.com", + SshConfig: &mgmProto.SSHConfig{ + SshPubKey: []byte("ssh-rsa AAAAB3NzaC2..."), + }, + }, + }, + Routes: []*mgmProto.Route{ + { + Network: "197.51.100.0/24", + Domains: []string{"prod.example.com", "staging.example.com"}, + NetID: "net-123abc", + }, + }, + DNSConfig: &mgmProto.DNSConfig{ + NameServerGroups: []*mgmProto.NameServerGroup{ + { + NameServers: []*mgmProto.NameServer{ + {IP: "8.8.8.8"}, + {IP: "1.1.1.1"}, + {IP: "203.0.113.53"}, + }, + Domains: []string{"example.com", "internal.example.com"}, + }, + }, + CustomZones: []*mgmProto.CustomZone{ + { + Domain: "custom.example.com", + Records: []*mgmProto.SimpleRecord{ + { + Name: "www.custom.example.com", + Type: 1, + RData: "203.0.113.10", + }, + { + Name: "internal.custom.example.com", + Type: 1, + RData: "192.168.1.10", + }, + }, + }, + }, + }, + } + + // Create anonymizer with test addresses + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + + // Anonymize the network map + err := anonymizeNetworkMap(networkMap, anonymizer) + require.NoError(t, err) + + // Test PeerConfig anonymization + peerCfg := networkMap.PeerConfig + require.NotEqual(t, "203.0.113.5", peerCfg.Address) + + // Verify DNS and FQDN are properly anonymized + require.NotEqual(t, "1.2.3.4", peerCfg.Dns) + require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn) + require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain")) + + // Verify SSH key is replaced + require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey) + + // Test RemotePeers anonymization + remotePeer := networkMap.RemotePeers[0] + + // Verify FQDN is anonymized + require.NotEqual(t, "peer2.corp.example.com", remotePeer.Fqdn) + require.True(t, strings.HasSuffix(remotePeer.Fqdn, ".domain")) + + // Check that public IPs are anonymized but private IPs are preserved + for _, allowedIP := range remotePeer.AllowedIps { + ip, _, err := net.ParseCIDR(allowedIP) + require.NoError(t, err) + + if ip.IsPrivate() || isInCGNATRange(ip) { + require.Contains(t, []string{ + "192.168.1.1/32", + "100.64.0.1/32", + "10.0.0.1/32", + }, allowedIP) + } else { + require.NotContains(t, []string{ + "203.0.113.1/32", + "2001:db8:1234::1/128", + }, allowedIP) + } + } + + // Test Routes anonymization + route := networkMap.Routes[0] + require.NotEqual(t, "197.51.100.0/24", route.Network) + for _, domain := range route.Domains { + require.True(t, strings.HasSuffix(domain, ".domain")) + require.NotContains(t, domain, "example.com") + } + + // Test DNS config anonymization + dnsConfig := networkMap.DNSConfig + nameServerGroup := dnsConfig.NameServerGroups[0] + + // Verify well-known DNS servers are preserved + require.Equal(t, "8.8.8.8", nameServerGroup.NameServers[0].IP) + require.Equal(t, "1.1.1.1", nameServerGroup.NameServers[1].IP) + + // Verify public DNS server is anonymized + require.NotEqual(t, "203.0.113.53", nameServerGroup.NameServers[2].IP) + + // Verify domains are anonymized + for _, domain := range nameServerGroup.Domains { + require.True(t, strings.HasSuffix(domain, ".domain")) + require.NotContains(t, domain, "example.com") + } + + // Test CustomZones anonymization + customZone := dnsConfig.CustomZones[0] + require.True(t, strings.HasSuffix(customZone.Domain, ".domain")) + require.NotContains(t, customZone.Domain, "example.com") + + // Verify records are properly anonymized + for _, record := range customZone.Records { + require.True(t, strings.HasSuffix(record.Name, ".domain")) + require.NotContains(t, record.Name, "example.com") + + ip := net.ParseIP(record.RData) + if ip != nil { + if !ip.IsPrivate() { + require.NotEqual(t, "203.0.113.10", record.RData) + } else { + require.Equal(t, "192.168.1.10", record.RData) + } + } + } +} + +// Helper function to check if IP is in CGNAT range +func isInCGNATRange(ip net.IP) bool { + cgnat := net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + return cgnat.Contains(ip) +} diff --git a/client/server/server.go b/client/server/server.go index 106bdf32b..71eb58a66 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -68,6 +68,8 @@ type Server struct { relayProbe *internal.Probe wgProbe *internal.Probe lastProbe time.Time + + persistNetworkMap bool } type oauthAuthFlow struct { @@ -196,6 +198,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, config *internal.Conf runOperation := func() error { log.Tracef("running client connection") s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder) + s.connectClient.SetNetworkMapPersistence(s.persistNetworkMap) probes := internal.ProbeHolder{ MgmProbe: s.mgmProbe, diff --git a/go.mod b/go.mod index e8c655422..c08713f2b 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 google.golang.org/grpc v1.64.1 - google.golang.org/protobuf v1.34.1 + google.golang.org/protobuf v1.34.2 gopkg.in/natefinch/lumberjack.v2 v2.0.0 ) @@ -224,7 +224,7 @@ require ( golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect diff --git a/go.sum b/go.sum index 47975d4ea..5519d6ded 100644 --- a/go.sum +++ b/go.sum @@ -1151,8 +1151,8 @@ google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaE google.golang.org/genproto v0.0.0-20210602131652-f16073e35f0c/go.mod h1:UODoCrxHCcBojKKwX1terBiRUaqAsFqJiF615XL43r0= google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 h1:OpXbo8JnN8+jZGPrL4SSfaDjSCjupr8lXyBAbexEm/U= google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434/go.mod h1:FfiGhwUm6CJviekPrc0oJ+7h29e+DmWU6UtjX0ZvI7Y= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291 h1:AgADTJarZTBqgjiUzRgfaBchgYB3/WFTC80GPwsMcRI= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240515191416-fc5f0ca64291/go.mod h1:EfXuqaE1J41VCDicxHzUDm+8rk+7ZdXzHV0IhO/I6s0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 h1:pPJltXNxVzT4pK9yD8vR9X75DaWYYmLGMsEvBfFQZzQ= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1189,8 +1189,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= -google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 8866394eb6879aa94d5fbd259b64830fdca45569 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 3 Dec 2024 15:33:41 +0100 Subject: [PATCH 065/159] [client] Don't choke on non-existent interface in route updates (#2922) --- client/firewall/nftables/state.go | 1 - .../internal/routemanager/systemops/systemops_windows.go | 9 ++++++--- 2 files changed, 6 insertions(+), 4 deletions(-) delete mode 100644 client/firewall/nftables/state.go diff --git a/client/firewall/nftables/state.go b/client/firewall/nftables/state.go deleted file mode 100644 index 7027fe987..000000000 --- a/client/firewall/nftables/state.go +++ /dev/null @@ -1 +0,0 @@ -package nftables diff --git a/client/internal/routemanager/systemops/systemops_windows.go b/client/internal/routemanager/systemops/systemops_windows.go index b1732a080..ad325e123 100644 --- a/client/internal/routemanager/systemops/systemops_windows.go +++ b/client/internal/routemanager/systemops/systemops_windows.go @@ -230,10 +230,13 @@ func (rm *RouteMonitor) parseUpdate(row *MIB_IPFORWARD_ROW2, notificationType MI if idx != 0 { intf, err := net.InterfaceByIndex(idx) if err != nil { - return update, fmt.Errorf("get interface name: %w", err) + log.Warnf("failed to get interface name for index %d: %v", idx, err) + update.Interface = &net.Interface{ + Index: idx, + } + } else { + update.Interface = intf } - - update.Interface = intf } log.Tracef("Received route update with destination %v, next hop %v, interface %v", row.DestinationPrefix, row.NextHop, update.Interface) From e5d42bc96330b8b10e7704c0f5a80fbcf960cb55 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 3 Dec 2024 16:07:18 +0100 Subject: [PATCH 066/159] [client] Add state handling cmdline options (#2821) --- client/cmd/state.go | 181 ++++++ client/internal/connect.go | 13 + client/internal/statemanager/manager.go | 228 ++++++-- client/proto/daemon.pb.go | 726 ++++++++++++++++++++---- client/proto/daemon.proto | 45 ++ client/proto/daemon_grpc.pb.go | 114 ++++ client/server/state.go | 103 +++- 7 files changed, 1237 insertions(+), 173 deletions(-) create mode 100644 client/cmd/state.go diff --git a/client/cmd/state.go b/client/cmd/state.go new file mode 100644 index 000000000..21a5508f4 --- /dev/null +++ b/client/cmd/state.go @@ -0,0 +1,181 @@ +package cmd + +import ( + "fmt" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/proto" +) + +var ( + allFlag bool +) + +var stateCmd = &cobra.Command{ + Use: "state", + Short: "Manage daemon state", + Long: "Provides commands for managing and inspecting the Netbird daemon state.", +} + +var stateListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List all stored states", + Long: "Lists all registered states with their status and basic information.", + Example: " netbird state list", + RunE: stateList, +} + +var stateCleanCmd = &cobra.Command{ + Use: "clean [state-name]", + Short: "Clean stored states", + Long: `Clean specific state or all states. The daemon must not be running. +This will perform cleanup operations and remove the state.`, + Example: ` netbird state clean dns_state + netbird state clean --all`, + RunE: stateClean, + PreRunE: func(cmd *cobra.Command, args []string) error { + // Check mutual exclusivity between --all flag and state-name argument + if allFlag && len(args) > 0 { + return fmt.Errorf("cannot specify both --all flag and state name") + } + if !allFlag && len(args) != 1 { + return fmt.Errorf("requires a state name argument or --all flag") + } + return nil + }, +} + +var stateDeleteCmd = &cobra.Command{ + Use: "delete [state-name]", + Short: "Delete stored states", + Long: `Delete specific state or all states from storage. The daemon must not be running. +This will remove the state without performing any cleanup operations.`, + Example: ` netbird state delete dns_state + netbird state delete --all`, + RunE: stateDelete, + PreRunE: func(cmd *cobra.Command, args []string) error { + // Check mutual exclusivity between --all flag and state-name argument + if allFlag && len(args) > 0 { + return fmt.Errorf("cannot specify both --all flag and state name") + } + if !allFlag && len(args) != 1 { + return fmt.Errorf("requires a state name argument or --all flag") + } + return nil + }, +} + +func init() { + rootCmd.AddCommand(stateCmd) + stateCmd.AddCommand(stateListCmd, stateCleanCmd, stateDeleteCmd) + + stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states") + stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states") +} + +func stateList(cmd *cobra.Command, _ []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + log.Errorf(errCloseConnection, err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + resp, err := client.ListStates(cmd.Context(), &proto.ListStatesRequest{}) + if err != nil { + return fmt.Errorf("failed to list states: %v", status.Convert(err).Message()) + } + + cmd.Printf("\nStored states:\n\n") + for _, state := range resp.States { + cmd.Printf("- %s\n", state.Name) + } + + return nil +} + +func stateClean(cmd *cobra.Command, args []string) error { + var stateName string + if !allFlag { + stateName = args[0] + } + + conn, err := getClient(cmd) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + log.Errorf(errCloseConnection, err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + resp, err := client.CleanState(cmd.Context(), &proto.CleanStateRequest{ + StateName: stateName, + All: allFlag, + }) + if err != nil { + return fmt.Errorf("failed to clean state: %v", status.Convert(err).Message()) + } + + if resp.CleanedStates == 0 { + cmd.Println("No states were cleaned") + return nil + } + + if allFlag { + cmd.Printf("Successfully cleaned %d states\n", resp.CleanedStates) + } else { + cmd.Printf("Successfully cleaned state %q\n", stateName) + } + + return nil +} + +func stateDelete(cmd *cobra.Command, args []string) error { + var stateName string + if !allFlag { + stateName = args[0] + } + + conn, err := getClient(cmd) + if err != nil { + return err + } + defer func() { + if err := conn.Close(); err != nil { + log.Errorf(errCloseConnection, err) + } + }() + + client := proto.NewDaemonServiceClient(conn) + resp, err := client.DeleteState(cmd.Context(), &proto.DeleteStateRequest{ + StateName: stateName, + All: allFlag, + }) + if err != nil { + return fmt.Errorf("failed to delete state: %v", status.Convert(err).Message()) + } + + if resp.DeletedStates == 0 { + cmd.Println("No states were deleted") + return nil + } + + if allFlag { + cmd.Printf("Successfully deleted %d states\n", resp.DeletedStates) + } else { + cmd.Printf("Successfully deleted state %q\n", stateName) + } + + return nil +} diff --git a/client/internal/connect.go b/client/internal/connect.go index 21f6d3ec8..4848b1c11 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -338,6 +338,19 @@ func (c *ConnectClient) Engine() *Engine { return e } +// Status returns the current client status +func (c *ConnectClient) Status() StatusType { + if c == nil { + return StatusIdle + } + status, err := CtxGetState(c.ctx).Status() + if err != nil { + return StatusIdle + } + + return status +} + func (c *ConnectClient) Stop() error { if c == nil { return nil diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index aae73b79f..9a99c76f1 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -19,6 +19,11 @@ import ( "github.com/netbirdio/netbird/util" ) +const ( + errStateNotRegistered = "state %s not registered" + errLoadStateFile = "load state file: %w" +) + // State interface defines the methods that all state types must implement type State interface { Name() string @@ -159,7 +164,7 @@ func (m *Manager) setState(name string, state State) error { defer m.mu.Unlock() if _, exists := m.states[name]; !exists { - return fmt.Errorf("state %s not registered", name) + return fmt.Errorf(errStateNotRegistered, name) } m.states[name] = state @@ -168,6 +173,63 @@ func (m *Manager) setState(name string, state State) error { return nil } +// DeleteStateByName handles deletion of states without cleanup. +// It doesn't require the state to be registered. +func (m *Manager) DeleteStateByName(stateName string) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + rawStates, err := m.loadStateFile(false) + if err != nil { + return fmt.Errorf(errLoadStateFile, err) + } + if rawStates == nil { + return nil + } + + if _, exists := rawStates[stateName]; !exists { + return fmt.Errorf("state %s not found", stateName) + } + + // Mark state as deleted by setting it to nil and marking it dirty + m.states[stateName] = nil + m.dirty[stateName] = struct{}{} + + return nil +} + +// DeleteAllStates removes all states. +func (m *Manager) DeleteAllStates() (int, error) { + if m == nil { + return 0, nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + rawStates, err := m.loadStateFile(false) + if err != nil { + return 0, fmt.Errorf(errLoadStateFile, err) + } + if rawStates == nil { + return 0, nil + } + + count := len(rawStates) + + // Mark all states as deleted and dirty + for name := range rawStates { + m.states[name] = nil + m.dirty[name] = struct{}{} + } + + return count, nil +} + func (m *Manager) periodicStateSave(ctx context.Context) { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() @@ -221,7 +283,7 @@ func (m *Manager) PersistState(ctx context.Context) error { } } - log.Debugf("persisted shutdown states: %v, took %v", maps.Keys(m.dirty), time.Since(start)) + log.Debugf("persisted states: %v, took %v", maps.Keys(m.dirty), time.Since(start)) clear(m.dirty) @@ -229,7 +291,7 @@ func (m *Manager) PersistState(ctx context.Context) error { } // loadStateFile reads and unmarshals the state file into a map of raw JSON messages -func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) { +func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, error) { data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { @@ -241,11 +303,13 @@ func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) { var rawStates map[string]json.RawMessage if err := json.Unmarshal(data, &rawStates); err != nil { - log.Warn("State file appears to be corrupted, attempting to delete it") - if err := os.Remove(m.filePath); err != nil { - log.Errorf("Failed to delete corrupted state file: %v", err) - } else { - log.Info("State file deleted") + if deleteCorrupt { + log.Warn("State file appears to be corrupted, attempting to delete it", err) + if err := os.Remove(m.filePath); err != nil { + log.Errorf("Failed to delete corrupted state file: %v", err) + } else { + log.Info("State file deleted") + } } return nil, fmt.Errorf("unmarshal states: %w", err) } @@ -257,7 +321,7 @@ func (m *Manager) loadStateFile() (map[string]json.RawMessage, error) { func (m *Manager) loadSingleRawState(name string, rawState json.RawMessage) (State, error) { stateType, ok := m.stateTypes[name] if !ok { - return nil, fmt.Errorf("state %s not registered", name) + return nil, fmt.Errorf(errStateNotRegistered, name) } if string(rawState) == "null" { @@ -281,7 +345,7 @@ func (m *Manager) LoadState(state State) error { m.mu.Lock() defer m.mu.Unlock() - rawStates, err := m.loadStateFile() + rawStates, err := m.loadStateFile(false) if err != nil { return err } @@ -308,6 +372,84 @@ func (m *Manager) LoadState(state State) error { return nil } +// cleanupSingleState handles the cleanup of a specific state and returns any error. +// The caller must hold the mutex. +func (m *Manager) cleanupSingleState(name string, rawState json.RawMessage) error { + // For unregistered states, preserve the raw JSON + if _, registered := m.stateTypes[name]; !registered { + m.states[name] = &RawState{data: rawState} + return nil + } + + // Load the state + loadedState, err := m.loadSingleRawState(name, rawState) + if err != nil { + return err + } + + if loadedState == nil { + return nil + } + + // Check if state supports cleanup + cleanableState, isCleanable := loadedState.(CleanableState) + if !isCleanable { + // If it doesn't support cleanup, keep it as-is + m.states[name] = loadedState + return nil + } + + // Perform cleanup + log.Infof("cleaning up state %s", name) + if err := cleanableState.Cleanup(); err != nil { + // On cleanup error, preserve the state + m.states[name] = loadedState + return fmt.Errorf("cleanup state: %w", err) + } + + // Successfully cleaned up - mark for deletion + m.states[name] = nil + m.dirty[name] = struct{}{} + return nil +} + +// CleanupStateByName loads and cleans up a specific state by name if it implements CleanableState. +// Returns an error if the state doesn't exist, isn't registered, or cleanup fails. +func (m *Manager) CleanupStateByName(name string) error { + if m == nil { + return nil + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Check if state is registered + if _, registered := m.stateTypes[name]; !registered { + return fmt.Errorf(errStateNotRegistered, name) + } + + // Load raw states from file + rawStates, err := m.loadStateFile(false) + if err != nil { + return err + } + if rawStates == nil { + return nil + } + + // Check if state exists in file + rawState, exists := rawStates[name] + if !exists { + return nil + } + + if err := m.cleanupSingleState(name, rawState); err != nil { + return fmt.Errorf("%s: %w", name, err) + } + + return nil +} + // PerformCleanup retrieves all states from the state file and calls Cleanup on registered states that support it. // Unregistered states are preserved in their original state. func (m *Manager) PerformCleanup() error { @@ -319,10 +461,9 @@ func (m *Manager) PerformCleanup() error { defer m.mu.Unlock() // Load raw states from file - rawStates, err := m.loadStateFile() + rawStates, err := m.loadStateFile(true) if err != nil { - log.Warnf("Failed to load state during cleanup: %v", err) - return err + return fmt.Errorf(errLoadStateFile, err) } if rawStates == nil { return nil @@ -332,47 +473,38 @@ func (m *Manager) PerformCleanup() error { // Process each state in the file for name, rawState := range rawStates { - // For unregistered states, preserve the raw JSON - if _, registered := m.stateTypes[name]; !registered { - m.states[name] = &RawState{data: rawState} - continue - } - - // Load the registered state - loadedState, err := m.loadSingleRawState(name, rawState) - if err != nil { - merr = multierror.Append(merr, err) - continue - } - - if loadedState == nil { - continue - } - - // Check if state supports cleanup - cleanableState, isCleanable := loadedState.(CleanableState) - if !isCleanable { - // If it doesn't support cleanup, keep it as-is - m.states[name] = loadedState - continue - } - - // Perform cleanup for cleanable states - log.Infof("client was not shut down properly, cleaning up %s", name) - if err := cleanableState.Cleanup(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("cleanup state for %s: %w", name, err)) - // On cleanup error, preserve the state - m.states[name] = loadedState - } else { - // Successfully cleaned up - mark for deletion - m.states[name] = nil - m.dirty[name] = struct{}{} + if err := m.cleanupSingleState(name, rawState); err != nil { + merr = multierror.Append(merr, fmt.Errorf("%s: %w", name, err)) } } return nberrors.FormatErrorOrNil(merr) } +// GetSavedStateNames returns all state names that are currently saved in the state file. +func (m *Manager) GetSavedStateNames() ([]string, error) { + if m == nil { + return nil, nil + } + + rawStates, err := m.loadStateFile(false) + if err != nil { + return nil, fmt.Errorf(errLoadStateFile, err) + } + if rawStates == nil { + return nil, nil + } + + var states []string + for name, state := range rawStates { + if len(state) != 0 && string(state) != "null" { + states = append(states, name) + } + } + + return states, nil +} + func marshalWithPanicRecovery(v any) ([]byte, error) { var bs []byte var err error diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 0eae4e0d6..98ce2c4a2 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -2103,6 +2103,349 @@ func (*SetLogLevelResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{30} } +// State represents a daemon state entry +type State struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` +} + +func (x *State) Reset() { + *x = State{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[31] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *State) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*State) ProtoMessage() {} + +func (x *State) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[31] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use State.ProtoReflect.Descriptor instead. +func (*State) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{31} +} + +func (x *State) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +// ListStatesRequest is empty as it requires no parameters +type ListStatesRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields +} + +func (x *ListStatesRequest) Reset() { + *x = ListStatesRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[32] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ListStatesRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListStatesRequest) ProtoMessage() {} + +func (x *ListStatesRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[32] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListStatesRequest.ProtoReflect.Descriptor instead. +func (*ListStatesRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{32} +} + +// ListStatesResponse contains a list of states +type ListStatesResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + States []*State `protobuf:"bytes,1,rep,name=states,proto3" json:"states,omitempty"` +} + +func (x *ListStatesResponse) Reset() { + *x = ListStatesResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[33] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ListStatesResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ListStatesResponse) ProtoMessage() {} + +func (x *ListStatesResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[33] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ListStatesResponse.ProtoReflect.Descriptor instead. +func (*ListStatesResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{33} +} + +func (x *ListStatesResponse) GetStates() []*State { + if x != nil { + return x.States + } + return nil +} + +// CleanStateRequest for cleaning states +type CleanStateRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + StateName string `protobuf:"bytes,1,opt,name=state_name,json=stateName,proto3" json:"state_name,omitempty"` + All bool `protobuf:"varint,2,opt,name=all,proto3" json:"all,omitempty"` +} + +func (x *CleanStateRequest) Reset() { + *x = CleanStateRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[34] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CleanStateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CleanStateRequest) ProtoMessage() {} + +func (x *CleanStateRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[34] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CleanStateRequest.ProtoReflect.Descriptor instead. +func (*CleanStateRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{34} +} + +func (x *CleanStateRequest) GetStateName() string { + if x != nil { + return x.StateName + } + return "" +} + +func (x *CleanStateRequest) GetAll() bool { + if x != nil { + return x.All + } + return false +} + +// CleanStateResponse contains the result of the clean operation +type CleanStateResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + CleanedStates int32 `protobuf:"varint,1,opt,name=cleaned_states,json=cleanedStates,proto3" json:"cleaned_states,omitempty"` +} + +func (x *CleanStateResponse) Reset() { + *x = CleanStateResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[35] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *CleanStateResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*CleanStateResponse) ProtoMessage() {} + +func (x *CleanStateResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[35] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use CleanStateResponse.ProtoReflect.Descriptor instead. +func (*CleanStateResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{35} +} + +func (x *CleanStateResponse) GetCleanedStates() int32 { + if x != nil { + return x.CleanedStates + } + return 0 +} + +// DeleteStateRequest for deleting states +type DeleteStateRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + StateName string `protobuf:"bytes,1,opt,name=state_name,json=stateName,proto3" json:"state_name,omitempty"` + All bool `protobuf:"varint,2,opt,name=all,proto3" json:"all,omitempty"` +} + +func (x *DeleteStateRequest) Reset() { + *x = DeleteStateRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[36] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DeleteStateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteStateRequest) ProtoMessage() {} + +func (x *DeleteStateRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[36] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteStateRequest.ProtoReflect.Descriptor instead. +func (*DeleteStateRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{36} +} + +func (x *DeleteStateRequest) GetStateName() string { + if x != nil { + return x.StateName + } + return "" +} + +func (x *DeleteStateRequest) GetAll() bool { + if x != nil { + return x.All + } + return false +} + +// DeleteStateResponse contains the result of the delete operation +type DeleteStateResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + DeletedStates int32 `protobuf:"varint,1,opt,name=deleted_states,json=deletedStates,proto3" json:"deleted_states,omitempty"` +} + +func (x *DeleteStateResponse) Reset() { + *x = DeleteStateResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[37] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DeleteStateResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteStateResponse) ProtoMessage() {} + +func (x *DeleteStateResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[37] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteStateResponse.ProtoReflect.Descriptor instead. +func (*DeleteStateResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{37} +} + +func (x *DeleteStateResponse) GetDeletedStates() int32 { + if x != nil { + return x.DeletedStates + } + return 0 +} + type SetNetworkMapPersistenceRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -2114,7 +2457,7 @@ type SetNetworkMapPersistenceRequest struct { func (x *SetNetworkMapPersistenceRequest) Reset() { *x = SetNetworkMapPersistenceRequest{} if protoimpl.UnsafeEnabled { - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[38] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2127,7 +2470,7 @@ func (x *SetNetworkMapPersistenceRequest) String() string { func (*SetNetworkMapPersistenceRequest) ProtoMessage() {} func (x *SetNetworkMapPersistenceRequest) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[31] + mi := &file_daemon_proto_msgTypes[38] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2140,7 +2483,7 @@ func (x *SetNetworkMapPersistenceRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use SetNetworkMapPersistenceRequest.ProtoReflect.Descriptor instead. func (*SetNetworkMapPersistenceRequest) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{31} + return file_daemon_proto_rawDescGZIP(), []int{38} } func (x *SetNetworkMapPersistenceRequest) GetEnabled() bool { @@ -2159,7 +2502,7 @@ type SetNetworkMapPersistenceResponse struct { func (x *SetNetworkMapPersistenceResponse) Reset() { *x = SetNetworkMapPersistenceResponse{} if protoimpl.UnsafeEnabled { - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[39] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2172,7 +2515,7 @@ func (x *SetNetworkMapPersistenceResponse) String() string { func (*SetNetworkMapPersistenceResponse) ProtoMessage() {} func (x *SetNetworkMapPersistenceResponse) ProtoReflect() protoreflect.Message { - mi := &file_daemon_proto_msgTypes[32] + mi := &file_daemon_proto_msgTypes[39] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2185,7 +2528,7 @@ func (x *SetNetworkMapPersistenceResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use SetNetworkMapPersistenceResponse.ProtoReflect.Descriptor instead. func (*SetNetworkMapPersistenceResponse) Descriptor() ([]byte, []int) { - return file_daemon_proto_rawDescGZIP(), []int{32} + return file_daemon_proto_rawDescGZIP(), []int{39} } var File_daemon_proto protoreflect.FileDescriptor @@ -2484,79 +2827,116 @@ var file_daemon_proto_rawDesc = []byte{ 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x3b, 0x0a, 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, - 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, - 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, - 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, - 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, - 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, - 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, - 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, - 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, - 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0xa9, 0x07, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, - 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, - 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, - 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, - 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, - 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, - 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, - 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, - 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, - 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, - 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, - 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, - 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, - 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, - 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, - 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, - 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, - 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, - 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, - 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, - 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, + 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, + 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a, + 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, + 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, + 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, + 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63, + 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, + 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, + 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c, + 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, + 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, + 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, + 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, + 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, + 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, + 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, + 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, + 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, + 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, + 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0x81, 0x09, 0x0a, + 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, + 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, + 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, + 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, + 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, + 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, + 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, + 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, + 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, + 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, + 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, + 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, + 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, + 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, + 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, + 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, + 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, + 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, + 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, + 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, + 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, + 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, + 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, + 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, + 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, + 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, + 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x33, } var ( @@ -2572,7 +2952,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 34) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 41) var file_daemon_proto_goTypes = []interface{}{ (LogLevel)(0), // 0: daemon.LogLevel (*LoginRequest)(nil), // 1: daemon.LoginRequest @@ -2606,18 +2986,25 @@ var file_daemon_proto_goTypes = []interface{}{ (*GetLogLevelResponse)(nil), // 29: daemon.GetLogLevelResponse (*SetLogLevelRequest)(nil), // 30: daemon.SetLogLevelRequest (*SetLogLevelResponse)(nil), // 31: daemon.SetLogLevelResponse - (*SetNetworkMapPersistenceRequest)(nil), // 32: daemon.SetNetworkMapPersistenceRequest - (*SetNetworkMapPersistenceResponse)(nil), // 33: daemon.SetNetworkMapPersistenceResponse - nil, // 34: daemon.Route.ResolvedIPsEntry - (*durationpb.Duration)(nil), // 35: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 36: google.protobuf.Timestamp + (*State)(nil), // 32: daemon.State + (*ListStatesRequest)(nil), // 33: daemon.ListStatesRequest + (*ListStatesResponse)(nil), // 34: daemon.ListStatesResponse + (*CleanStateRequest)(nil), // 35: daemon.CleanStateRequest + (*CleanStateResponse)(nil), // 36: daemon.CleanStateResponse + (*DeleteStateRequest)(nil), // 37: daemon.DeleteStateRequest + (*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse + (*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest + (*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse + nil, // 41: daemon.Route.ResolvedIPsEntry + (*durationpb.Duration)(nil), // 42: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 43: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 35, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 42, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 36, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 36, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 35, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 43, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 43, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 42, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration 16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState 14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState @@ -2625,41 +3012,48 @@ var file_daemon_proto_depIdxs = []int32{ 17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState 18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 25, // 11: daemon.ListRoutesResponse.routes:type_name -> daemon.Route - 34, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry + 41, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry 0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel - 24, // 15: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList - 1, // 16: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 3, // 17: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 5, // 18: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 7, // 19: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 9, // 20: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 11, // 21: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 20, // 22: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest - 22, // 23: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest - 22, // 24: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest - 26, // 25: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 28, // 26: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 30, // 27: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 32, // 28: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest - 2, // 29: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 4, // 30: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 6, // 31: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 8, // 32: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 10, // 33: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 12, // 34: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 21, // 35: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse - 23, // 36: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse - 23, // 37: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse - 27, // 38: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 29, // 39: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 31, // 40: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 33, // 41: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse - 29, // [29:42] is the sub-list for method output_type - 16, // [16:29] is the sub-list for method input_type - 16, // [16:16] is the sub-list for extension type_name - 16, // [16:16] is the sub-list for extension extendee - 0, // [0:16] is the sub-list for field type_name + 32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State + 24, // 16: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList + 1, // 17: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 3, // 18: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 5, // 19: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 7, // 20: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 9, // 21: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 11, // 22: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 20, // 23: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest + 22, // 24: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest + 22, // 25: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest + 26, // 26: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 28, // 27: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 30, // 28: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 33, // 29: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 35, // 30: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 37, // 31: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 39, // 32: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest + 2, // 33: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 4, // 34: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 6, // 35: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 8, // 36: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 10, // 37: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 12, // 38: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 21, // 39: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse + 23, // 40: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse + 23, // 41: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse + 27, // 42: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 29, // 43: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 31, // 44: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 34, // 45: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 36, // 46: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 38, // 47: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 40, // 48: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse + 33, // [33:49] is the sub-list for method output_type + 17, // [17:33] is the sub-list for method input_type + 17, // [17:17] is the sub-list for extension type_name + 17, // [17:17] is the sub-list for extension extendee + 0, // [0:17] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -3041,7 +3435,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[31].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SetNetworkMapPersistenceRequest); i { + switch v := v.(*State); i { case 0: return &v.state case 1: @@ -3053,6 +3447,90 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[32].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ListStatesRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[33].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ListStatesResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[34].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CleanStateRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[35].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*CleanStateResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[36].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DeleteStateRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[37].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DeleteStateResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[38].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SetNetworkMapPersistenceRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[39].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*SetNetworkMapPersistenceResponse); i { case 0: return &v.state @@ -3072,7 +3550,7 @@ func file_daemon_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_daemon_proto_rawDesc, NumEnums: 1, - NumMessages: 34, + NumMessages: 41, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 04aa66882..96ade5b4e 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -46,6 +46,15 @@ service DaemonService { // SetLogLevel sets the log level of the daemon rpc SetLogLevel(SetLogLevelRequest) returns (SetLogLevelResponse) {} + // List all states + rpc ListStates(ListStatesRequest) returns (ListStatesResponse) {} + + // Clean specific state or all states + rpc CleanState(CleanStateRequest) returns (CleanStateResponse) {} + + // Delete specific state or all states + rpc DeleteState(DeleteStateRequest) returns (DeleteStateResponse) {} + // SetNetworkMapPersistence enables or disables network map persistence rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {} } @@ -299,6 +308,42 @@ message SetLogLevelRequest { message SetLogLevelResponse { } +// State represents a daemon state entry +message State { + string name = 1; +} + +// ListStatesRequest is empty as it requires no parameters +message ListStatesRequest {} + +// ListStatesResponse contains a list of states +message ListStatesResponse { + repeated State states = 1; +} + +// CleanStateRequest for cleaning states +message CleanStateRequest { + string state_name = 1; + bool all = 2; +} + +// CleanStateResponse contains the result of the clean operation +message CleanStateResponse { + int32 cleaned_states = 1; +} + +// DeleteStateRequest for deleting states +message DeleteStateRequest { + string state_name = 1; + bool all = 2; +} + +// DeleteStateResponse contains the result of the delete operation +message DeleteStateResponse { + int32 deleted_states = 1; +} + + message SetNetworkMapPersistenceRequest { bool enabled = 1; } diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 31645763e..2e063604a 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -43,6 +43,12 @@ type DaemonServiceClient interface { GetLogLevel(ctx context.Context, in *GetLogLevelRequest, opts ...grpc.CallOption) (*GetLogLevelResponse, error) // SetLogLevel sets the log level of the daemon SetLogLevel(ctx context.Context, in *SetLogLevelRequest, opts ...grpc.CallOption) (*SetLogLevelResponse, error) + // List all states + ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) + // Clean specific state or all states + CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) + // Delete specific state or all states + DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) // SetNetworkMapPersistence enables or disables network map persistence SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) } @@ -163,6 +169,33 @@ func (c *daemonServiceClient) SetLogLevel(ctx context.Context, in *SetLogLevelRe return out, nil } +func (c *daemonServiceClient) ListStates(ctx context.Context, in *ListStatesRequest, opts ...grpc.CallOption) (*ListStatesResponse, error) { + out := new(ListStatesResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListStates", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) CleanState(ctx context.Context, in *CleanStateRequest, opts ...grpc.CallOption) (*CleanStateResponse, error) { + out := new(CleanStateResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/CleanState", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *daemonServiceClient) DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) { + out := new(DeleteStateResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeleteState", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) { out := new(SetNetworkMapPersistenceResponse) err := c.cc.Invoke(ctx, "/daemon.DaemonService/SetNetworkMapPersistence", in, out, opts...) @@ -201,6 +234,12 @@ type DaemonServiceServer interface { GetLogLevel(context.Context, *GetLogLevelRequest) (*GetLogLevelResponse, error) // SetLogLevel sets the log level of the daemon SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) + // List all states + ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error) + // Clean specific state or all states + CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error) + // Delete specific state or all states + DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) // SetNetworkMapPersistence enables or disables network map persistence SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) mustEmbedUnimplementedDaemonServiceServer() @@ -246,6 +285,15 @@ func (UnimplementedDaemonServiceServer) GetLogLevel(context.Context, *GetLogLeve func (UnimplementedDaemonServiceServer) SetLogLevel(context.Context, *SetLogLevelRequest) (*SetLogLevelResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method SetLogLevel not implemented") } +func (UnimplementedDaemonServiceServer) ListStates(context.Context, *ListStatesRequest) (*ListStatesResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ListStates not implemented") +} +func (UnimplementedDaemonServiceServer) CleanState(context.Context, *CleanStateRequest) (*CleanStateResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method CleanState not implemented") +} +func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method DeleteState not implemented") +} func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented") } @@ -478,6 +526,60 @@ func _DaemonService_SetLogLevel_Handler(srv interface{}, ctx context.Context, de return interceptor(ctx, in, info, handler) } +func _DaemonService_ListStates_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListStatesRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).ListStates(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/ListStates", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).ListStates(ctx, req.(*ListStatesRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_CleanState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(CleanStateRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).CleanState(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/CleanState", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).CleanState(ctx, req.(*CleanStateRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _DaemonService_DeleteState_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DeleteStateRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).DeleteState(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/DeleteState", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).DeleteState(ctx, req.(*DeleteStateRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(SetNetworkMapPersistenceRequest) if err := dec(in); err != nil { @@ -551,6 +653,18 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "SetLogLevel", Handler: _DaemonService_SetLogLevel_Handler, }, + { + MethodName: "ListStates", + Handler: _DaemonService_ListStates_Handler, + }, + { + MethodName: "CleanState", + Handler: _DaemonService_CleanState_Handler, + }, + { + MethodName: "DeleteState", + Handler: _DaemonService_DeleteState_Handler, + }, { MethodName: "SetNetworkMapPersistence", Handler: _DaemonService_SetNetworkMapPersistence_Handler, diff --git a/client/server/state.go b/client/server/state.go index 509782e86..222c7c7bd 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -5,12 +5,112 @@ import ( "fmt" "github.com/hashicorp/go-multierror" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/proto" ) -// restoreResidualConfig checks if the client was not shut down in a clean way and restores residual state if required. +// ListStates returns a list of all saved states +func (s *Server) ListStates(_ context.Context, _ *proto.ListStatesRequest) (*proto.ListStatesResponse, error) { + mgr := statemanager.New(statemanager.GetDefaultStatePath()) + + stateNames, err := mgr.GetSavedStateNames() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get saved state names: %v", err) + } + + states := make([]*proto.State, 0, len(stateNames)) + for _, name := range stateNames { + states = append(states, &proto.State{ + Name: name, + }) + } + + return &proto.ListStatesResponse{ + States: states, + }, nil +} + +// CleanState handles cleaning of states (performing cleanup operations) +func (s *Server) CleanState(ctx context.Context, req *proto.CleanStateRequest) (*proto.CleanStateResponse, error) { + if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting { + return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.") + } + + if req.All { + // Reuse existing cleanup logic for all states + if err := restoreResidualState(ctx); err != nil { + return nil, status.Errorf(codes.Internal, "failed to clean all states: %v", err) + } + + // Get count of cleaned states + mgr := statemanager.New(statemanager.GetDefaultStatePath()) + stateNames, err := mgr.GetSavedStateNames() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get state count: %v", err) + } + + return &proto.CleanStateResponse{ + CleanedStates: int32(len(stateNames)), + }, nil + } + + // Handle single state cleanup + mgr := statemanager.New(statemanager.GetDefaultStatePath()) + registerStates(mgr) + + if err := mgr.CleanupStateByName(req.StateName); err != nil { + return nil, status.Errorf(codes.Internal, "failed to clean state %s: %v", req.StateName, err) + } + + if err := mgr.PersistState(ctx); err != nil { + return nil, status.Errorf(codes.Internal, "failed to persist state changes: %v", err) + } + + return &proto.CleanStateResponse{ + CleanedStates: 1, + }, nil +} + +// DeleteState handles deletion of states without cleanup +func (s *Server) DeleteState(ctx context.Context, req *proto.DeleteStateRequest) (*proto.DeleteStateResponse, error) { + if s.connectClient.Status() == internal.StatusConnected || s.connectClient.Status() == internal.StatusConnecting { + return nil, status.Errorf(codes.FailedPrecondition, "cannot clean state while connecting or connected, run 'netbird down' first.") + } + + mgr := statemanager.New(statemanager.GetDefaultStatePath()) + + var count int + var err error + + if req.All { + count, err = mgr.DeleteAllStates() + } else { + err = mgr.DeleteStateByName(req.StateName) + if err == nil { + count = 1 + } + } + + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to delete state: %v", err) + } + + // Persist the changes + if err := mgr.PersistState(ctx); err != nil { + return nil, status.Errorf(codes.Internal, "failed to persist state changes: %v", err) + } + + return &proto.DeleteStateResponse{ + DeletedStates: int32(count), + }, nil +} + +// restoreResidualState checks if the client was not shut down in a clean way and restores residual if required. // Otherwise, we might not be able to connect to the management server to retrieve new config. func restoreResidualState(ctx context.Context) error { path := statemanager.GetDefaultStatePath() @@ -24,6 +124,7 @@ func restoreResidualState(ctx context.Context) error { registerStates(mgr) var merr *multierror.Error + if err := mgr.PerformCleanup(); err != nil { merr = multierror.Append(merr, fmt.Errorf("perform cleanup: %w", err)) } From d063fbb8b93f0053a0e734fa118816f0cdc5ad91 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 3 Dec 2024 16:41:19 +0100 Subject: [PATCH 067/159] [management] merge update account peers in sync call (#2978) --- management/server/peer.go | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index a301544d5..474c2d665 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -678,10 +678,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac if err != nil { return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err) } - - if sync.UpdateAccountPeers { - am.updateAccountPeers(ctx, account.Id) - } } peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra) @@ -689,17 +685,20 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return nil, nil, nil, fmt.Errorf("failed to validate peer: %w", err) } - var postureChecks []*posture.Checks + postureChecks, err := am.getPeerPostureChecks(account, peer.ID) + if err != nil { + return nil, nil, nil, err + } + + if isStatusChanged || sync.UpdateAccountPeers || (updated && len(postureChecks) > 0) { + am.updateAccountPeers(ctx, account.Id) + } if peerNotValid { emptyMap := &NetworkMap{ Network: account.Network.Copy(), } - return peer, emptyMap, postureChecks, nil - } - - if isStatusChanged { - am.updateAccountPeers(ctx, account.Id) + return peer, emptyMap, []*posture.Checks{}, nil } validPeersMap, err := am.GetValidatedPeers(account) @@ -707,11 +706,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err) } - postureChecks, err = am.getPeerPostureChecks(account, peer.ID) - if err != nil { - return nil, nil, nil, err - } - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil } From b50b89ba14b44f8fb3699b00136ea2debc703ca2 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 4 Dec 2024 14:09:04 +0100 Subject: [PATCH 068/159] [client] Cleanup status resources on engine stop (#2981) cleanup leftovers from status recorder when stopping the engine --- client/internal/engine.go | 4 ++++ management/server/account_test.go | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index fc9620d80..782bb48bb 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -276,6 +276,10 @@ func (e *Engine) Stop() error { e.srWatcher.Close() } + e.statusRecorder.ReplaceOfflinePeers([]peer.State{}) + e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{}) + e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{}) + err := e.removeAllPeers() if err != nil { return fmt.Errorf("failed to remove all peers: %s", err) diff --git a/management/server/account_test.go b/management/server/account_test.go index 4ff812607..d952e118a 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2998,10 +2998,10 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 1, 3, 4, 10}, + {"Small", 50, 5, 1, 3, 3, 10}, {"Medium", 500, 100, 7, 13, 10, 60}, {"Large", 5000, 200, 65, 80, 60, 170}, - {"Small single", 50, 10, 1, 3, 4, 60}, + {"Small single", 50, 10, 1, 3, 3, 60}, {"Medium single", 500, 10, 7, 13, 10, 26}, {"Large 5", 5000, 15, 65, 80, 60, 170}, } @@ -3141,10 +3141,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { }{ {"Small", 50, 5, 107, 120, 107, 140}, {"Medium", 500, 100, 105, 140, 105, 170}, - {"Large", 5000, 200, 180, 220, 180, 320}, + {"Large", 5000, 200, 180, 220, 180, 340}, {"Small single", 50, 10, 107, 120, 105, 140}, {"Medium single", 500, 10, 105, 140, 105, 170}, - {"Large 5", 5000, 15, 180, 220, 180, 320}, + {"Large 5", 5000, 15, 180, 220, 180, 340}, } log.SetOutput(io.Discard) From c853011a32d3fb052df7388dbbde92e9cfb3d31c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 5 Dec 2024 12:28:35 +0100 Subject: [PATCH 069/159] [client] Don't return error in rule removal if protocol is not supported (#2990) --- client/internal/routemanager/systemops/systemops_linux.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 455e3407e..9041cbf2d 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -450,7 +450,7 @@ func addRule(params ruleParams) error { rule.Invert = params.invert rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) { + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("add routing rule: %w", err) } @@ -467,7 +467,7 @@ func removeRule(params ruleParams) error { rule.Priority = params.priority rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) { + if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) { return fmt.Errorf("remove routing rule: %w", err) } From 6cfbb1f3206ceafba1c411d2b5cff385dd38cad1 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 5 Dec 2024 12:41:12 +0100 Subject: [PATCH 070/159] [client] Init route selector early (#2989) --- client/internal/routemanager/manager.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index f1c4ae5ef..8bf3a91b0 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -124,6 +124,8 @@ func NewManager( // Init sets up the routing func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { + m.routeSelector = m.initSelector() + if nbnet.CustomRoutingDisabled() { return nil, nil, nil } @@ -144,8 +146,6 @@ func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) return nil, nil, fmt.Errorf("setup routing: %w", err) } - m.routeSelector = m.initSelector() - log.Info("Routing setup complete") return beforePeerHook, afterPeerHook, nil } From e67fe89adb35fa2eb4afb54afc98515ea97c895d Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 5 Dec 2024 13:03:11 +0100 Subject: [PATCH 071/159] Reduce max wait time to initialize peer connections (#2984) * Reduce max wait time to initialize peer connections setting rand time range to 100-300ms instead of 100-800ms * remove min wait time --- client/internal/peer/conn.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index a8de2fccb..3a698a82a 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -597,9 +597,8 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd } func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) { - minWait := 100 - maxWait := 800 - duration := time.Duration(rand.Intn(maxWait-minWait)+minWait) * time.Millisecond + maxWait := 300 + duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond timeout := time.NewTimer(duration) defer timeout.Stop() From 713e320c4c9b605fffdc04365d0b4a0ed4b59faf Mon Sep 17 00:00:00 2001 From: "M. Essam" Date: Thu, 5 Dec 2024 15:15:23 +0200 Subject: [PATCH 072/159] Update account peers on login on meta change (#2991) * Update account peers on login on meta change * Factor out LoginPeer peer not found handling --- management/server/peer.go | 41 ++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 474c2d665..761aa39a2 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -710,26 +710,30 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil } +func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login PeerLogin, err error) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { + if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { + // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. + // Try registering it. + newPeer := &nbpeer.Peer{ + Key: login.WireGuardPubKey, + Meta: login.Meta, + SSHKey: login.SSHKey, + Location: nbpeer.Location{ConnectionIP: login.ConnectionIP}, + } + + return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer) + } + + log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) + return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer") +} + // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { - if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { - // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. - // Try registering it. - newPeer := &nbpeer.Peer{ - Key: login.WireGuardPubKey, - Meta: login.Meta, - SSHKey: login.SSHKey, - Location: nbpeer.Location{ConnectionIP: login.ConnectionIP}, - } - - return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer) - } - - log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err) - return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer") + return am.handlePeerLoginNotFound(ctx, login, err) } // when the client sends a login request with a JWT which is used to get the user ID, @@ -828,7 +832,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) return nil, nil, nil, err } - if updateRemotePeers || isStatusChanged { + postureChecks, err := am.getPeerPostureChecks(account, peer.ID) + if err != nil { + return nil, nil, nil, err + } + + if updateRemotePeers || isStatusChanged || (updated && len(postureChecks) > 0) { am.updateAccountPeers(ctx, accountID) } From ff330e644e2bd9c6658fdc1daae0798886e32adb Mon Sep 17 00:00:00 2001 From: Edouard Vanbelle <15628033+EdouardVanbelle@users.noreply.github.com> Date: Thu, 5 Dec 2024 15:38:00 +0100 Subject: [PATCH 073/159] upgrade zcalusic/sysinfo@v1.1.3 (add serial for ARM arch) (#2954) Signed-off-by: Edouard Vanbelle --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index c08713f2b..2b4111ce3 100644 --- a/go.mod +++ b/go.mod @@ -80,7 +80,7 @@ require ( github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 github.com/things-go/go-socks5 v0.0.4 github.com/yusufpapurcu/wmi v1.2.4 - github.com/zcalusic/sysinfo v1.0.2 + github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 go.opentelemetry.io/otel v1.26.0 go.opentelemetry.io/otel/exporters/prometheus v0.48.0 diff --git a/go.sum b/go.sum index 5519d6ded..35abe82d2 100644 --- a/go.sum +++ b/go.sum @@ -708,8 +708,8 @@ github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= -github.com/zcalusic/sysinfo v1.0.2 h1:nwTTo2a+WQ0NXwo0BGRojOJvJ/5XKvQih+2RrtWqfxc= -github.com/zcalusic/sysinfo v1.0.2/go.mod h1:kluzTYflRWo6/tXVMJPdEjShsbPpsFRyy+p1mBQPC30= +github.com/zcalusic/sysinfo v1.1.3 h1:u/AVENkuoikKuIZ4sUEJ6iibpmQP6YpGD8SSMCrqAF0= +github.com/zcalusic/sysinfo v1.1.3/go.mod h1:NX+qYnWGtJVPV0yWldff9uppNKU4h40hJIRPf/pGLv4= github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY= github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg= From e40a29ba17749e0276510149b8da7b3b71550ff4 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 6 Dec 2024 16:51:42 +0100 Subject: [PATCH 074/159] [client] Add support for state manager on iOS (#2996) --- client/internal/config.go | 21 +++-- client/internal/connect.go | 2 + client/internal/engine.go | 11 +++ client/internal/mobile_dependency.go | 1 + .../routeselector/routeselector_test.go | 85 +++++++++++++++++++ client/ios/NetBirdSDK/client.go | 9 +- client/ios/NetBirdSDK/preferences.go | 5 +- client/ios/NetBirdSDK/preferences_test.go | 11 ++- 8 files changed, 130 insertions(+), 15 deletions(-) diff --git a/client/internal/config.go b/client/internal/config.go index ce87835cd..998690ef1 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -46,6 +46,7 @@ type ConfigInput struct { ManagementURL string AdminURL string ConfigPath string + StateFilePath string PreSharedKey *string ServerSSHAllowed *bool NATExternalIPs []string @@ -105,10 +106,10 @@ type Config struct { // DNSRouteInterval is the interval in which the DNS routes are updated DNSRouteInterval time.Duration - //Path to a certificate used for mTLS authentication + // Path to a certificate used for mTLS authentication ClientCertPath string - //Path to corresponding private key of ClientCertPath + // Path to corresponding private key of ClientCertPath ClientCertKeyPath string ClientCertKeyPair *tls.Certificate `json:"-"` @@ -116,7 +117,7 @@ type Config struct { // ReadConfig read config file and return with Config. If it is not exists create a new with default values func ReadConfig(configPath string) (*Config, error) { - if configFileIsExists(configPath) { + if fileExists(configPath) { err := util.EnforcePermission(configPath) if err != nil { log.Errorf("failed to enforce permission on config dir: %v", err) @@ -149,7 +150,7 @@ func ReadConfig(configPath string) (*Config, error) { // UpdateConfig update existing configuration according to input configuration and return with the configuration func UpdateConfig(input ConfigInput) (*Config, error) { - if !configFileIsExists(input.ConfigPath) { + if !fileExists(input.ConfigPath) { return nil, status.Errorf(codes.NotFound, "config file doesn't exist") } @@ -158,7 +159,7 @@ func UpdateConfig(input ConfigInput) (*Config, error) { // UpdateOrCreateConfig reads existing config or generates a new one func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { - if !configFileIsExists(input.ConfigPath) { + if !fileExists(input.ConfigPath) { log.Infof("generating new config %s", input.ConfigPath) cfg, err := createNewConfig(input) if err != nil { @@ -472,11 +473,19 @@ func isPreSharedKeyHidden(preSharedKey *string) bool { return false } -func configFileIsExists(path string) bool { +func fileExists(path string) bool { _, err := os.Stat(path) return !os.IsNotExist(err) } +func createFile(path string) error { + file, err := os.Create(path) + if err != nil { + return err + } + return file.Close() +} + // UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain. // If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config. // The check is performed only for the NetBird's managed version. diff --git a/client/internal/connect.go b/client/internal/connect.go index 4848b1c11..782984e27 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -91,6 +91,7 @@ func (c *ConnectClient) RunOniOS( fileDescriptor int32, networkChangeListener listener.NetworkChangeListener, dnsManager dns.IosDnsManager, + stateFilePath string, ) error { // Set GC percent to 5% to reduce memory usage as iOS only allows 50MB of memory for the extension. debug.SetGCPercent(5) @@ -99,6 +100,7 @@ func (c *ConnectClient) RunOniOS( FileDescriptor: fileDescriptor, NetworkChangeListener: networkChangeListener, DnsManager: dnsManager, + StateFilePath: stateFilePath, } return c.run(mobileDependency, nil, nil) } diff --git a/client/internal/engine.go b/client/internal/engine.go index 782bb48bb..63caec02a 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -243,6 +243,17 @@ func NewEngineWithProbes( probes: probes, checks: checks, } + if runtime.GOOS == "ios" { + if !fileExists(mobileDep.StateFilePath) { + err := createFile(mobileDep.StateFilePath) + if err != nil { + log.Errorf("failed to create state file: %v", err) + // we are not exiting as we can run without the state manager + } + } + + engine.stateManager = statemanager.New(mobileDep.StateFilePath) + } if path := statemanager.GetDefaultStatePath(); path != "" { engine.stateManager = statemanager.New(path) } diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 2b0c92cc6..4ac0fc141 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -19,4 +19,5 @@ type MobileDependency struct { // iOS only DnsManager dns.IosDnsManager FileDescriptor int32 + StateFilePath string } diff --git a/client/internal/routeselector/routeselector_test.go b/client/internal/routeselector/routeselector_test.go index 7df433f92..b1671f254 100644 --- a/client/internal/routeselector/routeselector_test.go +++ b/client/internal/routeselector/routeselector_test.go @@ -273,3 +273,88 @@ func TestRouteSelector_FilterSelected(t *testing.T) { "route2|192.168.0.0/16": {}, }, filtered) } + +func TestRouteSelector_NewRoutesBehavior(t *testing.T) { + initialRoutes := []route.NetID{"route1", "route2", "route3"} + newRoutes := []route.NetID{"route1", "route2", "route3", "route4", "route5"} + + tests := []struct { + name string + initialState func(rs *routeselector.RouteSelector) error // Setup initial state + wantNewSelected []route.NetID // Expected selected routes after new routes appear + }{ + { + name: "New routes with initial selectAll state", + initialState: func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return nil + }, + // When selectAll is true, all routes including new ones should be selected + wantNewSelected: []route.NetID{"route1", "route2", "route3", "route4", "route5"}, + }, + { + name: "New routes after specific selection", + initialState: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1", "route2"}, false, initialRoutes) + }, + // When specific routes were selected, new routes should remain unselected + wantNewSelected: []route.NetID{"route1", "route2"}, + }, + { + name: "New routes after deselect all", + initialState: func(rs *routeselector.RouteSelector) error { + rs.DeselectAllRoutes() + return nil + }, + // After deselect all, new routes should remain unselected + wantNewSelected: []route.NetID{}, + }, + { + name: "New routes after deselecting specific routes", + initialState: func(rs *routeselector.RouteSelector) error { + rs.SelectAllRoutes() + return rs.DeselectRoutes([]route.NetID{"route1"}, initialRoutes) + }, + // After deselecting specific routes, new routes should remain unselected + wantNewSelected: []route.NetID{"route2", "route3"}, + }, + { + name: "New routes after selecting with append", + initialState: func(rs *routeselector.RouteSelector) error { + return rs.SelectRoutes([]route.NetID{"route1"}, true, initialRoutes) + }, + // When routes were appended, new routes should remain unselected + wantNewSelected: []route.NetID{"route1"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rs := routeselector.NewRouteSelector() + + // Setup initial state + err := tt.initialState(rs) + require.NoError(t, err) + + // Verify selection state with new routes + for _, id := range newRoutes { + assert.Equal(t, rs.IsSelected(id), slices.Contains(tt.wantNewSelected, id), + "Route %s selection state incorrect", id) + } + + // Additional verification using FilterSelected + routes := route.HAMap{ + "route1|10.0.0.0/8": {}, + "route2|192.168.0.0/16": {}, + "route3|172.16.0.0/12": {}, + "route4|10.10.0.0/16": {}, + "route5|192.168.1.0/24": {}, + } + + filtered := rs.FilterSelected(routes) + expectedLen := len(tt.wantNewSelected) + assert.Equal(t, expectedLen, len(filtered), + "FilterSelected returned wrong number of routes, got %d want %d", len(filtered), expectedLen) + }) + } +} diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 9d65bdbe0..6f501e0c6 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -59,6 +59,7 @@ func init() { // Client struct manage the life circle of background service type Client struct { cfgFile string + stateFile string recorder *peer.Status ctxCancel context.CancelFunc ctxCancelLock *sync.Mutex @@ -73,9 +74,10 @@ type Client struct { } // NewClient instantiate a new Client -func NewClient(cfgFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client { +func NewClient(cfgFile, stateFile, deviceName string, osVersion string, osName string, networkChangeListener NetworkChangeListener, dnsManager DnsManager) *Client { return &Client{ cfgFile: cfgFile, + stateFile: stateFile, deviceName: deviceName, osName: osName, osVersion: osVersion, @@ -91,7 +93,8 @@ func (c *Client) Run(fd int32, interfaceName string) error { log.Infof("Starting NetBird client") log.Debugf("Tunnel uses interface: %s", interfaceName) cfg, err := internal.UpdateOrCreateConfig(internal.ConfigInput{ - ConfigPath: c.cfgFile, + ConfigPath: c.cfgFile, + StateFilePath: c.stateFile, }) if err != nil { return err @@ -124,7 +127,7 @@ func (c *Client) Run(fd int32, interfaceName string) error { cfg.WgIface = interfaceName c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager) + return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile) } // Stop the internal client and free the resources diff --git a/client/ios/NetBirdSDK/preferences.go b/client/ios/NetBirdSDK/preferences.go index b78146679..5a0abd9a7 100644 --- a/client/ios/NetBirdSDK/preferences.go +++ b/client/ios/NetBirdSDK/preferences.go @@ -10,9 +10,10 @@ type Preferences struct { } // NewPreferences create new Preferences instance -func NewPreferences(configPath string) *Preferences { +func NewPreferences(configPath string, stateFilePath string) *Preferences { ci := internal.ConfigInput{ - ConfigPath: configPath, + ConfigPath: configPath, + StateFilePath: stateFilePath, } return &Preferences{ci} } diff --git a/client/ios/NetBirdSDK/preferences_test.go b/client/ios/NetBirdSDK/preferences_test.go index aa6a475ae..7e5325a00 100644 --- a/client/ios/NetBirdSDK/preferences_test.go +++ b/client/ios/NetBirdSDK/preferences_test.go @@ -9,7 +9,8 @@ import ( func TestPreferences_DefaultValues(t *testing.T) { cfgFile := filepath.Join(t.TempDir(), "netbird.json") - p := NewPreferences(cfgFile) + stateFile := filepath.Join(t.TempDir(), "state.json") + p := NewPreferences(cfgFile, stateFile) defaultVar, err := p.GetAdminURL() if err != nil { t.Fatalf("failed to read default value: %s", err) @@ -42,7 +43,8 @@ func TestPreferences_DefaultValues(t *testing.T) { func TestPreferences_ReadUncommitedValues(t *testing.T) { exampleString := "exampleString" cfgFile := filepath.Join(t.TempDir(), "netbird.json") - p := NewPreferences(cfgFile) + stateFile := filepath.Join(t.TempDir(), "state.json") + p := NewPreferences(cfgFile, stateFile) p.SetAdminURL(exampleString) resp, err := p.GetAdminURL() @@ -79,7 +81,8 @@ func TestPreferences_Commit(t *testing.T) { exampleURL := "https://myurl.com:443" examplePresharedKey := "topsecret" cfgFile := filepath.Join(t.TempDir(), "netbird.json") - p := NewPreferences(cfgFile) + stateFile := filepath.Join(t.TempDir(), "state.json") + p := NewPreferences(cfgFile, stateFile) p.SetAdminURL(exampleURL) p.SetManagementURL(exampleURL) @@ -90,7 +93,7 @@ func TestPreferences_Commit(t *testing.T) { t.Fatalf("failed to save changes: %s", err) } - p = NewPreferences(cfgFile) + p = NewPreferences(cfgFile, stateFile) resp, err := p.GetAdminURL() if err != nil { t.Fatalf("failed to read admin url: %s", err) From 9a96b91d9d679e4bf6081c76a3e341e946f340f2 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 9 Dec 2024 14:21:28 +0100 Subject: [PATCH 075/159] Fix merge Signed-off-by: bcmmbaga --- management/server/peer.go | 85 +++++++++++++++++++++++++++++++++------ 1 file changed, 73 insertions(+), 12 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index f6e4de7b7..a174bc415 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -649,6 +649,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac var isStatusChanged bool var updated bool var err error + var postureChecks []*posture.Checks err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, sync.WireGuardPubKey) @@ -690,7 +691,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac if updated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID) - err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) + if err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil { + return err + } + + postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID) if err != nil { return err } @@ -701,12 +706,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return nil, nil, nil, err } - postureChecks, err := am.getPeerPostureChecks(account, peer.ID) - if err != nil { - return nil, nil, nil, err - } - - if isStatusChanged || (updated && sync.UpdateAccountPeers) || (updated && len(postureChecks) > 0){ + if isStatusChanged || sync.UpdateAccountPeers || (updated && len(postureChecks) > 0) { am.updateAccountPeers(ctx, accountID) } @@ -764,6 +764,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) var isRequiresApproval bool var isStatusChanged bool var isPeerUpdated bool + var postureChecks []*posture.Checks err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey) @@ -809,6 +810,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) if isPeerUpdated { am.metrics.AccountManagerMetrics().CountPeerMetUpdate() shouldStorePeer = true + + postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID) + if err != nil { + return err + } } if peer.SSHKey != login.SSHKey { @@ -831,11 +837,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) unlockPeer() unlockPeer = nil - postureChecks, err := am.getPeerPostureChecks(account, peer.ID) - if err != nil { - return nil, nil, nil, err - } - if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { am.updateAccountPeers(ctx, accountID) } @@ -843,6 +844,66 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) } +// getPeerPostureChecks returns the posture checks for the peer. +func getPeerPostureChecks(ctx context.Context, transaction Store, accountID, peerID string) ([]*posture.Checks, error) { + policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + if len(policies) == 0 { + return nil, nil + } + + var peerPostureChecksIDs []string + + for _, policy := range policies { + if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { + continue + } + + postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID) + if err != nil { + return nil, err + } + + peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...) + } + + peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, peerPostureChecksIDs) + if err != nil { + return nil, err + } + + return maps.Values(peerPostureChecks), nil +} + +// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks. +func processPeerPostureChecks(ctx context.Context, transaction Store, policy *Policy, accountID, peerID string) ([]string, error) { + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + sourceGroups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, rule.Sources) + if err != nil { + return nil, err + } + + for _, sourceGroup := range rule.Sources { + group, ok := sourceGroups[sourceGroup] + if !ok { + return nil, fmt.Errorf("failed to check peer in policy source group") + } + + if slices.Contains(group.Peers, peerID) { + return policy.SourcePostureChecks, nil + } + } + } + return nil, nil +} + // checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO // and if the peer login is expired. // The NetBird client doesn't have a way to check if the peer needs login besides sending a login request From 2147bf75eb25a4d63ea428484a1dbfac5e2fbf82 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 9 Dec 2024 17:10:31 +0100 Subject: [PATCH 076/159] [client] Add peer conn init limit (#3001) Limit the peer connection initialization to 200 peers at the same time --- client/internal/engine.go | 6 +- client/internal/peer/conn.go | 9 ++- client/internal/peer/conn_test.go | 9 +-- util/semaphore-group/semaphore_group.go | 48 ++++++++++++++ util/semaphore-group/semaphore_group_test.go | 66 ++++++++++++++++++++ 5 files changed, 131 insertions(+), 7 deletions(-) create mode 100644 util/semaphore-group/semaphore_group.go create mode 100644 util/semaphore-group/semaphore_group_test.go diff --git a/client/internal/engine.go b/client/internal/engine.go index 63caec02a..34219def1 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -39,6 +39,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" @@ -62,6 +63,7 @@ import ( const ( PeerConnectionTimeoutMax = 45000 // ms PeerConnectionTimeoutMin = 30000 // ms + connInitLimit = 200 ) var ErrResetConnection = fmt.Errorf("reset connection") @@ -177,6 +179,7 @@ type Engine struct { // Network map persistence persistNetworkMap bool latestNetworkMap *mgmProto.NetworkMap + connSemaphore *semaphoregroup.SemaphoreGroup } // Peer is an instance of the Connection Peer @@ -242,6 +245,7 @@ func NewEngineWithProbes( statusRecorder: statusRecorder, probes: probes, checks: checks, + connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), } if runtime.GOOS == "ios" { if !fileExists(mobileDep.StateFilePath) { @@ -1051,7 +1055,7 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs string) (*peer.Conn, e }, } - peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher) + peerConn, err := peer.NewConn(e.ctx, config, e.statusRecorder, e.signaler, e.mobileDep.IFaceDiscover, e.relayManager, e.srWatcher, e.connSemaphore) if err != nil { return nil, err } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 3a698a82a..5c2e2cb60 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -23,6 +23,7 @@ import ( relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" nbnet "github.com/netbirdio/netbird/util/net" + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) type ConnPriority int @@ -104,12 +105,13 @@ type Conn struct { wgProxyICE wgproxy.Proxy wgProxyRelay wgproxy.Proxy - guard *guard.Guard + guard *guard.Guard + semaphore *semaphoregroup.SemaphoreGroup } // NewConn creates a new not opened Conn to the remote peer. // To establish a connection run Conn.Open -func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher) (*Conn, error) { +func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Status, signaler *Signaler, iFaceDiscover stdnet.ExternalIFaceDiscover, relayManager *relayClient.Manager, srWatcher *guard.SRWatcher, semaphore *semaphoregroup.SemaphoreGroup) (*Conn, error) { allowedIP, _, err := net.ParseCIDR(config.WgConfig.AllowedIps) if err != nil { log.Errorf("failed to parse allowedIPS: %v", err) @@ -130,6 +132,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu allowedIP: allowedIP, statusRelay: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(), + semaphore: semaphore, } rFns := WorkerRelayCallbacks{ @@ -169,6 +172,7 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu // It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will // be used. func (conn *Conn) Open() { + conn.semaphore.Add(conn.ctx) conn.log.Debugf("open connection to peer") conn.mu.Lock() @@ -191,6 +195,7 @@ func (conn *Conn) Open() { } func (conn *Conn) startHandshakeAndReconnect(ctx context.Context) { + defer conn.semaphore.Done(conn.ctx) conn.waitInitialRandomSleepTime(ctx) err := conn.handshaker.sendOffer() diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 039952588..b3e9d5b60 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/client/internal/peer/ice" "github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/util" + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) var connConf = ConnConfig{ @@ -46,7 +47,7 @@ func TestNewConn_interfaceFilter(t *testing.T) { func TestConn_GetKey(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, nil, nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } @@ -58,7 +59,7 @@ func TestConn_GetKey(t *testing.T) { func TestConn_OnRemoteOffer(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } @@ -92,7 +93,7 @@ func TestConn_OnRemoteOffer(t *testing.T) { func TestConn_OnRemoteAnswer(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } @@ -125,7 +126,7 @@ func TestConn_OnRemoteAnswer(t *testing.T) { } func TestConn_Status(t *testing.T) { swWatcher := guard.NewSRWatcher(nil, nil, nil, connConf.ICEConfig) - conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher) + conn, err := NewConn(context.Background(), connConf, NewRecorder("https://mgm"), nil, nil, nil, swWatcher, semaphoregroup.NewSemaphoreGroup(1)) if err != nil { return } diff --git a/util/semaphore-group/semaphore_group.go b/util/semaphore-group/semaphore_group.go new file mode 100644 index 000000000..ad74e1bfc --- /dev/null +++ b/util/semaphore-group/semaphore_group.go @@ -0,0 +1,48 @@ +package semaphoregroup + +import ( + "context" + "sync" +) + +// SemaphoreGroup is a custom type that combines sync.WaitGroup and a semaphore. +type SemaphoreGroup struct { + waitGroup sync.WaitGroup + semaphore chan struct{} +} + +// NewSemaphoreGroup creates a new SemaphoreGroup with the specified semaphore limit. +func NewSemaphoreGroup(limit int) *SemaphoreGroup { + return &SemaphoreGroup{ + semaphore: make(chan struct{}, limit), + } +} + +// Add increments the internal WaitGroup counter and acquires a semaphore slot. +func (sg *SemaphoreGroup) Add(ctx context.Context) { + sg.waitGroup.Add(1) + + // Acquire semaphore slot + select { + case <-ctx.Done(): + return + case sg.semaphore <- struct{}{}: + } +} + +// Done decrements the internal WaitGroup counter and releases a semaphore slot. +func (sg *SemaphoreGroup) Done(ctx context.Context) { + sg.waitGroup.Done() + + // Release semaphore slot + select { + case <-ctx.Done(): + return + case <-sg.semaphore: + } +} + +// Wait waits until the internal WaitGroup counter is zero. +func (sg *SemaphoreGroup) Wait() { + sg.waitGroup.Wait() +} diff --git a/util/semaphore-group/semaphore_group_test.go b/util/semaphore-group/semaphore_group_test.go new file mode 100644 index 000000000..d4491cf77 --- /dev/null +++ b/util/semaphore-group/semaphore_group_test.go @@ -0,0 +1,66 @@ +package semaphoregroup + +import ( + "context" + "testing" + "time" +) + +func TestSemaphoreGroup(t *testing.T) { + semGroup := NewSemaphoreGroup(2) + + for i := 0; i < 5; i++ { + semGroup.Add(context.Background()) + go func(id int) { + defer semGroup.Done(context.Background()) + + got := len(semGroup.semaphore) + if got == 0 { + t.Errorf("Expected semaphore length > 0 , got 0") + } + + time.Sleep(time.Millisecond) + t.Logf("Goroutine %d is running\n", id) + }(i) + } + + semGroup.Wait() + + want := 0 + got := len(semGroup.semaphore) + if got != want { + t.Errorf("Expected semaphore length %d, got %d", want, got) + } +} + +func TestSemaphoreGroupContext(t *testing.T) { + semGroup := NewSemaphoreGroup(1) + semGroup.Add(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancel) + rChan := make(chan struct{}) + + go func() { + semGroup.Add(ctx) + rChan <- struct{}{} + }() + select { + case <-rChan: + case <-time.NewTimer(2 * time.Second).C: + t.Error("Adding to semaphore group should not block when context is not done") + } + + semGroup.Done(context.Background()) + + ctxDone, cancelDone := context.WithTimeout(context.Background(), 1*time.Second) + t.Cleanup(cancelDone) + go func() { + semGroup.Done(ctxDone) + rChan <- struct{}{} + }() + select { + case <-rChan: + case <-time.NewTimer(2 * time.Second).C: + t.Error("Releasing from semaphore group should not block when context is not done") + } +} From 97bb74f824786c20f78f3131bb8ed3e9ece26782 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 9 Dec 2024 18:40:06 +0100 Subject: [PATCH 077/159] Remove peer login log (#3005) Signed-off-by: bcmmbaga --- management/server/peer.go | 1 - 1 file changed, 1 deletion(-) diff --git a/management/server/peer.go b/management/server/peer.go index 761aa39a2..ba211be96 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -740,7 +740,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // it means that the client has already checked if it needs login and had been through the SSO flow // so, we can skip this check and directly proceed with the login if login.UserID == "" { - log.Info("Peer needs login") err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login) if err != nil { return nil, nil, nil, err From 6142828a9ce8ca72d25a1ded4d3ddc9605566f58 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 10 Dec 2024 15:59:25 +0100 Subject: [PATCH 078/159] [management] restructure api files (#3013) --- management/cmd/management.go | 3 +- management/server/http/configs/auth.go | 9 ++ management/server/http/handler.go | 151 +++--------------- .../accounts}/accounts_handler.go | 34 ++-- .../accounts}/accounts_handler_test.go | 12 +- .../dns}/dns_settings_handler.go | 33 ++-- .../dns}/dns_settings_handler_test.go | 10 +- .../{ => handlers/dns}/nameservers_handler.go | 44 +++-- .../dns}/nameservers_handler_test.go | 14 +- .../{ => handlers/events}/events_handler.go | 25 +-- .../events}/events_handler_test.go | 10 +- .../{ => handlers/groups}/groups_handler.go | 47 +++--- .../groups}/groups_handler_test.go | 18 +-- .../{ => handlers/peers}/peers_handler.go | 39 +++-- .../peers}/peers_handler_test.go | 6 +- .../policies}/geolocation_handler_test.go | 16 +- .../policies}/geolocations_handler.go | 29 ++-- .../policies}/policies_handler.go | 49 +++--- .../policies}/policies_handler_test.go | 16 +- .../policies}/posture_checks_handler.go | 47 +++--- .../policies}/posture_checks_handler_test.go | 32 ++-- .../{ => handlers/routes}/routes_handler.go | 46 +++--- .../routes}/routes_handler_test.go | 16 +- .../setup_keys}/setupkeys_handler.go | 42 +++-- .../setup_keys}/setupkeys_handler_test.go | 17 +- .../http/{ => handlers/users}/pat_handler.go | 39 +++-- .../{ => handlers/users}/pat_handler_test.go | 14 +- .../{ => handlers/users}/users_handler.go | 47 +++--- .../users}/users_handler_test.go | 18 +-- management/server/http/util/util.go | 4 + 30 files changed, 461 insertions(+), 426 deletions(-) create mode 100644 management/server/http/configs/auth.go rename management/server/http/{ => handlers/accounts}/accounts_handler.go (79%) rename management/server/http/{ => handlers/accounts}/accounts_handler_test.go (97%) rename management/server/http/{ => handlers/dns}/dns_settings_handler.go (62%) rename management/server/http/{ => handlers/dns}/dns_settings_handler_test.go (94%) rename management/server/http/{ => handlers/dns}/nameservers_handler.go (77%) rename management/server/http/{ => handlers/dns}/nameservers_handler_test.go (95%) rename management/server/http/{ => handlers/events}/events_handler.go (79%) rename management/server/http/{ => handlers/events}/events_handler_test.go (97%) rename management/server/http/{ => handlers/groups}/groups_handler.go (81%) rename management/server/http/{ => handlers/groups}/groups_handler_test.go (95%) rename management/server/http/{ => handlers/peers}/peers_handler.go (88%) rename management/server/http/{ => handlers/peers}/peers_handler_test.go (99%) rename management/server/http/{ => handlers/policies}/geolocation_handler_test.go (94%) rename management/server/http/{ => handlers/policies}/geolocations_handler.go (72%) rename management/server/http/{ => handlers/policies}/policies_handler.go (84%) rename management/server/http/{ => handlers/policies}/policies_handler_test.go (95%) rename management/server/http/{ => handlers/policies}/posture_checks_handler.go (70%) rename management/server/http/{ => handlers/policies}/posture_checks_handler_test.go (96%) rename management/server/http/{ => handlers/routes}/routes_handler.go (85%) rename management/server/http/{ => handlers/routes}/routes_handler_test.go (98%) rename management/server/http/{ => handlers/setup_keys}/setupkeys_handler.go (78%) rename management/server/http/{ => handlers/setup_keys}/setupkeys_handler_test.go (95%) rename management/server/http/{ => handlers/users}/pat_handler.go (75%) rename management/server/http/{ => handlers/users}/pat_handler_test.go (96%) rename management/server/http/{ => handlers/users}/users_handler.go (80%) rename management/server/http/{ => handlers/users}/users_handler_test.go (97%) diff --git a/management/cmd/management.go b/management/cmd/management.go index 719d1a78c..bfa158c5b 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -42,6 +42,7 @@ import ( nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" httpapi "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/metrics" @@ -257,7 +258,7 @@ var ( return fmt.Errorf("failed creating JWT validator: %v", err) } - httpAPIAuthCfg := httpapi.AuthCfg{ + httpAPIAuthCfg := configs.AuthCfg{ Issuer: config.HttpConfig.AuthIssuer, Audience: config.HttpConfig.AuthAudience, UserIDClaim: config.HttpConfig.AuthUserIDClaim, diff --git a/management/server/http/configs/auth.go b/management/server/http/configs/auth.go new file mode 100644 index 000000000..aa91fa55b --- /dev/null +++ b/management/server/http/configs/auth.go @@ -0,0 +1,9 @@ +package configs + +// AuthCfg contains parameters for authentication middleware +type AuthCfg struct { + Issuer string + Audience string + UserIDClaim string + KeysLocation string +} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index c3928bff6..373aa4dd7 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -12,6 +12,16 @@ import ( s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/handlers/accounts" + "github.com/netbirdio/netbird/management/server/http/handlers/dns" + "github.com/netbirdio/netbird/management/server/http/handlers/events" + "github.com/netbirdio/netbird/management/server/http/handlers/groups" + "github.com/netbirdio/netbird/management/server/http/handlers/peers" + "github.com/netbirdio/netbird/management/server/http/handlers/policies" + "github.com/netbirdio/netbird/management/server/http/handlers/routes" + "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" + "github.com/netbirdio/netbird/management/server/http/handlers/users" "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -20,27 +30,15 @@ import ( const apiPrefix = "/api" -// AuthCfg contains parameters for authentication middleware -type AuthCfg struct { - Issuer string - Audience string - UserIDClaim string - KeysLocation string -} - type apiHandler struct { Router *mux.Router AccountManager s.AccountManager geolocationManager *geolocation.Geolocation - AuthCfg AuthCfg -} - -// EmptyObject is an empty struct used to return empty JSON object -type emptyObject struct { + AuthCfg configs.AuthCfg } // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { +func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -86,122 +84,15 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa return nil, fmt.Errorf("register integrations endpoints: %w", err) } - api.addAccountsEndpoint() - api.addPeersEndpoint() - api.addUsersEndpoint() - api.addUsersTokensEndpoint() - api.addSetupKeysEndpoint() - api.addPoliciesEndpoint() - api.addGroupsEndpoint() - api.addRoutesEndpoint() - api.addDNSNameserversEndpoint() - api.addDNSSettingEndpoint() - api.addEventsEndpoint() - api.addPostureCheckEndpoint() - api.addLocationsEndpoint() + accounts.AddEndpoints(api.AccountManager, authCfg, router) + peers.AddEndpoints(api.AccountManager, authCfg, router) + users.AddEndpoints(api.AccountManager, authCfg, router) + setup_keys.AddEndpoints(api.AccountManager, authCfg, router) + policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router) + groups.AddEndpoints(api.AccountManager, authCfg, router) + routes.AddEndpoints(api.AccountManager, authCfg, router) + dns.AddEndpoints(api.AccountManager, authCfg, router) + events.AddEndpoints(api.AccountManager, authCfg, router) return rootRouter, nil } - -func (apiHandler *apiHandler) addAccountsEndpoint() { - accountsHandler := NewAccountsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.UpdateAccount).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/accounts/{accountId}", accountsHandler.DeleteAccount).Methods("DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/accounts", accountsHandler.GetAllAccounts).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addPeersEndpoint() { - peersHandler := NewPeersHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). - Methods("GET", "PUT", "DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addUsersEndpoint() { - userHandler := NewUsersHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/users", userHandler.GetAllUsers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}", userHandler.UpdateUser).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}", userHandler.DeleteUser).Methods("DELETE", "OPTIONS") - apiHandler.Router.HandleFunc("/users", userHandler.CreateUser).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/invite", userHandler.InviteUser).Methods("POST", "OPTIONS") -} - -func (apiHandler *apiHandler) addUsersTokensEndpoint() { - tokenHandler := NewPATsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.GetAllTokens).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens", tokenHandler.CreateToken).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.GetToken).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.DeleteToken).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addSetupKeysEndpoint() { - keysHandler := NewSetupKeysHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/setup-keys", keysHandler.GetAllSetupKeys).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys", keysHandler.CreateSetupKey).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.GetSetupKey).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.DeleteSetupKey).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addPoliciesEndpoint() { - policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/policies", policiesHandler.CreatePolicy).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.UpdatePolicy).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.GetPolicy).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/policies/{policyId}", policiesHandler.DeletePolicy).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addGroupsEndpoint() { - groupsHandler := NewGroupsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/groups", groupsHandler.GetAllGroups).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/groups", groupsHandler.CreateGroup).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.UpdateGroup).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.GetGroup).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/groups/{groupId}", groupsHandler.DeleteGroup).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addRoutesEndpoint() { - routesHandler := NewRoutesHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/routes", routesHandler.GetAllRoutes).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/routes", routesHandler.CreateRoute).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.UpdateRoute).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.GetRoute).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/routes/{routeId}", routesHandler.DeleteRoute).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addDNSNameserversEndpoint() { - nameserversHandler := NewNameserversHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.GetAllNameservers).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers", nameserversHandler.CreateNameserverGroup).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.UpdateNameserverGroup).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.GetNameserverGroup).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.DeleteNameserverGroup).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addDNSSettingEndpoint() { - dnsSettingsHandler := NewDNSSettingsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.GetDNSSettings).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/dns/settings", dnsSettingsHandler.UpdateDNSSettings).Methods("PUT", "OPTIONS") -} - -func (apiHandler *apiHandler) addEventsEndpoint() { - eventsHandler := NewEventsHandler(apiHandler.AccountManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/events", eventsHandler.GetAllEvents).Methods("GET", "OPTIONS") -} - -func (apiHandler *apiHandler) addPostureCheckEndpoint() { - postureCheckHandler := NewPostureChecksHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.GetAllPostureChecks).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks", postureCheckHandler.CreatePostureCheck).Methods("POST", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.UpdatePostureCheck).Methods("PUT", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.GetPostureCheck).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.DeletePostureCheck).Methods("DELETE", "OPTIONS") -} - -func (apiHandler *apiHandler) addLocationsEndpoint() { - locationHandler := NewGeolocationsHandlerHandler(apiHandler.AccountManager, apiHandler.geolocationManager, apiHandler.AuthCfg) - apiHandler.Router.HandleFunc("/locations/countries", locationHandler.GetAllCountries).Methods("GET", "OPTIONS") - apiHandler.Router.HandleFunc("/locations/countries/{country}/cities", locationHandler.GetCitiesByCountry).Methods("GET", "OPTIONS") -} diff --git a/management/server/http/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go similarity index 79% rename from management/server/http/accounts_handler.go rename to management/server/http/handlers/accounts/accounts_handler.go index 4baf9c692..c95207777 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -1,4 +1,4 @@ -package http +package accounts import ( "encoding/json" @@ -10,20 +10,28 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// AccountsHandler is a handler that handles the server.Account HTTP endpoints -type AccountsHandler struct { +// handler is a handler that handles the server.Account HTTP endpoints +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewAccountsHandler creates a new AccountsHandler HTTP handler -func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) *AccountsHandler { - return &AccountsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + accountsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") + router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") + router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") +} + +// newHandler creates a new handler HTTP handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -32,8 +40,8 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) * } } -// GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. -func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { +// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. +func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -51,8 +59,8 @@ func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } -// UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) -func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { +// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) +func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) _, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -111,8 +119,8 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, &resp) } -// DeleteAccount is a HTTP DELETE handler to delete an account -func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) { +// deleteAccount is a HTTP DELETE handler to delete an account +func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) vars := mux.Vars(r) targetAccountID := vars["accountId"] @@ -127,7 +135,7 @@ func (h *AccountsHandler) DeleteAccount(w http.ResponseWriter, r *http.Request) return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } func toAccountResponse(accountID string, settings *server.Settings) *api.Account { diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go similarity index 97% rename from management/server/http/accounts_handler_test.go rename to management/server/http/handlers/accounts/accounts_handler_test.go index cacb3d430..9d7e8a84d 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -1,4 +1,4 @@ -package http +package accounts import ( "bytes" @@ -20,8 +20,8 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) -func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { - return &AccountsHandler{ +func initAccountsTestData(account *server.Account, admin *server.User) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return account.Id, admin.Id, nil @@ -89,7 +89,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody io.Reader }{ { - name: "GetAllAccounts OK", + name: "getAllAccounts OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/accounts", @@ -189,8 +189,8 @@ func TestAccounts_AccountsHandler(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/accounts", handler.GetAllAccounts).Methods("GET") - router.HandleFunc("/api/accounts/{accountId}", handler.UpdateAccount).Methods("PUT") + router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET") + router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go similarity index 62% rename from management/server/http/dns_settings_handler.go rename to management/server/http/handlers/dns/dns_settings_handler.go index 13c2101a7..7dd8c1fc1 100644 --- a/management/server/http/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -1,26 +1,39 @@ -package http +package dns import ( "encoding/json" "net/http" + "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// DNSSettingsHandler is a handler that returns the DNS settings of the account -type DNSSettingsHandler struct { +// dnsSettingsHandler is a handler that returns the DNS settings of the account +type dnsSettingsHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewDNSSettingsHandler returns a new instance of DNSSettingsHandler handler -func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg) *DNSSettingsHandler { - return &DNSSettingsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + addDNSSettingEndpoint(accountManager, authCfg, router) + addDNSNameserversEndpoint(accountManager, authCfg, router) +} + +func addDNSSettingEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + dnsSettingsHandler := newDNSSettingsHandler(accountManager, authCfg) + router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS") +} + +// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler +func newDNSSettingsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *dnsSettingsHandler { + return &dnsSettingsHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -29,8 +42,8 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg } } -// GetDNSSettings returns the DNS settings for the account -func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { +// getDNSSettings returns the DNS settings for the account +func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -52,8 +65,8 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque util.WriteJSONObject(r.Context(), w, apiDNSSettings) } -// UpdateDNSSettings handles update to DNS settings of an account -func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { +// updateDNSSettings handles update to DNS settings of an account +func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go similarity index 94% rename from management/server/http/dns_settings_handler_test.go rename to management/server/http/handlers/dns/dns_settings_handler_test.go index 8baea7b15..a64e3fd83 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -1,4 +1,4 @@ -package http +package dns import ( "bytes" @@ -40,8 +40,8 @@ var testingDNSSettingsAccount = &server.Account{ DNSSettings: baseExistingDNSSettings, } -func initDNSSettingsTestData() *DNSSettingsHandler { - return &DNSSettingsHandler{ +func initDNSSettingsTestData() *dnsSettingsHandler { + return &dnsSettingsHandler{ accountManager: &mock_server.MockAccountManager{ GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { return &testingDNSSettingsAccount.DNSSettings, nil @@ -120,8 +120,8 @@ func TestDNSSettingsHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/dns/settings", p.GetDNSSettings).Methods("GET") - router.HandleFunc("/api/dns/settings", p.UpdateDNSSettings).Methods("PUT") + router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET") + router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/nameservers_handler.go b/management/server/http/handlers/dns/nameservers_handler.go similarity index 77% rename from management/server/http/nameservers_handler.go rename to management/server/http/handlers/dns/nameservers_handler.go index e7a2bc2ae..09047e231 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/handlers/dns/nameservers_handler.go @@ -1,4 +1,4 @@ -package http +package dns import ( "encoding/json" @@ -11,20 +11,30 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// NameserversHandler is the nameserver group handler of the account -type NameserversHandler struct { +// nameserversHandler is the nameserver group handler of the account +type nameserversHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewNameserversHandler returns a new instance of NameserversHandler handler -func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg) *NameserversHandler { - return &NameserversHandler{ +func addDNSNameserversEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + nameserversHandler := newNameserversHandler(accountManager, authCfg) + router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS") +} + +// newNameserversHandler returns a new instance of nameserversHandler handler +func newNameserversHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *nameserversHandler { + return &nameserversHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +43,8 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg } } -// GetAllNameservers returns the list of nameserver groups for the account -func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { +// getAllNameservers returns the list of nameserver groups for the account +func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -57,8 +67,8 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re util.WriteJSONObject(r.Context(), w, apiNameservers) } -// CreateNameserverGroup handles nameserver group creation request -func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { +// createNameserverGroup handles nameserver group creation request +func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -90,8 +100,8 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, &resp) } -// UpdateNameserverGroup handles update to a nameserver group identified by a given ID -func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { +// updateNameserverGroup handles update to a nameserver group identified by a given ID +func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -141,8 +151,8 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, &resp) } -// DeleteNameserverGroup handles nameserver group deletion request -func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { +// deleteNameserverGroup handles nameserver group deletion request +func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -162,11 +172,11 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetNameserverGroup handles a nameserver group Get request identified by ID -func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { +// getNameserverGroup handles a nameserver group Get request identified by ID +func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go similarity index 95% rename from management/server/http/nameservers_handler_test.go rename to management/server/http/handlers/dns/nameservers_handler_test.go index 98c2e402d..c6561e4d8 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -1,4 +1,4 @@ -package http +package dns import ( "bytes" @@ -50,8 +50,8 @@ var baseExistingNSGroup = &nbdns.NameServerGroup{ Enabled: true, } -func initNameserversTestData() *NameserversHandler { - return &NameserversHandler{ +func initNameserversTestData() *nameserversHandler { + return &nameserversHandler{ accountManager: &mock_server.MockAccountManager{ GetNameServerGroupFunc: func(_ context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { if nsGroupID == existingNSGroupID { @@ -206,10 +206,10 @@ func TestNameserversHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.GetNameserverGroup).Methods("GET") - router.HandleFunc("/api/dns/nameservers", p.CreateNameserverGroup).Methods("POST") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.DeleteNameserverGroup).Methods("DELETE") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.UpdateNameserverGroup).Methods("PUT") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET") + router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/events_handler.go b/management/server/http/handlers/events/events_handler.go similarity index 79% rename from management/server/http/events_handler.go rename to management/server/http/handlers/events/events_handler.go index ee0c63f28..62da59535 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/handlers/events/events_handler.go @@ -1,28 +1,35 @@ -package http +package events import ( "context" "fmt" "net/http" + "github.com/gorilla/mux" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// EventsHandler HTTP handler -type EventsHandler struct { +// handler HTTP handler +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewEventsHandler creates a new EventsHandler HTTP handler -func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *EventsHandler { - return &EventsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + eventsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS") +} + +// newHandler creates a new events handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +38,8 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev } } -// GetAllEvents list of the given account -func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { +// getAllEvents list of the given account +func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -60,7 +67,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, events) } -func (h *EventsHandler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { +func (h *handler) fillEventsWithUserInfo(ctx context.Context, events []*api.Event, accountId, userId string) error { // build email, name maps based on users userInfos, err := h.accountManager.GetUsersFromAccount(ctx, accountId, userId) if err != nil { diff --git a/management/server/http/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go similarity index 97% rename from management/server/http/events_handler_test.go rename to management/server/http/handlers/events/events_handler_test.go index e525cf2ee..6af2e5346 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -1,4 +1,4 @@ -package http +package events import ( "context" @@ -20,8 +20,8 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) -func initEventsTestData(account string, events ...*activity.Event) *EventsHandler { - return &EventsHandler{ +func initEventsTestData(account string, events ...*activity.Event) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetEventsFunc: func(_ context.Context, accountID, userID string) ([]*activity.Event, error) { if accountID == account { @@ -183,7 +183,7 @@ func TestEvents_GetEvents(t *testing.T) { requestBody io.Reader }{ { - name: "GetAllEvents OK", + name: "getAllEvents OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/events/", @@ -201,7 +201,7 @@ func TestEvents_GetEvents(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/events/", handler.GetAllEvents).Methods("GET") + router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go similarity index 81% rename from management/server/http/groups_handler.go rename to management/server/http/handlers/groups/groups_handler.go index f369d1a00..e60529cec 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -1,13 +1,15 @@ -package http +package groups import ( "encoding/json" "net/http" "github.com/gorilla/mux" - nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/http/configs" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" @@ -16,15 +18,24 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) -// GroupsHandler is a handler that returns groups of the account -type GroupsHandler struct { +// handler is a handler that returns groups of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewGroupsHandler creates a new GroupsHandler HTTP handler -func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *GroupsHandler { - return &GroupsHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + groupsHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS") + router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS") + router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS") +} + +// newHandler creates a new groups handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +44,8 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr } } -// GetAllGroups list for the account -func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { +// getAllGroups list for the account +func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -63,8 +74,8 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, groupsResponse) } -// UpdateGroup handles update to a group identified by a given ID -func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { +// updateGroup handles update to a group identified by a given ID +func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -141,8 +152,8 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } -// CreateGroup handles group creation request -func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { +// createGroup handles group creation request +func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -189,8 +200,8 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, &group)) } -// DeleteGroup handles group deletion request -func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { +// deleteGroup handles group deletion request +func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -215,11 +226,11 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetGroup returns a group -func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { +// getGroup returns a group +func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go similarity index 95% rename from management/server/http/groups_handler_test.go rename to management/server/http/handlers/groups/groups_handler_test.go index 7f3c81f18..089c1a40f 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -1,4 +1,4 @@ -package http +package groups import ( "bytes" @@ -31,8 +31,8 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(initGroups ...*nbgroup.Group) *GroupsHandler { - return &GroupsHandler{ +func initGroupTestData(initGroups ...*nbgroup.Group) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { if !strings.HasPrefix(group.ID, "id-") { @@ -106,14 +106,14 @@ func TestGetGroup(t *testing.T) { requestBody io.Reader }{ { - name: "GetGroup OK", + name: "getGroup OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/groups/idofthegroup", expectedStatus: http.StatusOK, }, { - name: "GetGroup not found", + name: "getGroup not found", requestType: http.MethodGet, requestPath: "/api/groups/notexists", expectedStatus: http.StatusNotFound, @@ -133,7 +133,7 @@ func TestGetGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.GetGroup).Methods("GET") + router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -254,8 +254,8 @@ func TestWriteGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/groups", p.CreateGroup).Methods("POST") - router.HandleFunc("/api/groups/{groupId}", p.UpdateGroup).Methods("PUT") + router.HandleFunc("/api/groups", p.createGroup).Methods("POST") + router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -331,7 +331,7 @@ func TestDeleteGroup(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.DeleteGroup).Methods("DELETE") + router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go similarity index 88% rename from management/server/http/peers_handler.go rename to management/server/http/handlers/peers/peers_handler.go index f5027cd77..c53cbc038 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -1,4 +1,4 @@ -package http +package peers import ( "context" @@ -12,21 +12,30 @@ import ( "github.com/netbirdio/netbird/management/server" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" ) -// PeersHandler is a handler that returns peers of the account -type PeersHandler struct { +// Handler is a handler that returns peers of the account +type Handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPeersHandler creates a new PeersHandler HTTP handler -func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *PeersHandler { - return &PeersHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + peersHandler := NewHandler(accountManager, authCfg) + router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). + Methods("GET", "PUT", "DELETE", "OPTIONS") + router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") +} + +// NewHandler creates a new peers Handler +func NewHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *Handler { + return &Handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -35,7 +44,7 @@ func NewPeersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Pee } } -func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { +func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { peerToReturn := peer.Copy() if peer.Status.Connected { // Although we have online status in store we do not yet have an updated channel so have to show it as disconnected @@ -48,7 +57,7 @@ func (h *PeersHandler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) return peerToReturn, nil } -func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { +func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) if err != nil { util.WriteError(ctx, err, w) @@ -75,7 +84,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *Handler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -120,18 +129,18 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, groupMinimumInfo, dnsDomain, valid)) } -func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { +func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID) if err != nil { log.WithContext(ctx).Errorf("failed to delete peer: %v", err) util.WriteError(ctx, err, w) return } - util.WriteJSONObject(ctx, w, emptyObject{}) + util.WriteJSONObject(ctx, w, util.EmptyObject{}) } // HandlePeer handles all peer requests for GET, PUT and DELETE operations -func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { +func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -168,7 +177,7 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { } // GetAllPeers returns a list of all peers associated with a provided account -func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { +func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -219,7 +228,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, respBody) } -func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { for _, peer := range respBody { _, ok := approvedPeersMap[peer.Id] if !ok { @@ -229,7 +238,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv } // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. -func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { +func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go similarity index 99% rename from management/server/http/peers_handler_test.go rename to management/server/http/handlers/peers/peers_handler_test.go index dd49c03b8..3e3e39deb 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -1,4 +1,4 @@ -package http +package peers import ( "bytes" @@ -38,8 +38,8 @@ const ( userIDKey ctxKey = "user_id" ) -func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { - return &PeersHandler{ +func initTestMetaData(peers ...*nbpeer.Peer) *Handler { + return &Handler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { var p *nbpeer.Peer diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go similarity index 94% rename from management/server/http/geolocation_handler_test.go rename to management/server/http/handlers/policies/geolocation_handler_test.go index 19c916dd2..002b914ef 100644 --- a/management/server/http/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "context" @@ -11,9 +11,9 @@ import ( "testing" "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server" "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -21,12 +21,12 @@ import ( "github.com/netbirdio/netbird/util" ) -func initGeolocationTestData(t *testing.T) *GeolocationsHandler { +func initGeolocationTestData(t *testing.T) *geolocationsHandler { t.Helper() var ( - mmdbPath = "../testdata/GeoLite2-City_20240305.mmdb" - geonamesdbPath = "../testdata/geonames_20240305.db" + mmdbPath = "../../../testdata/GeoLite2-City_20240305.mmdb" + geonamesdbPath = "../../../testdata/geonames_20240305.db" ) tempDir := t.TempDir() @@ -41,7 +41,7 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler { assert.NoError(t, err) t.Cleanup(func() { _ = geo.Stop() }) - return &GeolocationsHandler{ + return &geolocationsHandler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil @@ -114,7 +114,7 @@ func TestGetCitiesByCountry(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.GetCitiesByCountry).Methods("GET") + router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -202,7 +202,7 @@ func TestGetAllCountries(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries", geolocationHandler.GetAllCountries).Methods("GET") + router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go similarity index 72% rename from management/server/http/geolocations_handler.go rename to management/server/http/handlers/policies/geolocations_handler.go index 418228abf..e5bf3e695 100644 --- a/management/server/http/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "net/http" @@ -9,6 +9,7 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" @@ -18,16 +19,22 @@ var ( countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$") ) -// GeolocationsHandler is a handler that returns locations. -type GeolocationsHandler struct { +// geolocationsHandler is a handler that returns locations. +type geolocationsHandler struct { accountManager server.AccountManager geolocationManager *geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -// NewGeolocationsHandlerHandler creates a new Geolocations handler -func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *GeolocationsHandler { - return &GeolocationsHandler{ +func addLocationsEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg) + router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") + router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") +} + +// newGeolocationsHandlerHandler creates a new Geolocations handler +func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { + return &geolocationsHandler{ accountManager: accountManager, geolocationManager: geolocationManager, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -37,8 +44,8 @@ func NewGeolocationsHandlerHandler(accountManager server.AccountManager, geoloca } } -// GetAllCountries retrieves a list of all countries -func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Request) { +// getAllCountries retrieves a list of all countries +func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { util.WriteError(r.Context(), err, w) return @@ -63,8 +70,8 @@ func (l *GeolocationsHandler) GetAllCountries(w http.ResponseWriter, r *http.Req util.WriteJSONObject(r.Context(), w, countries) } -// GetCitiesByCountry retrieves a list of cities based on the given country code -func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.Request) { +// getCitiesByCountry retrieves a list of cities based on the given country code +func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) { if err := l.authenticateUser(r); err != nil { util.WriteError(r.Context(), err, w) return @@ -96,7 +103,7 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http. util.WriteJSONObject(r.Context(), w, cities) } -func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { +func (l *geolocationsHandler) authenticateUser(r *http.Request) error { claims := l.claimsExtractor.FromRequestContext(r) _, userID, err := l.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go similarity index 84% rename from management/server/http/policies_handler.go rename to management/server/http/handlers/policies/policies_handler.go index eff9092d4..a47f2e620 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "encoding/json" @@ -6,23 +6,36 @@ import ( "strconv" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/geolocation" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// Policies is a handler that returns policy of the account -type Policies struct { +// handler is a handler that returns policy of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPoliciesHandler creates a new Policies handler -func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Policies { - return &Policies{ +func AddEndpoints(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + policiesHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") + router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS") + router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS") + addPostureCheckEndpoint(accountManager, locationManager, authCfg, router) +} + +// newHandler creates a new policies handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +44,8 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) * } } -// GetAllPolicies list for the account -func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { +// getAllPolicies list for the account +func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -65,8 +78,8 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, policies) } -// UpdatePolicy handles update to a policy identified by a given ID -func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { +// updatePolicy handles update to a policy identified by a given ID +func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -90,8 +103,8 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { h.savePolicy(w, r, accountID, userID, policyID) } -// CreatePolicy handles policy creation request -func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { +// createPolicy handles policy creation request +func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -103,7 +116,7 @@ func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { } // savePolicy handles policy creation and update -func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { +func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID string, userID string, policyID string) { var req api.PutApiPoliciesPolicyIdJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -245,8 +258,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID util.WriteJSONObject(r.Context(), w, resp) } -// DeletePolicy handles policy deletion request -func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { +// deletePolicy handles policy deletion request +func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -266,11 +279,11 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetPolicy handles a group Get request identified by ID -func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { +// getPolicy handles a group Get request identified by ID +func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go similarity index 95% rename from management/server/http/policies_handler_test.go rename to management/server/http/handlers/policies/policies_handler_test.go index f8a897eb2..4b465a85a 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "bytes" @@ -24,12 +24,12 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) -func initPoliciesTestData(policies ...*server.Policy) *Policies { +func initPoliciesTestData(policies ...*server.Policy) *handler { testPolicies := make(map[string]*server.Policy, len(policies)) for _, policy := range policies { testPolicies[policy.ID] = policy } - return &Policies{ + return &handler{ accountManager: &mock_server.MockAccountManager{ GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) { policy, ok := testPolicies[policyID] @@ -91,14 +91,14 @@ func TestPoliciesGetPolicy(t *testing.T) { requestBody io.Reader }{ { - name: "GetPolicy OK", + name: "getPolicy OK", expectedBody: true, requestType: http.MethodGet, requestPath: "/api/policies/idofthepolicy", expectedStatus: http.StatusOK, }, { - name: "GetPolicy not found", + name: "getPolicy not found", requestType: http.MethodGet, requestPath: "/api/policies/notexists", expectedStatus: http.StatusNotFound, @@ -121,7 +121,7 @@ func TestPoliciesGetPolicy(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/policies/{policyId}", p.GetPolicy).Methods("GET") + router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -269,8 +269,8 @@ func TestPoliciesWritePolicy(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/policies", p.CreatePolicy).Methods("POST") - router.HandleFunc("/api/policies/{policyId}", p.UpdatePolicy).Methods("PUT") + router.HandleFunc("/api/policies", p.createPolicy).Methods("POST") + router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go similarity index 70% rename from management/server/http/posture_checks_handler.go rename to management/server/http/handlers/policies/posture_checks_handler.go index 2c8204292..44917605b 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -1,4 +1,4 @@ -package http +package policies import ( "encoding/json" @@ -9,22 +9,33 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) -// PostureChecksHandler is a handler that returns posture checks of the account. -type PostureChecksHandler struct { +// postureChecksHandler is a handler that returns posture checks of the account. +type postureChecksHandler struct { accountManager server.AccountManager geolocationManager *geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPostureChecksHandler creates a new PostureChecks handler -func NewPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg AuthCfg) *PostureChecksHandler { - return &PostureChecksHandler{ +func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { + postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg) + router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS") + addLocationsEndpoint(accountManager, locationManager, authCfg, router) +} + +// newPostureChecksHandler creates a new PostureChecks handler +func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { + return &postureChecksHandler{ accountManager: accountManager, geolocationManager: geolocationManager, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -34,8 +45,8 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa } } -// GetAllPostureChecks list for the account -func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { +// getAllPostureChecks list for the account +func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -57,8 +68,8 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, postureChecks) } -// UpdatePostureCheck handles update to a posture check identified by a given ID -func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { +// updatePostureCheck handles update to a posture check identified by a given ID +func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -82,8 +93,8 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http p.savePostureChecks(w, r, accountID, userID, postureChecksID) } -// CreatePostureCheck handles posture check creation request -func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http.Request) { +// createPostureCheck handles posture check creation request +func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -94,8 +105,8 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http p.savePostureChecks(w, r, accountID, userID, "") } -// GetPostureCheck handles a posture check Get request identified by ID -func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { +// getPostureCheck handles a posture check Get request identified by ID +func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -119,8 +130,8 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re util.WriteJSONObject(r.Context(), w, postureChecks.ToAPIResponse()) } -// DeletePostureCheck handles posture check deletion request -func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { +// deletePostureCheck handles posture check deletion request +func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) accountID, userID, err := p.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -140,11 +151,11 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } // savePostureChecks handles posture checks create and update -func (p *PostureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { +func (p *postureChecksHandler) savePostureChecks(w http.ResponseWriter, r *http.Request, accountID, userID, postureChecksID string) { var ( err error req api.PostureCheckUpdate diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go similarity index 96% rename from management/server/http/posture_checks_handler_test.go rename to management/server/http/handlers/policies/posture_checks_handler_test.go index f400cec81..e9a539e45 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -1,4 +1,4 @@ -package http +package policies import ( "bytes" @@ -25,13 +25,13 @@ import ( var berlin = "Berlin" var losAngeles = "Los Angeles" -func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksHandler { +func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksHandler { testPostureChecks := make(map[string]*posture.Checks, len(postureChecks)) for _, postureCheck := range postureChecks { testPostureChecks[postureCheck.ID] = postureCheck } - return &PostureChecksHandler{ + return &postureChecksHandler{ accountManager: &mock_server.MockAccountManager{ GetPostureChecksFunc: func(_ context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { p, ok := testPostureChecks[postureChecksID] @@ -147,35 +147,35 @@ func TestGetPostureCheck(t *testing.T) { requestBody io.Reader }{ { - name: "GetPostureCheck NBVersion OK", + name: "getPostureCheck NBVersion OK", expectedBody: true, id: postureCheck.ID, checkName: postureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck OSVersion OK", + name: "getPostureCheck OSVersion OK", expectedBody: true, id: osPostureCheck.ID, checkName: osPostureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck GeoLocation OK", + name: "getPostureCheck GeoLocation OK", expectedBody: true, id: geoPostureCheck.ID, checkName: geoPostureCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck PrivateNetwork OK", + name: "getPostureCheck PrivateNetwork OK", expectedBody: true, id: privateNetworkCheck.ID, checkName: privateNetworkCheck.Name, expectedStatus: http.StatusOK, }, { - name: "GetPostureCheck Not Found", + name: "getPostureCheck Not Found", id: "not-exists", expectedStatus: http.StatusNotFound, }, @@ -189,7 +189,7 @@ func TestGetPostureCheck(t *testing.T) { req := httptest.NewRequest(http.MethodGet, "/api/posture-checks/"+tc.id, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/posture-checks/{postureCheckId}", p.GetPostureCheck).Methods("GET") + router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -231,7 +231,7 @@ func TestPostureCheckUpdate(t *testing.T) { requestType string requestPath string requestBody io.Reader - setupHandlerFunc func(handler *PostureChecksHandler) + setupHandlerFunc func(handler *postureChecksHandler) }{ { name: "Create Posture Checks NB version", @@ -286,7 +286,7 @@ func TestPostureCheckUpdate(t *testing.T) { }, }, }, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -427,7 +427,7 @@ func TestPostureCheckUpdate(t *testing.T) { }`)), expectedStatus: http.StatusPreconditionFailed, expectedBody: false, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -614,7 +614,7 @@ func TestPostureCheckUpdate(t *testing.T) { }, }, }, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -677,7 +677,7 @@ func TestPostureCheckUpdate(t *testing.T) { }`)), expectedStatus: http.StatusPreconditionFailed, expectedBody: false, - setupHandlerFunc: func(handler *PostureChecksHandler) { + setupHandlerFunc: func(handler *postureChecksHandler) { handler.geolocationManager = nil }, }, @@ -842,8 +842,8 @@ func TestPostureCheckUpdate(t *testing.T) { } router := mux.NewRouter() - router.HandleFunc("/api/posture-checks", defaultHandler.CreatePostureCheck).Methods("POST") - router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.UpdatePostureCheck).Methods("PUT") + router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST") + router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go similarity index 85% rename from management/server/http/routes_handler.go rename to management/server/http/handlers/routes/routes_handler.go index f44a164e2..9d420066c 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -1,4 +1,4 @@ -package http +package routes import ( "encoding/json" @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" @@ -23,15 +24,24 @@ import ( const maxDomains = 32 const failedToConvertRoute = "failed to convert route to response: %v" -// RoutesHandler is the routes handler of the account -type RoutesHandler struct { +// handler is the routes handler of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewRoutesHandler returns a new instance of RoutesHandler handler -func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *RoutesHandler { - return &RoutesHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + routesHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS") + router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS") + router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS") +} + +// newHandler returns a new instance of routes handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -40,8 +50,8 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro } } -// GetAllRoutes returns the list of routes for the account -func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { +// getAllRoutes returns the list of routes for the account +func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -67,8 +77,8 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, apiRoutes) } -// CreateRoute handles route creation request -func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { +// createRoute handles route creation request +func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -139,7 +149,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, routes) } -func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { +func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error { if req.Network != nil && req.Domains != nil { return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided") } @@ -164,8 +174,8 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro return nil } -// UpdateRoute handles update to a route identified by a given ID -func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { +// updateRoute handles update to a route identified by a given ID +func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -257,8 +267,8 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, routes) } -// DeleteRoute handles route deletion request -func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { +// deleteRoute handles route deletion request +func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -278,11 +288,11 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// GetRoute handles a route Get request identified by ID -func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { +// getRoute handles a route Get request identified by ID +func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { diff --git a/management/server/http/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go similarity index 98% rename from management/server/http/routes_handler_test.go rename to management/server/http/handlers/routes/routes_handler_test.go index 83bd7004d..a25c899c9 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -1,4 +1,4 @@ -package http +package routes import ( "bytes" @@ -87,8 +87,8 @@ var testingAccount = &server.Account{ }, } -func initRoutesTestData() *RoutesHandler { - return &RoutesHandler{ +func initRoutesTestData() *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) { if routeID == existingRouteID { @@ -152,7 +152,7 @@ func initRoutesTestData() *RoutesHandler { return nil }, GetAccountIDFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { - //return testingAccount, testingAccount.Users["test_user"], nil + // return testingAccount, testingAccount.Users["test_user"], nil return testingAccount.Id, testingAccount.Users["test_user"].Id, nil }, }, @@ -521,10 +521,10 @@ func TestRoutesHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/routes/{routeId}", p.GetRoute).Methods("GET") - router.HandleFunc("/api/routes/{routeId}", p.DeleteRoute).Methods("DELETE") - router.HandleFunc("/api/routes", p.CreateRoute).Methods("POST") - router.HandleFunc("/api/routes/{routeId}", p.UpdateRoute).Methods("PUT") + router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET") + router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE") + router.HandleFunc("/api/routes", p.createRoute).Methods("POST") + router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go similarity index 78% rename from management/server/http/setupkeys_handler.go rename to management/server/http/handlers/setup_keys/setupkeys_handler.go index 9ba5977bb..9432d5549 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -1,4 +1,4 @@ -package http +package setup_keys import ( "context" @@ -10,20 +10,30 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// SetupKeysHandler is a handler that returns a list of setup keys of the account -type SetupKeysHandler struct { +// handler is a handler that returns a list of setup keys of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewSetupKeysHandler creates a new SetupKeysHandler HTTP handler -func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) *SetupKeysHandler { - return &SetupKeysHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + keysHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS") +} + +// newHandler creates a new setup key handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -32,8 +42,8 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) } } -// CreateSetupKey is a POST requests that creates a new SetupKey -func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { +// createSetupKey is a POST requests that creates a new SetupKey +func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -89,8 +99,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -// GetSetupKey is a GET request to get a SetupKey by ID -func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { +// getSetupKey is a GET request to get a SetupKey by ID +func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -114,8 +124,8 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { writeSuccess(r.Context(), w, key) } -// UpdateSetupKey is a PUT request to update server.SetupKey -func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { +// updateSetupKey is a PUT request to update server.SetupKey +func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -155,8 +165,8 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request writeSuccess(r.Context(), w, newKey) } -// GetAllSetupKeys is a GET request that returns a list of SetupKey -func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { +// getAllSetupKeys is a GET request that returns a list of SetupKey +func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -178,7 +188,7 @@ func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Reques util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request) { +func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -199,7 +209,7 @@ func (h *SetupKeysHandler) DeleteSetupKey(w http.ResponseWriter, r *http.Request return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go similarity index 95% rename from management/server/http/setupkeys_handler_test.go rename to management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 09256d0ea..516a2ab8b 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -1,4 +1,4 @@ -package http +package setup_keys import ( "bytes" @@ -26,12 +26,13 @@ const ( newSetupKeyName = "New Setup Key" updatedSetupKeyName = "KKKey" notFoundSetupKeyID = "notFoundSetupKeyID" + testAccountID = "test_id" ) func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey, user *server.User, -) *SetupKeysHandler { - return &SetupKeysHandler{ +) *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil @@ -178,11 +179,11 @@ func TestSetupKeysHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeys).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys", handler.CreateSetupKey).Methods("POST", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.GetSetupKey).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.UpdateSetupKey).Methods("PUT", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.DeleteSetupKey).Methods("DELETE", "OPTIONS") + router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/pat_handler.go b/management/server/http/handlers/users/pat_handler.go similarity index 75% rename from management/server/http/pat_handler.go rename to management/server/http/handlers/users/pat_handler.go index dfa9563e3..2caf98ad8 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -1,4 +1,4 @@ -package http +package users import ( "encoding/json" @@ -9,20 +9,29 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" ) -// PATHandler is the nameserver group handler of the account -type PATHandler struct { +// patHandler is the nameserver group handler of the account +type patHandler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewPATsHandler creates a new PATHandler HTTP handler -func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATHandler { - return &PATHandler{ +func addUsersTokensEndpoint(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + tokenHandler := newPATsHandler(accountManager, authCfg) + router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS") +} + +// newPATsHandler creates a new patHandler HTTP handler +func newPATsHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *patHandler { + return &patHandler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -31,8 +40,8 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH } } -// GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user -func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { +// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user +func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -61,8 +70,8 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, patResponse) } -// GetToken is HTTP GET handler that returns a personal access token for the given user -func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { +// getToken is HTTP GET handler that returns a personal access token for the given user +func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -92,8 +101,8 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toPATResponse(pat)) } -// CreateToken is HTTP POST handler that creates a personal access token for the given user -func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { +// createToken is HTTP POST handler that creates a personal access token for the given user +func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -124,8 +133,8 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toPATGeneratedResponse(pat)) } -// DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user -func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { +// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user +func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) accountID, userID, err := h.accountManager.GetAccountIDFromToken(r.Context(), claims) if err != nil { @@ -152,7 +161,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { diff --git a/management/server/http/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go similarity index 96% rename from management/server/http/pat_handler_test.go rename to management/server/http/handlers/users/pat_handler_test.go index c28228a50..ef6fb973e 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -1,4 +1,4 @@ -package http +package users import ( "bytes" @@ -61,8 +61,8 @@ var testAccount = &server.Account{ }, } -func initPATTestData() *PATHandler { - return &PATHandler{ +func initPATTestData() *patHandler { + return &patHandler{ accountManager: &mock_server.MockAccountManager{ CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { if accountID != existingAccountID { @@ -186,10 +186,10 @@ func TestTokenHandlers(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}/tokens", p.GetAllTokens).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.GetToken).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens", p.CreateToken).Methods("POST") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.DeleteToken).Methods("DELETE") + router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/users_handler.go b/management/server/http/handlers/users/users_handler.go similarity index 80% rename from management/server/http/users_handler.go rename to management/server/http/handlers/users/users_handler.go index 6e151a0da..c843bc52b 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -1,4 +1,4 @@ -package http +package users import ( "encoding/json" @@ -9,6 +9,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" @@ -16,15 +17,25 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" ) -// UsersHandler is a handler that returns users of the account -type UsersHandler struct { +// handler is a handler that returns users of the account +type handler struct { accountManager server.AccountManager claimsExtractor *jwtclaims.ClaimsExtractor } -// NewUsersHandler creates a new UsersHandler HTTP handler -func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *UsersHandler { - return &UsersHandler{ +func AddEndpoints(accountManager server.AccountManager, authCfg configs.AuthCfg, router *mux.Router) { + userHandler := newHandler(accountManager, authCfg) + router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS") + router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") + addUsersTokensEndpoint(accountManager, authCfg, router) +} + +// newHandler creates a new UsersHandler HTTP handler +func newHandler(accountManager server.AccountManager, authCfg configs.AuthCfg) *handler { + return &handler{ accountManager: accountManager, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), @@ -33,8 +44,8 @@ func NewUsersHandler(accountManager server.AccountManager, authCfg AuthCfg) *Use } } -// UpdateUser is a PUT requests to update User data -func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { +// updateUser is a PUT requests to update User data +func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPut { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -94,8 +105,8 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } -// DeleteUser is a DELETE request to delete a user -func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { +// deleteUser is a DELETE request to delete a user +func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodDelete { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -121,11 +132,11 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -// CreateUser creates a User in the system with a status "invited" (effectively this is a user invite). -func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { +// createUser creates a User in the system with a status "invited" (effectively this is a user invite). +func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -175,9 +186,9 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, claims.UserId)) } -// GetAllUsers returns a list of users of the account this user belongs to. +// getAllUsers returns a list of users of the account this user belongs to. // It also gathers additional user data (like email and name) from the IDP manager. -func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { +func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodGet { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -222,9 +233,9 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, users) } -// InviteUser resend invitations to users who haven't activated their accounts, +// inviteUser resend invitations to users who haven't activated their accounts, // prior to the expiration period. -func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -250,7 +261,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { return } - util.WriteJSONObject(r.Context(), w, emptyObject{}) + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { diff --git a/management/server/http/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go similarity index 97% rename from management/server/http/users_handler_test.go rename to management/server/http/handlers/users/users_handler_test.go index f3d989da1..6f6a91236 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -1,4 +1,4 @@ -package http +package users import ( "bytes" @@ -61,8 +61,8 @@ var usersTestAccount = &server.Account{ }, } -func initUsersTestData() *UsersHandler { - return &UsersHandler{ +func initUsersTestData() *handler { + return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return usersTestAccount.Id, claims.UserId, nil @@ -147,7 +147,7 @@ func TestGetUsers(t *testing.T) { requestPath string expectedUserIDs []string }{ - {name: "GetAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}}, + {name: "getAllUsers", requestType: http.MethodGet, requestPath: "/api/users", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID, serviceUserID}}, {name: "GetOnlyServiceUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=true", expectedStatus: http.StatusOK, expectedUserIDs: []string{serviceUserID}}, {name: "GetOnlyRegularUsers", requestType: http.MethodGet, requestPath: "/api/users?service_user=false", expectedStatus: http.StatusOK, expectedUserIDs: []string{existingUserID, regularUserID}}, } @@ -159,7 +159,7 @@ func TestGetUsers(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, nil) - userHandler.GetAllUsers(recorder, req) + userHandler.getAllUsers(recorder, req) res := recorder.Result() defer res.Body.Close() @@ -265,7 +265,7 @@ func TestUpdateUser(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}", userHandler.UpdateUser).Methods("PUT") + router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -356,7 +356,7 @@ func TestCreateUser(t *testing.T) { req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) rr := httptest.NewRecorder() - userHandler.CreateUser(rr, req) + userHandler.createUser(rr, req) res := rr.Result() defer res.Body.Close() @@ -401,7 +401,7 @@ func TestInviteUser(t *testing.T) { req = mux.SetURLVars(req, tc.requestVars) rr := httptest.NewRecorder() - userHandler.InviteUser(rr, req) + userHandler.inviteUser(rr, req) res := rr.Result() defer res.Body.Close() @@ -454,7 +454,7 @@ func TestDeleteUser(t *testing.T) { req = mux.SetURLVars(req, tc.requestVars) rr := httptest.NewRecorder() - userHandler.DeleteUser(rr, req) + userHandler.deleteUser(rr, req) res := rr.Result() defer res.Body.Close() diff --git a/management/server/http/util/util.go b/management/server/http/util/util.go index 603c1c696..3d7eed498 100644 --- a/management/server/http/util/util.go +++ b/management/server/http/util/util.go @@ -14,6 +14,10 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) +// EmptyObject is an empty struct used to return empty JSON object +type EmptyObject struct { +} + type ErrorResponse struct { Message string `json:"message"` Code int `json:"code"` From dcba6a6b7e5bcb8980860fee4d634002a9270e56 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 11 Dec 2024 16:46:51 +0100 Subject: [PATCH 079/159] fix: client/Dockerfile to reduce vulnerabilities (#3019) The following vulnerabilities are fixed with an upgrade: - https://snyk.io/vuln/SNYK-ALPINE320-OPENSSL-8235201 - https://snyk.io/vuln/SNYK-ALPINE320-OPENSSL-8235201 Co-authored-by: snyk-bot --- client/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/Dockerfile b/client/Dockerfile index b9f7c1355..2f5ff14ae 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.20 +FROM alpine:3.21.0 RUN apk add --no-cache ca-certificates iptables ip6tables ENV NB_FOREGROUND_MODE=true ENTRYPOINT [ "/usr/local/bin/netbird","up"] From a4a30744adab711799e10e1355c20b73c81fb2b3 Mon Sep 17 00:00:00 2001 From: "M. Essam" Date: Sat, 14 Dec 2024 22:17:53 +0200 Subject: [PATCH 080/159] Fix race condition with systray ready (#2993) --- client/ui/client_ui.go | 1 + 1 file changed, 1 insertion(+) diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index d046bab5f..8ca0db73f 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -572,6 +572,7 @@ func (s *serviceClient) onTrayReady() { s.update.SetOnUpdateListener(s.onUpdateAvailable) go func() { s.getSrvConfig() + time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon for { err := s.updateStatus() if err != nil { From 287ae811958c362b9acaf888f07968b7960964d2 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 14 Dec 2024 21:18:46 +0100 Subject: [PATCH 081/159] [misc] split tests with management and rest (#3051) optimize go cache for tests --- .github/workflows/golang-test-darwin.yml | 6 +- .github/workflows/golang-test-linux.yml | 206 +++++++++++++++++++--- .github/workflows/golang-test-windows.yml | 25 ++- 3 files changed, 204 insertions(+), 33 deletions(-) diff --git a/.github/workflows/golang-test-darwin.yml b/.github/workflows/golang-test-darwin.yml index 88db8c5e8..2dbeb106a 100644 --- a/.github/workflows/golang-test-darwin.yml +++ b/.github/workflows/golang-test-darwin.yml @@ -21,6 +21,7 @@ jobs: uses: actions/setup-go@v5 with: go-version: "1.23.x" + cache: false - name: Checkout code uses: actions/checkout@v4 @@ -28,8 +29,9 @@ jobs: uses: actions/cache@v4 with: path: ~/go/pkg/mod - key: macos-go-${{ hashFiles('**/go.sum') }} + key: macos-gotest-${{ hashFiles('**/go.sum') }} restore-keys: | + macos-gotest- macos-go- - name: Install libpcap @@ -42,4 +44,4 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... + run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 36dcb791f..85658d237 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -11,7 +11,133 @@ concurrency: cancel-in-progress: true jobs: + build-cache: + runs-on: ubuntu-22.04 + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache@v4 + id: cache + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + + + - name: Install dependencies + if: steps.cache.outputs.cache-hit != 'true' + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: steps.cache.outputs.cache-hit != 'true' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Build client + if: steps.cache.outputs.cache-hit != 'true' + working-directory: client + run: CGO_ENABLED=1 go build . + + - name: Build client 386 + if: steps.cache.outputs.cache-hit != 'true' + working-directory: client + run: CGO_ENABLED=1 GOARCH=386 go build -o client-386 . + + - name: Build management + if: steps.cache.outputs.cache-hit != 'true' + working-directory: management + run: CGO_ENABLED=1 go build . + + - name: Build management 386 + if: steps.cache.outputs.cache-hit != 'true' + working-directory: management + run: CGO_ENABLED=1 GOARCH=386 go build -o management-386 . + + - name: Build signal + if: steps.cache.outputs.cache-hit != 'true' + working-directory: signal + run: CGO_ENABLED=1 go build . + + - name: Build signal 386 + if: steps.cache.outputs.cache-hit != 'true' + working-directory: signal + run: CGO_ENABLED=1 GOARCH=386 go build -o signal-386 . + + - name: Build relay + if: steps.cache.outputs.cache-hit != 'true' + working-directory: relay + run: CGO_ENABLED=1 go build . + + - name: Build relay 386 + if: steps.cache.outputs.cache-hit != 'true' + working-directory: relay + run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 . + test: + needs: [build-cache] + strategy: + fail-fast: false + matrix: + arch: [ '386','amd64' ] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Install modules + run: go mod tidy + + - name: check git status + run: git --no-pager diff --exit-code + + - name: Test + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v /management) + + test_management: + needs: [ build-cache ] strategy: fail-fast: false matrix: @@ -23,19 +149,26 @@ jobs: uses: actions/setup-go@v5 with: go-version: "1.23.x" - - - - name: Cache Go modules - uses: actions/cache@v4 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- + cache: false - name: Checkout code uses: actions/checkout@v4 + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev @@ -50,9 +183,10 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) benchmark: + needs: [ build-cache ] strategy: fail-fast: false matrix: @@ -64,19 +198,26 @@ jobs: uses: actions/setup-go@v5 with: go-version: "1.23.x" - - - - name: Cache Go modules - uses: actions/cache@v4 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- + cache: false - name: Checkout code uses: actions/checkout@v4 + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev @@ -91,27 +232,36 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m ./... test_client_on_docker: + needs: [ build-cache ] runs-on: ubuntu-20.04 steps: - name: Install Go uses: actions/setup-go@v5 with: go-version: "1.23.x" - - - name: Cache Go modules - uses: actions/cache@v4 - with: - path: ~/go/pkg/mod - key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go- + cache: false - name: Checkout code uses: actions/checkout@v4 + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + - name: Install dependencies run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev diff --git a/.github/workflows/golang-test-windows.yml b/.github/workflows/golang-test-windows.yml index d378bec3f..3a3c47052 100644 --- a/.github/workflows/golang-test-windows.yml +++ b/.github/workflows/golang-test-windows.yml @@ -24,6 +24,23 @@ jobs: id: go with: go-version: "1.23.x" + cache: false + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $env:GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest- + ${{ runner.os }}-go- - name: Download wintun uses: carlosperate/download-file-action@v2 @@ -42,11 +59,13 @@ jobs: - run: choco install -y sysinternals --ignore-checksums - run: choco install -y mingw - - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod - - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build + - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }} + - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }} + - run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy + - run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV - name: test - run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1" + run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1" - name: test output if: ${{ always() }} run: Get-Content test-out.txt From f591e4740475e7bc0028fe91f27e4865ab699ec0 Mon Sep 17 00:00:00 2001 From: "M. Essam" Date: Mon, 16 Dec 2024 10:41:36 +0200 Subject: [PATCH 082/159] Handle DNF5 install script (#3026) --- release_files/install.sh | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/release_files/install.sh b/release_files/install.sh index b0fec2733..bb917c39a 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -239,7 +239,12 @@ install_netbird() { dnf) add_rpm_repo ${SUDO} dnf -y install dnf-plugin-config-manager - ${SUDO} dnf config-manager --add-repo /etc/yum.repos.d/netbird.repo + if [[ "$(dnf --version | head -n1 | cut -d. -f1)" > "4" ]]; + then + ${SUDO} dnf config-manager addrepo --from-repofile=/etc/yum.repos.d/netbird.repo + else + ${SUDO} dnf config-manager --add-repo /etc/yum.repos.d/netbird.repo + fi ${SUDO} dnf -y install netbird if ! $SKIP_UI_APP; then From 3844516aa7722e3a9cdaf478bfb409caed701b8a Mon Sep 17 00:00:00 2001 From: Jesse R Codling Date: Mon, 16 Dec 2024 03:58:54 -0500 Subject: [PATCH 083/159] [client] fix: reformat IPv6 ICE addresses when punching (#3050) Should fix #2327 and #2606 by checking for IPv6 addresses from ICE --- client/internal/peer/worker_ice.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 7ce4797c3..4cdd18ff1 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -264,7 +264,13 @@ func (w *WorkerICE) closeAgent(cancel context.CancelFunc) { func (w *WorkerICE) punchRemoteWGPort(pair *ice.CandidatePair, remoteWgPort int) { // wait local endpoint configuration time.Sleep(time.Second) - addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pair.Remote.Address(), remoteWgPort)) + addrString := pair.Remote.Address() + parsed, err := netip.ParseAddr(addrString) + if (err == nil) && (parsed.Is6()) { + addrString = fmt.Sprintf("[%s]", addrString) + //IPv6 Literals need to be wrapped in brackets for Resolve*Addr() + } + addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addrString, remoteWgPort)) if err != nil { w.log.Warnf("got an error while resolving the udp address, err: %s", err) return From 9eff58ae62c094363172a0bf2cdc1557b13b4a20 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 16 Dec 2024 10:30:41 +0100 Subject: [PATCH 084/159] Upgrade x/crypto package (#3055) Mitigates the CVE-2024-45337 --- go.mod | 10 +++++----- go.sum | 20 ++++++++++---------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/go.mod b/go.mod index 2b4111ce3..c504925d2 100644 --- a/go.mod +++ b/go.mod @@ -19,8 +19,8 @@ require ( github.com/spf13/cobra v1.7.0 github.com/spf13/pflag v1.0.5 github.com/vishvananda/netlink v1.2.1-beta.2 - golang.org/x/crypto v0.28.0 - golang.org/x/sys v0.26.0 + golang.org/x/crypto v0.31.0 + golang.org/x/sys v0.28.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 @@ -92,8 +92,8 @@ require ( golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a golang.org/x/net v0.30.0 golang.org/x/oauth2 v0.19.0 - golang.org/x/sync v0.8.0 - golang.org/x/term v0.25.0 + golang.org/x/sync v0.10.0 + golang.org/x/term v0.27.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/postgres v1.5.7 @@ -219,7 +219,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect - golang.org/x/text v0.19.0 // indirect + golang.org/x/text v0.21.0 // indirect golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect diff --git a/go.sum b/go.sum index 35abe82d2..04d2bc59a 100644 --- a/go.sum +++ b/go.sum @@ -774,8 +774,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.8.0/go.mod h1:mRqEX+O9/h5TFCrQhkgjo2yKi0yYA+9ecGkdQoHrywE= golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1mg= -golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= -golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= +golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -901,8 +901,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.8.0 h1:3NFvSEYkUoMifnESzZl15y791HH1qU2xm6eCJU5ZPXQ= -golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20170830134202-bb24a47a89ea/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -974,8 +974,8 @@ golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= -golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -983,8 +983,8 @@ golang.org/x/term v0.7.0/go.mod h1:P32HKFT3hSsZrRxla30E9HqToFYAQPCMs/zFMBUFqPY= golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= golang.org/x/term v0.11.0/go.mod h1:zC9APTIj3jG3FdV/Ons+XE1riIZXG4aZ4GTHiPZJPIU= golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= -golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= -golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= +golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= +golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -999,8 +999,8 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= -golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= From 703647da1e59636ac4affe34261c27f4e18b8878 Mon Sep 17 00:00:00 2001 From: "VYSE V.E.O" Date: Mon, 16 Dec 2024 21:17:46 +0800 Subject: [PATCH 085/159] fix client unsupported h2 protocol when only 443 activated (#3009) When I remove 80 http port in Caddyfile, netbird client cannot connect server:443. Logs show error below: {"level":"debug","ts":1733809631.4012625,"logger":"http.stdlib","msg":"http: TLS handshake error from redacted:41580: tls: client requested unsupported application protocols ([h2])"} I wonder here h2 protocol is absent. --- infrastructure_files/getting-started-with-zitadel.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index 7793d1fda..9b80058c2 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -530,7 +530,7 @@ renderCaddyfile() { { debug servers :80,:443 { - protocols h1 h2c h3 + protocols h1 h2c h2 h3 } } From 37ad370344aec96c01d461db53473c78ed7f7240 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 16 Dec 2024 18:09:31 +0100 Subject: [PATCH 086/159] [client] Avoid using iota on mixed const block (#3057) Used the values as resolved when the first iota value was the second const in the block. --- client/iface/device/kernel_module_linux.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/client/iface/device/kernel_module_linux.go b/client/iface/device/kernel_module_linux.go index 0d195779d..b28ddd36c 100644 --- a/client/iface/device/kernel_module_linux.go +++ b/client/iface/device/kernel_module_linux.go @@ -27,14 +27,14 @@ import ( type status int const ( - defaultModuleDir = "/lib/modules" - unknown status = iota - unloaded - unloading - loading - live - inuse - envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED" + unknown status = 1 + unloaded status = 2 + unloading status = 3 + loading status = 4 + live status = 5 + inuse status = 6 + defaultModuleDir = "/lib/modules" + envDisableWireGuardKernel = "NB_WG_KERNEL_DISABLED" ) type module struct { From ddc365f7a0d52ab1d10d44d76f70176040e2fe64 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 20 Dec 2024 11:30:28 +0100 Subject: [PATCH 087/159] [client, management] Add new network concept (#3047) --------- Co-authored-by: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Co-authored-by: bcmmbaga Co-authored-by: Maycon Santos Co-authored-by: Zoltan Papp --- .github/workflows/golang-test-linux.yml | 2 +- .github/workflows/golangci-lint.yml | 4 +- client/anonymize/anonymize.go | 15 +- client/anonymize/anonymize_test.go | 56 +- client/cmd/networks.go | 173 ++ client/cmd/root.go | 6 +- client/cmd/route.go | 174 -- client/cmd/status.go | 44 +- client/cmd/status_test.go | 28 +- client/cmd/testutil_test.go | 6 +- client/internal/dns/handler_chain.go | 222 +++ client/internal/dns/handler_chain_test.go | 511 ++++++ client/internal/dns/local.go | 14 +- client/internal/dns/mock_server.go | 22 +- client/internal/dns/server.go | 104 +- client/internal/dns/server_test.go | 91 +- client/internal/dns/service_listener.go | 1 + client/internal/dns/upstream.go | 9 + client/internal/dnsfwd/forwarder.go | 157 ++ client/internal/dnsfwd/manager.go | 106 ++ client/internal/engine.go | 155 +- client/internal/engine_test.go | 27 +- client/internal/peer/conn.go | 5 + client/internal/peer/status.go | 28 +- client/internal/peerstore/store.go | 87 + client/internal/routemanager/client.go | 81 +- .../routemanager/dnsinterceptor/handler.go | 356 ++++ client/internal/routemanager/dynamic/route.go | 8 +- client/internal/routemanager/manager.go | 113 +- client/internal/routemanager/manager_test.go | 6 +- client/internal/routemanager/mock.go | 31 +- client/ios/NetBirdSDK/client.go | 17 +- client/proto/daemon.pb.go | 622 +++---- client/proto/daemon.proto | 28 +- client/proto/daemon_grpc.pb.go | 88 +- client/server/{route.go => network.go} | 67 +- client/server/server.go | 4 +- client/server/server_test.go | 6 +- client/ui/client_ui.go | 8 +- client/ui/{route.go => network.go} | 164 +- dns/dns.go | 6 + go.mod | 3 +- go.sum | 5 +- management/client/client_test.go | 6 +- management/cmd/management.go | 24 +- management/cmd/migration_up.go | 4 +- management/proto/management.pb.go | 527 +++--- management/proto/management.proto | 9 + management/server/account.go | 1105 ++---------- management/server/account_request_buffer.go | 11 +- management/server/account_test.go | 345 ++-- management/server/activity/codes.go | 36 + management/server/config.go | 3 +- management/server/dns.go | 141 +- management/server/dns_test.go | 39 +- management/server/ephemeral.go | 10 +- management/server/ephemeral_test.go | 12 +- management/server/group.go | 237 ++- management/server/group_test.go | 76 +- management/server/groups/manager.go | 196 +++ management/server/grpcserver.go | 33 +- management/server/http/api/openapi.yml | 697 +++++++- management/server/http/api/types.gen.go | 194 ++- management/server/http/handler.go | 8 +- .../handlers/accounts/accounts_handler.go | 9 +- .../accounts/accounts_handler_test.go | 78 +- .../http/handlers/dns/dns_settings_handler.go | 3 +- .../handlers/dns/dns_settings_handler_test.go | 14 +- .../handlers/events/events_handler_test.go | 8 +- .../http/handlers/groups/groups_handler.go | 49 +- .../handlers/groups/groups_handler_test.go | 24 +- .../server/http/handlers/networks/handler.go | 321 ++++ .../handlers/networks/resources_handler.go | 222 +++ .../http/handlers/networks/routers_handler.go | 165 ++ .../http/handlers/peers/peers_handler.go | 47 +- .../http/handlers/peers/peers_handler_test.go | 27 +- .../policies/geolocation_handler_test.go | 6 +- .../handlers/policies/policies_handler.go | 101 +- .../policies/policies_handler_test.go | 47 +- .../http/handlers/routes/routes_handler.go | 2 +- .../handlers/routes/routes_handler_test.go | 40 +- .../handlers/setup_keys/setupkeys_handler.go | 13 +- .../setup_keys/setupkeys_handler_test.go | 26 +- .../server/http/handlers/users/pat_handler.go | 5 +- .../http/handlers/users/pat_handler_test.go | 24 +- .../http/handlers/users/users_handler.go | 15 +- .../http/handlers/users/users_handler_test.go | 32 +- .../server/http/middleware/access_control.go | 4 +- .../server/http/middleware/auth_middleware.go | 4 +- .../http/middleware/auth_middleware_test.go | 10 +- management/server/integrated_validator.go | 8 +- .../server/integrated_validator/interface.go | 4 +- management/server/management_proto_test.go | 11 +- management/server/management_test.go | 10 +- management/server/metrics/selfhosted.go | 7 +- management/server/metrics/selfhosted_test.go | 58 +- management/server/migration/migration_test.go | 66 +- management/server/mock_server/account_mock.go | 162 +- management/server/nameserver.go | 63 +- management/server/nameserver_test.go | 15 +- management/server/networks/manager.go | 187 +++ management/server/networks/manager_test.go | 254 +++ .../server/networks/resources/manager.go | 383 +++++ .../server/networks/resources/manager_test.go | 411 +++++ .../networks/resources/types/resource.go | 169 ++ .../networks/resources/types/resource_test.go | 53 + management/server/networks/routers/manager.go | 289 ++++ .../server/networks/routers/manager_test.go | 234 +++ .../server/networks/routers/types/router.go | 75 + .../networks/routers/types/router_test.go | 100 ++ management/server/networks/types/network.go | 56 + management/server/peer.go | 104 +- management/server/peer_test.go | 243 ++- management/server/permissions/manager.go | 102 ++ management/server/policy.go | 467 +----- management/server/policy_test.go | 192 +-- management/server/posture_checks.go | 54 +- management/server/posture_checks_test.go | 51 +- management/server/resource.go | 21 + management/server/route.go | 304 +--- management/server/route_test.go | 620 ++++++- management/server/settings/manager.go | 37 + management/server/setupkey.go | 241 +-- management/server/setupkey_test.go | 60 +- management/server/status/error.go | 32 + management/server/{ => store}/file_store.go | 32 +- management/server/{ => store}/sql_store.go | 461 ++++-- .../server/{ => store}/sql_store_test.go | 696 ++++++-- management/server/{ => store}/store.go | 119 +- management/server/{ => store}/store_test.go | 10 +- management/server/testdata/networks.sql | 18 + management/server/testdata/store.sql | 14 + management/server/types/account.go | 1475 +++++++++++++++++ management/server/types/account_test.go | 375 +++++ management/server/types/dns_settings.go | 16 + management/server/types/firewall_rule.go | 130 ++ management/server/{group => types}/group.go | 44 +- .../server/{group => types}/group_test.go | 2 +- management/server/{ => types}/network.go | 12 +- management/server/{ => types}/network_test.go | 2 +- .../{ => types}/personal_access_token.go | 2 +- .../{ => types}/personal_access_token_test.go | 2 +- management/server/types/policy.go | 125 ++ management/server/types/policyrule.go | 87 + management/server/types/resource.go | 30 + .../server/types/route_firewall_rule.go | 32 + management/server/types/settings.go | 68 + management/server/types/setupkey.go | 181 ++ management/server/types/user.go | 231 +++ management/server/updatechannel.go | 3 +- management/server/user.go | 356 +--- management/server/user_test.go | 444 +++-- management/server/users/manager.go | 49 + management/server/util/util.go | 16 + route/route.go | 58 +- 155 files changed, 13909 insertions(+), 4993 deletions(-) create mode 100644 client/cmd/networks.go delete mode 100644 client/cmd/route.go create mode 100644 client/internal/dns/handler_chain.go create mode 100644 client/internal/dns/handler_chain_test.go create mode 100644 client/internal/dnsfwd/forwarder.go create mode 100644 client/internal/dnsfwd/manager.go create mode 100644 client/internal/peerstore/store.go create mode 100644 client/internal/routemanager/dnsinterceptor/handler.go rename client/server/{route.go => network.go} (58%) rename client/ui/{route.go => network.go} (56%) create mode 100644 management/server/groups/manager.go create mode 100644 management/server/http/handlers/networks/handler.go create mode 100644 management/server/http/handlers/networks/resources_handler.go create mode 100644 management/server/http/handlers/networks/routers_handler.go create mode 100644 management/server/networks/manager.go create mode 100644 management/server/networks/manager_test.go create mode 100644 management/server/networks/resources/manager.go create mode 100644 management/server/networks/resources/manager_test.go create mode 100644 management/server/networks/resources/types/resource.go create mode 100644 management/server/networks/resources/types/resource_test.go create mode 100644 management/server/networks/routers/manager.go create mode 100644 management/server/networks/routers/manager_test.go create mode 100644 management/server/networks/routers/types/router.go create mode 100644 management/server/networks/routers/types/router_test.go create mode 100644 management/server/networks/types/network.go create mode 100644 management/server/permissions/manager.go create mode 100644 management/server/resource.go create mode 100644 management/server/settings/manager.go rename management/server/{ => store}/file_store.go (89%) rename management/server/{ => store}/sql_store.go (76%) rename management/server/{ => store}/sql_store_test.go (75%) rename management/server/{ => store}/store.go (73%) rename management/server/{ => store}/store_test.go (93%) create mode 100644 management/server/testdata/networks.sql create mode 100644 management/server/types/account.go create mode 100644 management/server/types/account_test.go create mode 100644 management/server/types/dns_settings.go create mode 100644 management/server/types/firewall_rule.go rename management/server/{group => types}/group.go (58%) rename management/server/{group => types}/group_test.go (99%) rename management/server/{ => types}/network.go (96%) rename management/server/{ => types}/network_test.go (98%) rename management/server/{ => types}/personal_access_token.go (99%) rename management/server/{ => types}/personal_access_token_test.go (98%) create mode 100644 management/server/types/policy.go create mode 100644 management/server/types/policyrule.go create mode 100644 management/server/types/resource.go create mode 100644 management/server/types/route_firewall_rule.go create mode 100644 management/server/types/settings.go create mode 100644 management/server/types/setupkey.go create mode 100644 management/server/types/user.go create mode 100644 management/server/users/manager.go create mode 100644 management/server/util/util.go diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 85658d237..50c6cba84 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -232,7 +232,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... test_client_on_docker: needs: [ build-cache ] diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index dacb1922b..89defce32 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -46,7 +46,7 @@ jobs: if: matrix.os == 'ubuntu-latest' run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev - name: golangci-lint - uses: golangci/golangci-lint-action@v3 + uses: golangci/golangci-lint-action@v4 with: version: latest - args: --timeout=12m + args: --timeout=12m --out-format colored-line-number diff --git a/client/anonymize/anonymize.go b/client/anonymize/anonymize.go index 9a6d97207..89552724a 100644 --- a/client/anonymize/anonymize.go +++ b/client/anonymize/anonymize.go @@ -21,6 +21,8 @@ type Anonymizer struct { currentAnonIPv6 netip.Addr startAnonIPv4 netip.Addr startAnonIPv6 netip.Addr + + domainKeyRegex *regexp.Regexp } func DefaultAddresses() (netip.Addr, netip.Addr) { @@ -36,6 +38,8 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer { currentAnonIPv6: startIPv6, startAnonIPv4: startIPv4, startAnonIPv6: startIPv6, + + domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`), } } @@ -171,20 +175,15 @@ func (a *Anonymizer) AnonymizeSchemeURI(text string) string { return re.ReplaceAllStringFunc(text, a.AnonymizeURI) } -// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string. func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string { - domainPattern := `dns\.Question{Name:"([^"]+)",` - domainRegex := regexp.MustCompile(domainPattern) - - return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string { - parts := strings.Split(match, `"`) + return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string { + parts := strings.SplitN(match, "=", 2) if len(parts) >= 2 { domain := parts[1] if strings.HasSuffix(domain, anonTLD) { return match } - randomDomain := generateRandomString(10) + anonTLD - return strings.Replace(match, domain, randomDomain, 1) + return "domain=" + a.AnonymizeDomain(domain) } return match }) diff --git a/client/anonymize/anonymize_test.go b/client/anonymize/anonymize_test.go index a3aae1ee9..ff2e48869 100644 --- a/client/anonymize/anonymize_test.go +++ b/client/anonymize/anonymize_test.go @@ -46,11 +46,59 @@ func TestAnonymizeIP(t *testing.T) { func TestAnonymizeDNSLogLine(t *testing.T) { anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{}) - testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}` + tests := []struct { + name string + input string + original string + expect string + }{ + { + name: "Basic domain with trailing content", + input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123", + original: "example.com", + expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`, + }, + { + name: "Domain with trailing dot", + input: "domain=example.com. processing request with status=pending", + original: "example.com", + expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`, + }, + { + name: "Multiple domains in log", + input: "forward domain=first.com status=ok, redirect to domain=second.com port=443", + original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately + expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`, + }, + { + name: "Already anonymized domain", + input: "got request domain=anon-xyz123.domain from=client1 to=server2", + original: "", // nothing should be anonymized + expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`, + }, + { + name: "Subdomain with trailing dot", + input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp", + original: "example.com", + expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`, + }, + { + name: "Handler chain pattern log", + input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100", + original: "example.com", + expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`, + }, + } - result := anonymizer.AnonymizeDNSLogLine(testLog) - require.NotEqual(t, testLog, result) - assert.NotContains(t, result, "example.com") + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := anonymizer.AnonymizeDNSLogLine(tc.input) + if tc.original != "" { + assert.NotContains(t, result, tc.original) + } + assert.Regexp(t, tc.expect, result) + }) + } } func TestAnonymizeDomain(t *testing.T) { diff --git a/client/cmd/networks.go b/client/cmd/networks.go new file mode 100644 index 000000000..7b9724bc5 --- /dev/null +++ b/client/cmd/networks.go @@ -0,0 +1,173 @@ +package cmd + +import ( + "fmt" + "strings" + + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/proto" +) + +var appendFlag bool + +var networksCMD = &cobra.Command{ + Use: "networks", + Aliases: []string{"routes"}, + Short: "Manage networks", + Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`, +} + +var routesListCmd = &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List networks", + Example: " netbird networks list", + Long: "List all available network routes.", + RunE: networksList, +} + +var routesSelectCmd = &cobra.Command{ + Use: "select network...|all", + Short: "Select network", + Long: "Select a list of networks by identifiers or 'all' to clear all selections and to accept all (including new) networks.\nDefault mode is replace, use -a to append to already selected networks.", + Example: " netbird networks select all\n netbird networks select route1 route2\n netbird routes select -a route3", + Args: cobra.MinimumNArgs(1), + RunE: networksSelect, +} + +var routesDeselectCmd = &cobra.Command{ + Use: "deselect network...|all", + Short: "Deselect networks", + Long: "Deselect previously selected networks by identifiers or 'all' to disable accepting any networks.", + Example: " netbird networks deselect all\n netbird networks deselect route1 route2", + Args: cobra.MinimumNArgs(1), + RunE: networksDeselect, +} + +func init() { + routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing") +} + +func networksList(cmd *cobra.Command, _ []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + resp, err := client.ListNetworks(cmd.Context(), &proto.ListNetworksRequest{}) + if err != nil { + return fmt.Errorf("failed to list network: %v", status.Convert(err).Message()) + } + + if len(resp.Routes) == 0 { + cmd.Println("No networks available.") + return nil + } + + printNetworks(cmd, resp) + + return nil +} + +func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) { + cmd.Println("Available Networks:") + for _, route := range resp.Routes { + printNetwork(cmd, route) + } +} + +func printNetwork(cmd *cobra.Command, route *proto.Network) { + selectedStatus := getSelectedStatus(route) + domains := route.GetDomains() + + if len(domains) > 0 { + printDomainRoute(cmd, route, domains, selectedStatus) + } else { + printNetworkRoute(cmd, route, selectedStatus) + } +} + +func getSelectedStatus(route *proto.Network) string { + if route.GetSelected() { + return "Selected" + } + return "Not Selected" +} + +func printDomainRoute(cmd *cobra.Command, route *proto.Network, domains []string, selectedStatus string) { + cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus) + resolvedIPs := route.GetResolvedIPs() + + if len(resolvedIPs) > 0 { + printResolvedIPs(cmd, domains, resolvedIPs) + } else { + cmd.Printf(" Resolved IPs: -\n") + } +} + +func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus string) { + cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus) +} + +func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) { + cmd.Printf(" Resolved IPs:\n") + for resolvedDomain, ipList := range resolvedIPs { + cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", ")) + } +} + +func networksSelect(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + req := &proto.SelectNetworksRequest{ + NetworkIDs: args, + } + + if len(args) == 1 && args[0] == "all" { + req.All = true + } else if appendFlag { + req.Append = true + } + + if _, err := client.SelectNetworks(cmd.Context(), req); err != nil { + return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message()) + } + + cmd.Println("Networks selected successfully.") + + return nil +} + +func networksDeselect(cmd *cobra.Command, args []string) error { + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + req := &proto.SelectNetworksRequest{ + NetworkIDs: args, + } + + if len(args) == 1 && args[0] == "all" { + req.All = true + } + + if _, err := client.DeselectNetworks(cmd.Context(), req); err != nil { + return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message()) + } + + cmd.Println("Networks deselected successfully.") + + return nil +} diff --git a/client/cmd/root.go b/client/cmd/root.go index 3f2d04ef3..0305bacc8 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -142,14 +142,14 @@ func init() { rootCmd.AddCommand(loginCmd) rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(sshCmd) - rootCmd.AddCommand(routesCmd) + rootCmd.AddCommand(networksCMD) rootCmd.AddCommand(debugCmd) serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service - routesCmd.AddCommand(routesListCmd) - routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd) + networksCMD.AddCommand(routesListCmd) + networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) debugCmd.AddCommand(debugBundleCmd) debugCmd.AddCommand(logCmd) diff --git a/client/cmd/route.go b/client/cmd/route.go deleted file mode 100644 index c8881822b..000000000 --- a/client/cmd/route.go +++ /dev/null @@ -1,174 +0,0 @@ -package cmd - -import ( - "fmt" - "strings" - - "github.com/spf13/cobra" - "google.golang.org/grpc/status" - - "github.com/netbirdio/netbird/client/proto" -) - -var appendFlag bool - -var routesCmd = &cobra.Command{ - Use: "routes", - Short: "Manage network routes", - Long: `Commands to list, select, or deselect network routes.`, -} - -var routesListCmd = &cobra.Command{ - Use: "list", - Aliases: []string{"ls"}, - Short: "List routes", - Example: " netbird routes list", - Long: "List all available network routes.", - RunE: routesList, -} - -var routesSelectCmd = &cobra.Command{ - Use: "select route...|all", - Short: "Select routes", - Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.", - Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3", - Args: cobra.MinimumNArgs(1), - RunE: routesSelect, -} - -var routesDeselectCmd = &cobra.Command{ - Use: "deselect route...|all", - Short: "Deselect routes", - Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.", - Example: " netbird routes deselect all\n netbird routes deselect route1 route2", - Args: cobra.MinimumNArgs(1), - RunE: routesDeselect, -} - -func init() { - routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing") -} - -func routesList(cmd *cobra.Command, _ []string) error { - conn, err := getClient(cmd) - if err != nil { - return err - } - defer conn.Close() - - client := proto.NewDaemonServiceClient(conn) - resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{}) - if err != nil { - return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message()) - } - - if len(resp.Routes) == 0 { - cmd.Println("No routes available.") - return nil - } - - printRoutes(cmd, resp) - - return nil -} - -func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) { - cmd.Println("Available Routes:") - for _, route := range resp.Routes { - printRoute(cmd, route) - } -} - -func printRoute(cmd *cobra.Command, route *proto.Route) { - selectedStatus := getSelectedStatus(route) - domains := route.GetDomains() - - if len(domains) > 0 { - printDomainRoute(cmd, route, domains, selectedStatus) - } else { - printNetworkRoute(cmd, route, selectedStatus) - } -} - -func getSelectedStatus(route *proto.Route) string { - if route.GetSelected() { - return "Selected" - } - return "Not Selected" -} - -func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) { - cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus) - resolvedIPs := route.GetResolvedIPs() - - if len(resolvedIPs) > 0 { - printResolvedIPs(cmd, domains, resolvedIPs) - } else { - cmd.Printf(" Resolved IPs: -\n") - } -} - -func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) { - cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus) -} - -func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) { - cmd.Printf(" Resolved IPs:\n") - for _, domain := range domains { - if ipList, exists := resolvedIPs[domain]; exists { - cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", ")) - } - } -} - -func routesSelect(cmd *cobra.Command, args []string) error { - conn, err := getClient(cmd) - if err != nil { - return err - } - defer conn.Close() - - client := proto.NewDaemonServiceClient(conn) - req := &proto.SelectRoutesRequest{ - RouteIDs: args, - } - - if len(args) == 1 && args[0] == "all" { - req.All = true - } else if appendFlag { - req.Append = true - } - - if _, err := client.SelectRoutes(cmd.Context(), req); err != nil { - return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message()) - } - - cmd.Println("Routes selected successfully.") - - return nil -} - -func routesDeselect(cmd *cobra.Command, args []string) error { - conn, err := getClient(cmd) - if err != nil { - return err - } - defer conn.Close() - - client := proto.NewDaemonServiceClient(conn) - req := &proto.SelectRoutesRequest{ - RouteIDs: args, - } - - if len(args) == 1 && args[0] == "all" { - req.All = true - } - - if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil { - return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message()) - } - - cmd.Println("Routes deselected successfully.") - - return nil -} diff --git a/client/cmd/status.go b/client/cmd/status.go index 6db52a677..fa4bff77b 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -40,6 +40,7 @@ type peerStateDetailOutput struct { Latency time.Duration `json:"latency" yaml:"latency"` RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` Routes []string `json:"routes" yaml:"routes"` + Networks []string `json:"networks" yaml:"networks"` } type peersStateOutput struct { @@ -98,6 +99,7 @@ type statusOutputOverview struct { RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"` RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"` Routes []string `json:"routes" yaml:"routes"` + Networks []string `json:"networks" yaml:"networks"` NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"` } @@ -282,7 +284,8 @@ func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverv FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(), RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(), RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(), - Routes: pbFullStatus.GetLocalPeerState().GetRoutes(), + Routes: pbFullStatus.GetLocalPeerState().GetNetworks(), + Networks: pbFullStatus.GetLocalPeerState().GetNetworks(), NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()), } @@ -390,7 +393,8 @@ func mapPeers(peers []*proto.PeerState) peersStateOutput { TransferSent: transferSent, Latency: pbPeerState.GetLatency().AsDuration(), RosenpassEnabled: pbPeerState.GetRosenpassEnabled(), - Routes: pbPeerState.GetRoutes(), + Routes: pbPeerState.GetNetworks(), + Networks: pbPeerState.GetNetworks(), } peersStateDetail = append(peersStateDetail, peerState) @@ -491,10 +495,10 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total) } - routes := "-" - if len(overview.Routes) > 0 { - sort.Strings(overview.Routes) - routes = strings.Join(overview.Routes, ", ") + networks := "-" + if len(overview.Networks) > 0 { + sort.Strings(overview.Networks) + networks = strings.Join(overview.Networks, ", ") } var dnsServersString string @@ -556,6 +560,7 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays "Interface type: %s\n"+ "Quantum resistance: %s\n"+ "Routes: %s\n"+ + "Networks: %s\n"+ "Peers count: %s\n", fmt.Sprintf("%s/%s%s", goos, goarch, goarm), overview.DaemonVersion, @@ -568,7 +573,8 @@ func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays interfaceIP, interfaceTypeString, rosenpassEnabledStatus, - routes, + networks, + networks, peersCountString, ) return summary @@ -631,10 +637,10 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo } } - routes := "-" - if len(peerState.Routes) > 0 { - sort.Strings(peerState.Routes) - routes = strings.Join(peerState.Routes, ", ") + networks := "-" + if len(peerState.Networks) > 0 { + sort.Strings(peerState.Networks) + networks = strings.Join(peerState.Networks, ", ") } peerString := fmt.Sprintf( @@ -652,6 +658,7 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo " Transfer status (received/sent) %s/%s\n"+ " Quantum resistance: %s\n"+ " Routes: %s\n"+ + " Networks: %s\n"+ " Latency: %s\n", peerState.FQDN, peerState.IP, @@ -668,7 +675,8 @@ func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bo toIEC(peerState.TransferReceived), toIEC(peerState.TransferSent), rosenpassEnabledStatus, - routes, + networks, + networks, peerState.Latency.String(), ) @@ -810,6 +818,14 @@ func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) { peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress) + for i, route := range peer.Networks { + peer.Networks[i] = a.AnonymizeIPString(route) + } + + for i, route := range peer.Networks { + peer.Networks[i] = a.AnonymizeRoute(route) + } + for i, route := range peer.Routes { peer.Routes[i] = a.AnonymizeIPString(route) } @@ -850,6 +866,10 @@ func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) } } + for i, route := range overview.Networks { + overview.Networks[i] = a.AnonymizeRoute(route) + } + for i, route := range overview.Routes { overview.Routes[i] = a.AnonymizeRoute(route) } diff --git a/client/cmd/status_test.go b/client/cmd/status_test.go index ca43df8a5..1f1e95726 100644 --- a/client/cmd/status_test.go +++ b/client/cmd/status_test.go @@ -44,7 +44,7 @@ var resp = &proto.StatusResponse{ LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)), BytesRx: 200, BytesTx: 100, - Routes: []string{ + Networks: []string{ "10.1.0.0/24", }, Latency: durationpb.New(time.Duration(10000000)), @@ -93,7 +93,7 @@ var resp = &proto.StatusResponse{ PubKey: "Some-Pub-Key", KernelInterface: true, Fqdn: "some-localhost.awesome-domain.com", - Routes: []string{ + Networks: []string{ "10.10.0.0/24", }, }, @@ -149,6 +149,9 @@ var overview = statusOutputOverview{ Routes: []string{ "10.1.0.0/24", }, + Networks: []string{ + "10.1.0.0/24", + }, Latency: time.Duration(10000000), }, { @@ -230,6 +233,9 @@ var overview = statusOutputOverview{ Routes: []string{ "10.10.0.0/24", }, + Networks: []string{ + "10.10.0.0/24", + }, } func TestConversionFromFullStatusToOutputOverview(t *testing.T) { @@ -295,6 +301,9 @@ func TestParsingToJSON(t *testing.T) { "quantumResistance": false, "routes": [ "10.1.0.0/24" + ], + "networks": [ + "10.1.0.0/24" ] }, { @@ -318,7 +327,8 @@ func TestParsingToJSON(t *testing.T) { "transferSent": 1000, "latency": 10000000, "quantumResistance": false, - "routes": null + "routes": null, + "networks": null } ] }, @@ -359,6 +369,9 @@ func TestParsingToJSON(t *testing.T) { "routes": [ "10.10.0.0/24" ], + "networks": [ + "10.10.0.0/24" + ], "dnsServers": [ { "servers": [ @@ -418,6 +431,8 @@ func TestParsingToYAML(t *testing.T) { quantumResistance: false routes: - 10.1.0.0/24 + networks: + - 10.1.0.0/24 - fqdn: peer-2.awesome-domain.com netbirdIp: 192.168.178.102 publicKey: Pubkey2 @@ -437,6 +452,7 @@ func TestParsingToYAML(t *testing.T) { latency: 10ms quantumResistance: false routes: [] + networks: [] cliVersion: development daemonVersion: 0.14.1 management: @@ -465,6 +481,8 @@ quantumResistance: false quantumResistancePermissive: false routes: - 10.10.0.0/24 +networks: + - 10.10.0.0/24 dnsServers: - servers: - 8.8.8.8:53 @@ -509,6 +527,7 @@ func TestParsingToDetail(t *testing.T) { Transfer status (received/sent) 200 B/100 B Quantum resistance: false Routes: 10.1.0.0/24 + Networks: 10.1.0.0/24 Latency: 10ms peer-2.awesome-domain.com: @@ -525,6 +544,7 @@ func TestParsingToDetail(t *testing.T) { Transfer status (received/sent) 2.0 KiB/1000 B Quantum resistance: false Routes: - + Networks: - Latency: 10ms OS: %s/%s @@ -543,6 +563,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Routes: 10.10.0.0/24 +Networks: 10.10.0.0/24 Peers count: 2/2 Connected `, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion) @@ -564,6 +585,7 @@ NetBird IP: 192.168.178.100/16 Interface type: Kernel Quantum resistance: false Routes: 10.10.0.0/24 +Networks: 10.10.0.0/24 Peers count: 2/2 Connected ` diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index d998f9ea9..e3e644357 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -10,6 +10,8 @@ import ( "go.opentelemetry.io/otel" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" @@ -71,7 +73,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc. t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } @@ -93,7 +95,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc. } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { t.Fatal(err) } diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go new file mode 100644 index 000000000..c6ac3ebd6 --- /dev/null +++ b/client/internal/dns/handler_chain.go @@ -0,0 +1,222 @@ +package dns + +import ( + "strings" + "sync" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" +) + +const ( + PriorityDNSRoute = 100 + PriorityMatchDomain = 50 + PriorityDefault = 0 +) + +type SubdomainMatcher interface { + dns.Handler + MatchSubdomains() bool +} + +type HandlerEntry struct { + Handler dns.Handler + Priority int + Pattern string + OrigPattern string + IsWildcard bool + StopHandler handlerWithStop + MatchSubdomains bool +} + +// HandlerChain represents a prioritized chain of DNS handlers +type HandlerChain struct { + mu sync.RWMutex + handlers []HandlerEntry +} + +// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain +type ResponseWriterChain struct { + dns.ResponseWriter + origPattern string + shouldContinue bool +} + +func (w *ResponseWriterChain) WriteMsg(m *dns.Msg) error { + // Check if this is a continue signal (NXDOMAIN with Zero bit set) + if m.Rcode == dns.RcodeNameError && m.MsgHdr.Zero { + w.shouldContinue = true + return nil + } + return w.ResponseWriter.WriteMsg(m) +} + +func NewHandlerChain() *HandlerChain { + return &HandlerChain{ + handlers: make([]HandlerEntry, 0), + } +} + +// GetOrigPattern returns the original pattern of the handler that wrote the response +func (w *ResponseWriterChain) GetOrigPattern() string { + return w.origPattern +} + +// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority +func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int, stopHandler handlerWithStop) { + c.mu.Lock() + defer c.mu.Unlock() + + origPattern := pattern + isWildcard := strings.HasPrefix(pattern, "*.") + if isWildcard { + pattern = pattern[2:] + } + pattern = dns.Fqdn(pattern) + origPattern = dns.Fqdn(origPattern) + + // First remove any existing handler with same original pattern and priority + for i := len(c.handlers) - 1; i >= 0; i-- { + if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority { + if c.handlers[i].StopHandler != nil { + c.handlers[i].StopHandler.stop() + } + c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) + break + } + } + + // Check if handler implements SubdomainMatcher interface + matchSubdomains := false + if matcher, ok := handler.(SubdomainMatcher); ok { + matchSubdomains = matcher.MatchSubdomains() + } + + log.Debugf("adding handler pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d", + pattern, origPattern, isWildcard, matchSubdomains, priority) + + entry := HandlerEntry{ + Handler: handler, + Priority: priority, + Pattern: pattern, + OrigPattern: origPattern, + IsWildcard: isWildcard, + StopHandler: stopHandler, + MatchSubdomains: matchSubdomains, + } + + // Insert handler in priority order + pos := 0 + for i, h := range c.handlers { + if h.Priority < priority { + pos = i + break + } + pos = i + 1 + } + + c.handlers = append(c.handlers[:pos], append([]HandlerEntry{entry}, c.handlers[pos:]...)...) +} + +// RemoveHandler removes a handler for the given pattern and priority +func (c *HandlerChain) RemoveHandler(pattern string, priority int) { + c.mu.Lock() + defer c.mu.Unlock() + + pattern = dns.Fqdn(pattern) + + // Find and remove handlers matching both original pattern and priority + for i := len(c.handlers) - 1; i >= 0; i-- { + entry := c.handlers[i] + if entry.OrigPattern == pattern && entry.Priority == priority { + if entry.StopHandler != nil { + entry.StopHandler.stop() + } + c.handlers = append(c.handlers[:i], c.handlers[i+1:]...) + return + } + } +} + +// HasHandlers returns true if there are any handlers remaining for the given pattern +func (c *HandlerChain) HasHandlers(pattern string) bool { + c.mu.RLock() + defer c.mu.RUnlock() + + pattern = dns.Fqdn(pattern) + for _, entry := range c.handlers { + if entry.Pattern == pattern { + return true + } + } + return false +} + +func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + return + } + + qname := r.Question[0].Name + log.Tracef("handling DNS request for domain=%s", qname) + + c.mu.RLock() + defer c.mu.RUnlock() + + log.Tracef("current handlers (%d):", len(c.handlers)) + for _, h := range c.handlers { + log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d", + h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) + } + + // Try handlers in priority order + for _, entry := range c.handlers { + var matched bool + switch { + case entry.Pattern == ".": + matched = true + case entry.IsWildcard: + parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".") + matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern) + default: + // For non-wildcard patterns: + // If handler wants subdomain matching, allow suffix match + // Otherwise require exact match + if entry.MatchSubdomains { + matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern) + } else { + matched = qname == entry.Pattern + } + } + + if !matched { + log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v matched=false", + qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard) + continue + } + + log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v", + qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains) + + chainWriter := &ResponseWriterChain{ + ResponseWriter: w, + origPattern: entry.OrigPattern, + } + entry.Handler.ServeDNS(chainWriter, r) + + // If handler wants to continue, try next handler + if chainWriter.shouldContinue { + log.Tracef("handler requested continue to next handler") + continue + } + return + } + + // No handler matched or all handlers passed + log.Tracef("no handler found for domain=%s", qname) + resp := &dns.Msg{} + resp.SetRcode(r, dns.RcodeNameError) + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } +} diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go new file mode 100644 index 000000000..727b6e908 --- /dev/null +++ b/client/internal/dns/handler_chain_test.go @@ -0,0 +1,511 @@ +package dns_test + +import ( + "net" + "testing" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + nbdns "github.com/netbirdio/netbird/client/internal/dns" +) + +// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order +func TestHandlerChain_ServeDNS_Priorities(t *testing.T) { + chain := nbdns.NewHandlerChain() + + // Create mock handlers for different priorities + defaultHandler := &nbdns.MockHandler{} + matchDomainHandler := &nbdns.MockHandler{} + dnsRouteHandler := &nbdns.MockHandler{} + + // Setup handlers with different priorities + chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault, nil) + chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain, nil) + chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute, nil) + + // Create test request + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + // Create test writer + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + // Setup expectations - only highest priority handler should be called + dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once() + matchDomainHandler.On("ServeDNS", mock.Anything, r).Maybe() + defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() + + // Execute + chain.ServeDNS(w, r) + + // Verify all expectations were met + dnsRouteHandler.AssertExpectations(t) + matchDomainHandler.AssertExpectations(t) + defaultHandler.AssertExpectations(t) +} + +// TestHandlerChain_ServeDNS_DomainMatching tests various domain matching scenarios +func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) { + tests := []struct { + name string + handlerDomain string + queryDomain string + isWildcard bool + matchSubdomains bool + shouldMatch bool + }{ + { + name: "exact match", + handlerDomain: "example.com.", + queryDomain: "example.com.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "subdomain with non-wildcard and MatchSubdomains true", + handlerDomain: "example.com.", + queryDomain: "sub.example.com.", + isWildcard: false, + matchSubdomains: true, + shouldMatch: true, + }, + { + name: "subdomain with non-wildcard and MatchSubdomains false", + handlerDomain: "example.com.", + queryDomain: "sub.example.com.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: false, + }, + { + name: "wildcard match", + handlerDomain: "*.example.com.", + queryDomain: "sub.example.com.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "wildcard no match on apex", + handlerDomain: "*.example.com.", + queryDomain: "example.com.", + isWildcard: true, + matchSubdomains: false, + shouldMatch: false, + }, + { + name: "root zone match", + handlerDomain: ".", + queryDomain: "anything.com.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: true, + }, + { + name: "no match different domain", + handlerDomain: "example.com.", + queryDomain: "example.org.", + isWildcard: false, + matchSubdomains: false, + shouldMatch: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + var handler dns.Handler + + if tt.matchSubdomains { + mockSubHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + handler = mockSubHandler + if tt.shouldMatch { + mockSubHandler.On("ServeDNS", mock.Anything, mock.Anything).Once() + } + } else { + mockHandler := &nbdns.MockHandler{} + handler = mockHandler + if tt.shouldMatch { + mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Once() + } + } + + pattern := tt.handlerDomain + if tt.isWildcard { + pattern = "*." + tt.handlerDomain[2:] + } + + chain.AddHandler(pattern, handler, nbdns.PriorityDefault, nil) + + r := new(dns.Msg) + r.SetQuestion(tt.queryDomain, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + chain.ServeDNS(w, r) + + if h, ok := handler.(*nbdns.MockHandler); ok { + h.AssertExpectations(t) + } else if h, ok := handler.(*nbdns.MockSubdomainHandler); ok { + h.AssertExpectations(t) + } + }) + } +} + +// TestHandlerChain_ServeDNS_OverlappingDomains tests behavior with overlapping domain patterns +func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) { + tests := []struct { + name string + handlers []struct { + pattern string + priority int + } + queryDomain string + expectedCalls int + expectedHandler int // index of the handler that should be called + }{ + { + name: "wildcard and exact same priority - exact should win", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "example.com.", priority: nbdns.PriorityDefault}, + }, + queryDomain: "example.com.", + expectedCalls: 1, + expectedHandler: 1, // exact match handler should be called + }, + { + name: "higher priority wildcard over lower priority exact", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "example.com.", priority: nbdns.PriorityDefault}, + {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "test.example.com.", + expectedCalls: 1, + expectedHandler: 1, // higher priority wildcard handler should be called + }, + { + name: "multiple wildcards different priorities", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "test.example.com.", + expectedCalls: 1, + expectedHandler: 2, // highest priority handler should be called + }, + { + name: "subdomain with mix of patterns", + handlers: []struct { + pattern string + priority int + }{ + {pattern: "*.example.com.", priority: nbdns.PriorityDefault}, + {pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain}, + {pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "sub.test.example.com.", + expectedCalls: 1, + expectedHandler: 2, // highest priority matching handler should be called + }, + { + name: "root zone with specific domain", + handlers: []struct { + pattern string + priority int + }{ + {pattern: ".", priority: nbdns.PriorityDefault}, + {pattern: "example.com.", priority: nbdns.PriorityDNSRoute}, + }, + queryDomain: "example.com.", + expectedCalls: 1, + expectedHandler: 1, // higher priority specific domain should win over root + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + var handlers []*nbdns.MockHandler + + // Setup handlers and expectations + for i := range tt.handlers { + handler := &nbdns.MockHandler{} + handlers = append(handlers, handler) + + // Set expectation based on whether this handler should be called + if i == tt.expectedHandler { + handler.On("ServeDNS", mock.Anything, mock.Anything).Once() + } else { + handler.On("ServeDNS", mock.Anything, mock.Anything).Maybe() + } + + chain.AddHandler(tt.handlers[i].pattern, handler, tt.handlers[i].priority, nil) + } + + // Create and execute request + r := new(dns.Msg) + r.SetQuestion(tt.queryDomain, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + chain.ServeDNS(w, r) + + // Verify expectations + for _, handler := range handlers { + handler.AssertExpectations(t) + } + }) + } +} + +// TestHandlerChain_ServeDNS_ChainContinuation tests the chain continuation functionality +func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) { + chain := nbdns.NewHandlerChain() + + // Create handlers + handler1 := &nbdns.MockHandler{} + handler2 := &nbdns.MockHandler{} + handler3 := &nbdns.MockHandler{} + + // Add handlers in priority order + chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute, nil) + chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain, nil) + chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault, nil) + + // Create test request + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + // Setup mock responses to simulate chain continuation + handler1.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // First handler signals continue + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true // Signal to continue + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + handler2.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // Second handler signals continue + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + resp.MsgHdr.Zero = true + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + handler3.On("ServeDNS", mock.Anything, r).Run(func(args mock.Arguments) { + // Last handler responds normally + w := args.Get(0).(*nbdns.ResponseWriterChain) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeSuccess) + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + + // Execute + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + chain.ServeDNS(w, r) + + // Verify all handlers were called in order + handler1.AssertExpectations(t) + handler2.AssertExpectations(t) + handler3.AssertExpectations(t) +} + +// mockResponseWriter implements dns.ResponseWriter for testing +type mockResponseWriter struct { + mock.Mock +} + +func (m *mockResponseWriter) LocalAddr() net.Addr { return nil } +func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil } +func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil } +func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil } +func (m *mockResponseWriter) Close() error { return nil } +func (m *mockResponseWriter) TsigStatus() error { return nil } +func (m *mockResponseWriter) TsigTimersOnly(bool) {} +func (m *mockResponseWriter) Hijack() {} + +func TestHandlerChain_PriorityDeregistration(t *testing.T) { + tests := []struct { + name string + ops []struct { + action string // "add" or "remove" + pattern string + priority int + } + query string + expectedCalls map[int]bool // map[priority]shouldBeCalled + }{ + { + name: "remove high priority keeps lower priority handler", + ops: []struct { + action string + pattern string + priority int + }{ + {"add", "example.com.", nbdns.PriorityDNSRoute}, + {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"remove", "example.com.", nbdns.PriorityDNSRoute}, + }, + query: "example.com.", + expectedCalls: map[int]bool{ + nbdns.PriorityDNSRoute: false, + nbdns.PriorityMatchDomain: true, + }, + }, + { + name: "remove lower priority keeps high priority handler", + ops: []struct { + action string + pattern string + priority int + }{ + {"add", "example.com.", nbdns.PriorityDNSRoute}, + {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"remove", "example.com.", nbdns.PriorityMatchDomain}, + }, + query: "example.com.", + expectedCalls: map[int]bool{ + nbdns.PriorityDNSRoute: true, + nbdns.PriorityMatchDomain: false, + }, + }, + { + name: "remove all handlers in order", + ops: []struct { + action string + pattern string + priority int + }{ + {"add", "example.com.", nbdns.PriorityDNSRoute}, + {"add", "example.com.", nbdns.PriorityMatchDomain}, + {"add", "example.com.", nbdns.PriorityDefault}, + {"remove", "example.com.", nbdns.PriorityDNSRoute}, + {"remove", "example.com.", nbdns.PriorityMatchDomain}, + }, + query: "example.com.", + expectedCalls: map[int]bool{ + nbdns.PriorityDNSRoute: false, + nbdns.PriorityMatchDomain: false, + nbdns.PriorityDefault: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + handlers := make(map[int]*nbdns.MockHandler) + + // Execute operations + for _, op := range tt.ops { + if op.action == "add" { + handler := &nbdns.MockHandler{} + handlers[op.priority] = handler + chain.AddHandler(op.pattern, handler, op.priority, nil) + } else { + chain.RemoveHandler(op.pattern, op.priority) + } + } + + // Create test request + r := new(dns.Msg) + r.SetQuestion(tt.query, dns.TypeA) + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + // Setup expectations + for priority, handler := range handlers { + if shouldCall, exists := tt.expectedCalls[priority]; exists && shouldCall { + handler.On("ServeDNS", mock.Anything, r).Once() + } else { + handler.On("ServeDNS", mock.Anything, r).Maybe() + } + } + + // Execute request + chain.ServeDNS(w, r) + + // Verify expectations + for _, handler := range handlers { + handler.AssertExpectations(t) + } + + // Verify handler exists check + for priority, shouldExist := range tt.expectedCalls { + if shouldExist { + assert.True(t, chain.HasHandlers(tt.ops[0].pattern), + "Handler chain should have handlers for pattern after removing priority %d", priority) + } + } + }) + } +} + +func TestHandlerChain_MultiPriorityHandling(t *testing.T) { + chain := nbdns.NewHandlerChain() + + testDomain := "example.com." + testQuery := "test.example.com." + + // Create handlers with MatchSubdomains enabled + routeHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + matchHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + defaultHandler := &nbdns.MockSubdomainHandler{Subdomains: true} + + // Create test request that will be reused + r := new(dns.Msg) + r.SetQuestion(testQuery, dns.TypeA) + + // Add handlers in mixed order + chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault, nil) + chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute, nil) + chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain, nil) + + // Test 1: Initial state with all three handlers + w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + // Highest priority handler (routeHandler) should be called + routeHandler.On("ServeDNS", mock.Anything, r).Return().Once() + + chain.ServeDNS(w, r) + routeHandler.AssertExpectations(t) + + // Test 2: Remove highest priority handler + chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute) + assert.True(t, chain.HasHandlers(testDomain)) + + w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + // Now middle priority handler (matchHandler) should be called + matchHandler.On("ServeDNS", mock.Anything, r).Return().Once() + + chain.ServeDNS(w, r) + matchHandler.AssertExpectations(t) + + // Test 3: Remove middle priority handler + chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) + assert.True(t, chain.HasHandlers(testDomain)) + + w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + // Now lowest priority handler (defaultHandler) should be called + defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once() + + chain.ServeDNS(w, r) + defaultHandler.AssertExpectations(t) + + // Test 4: Remove last handler + chain.RemoveHandler(testDomain, nbdns.PriorityDefault) + assert.False(t, chain.HasHandlers(testDomain)) +} diff --git a/client/internal/dns/local.go b/client/internal/dns/local.go index 6a459794b..9a78d4d50 100644 --- a/client/internal/dns/local.go +++ b/client/internal/dns/local.go @@ -17,12 +17,24 @@ type localResolver struct { records sync.Map } +func (d *localResolver) MatchSubdomains() bool { + return true +} + func (d *localResolver) stop() { } +// String returns a string representation of the local resolver +func (d *localResolver) String() string { + return fmt.Sprintf("local resolver [%d records]", len(d.registeredMap)) +} + // ServeDNS handles a DNS request func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - log.Tracef("received question: %#v", r.Question[0]) + if len(r.Question) > 0 { + log.Tracef("received question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + } + replyMessage := &dns.Msg{} replyMessage.SetReply(r) replyMessage.RecursionAvailable = true diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 0739f0542..7e36ea5df 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -3,14 +3,30 @@ package dns import ( "fmt" + "github.com/miekg/dns" + nbdns "github.com/netbirdio/netbird/dns" ) // MockServer is the mock instance of a dns server type MockServer struct { - InitializeFunc func() error - StopFunc func() - UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + InitializeFunc func() error + StopFunc func() + UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error + RegisterHandlerFunc func([]string, dns.Handler, int) + DeregisterHandlerFunc func([]string, int) +} + +func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { + if m.RegisterHandlerFunc != nil { + m.RegisterHandlerFunc(domains, handler, priority) + } +} + +func (m *MockServer) DeregisterHandler(domains []string, priority int) { + if m.DeregisterHandlerFunc != nil { + m.DeregisterHandlerFunc(domains, priority) + } } // Initialize mock implementation of Initialize from Server interface diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index f0277319c..5a9cb50d0 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -30,6 +30,8 @@ type IosDnsManager interface { // Server is a dns server interface type Server interface { + RegisterHandler(domains []string, handler dns.Handler, priority int) + DeregisterHandler(domains []string, priority int) Initialize() error Stop() DnsIP() string @@ -48,12 +50,14 @@ type DefaultServer struct { mux sync.Mutex service service dnsMuxMap registeredHandlerMap + handlerPriorities map[string]int localResolver *localResolver wgInterface WGIface hostManager hostManager updateSerial uint64 previousConfigHash uint64 currentConfig HostDNSConfig + handlerChain *HandlerChain // permanent related properties permanent bool @@ -74,8 +78,9 @@ type handlerWithStop interface { } type muxUpdate struct { - domain string - handler handlerWithStop + domain string + handler handlerWithStop + priority int } // NewDefaultServer returns a new dns server @@ -135,10 +140,12 @@ func NewDefaultServerIos( func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ - ctx: ctx, - ctxCancel: stop, - service: dnsService, - dnsMuxMap: make(registeredHandlerMap), + ctx: ctx, + ctxCancel: stop, + service: dnsService, + handlerChain: NewHandlerChain(), + dnsMuxMap: make(registeredHandlerMap), + handlerPriorities: make(map[string]int), localResolver: &localResolver{ registeredMap: make(registrationMap), }, @@ -151,6 +158,51 @@ func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService servi return defaultServer } +func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) { + s.mux.Lock() + defer s.mux.Unlock() + + s.registerHandler(domains, handler, priority) +} + +func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { + log.Debugf("registering handler %s with priority %d", handler, priority) + + for _, domain := range domains { + if domain == "" { + log.Warn("skipping empty domain") + continue + } + s.handlerChain.AddHandler(domain, handler, priority, nil) + s.handlerPriorities[domain] = priority + s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain) + } +} + +func (s *DefaultServer) DeregisterHandler(domains []string, priority int) { + s.mux.Lock() + defer s.mux.Unlock() + + s.deregisterHandler(domains, priority) +} + +func (s *DefaultServer) deregisterHandler(domains []string, priority int) { + log.Debugf("deregistering handler %v with priority %d", domains, priority) + + for _, domain := range domains { + s.handlerChain.RemoveHandler(domain, priority) + + // Only deregister from service if no handlers remain + if !s.handlerChain.HasHandlers(domain) { + if domain == "" { + log.Warn("skipping empty domain") + continue + } + s.service.DeregisterMux(nbdns.NormalizeZone(domain)) + } + } +} + // Initialize instantiate host manager and the dns service func (s *DefaultServer) Initialize() (err error) { s.mux.Lock() @@ -343,14 +395,14 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) localRecords := make(map[string]nbdns.SimpleRecord, 0) for _, customZone := range customZones { - if len(customZone.Records) == 0 { return nil, nil, fmt.Errorf("received an empty list of records") } muxUpdates = append(muxUpdates, muxUpdate{ - domain: customZone.Domain, - handler: s.localResolver, + domain: customZone.Domain, + handler: s.localResolver, + priority: PriorityMatchDomain, }) for _, record := range customZone.Records { @@ -412,8 +464,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam if nsGroup.Primary { muxUpdates = append(muxUpdates, muxUpdate{ - domain: nbdns.RootZone, - handler: handler, + domain: nbdns.RootZone, + handler: handler, + priority: PriorityDefault, }) continue } @@ -429,8 +482,9 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam return nil, fmt.Errorf("received a nameserver group with an empty domain element") } muxUpdates = append(muxUpdates, muxUpdate{ - domain: domain, - handler: handler, + domain: domain, + handler: handler, + priority: PriorityMatchDomain, }) } } @@ -440,12 +494,16 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { muxUpdateMap := make(registeredHandlerMap) + handlersByPriority := make(map[string]int) var isContainRootUpdate bool + // First register new handlers for _, update := range muxUpdates { - s.service.RegisterMux(update.domain, update.handler) + s.registerHandler([]string{update.domain}, update.handler, update.priority) muxUpdateMap[update.domain] = update.handler + handlersByPriority[update.domain] = update.priority + if existingHandler, ok := s.dnsMuxMap[update.domain]; ok { existingHandler.stop() } @@ -455,6 +513,7 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { } } + // Then deregister old handlers not in the update for key, existingHandler := range s.dnsMuxMap { _, found := muxUpdateMap[key] if !found { @@ -463,12 +522,16 @@ func (s *DefaultServer) updateMux(muxUpdates []muxUpdate) { existingHandler.stop() } else { existingHandler.stop() - s.service.DeregisterMux(key) + // Deregister with the priority that was used to register + if oldPriority, ok := s.handlerPriorities[key]; ok { + s.deregisterHandler([]string{key}, oldPriority) + } } } } s.dnsMuxMap = muxUpdateMap + s.handlerPriorities = handlersByPriority } func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) { @@ -517,13 +580,13 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { removeIndex[nbdns.RootZone] = -1 s.currentConfig.RouteAll = false - s.service.DeregisterMux(nbdns.RootZone) + s.deregisterHandler([]string{nbdns.RootZone}, PriorityDefault) } for i, item := range s.currentConfig.Domains { if _, found := removeIndex[item.Domain]; found { s.currentConfig.Domains[i].Disabled = true - s.service.DeregisterMux(item.Domain) + s.deregisterHandler([]string{item.Domain}, PriorityMatchDomain) removeIndex[item.Domain] = i } } @@ -554,7 +617,7 @@ func (s *DefaultServer) upstreamCallbacks( continue } s.currentConfig.Domains[i].Disabled = false - s.service.RegisterMux(domain, handler) + s.registerHandler([]string{domain}, handler, PriorityMatchDomain) } l := log.WithField("nameservers", nsGroup.NameServers) @@ -562,7 +625,7 @@ func (s *DefaultServer) upstreamCallbacks( if nsGroup.Primary { s.currentConfig.RouteAll = true - s.service.RegisterMux(nbdns.RootZone, handler) + s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) } if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") @@ -593,7 +656,8 @@ func (s *DefaultServer) addHostRootZone() { } handler.deactivate = func(error) {} handler.reactivate = func() {} - s.service.RegisterMux(nbdns.RootZone, handler) + + s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) } func (s *DefaultServer) updateNSGroupStates(groups []*nbdns.NameServerGroup) { diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index eab9f4ecb..44d20c6f3 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -11,7 +11,9 @@ import ( "time" "github.com/golang/mock/gomock" + "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/mock" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/firewall/uspfilter" @@ -512,7 +514,7 @@ func TestDNSServerStartStop(t *testing.T) { t.Error(err) } - dnsServer.service.RegisterMux("netbird.cloud", dnsServer.localResolver) + dnsServer.registerHandler([]string{"netbird.cloud"}, dnsServer.localResolver, 1) resolver := &net.Resolver{ PreferGo: true, @@ -560,7 +562,9 @@ func TestDNSServerUpstreamDeactivateCallback(t *testing.T) { localResolver: &localResolver{ registeredMap: make(registrationMap), }, - hostManager: hostManager, + handlerChain: NewHandlerChain(), + handlerPriorities: make(map[string]int), + hostManager: hostManager, currentConfig: HostDNSConfig{ Domains: []DomainConfig{ {false, "domain0", false}, @@ -872,3 +876,86 @@ func newDnsResolver(ip string, port int) *net.Resolver { }, } } + +// MockHandler implements dns.Handler interface for testing +type MockHandler struct { + mock.Mock +} + +func (m *MockHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + m.Called(w, r) +} + +type MockSubdomainHandler struct { + MockHandler + Subdomains bool +} + +func (m *MockSubdomainHandler) MatchSubdomains() bool { + return m.Subdomains +} + +func TestHandlerChain_DomainPriorities(t *testing.T) { + chain := NewHandlerChain() + + dnsRouteHandler := &MockHandler{} + upstreamHandler := &MockSubdomainHandler{ + Subdomains: true, + } + + chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute, nil) + chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain, nil) + + testCases := []struct { + name string + query string + expectedHandler dns.Handler + }{ + { + name: "exact domain with dns route handler", + query: "example.com.", + expectedHandler: dnsRouteHandler, + }, + { + name: "subdomain should use upstream handler", + query: "sub.example.com.", + expectedHandler: upstreamHandler, + }, + { + name: "deep subdomain should use upstream handler", + query: "deep.sub.example.com.", + expectedHandler: upstreamHandler, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := new(dns.Msg) + r.SetQuestion(tc.query, dns.TypeA) + w := &ResponseWriterChain{ResponseWriter: &mockResponseWriter{}} + + if mh, ok := tc.expectedHandler.(*MockHandler); ok { + mh.On("ServeDNS", mock.Anything, r).Once() + } else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok { + mh.On("ServeDNS", mock.Anything, r).Once() + } + + chain.ServeDNS(w, r) + + if mh, ok := tc.expectedHandler.(*MockHandler); ok { + mh.AssertExpectations(t) + } else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok { + mh.AssertExpectations(t) + } + + // Reset mocks + if mh, ok := tc.expectedHandler.(*MockHandler); ok { + mh.ExpectedCalls = nil + mh.Calls = nil + } else if mh, ok := tc.expectedHandler.(*MockSubdomainHandler); ok { + mh.ExpectedCalls = nil + mh.Calls = nil + } + }) + } +} diff --git a/client/internal/dns/service_listener.go b/client/internal/dns/service_listener.go index e0f9da26f..72dc4bc6e 100644 --- a/client/internal/dns/service_listener.go +++ b/client/internal/dns/service_listener.go @@ -105,6 +105,7 @@ func (s *serviceViaListener) Stop() { } func (s *serviceViaListener) RegisterMux(pattern string, handler dns.Handler) { + log.Debugf("registering dns handler for pattern: %s", pattern) s.dnsMux.Handle(pattern, handler) } diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index b3baf2fa8..f0aa12b65 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -66,6 +66,15 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status) * } } +// String returns a string representation of the upstream resolver +func (u *upstreamResolverBase) String() string { + return fmt.Sprintf("upstream %v", u.upstreamServers) +} + +func (u *upstreamResolverBase) MatchSubdomains() bool { + return true +} + func (u *upstreamResolverBase) stop() { log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go new file mode 100644 index 000000000..ae31ffac6 --- /dev/null +++ b/client/internal/dnsfwd/forwarder.go @@ -0,0 +1,157 @@ +package dnsfwd + +import ( + "context" + "errors" + "net" + + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" +) + +const errResolveFailed = "failed to resolve query for domain=%s: %v" + +type DNSForwarder struct { + listenAddress string + ttl uint32 + domains []string + + dnsServer *dns.Server + mux *dns.ServeMux +} + +func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder { + log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) + return &DNSForwarder{ + listenAddress: listenAddress, + ttl: ttl, + } +} + +func (f *DNSForwarder) Listen(domains []string) error { + log.Infof("listen DNS forwarder on address=%s", f.listenAddress) + mux := dns.NewServeMux() + + dnsServer := &dns.Server{ + Addr: f.listenAddress, + Net: "udp", + Handler: mux, + } + f.dnsServer = dnsServer + f.mux = mux + + f.UpdateDomains(domains) + + return dnsServer.ListenAndServe() +} + +func (f *DNSForwarder) UpdateDomains(domains []string) { + log.Debugf("Updating domains from %v to %v", f.domains, domains) + + for _, d := range f.domains { + f.mux.HandleRemove(d) + } + + newDomains := filterDomains(domains) + for _, d := range newDomains { + f.mux.HandleFunc(d, f.handleDNSQuery) + } + f.domains = newDomains +} + +func (f *DNSForwarder) Close(ctx context.Context) error { + if f.dnsServer == nil { + return nil + } + return f.dnsServer.ShutdownContext(ctx) +} + +func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) { + if len(query.Question) == 0 { + return + } + log.Tracef("received DNS request for DNS forwarder: domain=%v type=%v class=%v", + query.Question[0].Name, query.Question[0].Qtype, query.Question[0].Qclass) + + question := query.Question[0] + domain := question.Name + + resp := query.SetReply(query) + + ips, err := net.LookupIP(domain) + if err != nil { + var dnsErr *net.DNSError + + switch { + case errors.As(err, &dnsErr): + resp.Rcode = dns.RcodeServerFailure + if dnsErr.IsNotFound { + // Pass through NXDOMAIN + resp.Rcode = dns.RcodeNameError + } + + if dnsErr.Server != "" { + log.Warnf("failed to resolve query for domain=%s server=%s: %v", domain, dnsErr.Server, err) + } else { + log.Warnf(errResolveFailed, domain, err) + } + default: + resp.Rcode = dns.RcodeServerFailure + log.Warnf(errResolveFailed, domain, err) + } + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write failure DNS response: %v", err) + } + return + } + + for _, ip := range ips { + var respRecord dns.RR + if ip.To4() == nil { + log.Tracef("resolved domain=%s to IPv6=%s", domain, ip) + rr := dns.AAAA{ + AAAA: ip, + Hdr: dns.RR_Header{ + Name: domain, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: f.ttl, + }, + } + respRecord = &rr + } else { + log.Tracef("resolved domain=%s to IPv4=%s", domain, ip) + rr := dns.A{ + A: ip, + Hdr: dns.RR_Header{ + Name: domain, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: f.ttl, + }, + } + respRecord = &rr + } + resp.Answer = append(resp.Answer, respRecord) + } + + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed to write DNS response: %v", err) + } +} + +// filterDomains returns a list of normalized domains +func filterDomains(domains []string) []string { + newDomains := make([]string, 0, len(domains)) + for _, d := range domains { + if d == "" { + log.Warn("empty domain in DNS forwarder") + continue + } + newDomains = append(newDomains, nbdns.NormalizeZone(d)) + } + return newDomains +} diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go new file mode 100644 index 000000000..7cff6d517 --- /dev/null +++ b/client/internal/dnsfwd/manager.go @@ -0,0 +1,106 @@ +package dnsfwd + +import ( + "context" + "fmt" + "net" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + firewall "github.com/netbirdio/netbird/client/firewall/manager" +) + +const ( + // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also + ListenPort = 5353 + dnsTTL = 60 //seconds +) + +type Manager struct { + firewall firewall.Manager + + fwRules []firewall.Rule + dnsForwarder *DNSForwarder +} + +func NewManager(fw firewall.Manager) *Manager { + return &Manager{ + firewall: fw, + } +} + +func (m *Manager) Start(domains []string) error { + log.Infof("starting DNS forwarder") + if m.dnsForwarder != nil { + return nil + } + + if err := m.allowDNSFirewall(); err != nil { + return err + } + + m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL) + go func() { + if err := m.dnsForwarder.Listen(domains); err != nil { + // todo handle close error if it is exists + log.Errorf("failed to start DNS forwarder, err: %v", err) + } + }() + + return nil +} + +func (m *Manager) UpdateDomains(domains []string) { + if m.dnsForwarder == nil { + return + } + + m.dnsForwarder.UpdateDomains(domains) +} + +func (m *Manager) Stop(ctx context.Context) error { + if m.dnsForwarder == nil { + return nil + } + + var mErr *multierror.Error + if err := m.dropDNSFirewall(); err != nil { + mErr = multierror.Append(mErr, err) + } + + if err := m.dnsForwarder.Close(ctx); err != nil { + mErr = multierror.Append(mErr, err) + } + + m.dnsForwarder = nil + return nberrors.FormatErrorOrNil(mErr) +} + +func (h *Manager) allowDNSFirewall() error { + dport := &firewall.Port{ + IsRange: false, + Values: []int{ListenPort}, + } + dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") + if err != nil { + log.Errorf("failed to add allow DNS router rules, err: %v", err) + return err + } + h.fwRules = dnsRules + + return nil +} + +func (h *Manager) dropDNSFirewall() error { + var mErr *multierror.Error + for _, rule := range h.fwRules { + if err := h.firewall.DeletePeerRule(rule); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("failed to delete DNS router rules, err: %v", err)) + } + } + + h.fwRules = nil + return nberrors.FormatErrorOrNil(mErr) +} diff --git a/client/internal/engine.go b/client/internal/engine.go index 34219def1..9724e2a22 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "maps" "math/rand" "net" "net/netip" @@ -30,10 +29,12 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer/guard" icemaker "github.com/netbirdio/netbird/client/internal/peer/ice" + "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/rosenpass" "github.com/netbirdio/netbird/client/internal/routemanager" @@ -117,7 +118,7 @@ type Engine struct { // mgmClient is a Management Service client mgmClient mgm.Client // peerConns is a map that holds all the peers that are known to this peer - peerConns map[string]*peer.Conn + peerStore *peerstore.Store beforePeerHook nbnet.AddHookFunc afterPeerHook nbnet.RemoveHookFunc @@ -137,10 +138,6 @@ type Engine struct { TURNs []*stun.URI stunTurn atomic.Value - // clientRoutes is the most recent list of clientRoutes received from the Management Service - clientRoutes route.HAMap - clientRoutesMu sync.RWMutex - clientCtx context.Context clientCancel context.CancelFunc @@ -161,9 +158,10 @@ type Engine struct { statusRecorder *peer.Status - firewall manager.Manager - routeManager routemanager.Manager - acl acl.Manager + firewall manager.Manager + routeManager routemanager.Manager + acl acl.Manager + dnsForwardMgr *dnsfwd.Manager dnsServer dns.Server @@ -234,7 +232,7 @@ func NewEngineWithProbes( signaler: peer.NewSignaler(signalClient, config.WgPrivateKey), mgmClient: mgmClient, relayManager: relayManager, - peerConns: make(map[string]*peer.Conn), + peerStore: peerstore.NewConnStore(), syncMsgMux: &sync.Mutex{}, config: config, mobileDep: mobileDep, @@ -287,6 +285,13 @@ func (e *Engine) Stop() error { e.routeManager.Stop(e.stateManager) } + if e.dnsForwardMgr != nil { + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + e.dnsForwardMgr = nil + } + if e.srWatcher != nil { e.srWatcher.Close() } @@ -300,10 +305,6 @@ func (e *Engine) Stop() error { return fmt.Errorf("failed to remove all peers: %s", err) } - e.clientRoutesMu.Lock() - e.clientRoutes = nil - e.clientRoutesMu.Unlock() - if e.cancel != nil { e.cancel() } @@ -382,6 +383,8 @@ func (e *Engine) Start() error { e.relayManager, initialRoutes, e.stateManager, + dnsServer, + e.peerStore, ) beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err != nil { @@ -460,8 +463,8 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { var modified []*mgmProto.RemotePeerConfig for _, p := range peersUpdate { peerPubKey := p.GetWgPubKey() - if peerConn, ok := e.peerConns[peerPubKey]; ok { - if peerConn.WgConfig().AllowedIps != strings.Join(p.AllowedIps, ",") { + if allowedIPs, ok := e.peerStore.AllowedIPs(peerPubKey); ok { + if allowedIPs != strings.Join(p.AllowedIps, ",") { modified = append(modified, p) continue } @@ -492,17 +495,12 @@ func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { // removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service. // It also removes peers that have been modified (e.g. change of IP address). They will be added again in addPeers method. func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error { - currentPeers := make([]string, 0, len(e.peerConns)) - for p := range e.peerConns { - currentPeers = append(currentPeers, p) - } - newPeers := make([]string, 0, len(peersUpdate)) for _, p := range peersUpdate { newPeers = append(newPeers, p.GetWgPubKey()) } - toRemove := util.SliceDiff(currentPeers, newPeers) + toRemove := util.SliceDiff(e.peerStore.PeersPubKey(), newPeers) for _, p := range toRemove { err := e.removePeer(p) @@ -516,7 +514,7 @@ func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) removeAllPeers() error { log.Debugf("removing all peer connections") - for p := range e.peerConns { + for _, p := range e.peerStore.PeersPubKey() { err := e.removePeer(p) if err != nil { return err @@ -540,9 +538,8 @@ func (e *Engine) removePeer(peerKey string) error { } }() - conn, exists := e.peerConns[peerKey] + conn, exists := e.peerStore.Remove(peerKey) if exists { - delete(e.peerConns, peerKey) conn.Close() } return nil @@ -786,7 +783,6 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error { } func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { - // intentionally leave it before checking serial because for now it can happen that peer IP changed but serial didn't if networkMap.GetPeerConfig() != nil { err := e.updateConfig(networkMap.GetPeerConfig()) @@ -806,20 +802,18 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.acl.ApplyFiltering(networkMap) } - protoRoutes := networkMap.GetRoutes() - if protoRoutes == nil { - protoRoutes = []*mgmProto.Route{} + var dnsRouteFeatureFlag bool + if networkMap.PeerConfig != nil { + dnsRouteFeatureFlag = networkMap.PeerConfig.RoutingPeerDnsResolutionEnabled } + routedDomains, routes := toRoutes(networkMap.GetRoutes()) - _, clientRoutes, err := e.routeManager.UpdateRoutes(serial, toRoutes(protoRoutes)) - if err != nil { + e.updateDNSForwarder(dnsRouteFeatureFlag, routedDomains) + + if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { log.Errorf("failed to update clientRoutes, err: %v", err) } - e.clientRoutesMu.Lock() - e.clientRoutes = clientRoutes - e.clientRoutesMu.Unlock() - log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) e.updateOfflinePeers(networkMap.GetOfflinePeers()) @@ -867,8 +861,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - err = e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)) - if err != nil { + if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig)); err != nil { log.Errorf("failed to update dns server, err: %v", err) } @@ -881,7 +874,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } -func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { +func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) { + if protoRoutes == nil { + protoRoutes = []*mgmProto.Route{} + } + + var dnsRoutes []string routes := make([]*route.Route, 0) for _, protoRoute := range protoRoutes { var prefix netip.Prefix @@ -892,6 +890,8 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { continue } } + dnsRoutes = append(dnsRoutes, protoRoute.Domains...) + convertedRoute := &route.Route{ ID: route.ID(protoRoute.ID), Network: prefix, @@ -905,7 +905,7 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { } routes = append(routes, convertedRoute) } - return routes + return dnsRoutes, routes } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { @@ -982,12 +982,16 @@ func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { func (e *Engine) addNewPeer(peerConfig *mgmProto.RemotePeerConfig) error { peerKey := peerConfig.GetWgPubKey() peerIPs := peerConfig.GetAllowedIps() - if _, ok := e.peerConns[peerKey]; !ok { + if _, ok := e.peerStore.PeerConn(peerKey); !ok { conn, err := e.createPeerConn(peerKey, strings.Join(peerIPs, ",")) if err != nil { return fmt.Errorf("create peer connection: %w", err) } - e.peerConns[peerKey] = conn + + if ok := e.peerStore.AddPeerConn(peerKey, conn); !ok { + conn.Close() + return fmt.Errorf("peer already exists: %s", peerKey) + } if e.beforePeerHook != nil && e.afterPeerHook != nil { conn.AddBeforeAddPeerHook(e.beforePeerHook) @@ -1076,8 +1080,8 @@ func (e *Engine) receiveSignalEvents() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - conn := e.peerConns[msg.Key] - if conn == nil { + conn, ok := e.peerStore.PeerConn(msg.Key) + if !ok { return fmt.Errorf("wrongly addressed message %s", msg.Key) } @@ -1135,7 +1139,7 @@ func (e *Engine) receiveSignalEvents() { return err } - go conn.OnRemoteCandidate(candidate, e.GetClientRoutes()) + go conn.OnRemoteCandidate(candidate, e.routeManager.GetClientRoutes()) case sProto.Body_MODE: } @@ -1239,7 +1243,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { if err != nil { return nil, nil, err } - routes := toRoutes(netMap.GetRoutes()) + _, routes := toRoutes(netMap.GetRoutes()) dnsCfg := toDNSConfig(netMap.GetDNSConfig()) return routes, &dnsCfg, nil } @@ -1322,26 +1326,6 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { } } -// GetClientRoutes returns the current routes from the route map -func (e *Engine) GetClientRoutes() route.HAMap { - e.clientRoutesMu.RLock() - defer e.clientRoutesMu.RUnlock() - - return maps.Clone(e.clientRoutes) -} - -// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only -func (e *Engine) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { - e.clientRoutesMu.RLock() - defer e.clientRoutesMu.RUnlock() - - routes := make(map[route.NetID][]*route.Route, len(e.clientRoutes)) - for id, v := range e.clientRoutes { - routes[id.NetID()] = v - } - return routes -} - // GetRouteManager returns the route manager func (e *Engine) GetRouteManager() routemanager.Manager { return e.routeManager @@ -1426,9 +1410,8 @@ func (e *Engine) receiveProbeEvents() { go e.probes.WgProbe.Receive(e.ctx, func() bool { log.Debug("received wg probe request") - for _, peer := range e.peerConns { - key := peer.GetKey() - wgStats, err := peer.WgConfig().WgInterface.GetStats(key) + for _, key := range e.peerStore.PeersPubKey() { + wgStats, err := e.wgInterface.GetStats(key) if err != nil { log.Debugf("failed to get wg stats for peer %s: %s", key, err) } @@ -1505,7 +1488,7 @@ func (e *Engine) startNetworkMonitor() { func (e *Engine) addrViaRoutes(addr netip.Addr) (bool, netip.Prefix, error) { var vpnRoutes []netip.Prefix - for _, routes := range e.GetClientRoutes() { + for _, routes := range e.routeManager.GetClientRoutes() { if len(routes) > 0 && routes[0] != nil { vpnRoutes = append(vpnRoutes, routes[0].Network) } @@ -1573,6 +1556,40 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) { return nm, nil } +// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag +func (e *Engine) updateDNSForwarder(enabled bool, domains []string) { + if !enabled { + if e.dnsForwardMgr == nil { + return + } + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + return + } + + if len(domains) > 0 { + log.Infof("enable domain router service for domains: %v", domains) + if e.dnsForwardMgr == nil { + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall) + + if err := e.dnsForwardMgr.Start(domains); err != nil { + log.Errorf("failed to start DNS forward: %v", err) + e.dnsForwardMgr = nil + } + } else { + log.Infof("update domain router service for domains: %v", domains) + e.dnsForwardMgr.UpdateDomains(domains) + } + } else if e.dnsForwardMgr != nil { + log.Infof("disable domain router service") + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + e.dnsForwardMgr = nil + } +} + // isChecksEqual checks if two slices of checks are equal. func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { for _, check := range checks { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index b58c1f7e9..b81d8bd3f 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -39,6 +39,8 @@ import ( mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" relayClient "github.com/netbirdio/netbird/relay/client" "github.com/netbirdio/netbird/route" @@ -251,7 +253,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { }, } engine.wgInterface = wgIface - engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil) + engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil, nil) _, _, err = engine.routeManager.Init() require.NoError(t, err) engine.dnsServer = &dns.MockServer{ @@ -391,8 +393,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { return } - if len(engine.peerConns) != c.expectedLen { - t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerConns)) + if len(engine.peerStore.PeersPubKey()) != c.expectedLen { + t.Errorf("expecting Engine.peerConns to be of size %d, got %d", c.expectedLen, len(engine.peerStore.PeersPubKey())) } if engine.networkSerial != c.expectedSerial { @@ -400,7 +402,7 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { } for _, p := range c.expectedPeers { - conn, ok := engine.peerConns[p.GetWgPubKey()] + conn, ok := engine.peerStore.PeerConn(p.GetWgPubKey()) if !ok { t.Errorf("expecting Engine.peerConns to contain peer %s", p) } @@ -625,10 +627,10 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) { }{} mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { input.inputSerial = updateSerial input.inputRoutes = newRoutes - return nil, nil, testCase.inputErr + return testCase.inputErr }, } @@ -801,8 +803,8 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) { assert.NoError(t, err, "shouldn't return error") mockRouteManager := &routemanager.MockManager{ - UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { - return nil, nil, nil + UpdateRoutesFunc: func(updateSerial uint64, newRoutes []*route.Route) error { + return nil }, } @@ -1196,7 +1198,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, config.Datadir) if err != nil { return nil, "", err } @@ -1218,7 +1220,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { return nil, "", err } @@ -1237,7 +1239,8 @@ func getConnectedPeers(e *Engine) int { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() i := 0 - for _, conn := range e.peerConns { + for _, id := range e.peerStore.PeersPubKey() { + conn, _ := e.peerStore.PeerConn(id) if conn.Status() == peer.StatusConnected { i++ } @@ -1249,5 +1252,5 @@ func getPeers(e *Engine) int { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - return len(e.peerConns) + return len(e.peerStore.PeersPubKey()) } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 5c2e2cb60..b8cb2582f 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -747,6 +747,11 @@ func (conn *Conn) setRelayedProxy(proxy wgproxy.Proxy) { conn.wgProxyRelay = proxy } +// AllowedIP returns the allowed IP of the remote peer +func (conn *Conn) AllowedIP() net.IP { + return conn.allowedIP +} + func isController(config ConnConfig) bool { return config.LocalKey > config.Key } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index 74e2ee82c..dc461257a 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -17,6 +17,11 @@ import ( relayClient "github.com/netbirdio/netbird/relay/client" ) +type ResolvedDomainInfo struct { + Prefixes []netip.Prefix + ParentDomain domain.Domain +} + // State contains the latest state of a peer type State struct { Mux *sync.RWMutex @@ -138,7 +143,7 @@ type Status struct { rosenpassEnabled bool rosenpassPermissive bool nsGroupStates []NSGroupState - resolvedDomainsStates map[domain.Domain][]netip.Prefix + resolvedDomainsStates map[domain.Domain]ResolvedDomainInfo // To reduce the number of notification invocation this bool will be true when need to call the notification // Some Peer actions mostly used by in a batch when the network map has been synchronized. In these type of events @@ -156,7 +161,7 @@ func NewRecorder(mgmAddress string) *Status { offlinePeers: make([]State, 0), notifier: newNotifier(), mgmAddress: mgmAddress, - resolvedDomainsStates: make(map[domain.Domain][]netip.Prefix), + resolvedDomainsStates: map[domain.Domain]ResolvedDomainInfo{}, } } @@ -591,16 +596,27 @@ func (d *Status) UpdateDNSStates(dnsStates []NSGroupState) { d.nsGroupStates = dnsStates } -func (d *Status) UpdateResolvedDomainsStates(domain domain.Domain, prefixes []netip.Prefix) { +func (d *Status) UpdateResolvedDomainsStates(originalDomain domain.Domain, resolvedDomain domain.Domain, prefixes []netip.Prefix) { d.mux.Lock() defer d.mux.Unlock() - d.resolvedDomainsStates[domain] = prefixes + + // Store both the original domain pattern and resolved domain + d.resolvedDomainsStates[resolvedDomain] = ResolvedDomainInfo{ + Prefixes: prefixes, + ParentDomain: originalDomain, + } } func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) { d.mux.Lock() defer d.mux.Unlock() - delete(d.resolvedDomainsStates, domain) + + // Remove all entries that have this domain as their parent + for k, v := range d.resolvedDomainsStates { + if v.ParentDomain == domain { + delete(d.resolvedDomainsStates, k) + } + } } func (d *Status) GetRosenpassState() RosenpassState { @@ -702,7 +718,7 @@ func (d *Status) GetDNSStates() []NSGroupState { return d.nsGroupStates } -func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix { +func (d *Status) GetResolvedDomainsStates() map[domain.Domain]ResolvedDomainInfo { d.mux.Lock() defer d.mux.Unlock() return maps.Clone(d.resolvedDomainsStates) diff --git a/client/internal/peerstore/store.go b/client/internal/peerstore/store.go new file mode 100644 index 000000000..6b3385ff5 --- /dev/null +++ b/client/internal/peerstore/store.go @@ -0,0 +1,87 @@ +package peerstore + +import ( + "net" + "sync" + + "golang.org/x/exp/maps" + + "github.com/netbirdio/netbird/client/internal/peer" +) + +// Store is a thread-safe store for peer connections. +type Store struct { + peerConns map[string]*peer.Conn + peerConnsMu sync.RWMutex +} + +func NewConnStore() *Store { + return &Store{ + peerConns: make(map[string]*peer.Conn), + } +} + +func (s *Store) AddPeerConn(pubKey string, conn *peer.Conn) bool { + s.peerConnsMu.Lock() + defer s.peerConnsMu.Unlock() + + _, ok := s.peerConns[pubKey] + if ok { + return false + } + + s.peerConns[pubKey] = conn + return true +} + +func (s *Store) Remove(pubKey string) (*peer.Conn, bool) { + s.peerConnsMu.Lock() + defer s.peerConnsMu.Unlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + delete(s.peerConns, pubKey) + return p, true +} + +func (s *Store) AllowedIPs(pubKey string) (string, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return "", false + } + return p.WgConfig().AllowedIps, true +} + +func (s *Store) AllowedIP(pubKey string) (net.IP, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + return p.AllowedIP(), true +} + +func (s *Store) PeerConn(pubKey string) (*peer.Conn, bool) { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + p, ok := s.peerConns[pubKey] + if !ok { + return nil, false + } + return p, true +} + +func (s *Store) PeersPubKey() []string { + s.peerConnsMu.RLock() + defer s.peerConnsMu.RUnlock() + + return maps.Keys(s.peerConns) +} diff --git a/client/internal/routemanager/client.go b/client/internal/routemanager/client.go index 13e45b3a3..73f552aab 100644 --- a/client/internal/routemanager/client.go +++ b/client/internal/routemanager/client.go @@ -13,12 +13,20 @@ import ( "github.com/netbirdio/netbird/client/iface" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/dnsinterceptor" "github.com/netbirdio/netbird/client/internal/routemanager/dynamic" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/static" "github.com/netbirdio/netbird/route" ) +const ( + handlerTypeDynamic = iota + handlerTypeDomain + handlerTypeStatic +) + type routerPeerStatus struct { connected bool relayed bool @@ -53,7 +61,18 @@ type clientNetwork struct { updateSerial uint64 } -func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration, wgInterface iface.IWGIface, statusRecorder *peer.Status, rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter) *clientNetwork { +func newClientNetworkWatcher( + ctx context.Context, + dnsRouteInterval time.Duration, + wgInterface iface.IWGIface, + statusRecorder *peer.Status, + rt *route.Route, + routeRefCounter *refcounter.RouteRefCounter, + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, + dnsServer nbdns.Server, + peerStore *peerstore.Store, + useNewDNSRoute bool, +) *clientNetwork { ctx, cancel := context.WithCancel(ctx) client := &clientNetwork{ @@ -65,7 +84,17 @@ func newClientNetworkWatcher(ctx context.Context, dnsRouteInterval time.Duration routePeersNotifiers: make(map[string]chan struct{}), routeUpdate: make(chan routesUpdate), peerStateUpdate: make(chan struct{}), - handler: handlerFromRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouteInterval, statusRecorder, wgInterface), + handler: handlerFromRoute( + rt, + routeRefCounter, + allowedIPsRefCounter, + dnsRouteInterval, + statusRecorder, + wgInterface, + dnsServer, + peerStore, + useNewDNSRoute, + ), } return client } @@ -368,10 +397,50 @@ func (c *clientNetwork) peersStateAndUpdateWatcher() { } } -func handlerFromRoute(rt *route.Route, routeRefCounter *refcounter.RouteRefCounter, allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, dnsRouterInteval time.Duration, statusRecorder *peer.Status, wgInterface iface.IWGIface) RouteHandler { - if rt.IsDynamic() { +func handlerFromRoute( + rt *route.Route, + routeRefCounter *refcounter.RouteRefCounter, + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, + dnsRouterInteval time.Duration, + statusRecorder *peer.Status, + wgInterface iface.IWGIface, + dnsServer nbdns.Server, + peerStore *peerstore.Store, + useNewDNSRoute bool, +) RouteHandler { + switch handlerType(rt, useNewDNSRoute) { + case handlerTypeDomain: + return dnsinterceptor.New( + rt, + routeRefCounter, + allowedIPsRefCounter, + statusRecorder, + dnsServer, + peerStore, + ) + case handlerTypeDynamic: dns := nbdns.NewServiceViaMemory(wgInterface) - return dynamic.NewRoute(rt, routeRefCounter, allowedIPsRefCounter, dnsRouterInteval, statusRecorder, wgInterface, fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())) + return dynamic.NewRoute( + rt, + routeRefCounter, + allowedIPsRefCounter, + dnsRouterInteval, + statusRecorder, + wgInterface, + fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort()), + ) + default: + return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) } - return static.NewRoute(rt, routeRefCounter, allowedIPsRefCounter) +} + +func handlerType(rt *route.Route, useNewDNSRoute bool) int { + if !rt.IsDynamic() { + return handlerTypeStatic + } + + if useNewDNSRoute { + return handlerTypeDomain + } + return handlerTypeDynamic } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go new file mode 100644 index 000000000..10cb03f1d --- /dev/null +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -0,0 +1,356 @@ +package dnsinterceptor + +import ( + "context" + "fmt" + "net" + "net/netip" + "strings" + "sync" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nberrors "github.com/netbirdio/netbird/client/errors" + nbdns "github.com/netbirdio/netbird/client/internal/dns" + "github.com/netbirdio/netbird/client/internal/dnsfwd" + "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" + "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" + "github.com/netbirdio/netbird/management/domain" + "github.com/netbirdio/netbird/route" +) + +type domainMap map[domain.Domain][]netip.Prefix + +type DnsInterceptor struct { + mu sync.RWMutex + route *route.Route + routeRefCounter *refcounter.RouteRefCounter + allowedIPsRefcounter *refcounter.AllowedIPsRefCounter + statusRecorder *peer.Status + dnsServer nbdns.Server + currentPeerKey string + interceptedDomains domainMap + peerStore *peerstore.Store +} + +func New( + rt *route.Route, + routeRefCounter *refcounter.RouteRefCounter, + allowedIPsRefCounter *refcounter.AllowedIPsRefCounter, + statusRecorder *peer.Status, + dnsServer nbdns.Server, + peerStore *peerstore.Store, +) *DnsInterceptor { + return &DnsInterceptor{ + route: rt, + routeRefCounter: routeRefCounter, + allowedIPsRefcounter: allowedIPsRefCounter, + statusRecorder: statusRecorder, + dnsServer: dnsServer, + interceptedDomains: make(domainMap), + peerStore: peerStore, + } +} + +func (d *DnsInterceptor) String() string { + return d.route.Domains.SafeString() +} + +func (d *DnsInterceptor) AddRoute(context.Context) error { + d.dnsServer.RegisterHandler(d.route.Domains.ToPunycodeList(), d, nbdns.PriorityDNSRoute) + return nil +} + +func (d *DnsInterceptor) RemoveRoute() error { + d.mu.Lock() + + var merr *multierror.Error + for domain, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if _, err := d.routeRefCounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove dynamic route for IP %s: %v", prefix, err)) + } + if d.currentPeerKey != "" { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } + } + } + log.Debugf("removed dynamic route(s) for [%s]: %s", domain.SafeString(), strings.ReplaceAll(fmt.Sprintf("%s", prefixes), " ", ", ")) + + } + for _, domain := range d.route.Domains { + d.statusRecorder.DeleteResolvedDomainsStates(domain) + } + + clear(d.interceptedDomains) + d.mu.Unlock() + + d.dnsServer.DeregisterHandler(d.route.Domains.ToPunycodeList(), nbdns.PriorityDNSRoute) + + return nberrors.FormatErrorOrNil(merr) +} + +func (d *DnsInterceptor) AddAllowedIPs(peerKey string) error { + d.mu.Lock() + defer d.mu.Unlock() + + var merr *multierror.Error + for domain, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if ref, err := d.allowedIPsRefcounter.Increment(prefix, peerKey); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) + } else if ref.Count > 1 && ref.Out != peerKey { + log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", + prefix.Addr(), + domain.SafeString(), + ref.Out, + ) + } + } + } + + d.currentPeerKey = peerKey + return nberrors.FormatErrorOrNil(merr) +} + +func (d *DnsInterceptor) RemoveAllowedIPs() error { + d.mu.Lock() + defer d.mu.Unlock() + + var merr *multierror.Error + for _, prefixes := range d.interceptedDomains { + for _, prefix := range prefixes { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } + } + } + + d.currentPeerKey = "" + return nberrors.FormatErrorOrNil(merr) +} + +// ServeDNS implements the dns.Handler interface +func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + if len(r.Question) == 0 { + return + } + log.Tracef("received DNS request for domain=%s type=%v class=%v", + r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) + + d.mu.RLock() + peerKey := d.currentPeerKey + d.mu.RUnlock() + + if peerKey == "" { + log.Tracef("no current peer key set, letting next handler try for domain=%s", r.Question[0].Name) + + d.continueToNextHandler(w, r, "no current peer key") + return + } + + upstreamIP, err := d.getUpstreamIP(peerKey) + if err != nil { + log.Errorf("failed to get upstream IP: %v", err) + d.continueToNextHandler(w, r, fmt.Sprintf("failed to get upstream IP: %v", err)) + return + } + + client := &dns.Client{ + Timeout: 5 * time.Second, + Net: "udp", + } + upstream := fmt.Sprintf("%s:%d", upstreamIP, dnsfwd.ListenPort) + reply, _, err := client.ExchangeContext(context.Background(), r, upstream) + + var answer []dns.RR + if reply != nil { + answer = reply.Answer + } + log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP, peerKey, r.Question[0].Name, answer) + + if err != nil { + log.Errorf("failed to exchange DNS request with %s: %v", upstream, err) + if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { + log.Errorf("failed writing DNS response: %v", err) + } + return + } + + reply.Id = r.Id + if err := d.writeMsg(w, reply); err != nil { + log.Errorf("failed writing DNS response: %v", err) + } +} + +// continueToNextHandler signals the handler chain to try the next handler +func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { + log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) + + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeNameError) + // Set Zero bit to signal handler chain to continue + resp.MsgHdr.Zero = true + if err := w.WriteMsg(resp); err != nil { + log.Errorf("failed writing DNS continue response: %v", err) + } +} + +func (d *DnsInterceptor) getUpstreamIP(peerKey string) (net.IP, error) { + peerAllowedIP, exists := d.peerStore.AllowedIP(peerKey) + if !exists { + return nil, fmt.Errorf("peer connection not found for key: %s", peerKey) + } + return peerAllowedIP, nil +} + +func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg) error { + if r == nil { + return fmt.Errorf("received nil DNS message") + } + + if len(r.Answer) > 0 && len(r.Question) > 0 { + origPattern := "" + if writer, ok := w.(*nbdns.ResponseWriterChain); ok { + origPattern = writer.GetOrigPattern() + } + + resolvedDomain := domain.Domain(r.Question[0].Name) + + // already punycode via RegisterHandler() + originalDomain := domain.Domain(origPattern) + if originalDomain == "" { + originalDomain = resolvedDomain + } + + var newPrefixes []netip.Prefix + for _, answer := range r.Answer { + var ip netip.Addr + switch rr := answer.(type) { + case *dns.A: + addr, ok := netip.AddrFromSlice(rr.A) + if !ok { + log.Tracef("failed to convert A record for domain=%s ip=%v", resolvedDomain, rr.A) + continue + } + ip = addr + case *dns.AAAA: + addr, ok := netip.AddrFromSlice(rr.AAAA) + if !ok { + log.Tracef("failed to convert AAAA record for domain=%s ip=%v", resolvedDomain, rr.AAAA) + continue + } + ip = addr + default: + continue + } + + prefix := netip.PrefixFrom(ip, ip.BitLen()) + newPrefixes = append(newPrefixes, prefix) + } + + if len(newPrefixes) > 0 { + if err := d.updateDomainPrefixes(resolvedDomain, originalDomain, newPrefixes); err != nil { + log.Errorf("failed to update domain prefixes: %v", err) + } + } + } + + if err := w.WriteMsg(r); err != nil { + return fmt.Errorf("failed to write DNS response: %v", err) + } + + return nil +} + +func (d *DnsInterceptor) updateDomainPrefixes(resolvedDomain, originalDomain domain.Domain, newPrefixes []netip.Prefix) error { + d.mu.Lock() + defer d.mu.Unlock() + + oldPrefixes := d.interceptedDomains[resolvedDomain] + toAdd, toRemove := determinePrefixChanges(oldPrefixes, newPrefixes) + + var merr *multierror.Error + + // Add new prefixes + for _, prefix := range toAdd { + if _, err := d.routeRefCounter.Increment(prefix, struct{}{}); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add route for IP %s: %v", prefix, err)) + continue + } + + if d.currentPeerKey == "" { + continue + } + if ref, err := d.allowedIPsRefcounter.Increment(prefix, d.currentPeerKey); err != nil { + merr = multierror.Append(merr, fmt.Errorf("add allowed IP %s: %v", prefix, err)) + } else if ref.Count > 1 && ref.Out != d.currentPeerKey { + log.Warnf("IP [%s] for domain [%s] is already routed by peer [%s]. HA routing disabled", + prefix.Addr(), + resolvedDomain.SafeString(), + ref.Out, + ) + } + } + + if !d.route.KeepRoute { + // Remove old prefixes + for _, prefix := range toRemove { + if _, err := d.routeRefCounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route for IP %s: %v", prefix, err)) + } + if d.currentPeerKey != "" { + if _, err := d.allowedIPsRefcounter.Decrement(prefix); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove allowed IP %s: %v", prefix, err)) + } + } + } + } + + // Update domain prefixes using resolved domain as key + if len(toAdd) > 0 || len(toRemove) > 0 { + d.interceptedDomains[resolvedDomain] = newPrefixes + originalDomain = domain.Domain(strings.TrimSuffix(string(originalDomain), ".")) + d.statusRecorder.UpdateResolvedDomainsStates(originalDomain, resolvedDomain, newPrefixes) + + if len(toAdd) > 0 { + log.Debugf("added dynamic route(s) for domain=%s (pattern: domain=%s): %s", + resolvedDomain.SafeString(), + originalDomain.SafeString(), + toAdd) + } + if len(toRemove) > 0 { + log.Debugf("removed dynamic route(s) for domain=%s (pattern: domain=%s): %s", + resolvedDomain.SafeString(), + originalDomain.SafeString(), + toRemove) + } + } + + return nberrors.FormatErrorOrNil(merr) +} + +func determinePrefixChanges(oldPrefixes, newPrefixes []netip.Prefix) (toAdd, toRemove []netip.Prefix) { + prefixSet := make(map[netip.Prefix]bool) + for _, prefix := range oldPrefixes { + prefixSet[prefix] = false + } + for _, prefix := range newPrefixes { + if _, exists := prefixSet[prefix]; exists { + prefixSet[prefix] = true + } else { + toAdd = append(toAdd, prefix) + } + } + for prefix, inUse := range prefixSet { + if !inUse { + toRemove = append(toRemove, prefix) + } + } + return +} diff --git a/client/internal/routemanager/dynamic/route.go b/client/internal/routemanager/dynamic/route.go index ac94d4a5c..a0fff7713 100644 --- a/client/internal/routemanager/dynamic/route.go +++ b/client/internal/routemanager/dynamic/route.go @@ -74,11 +74,7 @@ func NewRoute( } func (r *Route) String() string { - s, err := r.route.Domains.String() - if err != nil { - return r.route.Domains.PunycodeString() - } - return s + return r.route.Domains.SafeString() } func (r *Route) AddRoute(ctx context.Context) error { @@ -292,7 +288,7 @@ func (r *Route) updateDynamicRoutes(ctx context.Context, newDomains domainMap) e updatedPrefixes := combinePrefixes(oldPrefixes, removedPrefixes, addedPrefixes) r.dynamicDomains[domain] = updatedPrefixes - r.statusRecorder.UpdateResolvedDomainsStates(domain, updatedPrefixes) + r.statusRecorder.UpdateResolvedDomainsStates(domain, domain, updatedPrefixes) } return nberrors.FormatErrorOrNil(merr) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 8bf3a91b0..389e97e2d 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -12,12 +12,15 @@ import ( "time" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/notifier" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" @@ -33,9 +36,11 @@ import ( // Manager is a route manager interface type Manager interface { Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) - UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) + UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error TriggerSelection(route.HAMap) GetRouteSelector() *routeselector.RouteSelector + GetClientRoutes() route.HAMap + GetClientRoutesWithNetID() map[route.NetID][]*route.Route SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string EnableServerRouter(firewall firewall.Manager) error @@ -60,6 +65,11 @@ type DefaultManager struct { allowedIPsRefCounter *refcounter.AllowedIPsRefCounter dnsRouteInterval time.Duration stateManager *statemanager.Manager + // clientRoutes is the most recent list of clientRoutes received from the Management Service + clientRoutes route.HAMap + dnsServer dns.Server + peerStore *peerstore.Store + useNewDNSRoute bool } func NewManager( @@ -71,6 +81,8 @@ func NewManager( relayMgr *relayClient.Manager, initialRoutes []*route.Route, stateManager *statemanager.Manager, + dnsServer dns.Server, + peerStore *peerstore.Store, ) *DefaultManager { mCTX, cancel := context.WithCancel(ctx) notifier := notifier.NewNotifier() @@ -88,6 +100,8 @@ func NewManager( pubKey: pubKey, notifier: notifier, stateManager: stateManager, + dnsServer: dnsServer, + peerStore: peerStore, } dm.routeRefCounter = refcounter.New( @@ -116,7 +130,7 @@ func NewManager( ) if runtime.GOOS == "android" { - cr := dm.clientRoutes(initialRoutes) + cr := dm.initialClientRoutes(initialRoutes) dm.notifier.SetInitialClientRoutes(cr) } return dm @@ -207,33 +221,41 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } m.ctx = nil + + m.mux.Lock() + defer m.mux.Unlock() + m.clientRoutes = nil } // UpdateRoutes compares received routes with existing routes and removes, updates or adds them to the client and server maps -func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { +func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, useNewDNSRoute bool) error { select { case <-m.ctx.Done(): log.Infof("not updating routes as context is closed") - return nil, nil, m.ctx.Err() + return nil default: - m.mux.Lock() - defer m.mux.Unlock() - - newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) - - filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) - m.updateClientNetworks(updateSerial, filteredClientRoutes) - m.notifier.OnNewRoutes(filteredClientRoutes) - - if m.serverRouter != nil { - err := m.serverRouter.updateRoutes(newServerRoutesMap) - if err != nil { - return nil, nil, fmt.Errorf("update routes: %w", err) - } - } - - return newServerRoutesMap, newClientRoutesIDMap, nil } + + m.mux.Lock() + defer m.mux.Unlock() + m.useNewDNSRoute = useNewDNSRoute + + newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) + + filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) + m.updateClientNetworks(updateSerial, filteredClientRoutes) + m.notifier.OnNewRoutes(filteredClientRoutes) + + if m.serverRouter != nil { + err := m.serverRouter.updateRoutes(newServerRoutesMap) + if err != nil { + return err + } + } + + m.clientRoutes = newClientRoutesIDMap + + return nil } // SetRouteChangeListener set RouteListener for route change Notifier @@ -251,9 +273,24 @@ func (m *DefaultManager) GetRouteSelector() *routeselector.RouteSelector { return m.routeSelector } -// GetClientRoutes returns the client routes -func (m *DefaultManager) GetClientRoutes() map[route.HAUniqueID]*clientNetwork { - return m.clientNetworks +// GetClientRoutes returns most recent list of clientRoutes received from the Management Service +func (m *DefaultManager) GetClientRoutes() route.HAMap { + m.mux.Lock() + defer m.mux.Unlock() + + return maps.Clone(m.clientRoutes) +} + +// GetClientRoutesWithNetID returns the current routes from the route map, but the keys consist of the network ID only +func (m *DefaultManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { + m.mux.Lock() + defer m.mux.Unlock() + + routes := make(map[route.NetID][]*route.Route, len(m.clientRoutes)) + for id, v := range m.clientRoutes { + routes[id.NetID()] = v + } + return routes } // TriggerSelection triggers the selection of routes, stopping deselected watchers and starting newly selected ones @@ -273,7 +310,18 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { continue } - clientNetworkWatcher := newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) + clientNetworkWatcher := newClientNetworkWatcher( + m.ctx, + m.dnsRouteInterval, + m.wgInterface, + m.statusRecorder, + routes[0], + m.routeRefCounter, + m.allowedIPsRefCounter, + m.dnsServer, + m.peerStore, + m.useNewDNSRoute, + ) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() clientNetworkWatcher.sendUpdateToClientNetworkWatcher(routesUpdate{routes: routes}) @@ -302,7 +350,18 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout for id, routes := range networks { clientNetworkWatcher, found := m.clientNetworks[id] if !found { - clientNetworkWatcher = newClientNetworkWatcher(m.ctx, m.dnsRouteInterval, m.wgInterface, m.statusRecorder, routes[0], m.routeRefCounter, m.allowedIPsRefCounter) + clientNetworkWatcher = newClientNetworkWatcher( + m.ctx, + m.dnsRouteInterval, + m.wgInterface, + m.statusRecorder, + routes[0], + m.routeRefCounter, + m.allowedIPsRefCounter, + m.dnsServer, + m.peerStore, + m.useNewDNSRoute, + ) m.clientNetworks[id] = clientNetworkWatcher go clientNetworkWatcher.peersStateAndUpdateWatcher() } @@ -345,7 +404,7 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID] return newServerRoutesMap, newClientRoutesIDMap } -func (m *DefaultManager) clientRoutes(initialRoutes []*route.Route) []*route.Route { +func (m *DefaultManager) initialClientRoutes(initialRoutes []*route.Route) []*route.Route { _, crMap := m.classifyRoutes(initialRoutes) rs := make([]*route.Route, 0, len(crMap)) for _, routes := range crMap { diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 07dac21b8..4b7c984e5 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -424,7 +424,7 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() - routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil) + routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil, nil, nil) _, _, err = routeManager.Init() @@ -436,11 +436,11 @@ func TestManagerUpdateRoutes(t *testing.T) { } if len(testCase.inputInitRoutes) > 0 { - _, _, err = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes) + _ = routeManager.UpdateRoutes(testCase.inputSerial, testCase.inputRoutes, false) require.NoError(t, err, "should update routes with init routes") } - _, _, err = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes) + _ = routeManager.UpdateRoutes(testCase.inputSerial+uint64(len(testCase.inputInitRoutes)), testCase.inputRoutes, false) require.NoError(t, err, "should update routes") expectedWatchers := testCase.clientNetworkWatchersExpected diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index 556a62351..64fdffceb 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -2,7 +2,6 @@ package routemanager import ( "context" - "fmt" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" @@ -15,10 +14,12 @@ import ( // MockManager is the mock instance of a route manager type MockManager struct { - UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) - TriggerSelectionFunc func(haMap route.HAMap) - GetRouteSelectorFunc func() *routeselector.RouteSelector - StopFunc func(manager *statemanager.Manager) + UpdateRoutesFunc func(updateSerial uint64, newRoutes []*route.Route) error + TriggerSelectionFunc func(haMap route.HAMap) + GetRouteSelectorFunc func() *routeselector.RouteSelector + GetClientRoutesFunc func() route.HAMap + GetClientRoutesWithNetIDFunc func() map[route.NetID][]*route.Route + StopFunc func(manager *statemanager.Manager) } func (m *MockManager) Init() (net.AddHookFunc, net.RemoveHookFunc, error) { @@ -31,11 +32,11 @@ func (m *MockManager) InitialRouteRange() []string { } // UpdateRoutes mock implementation of UpdateRoutes from Manager interface -func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route) (map[route.ID]*route.Route, route.HAMap, error) { +func (m *MockManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Route, b bool) error { if m.UpdateRoutesFunc != nil { return m.UpdateRoutesFunc(updateSerial, newRoutes) } - return nil, nil, fmt.Errorf("method UpdateRoutes is not implemented") + return nil } func (m *MockManager) TriggerSelection(networks route.HAMap) { @@ -52,6 +53,22 @@ func (m *MockManager) GetRouteSelector() *routeselector.RouteSelector { return nil } +// GetClientRoutes mock implementation of GetClientRoutes from Manager interface +func (m *MockManager) GetClientRoutes() route.HAMap { + if m.GetClientRoutesFunc != nil { + return m.GetClientRoutesFunc() + } + return nil +} + +// GetClientRoutesWithNetID mock implementation of GetClientRoutesWithNetID from Manager interface +func (m *MockManager) GetClientRoutesWithNetID() map[route.NetID][]*route.Route { + if m.GetClientRoutesWithNetIDFunc != nil { + return m.GetClientRoutesWithNetIDFunc() + } + return nil +} + // Start mock implementation of Start from Manager interface func (m *MockManager) Start(ctx context.Context, iface *iface.WGIface) { } diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 6f501e0c6..befce56a2 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -272,8 +272,8 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { return nil, fmt.Errorf("not connected") } - routesMap := engine.GetClientRoutesWithNetID() routeManager := engine.GetRouteManager() + routesMap := routeManager.GetClientRoutesWithNetID() if routeManager == nil { return nil, fmt.Errorf("could not get route manager") } @@ -317,7 +317,7 @@ func (c *Client) GetRoutesSelectionDetails() (*RoutesSelectionDetails, error) { } -func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain][]netip.Prefix) *RoutesSelectionDetails { +func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[domain.Domain]peer.ResolvedDomainInfo) *RoutesSelectionDetails { var routeSelection []RoutesSelectionInfo for _, r := range routes { domainList := make([]DomainInfo, 0) @@ -325,9 +325,10 @@ func prepareRouteSelectionDetails(routes []*selectRoute, resolvedDomains map[dom domainResp := DomainInfo{ Domain: d.SafeString(), } - if prefixes, exists := resolvedDomains[d]; exists { + + if info, exists := resolvedDomains[d]; exists { var ipStrings []string - for _, prefix := range prefixes { + for _, prefix := range info.Prefixes { ipStrings = append(ipStrings, prefix.Addr().String()) } domainResp.ResolvedIPs = strings.Join(ipStrings, ", ") @@ -365,12 +366,12 @@ func (c *Client) SelectRoute(id string) error { } else { log.Debugf("select route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.SelectRoutes(routes, true, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + if err := routeSelector.SelectRoutes(routes, true, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { log.Debugf("error when selecting routes: %s", err) return fmt.Errorf("select routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) return nil } @@ -392,12 +393,12 @@ func (c *Client) DeselectRoute(id string) error { } else { log.Debugf("deselect route with id: %s", id) routes := toNetIDs([]string{id}) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + if err := routeSelector.DeselectRoutes(routes, maps.Keys(routeManager.GetClientRoutesWithNetID())); err != nil { log.Debugf("error when deselecting routes: %s", err) return fmt.Errorf("deselect routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) return nil } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 98ce2c4a2..f0d3021e9 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.23.4 +// protoc v3.21.9 // source: daemon.proto package proto @@ -908,7 +908,7 @@ type PeerState struct { BytesRx int64 `protobuf:"varint,13,opt,name=bytesRx,proto3" json:"bytesRx,omitempty"` BytesTx int64 `protobuf:"varint,14,opt,name=bytesTx,proto3" json:"bytesTx,omitempty"` RosenpassEnabled bool `protobuf:"varint,15,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` - Routes []string `protobuf:"bytes,16,rep,name=routes,proto3" json:"routes,omitempty"` + Networks []string `protobuf:"bytes,16,rep,name=networks,proto3" json:"networks,omitempty"` Latency *durationpb.Duration `protobuf:"bytes,17,opt,name=latency,proto3" json:"latency,omitempty"` RelayAddress string `protobuf:"bytes,18,opt,name=relayAddress,proto3" json:"relayAddress,omitempty"` } @@ -1043,9 +1043,9 @@ func (x *PeerState) GetRosenpassEnabled() bool { return false } -func (x *PeerState) GetRoutes() []string { +func (x *PeerState) GetNetworks() []string { if x != nil { - return x.Routes + return x.Networks } return nil } @@ -1076,7 +1076,7 @@ type LocalPeerState struct { Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` RosenpassEnabled bool `protobuf:"varint,5,opt,name=rosenpassEnabled,proto3" json:"rosenpassEnabled,omitempty"` RosenpassPermissive bool `protobuf:"varint,6,opt,name=rosenpassPermissive,proto3" json:"rosenpassPermissive,omitempty"` - Routes []string `protobuf:"bytes,7,rep,name=routes,proto3" json:"routes,omitempty"` + Networks []string `protobuf:"bytes,7,rep,name=networks,proto3" json:"networks,omitempty"` } func (x *LocalPeerState) Reset() { @@ -1153,9 +1153,9 @@ func (x *LocalPeerState) GetRosenpassPermissive() bool { return false } -func (x *LocalPeerState) GetRoutes() []string { +func (x *LocalPeerState) GetNetworks() []string { if x != nil { - return x.Routes + return x.Networks } return nil } @@ -1511,14 +1511,14 @@ func (x *FullStatus) GetDnsServers() []*NSGroupState { return nil } -type ListRoutesRequest struct { +type ListNetworksRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } -func (x *ListRoutesRequest) Reset() { - *x = ListRoutesRequest{} +func (x *ListNetworksRequest) Reset() { + *x = ListNetworksRequest{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[19] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1526,13 +1526,13 @@ func (x *ListRoutesRequest) Reset() { } } -func (x *ListRoutesRequest) String() string { +func (x *ListNetworksRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ListRoutesRequest) ProtoMessage() {} +func (*ListNetworksRequest) ProtoMessage() {} -func (x *ListRoutesRequest) ProtoReflect() protoreflect.Message { +func (x *ListNetworksRequest) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[19] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1544,21 +1544,21 @@ func (x *ListRoutesRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ListRoutesRequest.ProtoReflect.Descriptor instead. -func (*ListRoutesRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use ListNetworksRequest.ProtoReflect.Descriptor instead. +func (*ListNetworksRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{19} } -type ListRoutesResponse struct { +type ListNetworksResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Routes []*Route `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"` + Routes []*Network `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"` } -func (x *ListRoutesResponse) Reset() { - *x = ListRoutesResponse{} +func (x *ListNetworksResponse) Reset() { + *x = ListNetworksResponse{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[20] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1566,13 +1566,13 @@ func (x *ListRoutesResponse) Reset() { } } -func (x *ListRoutesResponse) String() string { +func (x *ListNetworksResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*ListRoutesResponse) ProtoMessage() {} +func (*ListNetworksResponse) ProtoMessage() {} -func (x *ListRoutesResponse) ProtoReflect() protoreflect.Message { +func (x *ListNetworksResponse) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[20] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1584,30 +1584,30 @@ func (x *ListRoutesResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use ListRoutesResponse.ProtoReflect.Descriptor instead. -func (*ListRoutesResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use ListNetworksResponse.ProtoReflect.Descriptor instead. +func (*ListNetworksResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{20} } -func (x *ListRoutesResponse) GetRoutes() []*Route { +func (x *ListNetworksResponse) GetRoutes() []*Network { if x != nil { return x.Routes } return nil } -type SelectRoutesRequest struct { +type SelectNetworksRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - RouteIDs []string `protobuf:"bytes,1,rep,name=routeIDs,proto3" json:"routeIDs,omitempty"` - Append bool `protobuf:"varint,2,opt,name=append,proto3" json:"append,omitempty"` - All bool `protobuf:"varint,3,opt,name=all,proto3" json:"all,omitempty"` + NetworkIDs []string `protobuf:"bytes,1,rep,name=networkIDs,proto3" json:"networkIDs,omitempty"` + Append bool `protobuf:"varint,2,opt,name=append,proto3" json:"append,omitempty"` + All bool `protobuf:"varint,3,opt,name=all,proto3" json:"all,omitempty"` } -func (x *SelectRoutesRequest) Reset() { - *x = SelectRoutesRequest{} +func (x *SelectNetworksRequest) Reset() { + *x = SelectNetworksRequest{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[21] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1615,13 +1615,13 @@ func (x *SelectRoutesRequest) Reset() { } } -func (x *SelectRoutesRequest) String() string { +func (x *SelectNetworksRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SelectRoutesRequest) ProtoMessage() {} +func (*SelectNetworksRequest) ProtoMessage() {} -func (x *SelectRoutesRequest) ProtoReflect() protoreflect.Message { +func (x *SelectNetworksRequest) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[21] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1633,40 +1633,40 @@ func (x *SelectRoutesRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SelectRoutesRequest.ProtoReflect.Descriptor instead. -func (*SelectRoutesRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use SelectNetworksRequest.ProtoReflect.Descriptor instead. +func (*SelectNetworksRequest) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{21} } -func (x *SelectRoutesRequest) GetRouteIDs() []string { +func (x *SelectNetworksRequest) GetNetworkIDs() []string { if x != nil { - return x.RouteIDs + return x.NetworkIDs } return nil } -func (x *SelectRoutesRequest) GetAppend() bool { +func (x *SelectNetworksRequest) GetAppend() bool { if x != nil { return x.Append } return false } -func (x *SelectRoutesRequest) GetAll() bool { +func (x *SelectNetworksRequest) GetAll() bool { if x != nil { return x.All } return false } -type SelectRoutesResponse struct { +type SelectNetworksResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } -func (x *SelectRoutesResponse) Reset() { - *x = SelectRoutesResponse{} +func (x *SelectNetworksResponse) Reset() { + *x = SelectNetworksResponse{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[22] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1674,13 +1674,13 @@ func (x *SelectRoutesResponse) Reset() { } } -func (x *SelectRoutesResponse) String() string { +func (x *SelectNetworksResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*SelectRoutesResponse) ProtoMessage() {} +func (*SelectNetworksResponse) ProtoMessage() {} -func (x *SelectRoutesResponse) ProtoReflect() protoreflect.Message { +func (x *SelectNetworksResponse) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[22] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1692,8 +1692,8 @@ func (x *SelectRoutesResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use SelectRoutesResponse.ProtoReflect.Descriptor instead. -func (*SelectRoutesResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use SelectNetworksResponse.ProtoReflect.Descriptor instead. +func (*SelectNetworksResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{22} } @@ -1744,20 +1744,20 @@ func (x *IPList) GetIps() []string { return nil } -type Route struct { +type Network struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields ID string `protobuf:"bytes,1,opt,name=ID,proto3" json:"ID,omitempty"` - Network string `protobuf:"bytes,2,opt,name=network,proto3" json:"network,omitempty"` + Range string `protobuf:"bytes,2,opt,name=range,proto3" json:"range,omitempty"` Selected bool `protobuf:"varint,3,opt,name=selected,proto3" json:"selected,omitempty"` Domains []string `protobuf:"bytes,4,rep,name=domains,proto3" json:"domains,omitempty"` ResolvedIPs map[string]*IPList `protobuf:"bytes,5,rep,name=resolvedIPs,proto3" json:"resolvedIPs,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` } -func (x *Route) Reset() { - *x = Route{} +func (x *Network) Reset() { + *x = Network{} if protoimpl.UnsafeEnabled { mi := &file_daemon_proto_msgTypes[24] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1765,13 +1765,13 @@ func (x *Route) Reset() { } } -func (x *Route) String() string { +func (x *Network) String() string { return protoimpl.X.MessageStringOf(x) } -func (*Route) ProtoMessage() {} +func (*Network) ProtoMessage() {} -func (x *Route) ProtoReflect() protoreflect.Message { +func (x *Network) ProtoReflect() protoreflect.Message { mi := &file_daemon_proto_msgTypes[24] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -1783,40 +1783,40 @@ func (x *Route) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use Route.ProtoReflect.Descriptor instead. -func (*Route) Descriptor() ([]byte, []int) { +// Deprecated: Use Network.ProtoReflect.Descriptor instead. +func (*Network) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{24} } -func (x *Route) GetID() string { +func (x *Network) GetID() string { if x != nil { return x.ID } return "" } -func (x *Route) GetNetwork() string { +func (x *Network) GetRange() string { if x != nil { - return x.Network + return x.Range } return "" } -func (x *Route) GetSelected() bool { +func (x *Network) GetSelected() bool { if x != nil { return x.Selected } return false } -func (x *Route) GetDomains() []string { +func (x *Network) GetDomains() []string { if x != nil { return x.Domains } return nil } -func (x *Route) GetResolvedIPs() map[string]*IPList { +func (x *Network) GetResolvedIPs() map[string]*IPList { if x != nil { return x.ResolvedIPs } @@ -2671,7 +2671,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, - 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xda, 0x05, 0x0a, 0x09, 0x50, 0x65, + 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xde, 0x05, 0x0a, 0x09, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, @@ -2710,233 +2710,235 @@ var file_daemon_proto_rawDesc = []byte{ 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, - 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, - 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, - 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, - 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, - 0x63, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0xec, 0x01, 0x0a, 0x0e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, - 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, - 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, - 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, - 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x6b, 0x65, 0x72, 0x6e, - 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x66, - 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, - 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, - 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, - 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, - 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, - 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x12, 0x16, 0x0a, - 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, - 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, - 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, - 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, - 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, - 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, - 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, - 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, - 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, - 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, - 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, - 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, - 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, - 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, - 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, - 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, - 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x22, 0x5b, 0x0a, 0x13, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x6f, 0x75, - 0x74, 0x65, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x72, 0x6f, 0x75, - 0x74, 0x65, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, 0x64, 0x12, 0x10, 0x0a, - 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, - 0x16, 0x0a, 0x14, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, 0x06, 0x49, 0x50, 0x4c, 0x69, 0x73, - 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, - 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, - 0x07, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, - 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, 0x65, 0x6c, 0x65, 0x63, - 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x04, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x40, 0x0a, - 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x18, 0x05, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, - 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x1a, - 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, - 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50, - 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, - 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, - 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, - 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73, - 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, - 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44, - 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, - 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13, - 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53, - 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, - 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, - 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, - 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a, - 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, - 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, - 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, - 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63, - 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, - 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, - 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c, - 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, - 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, - 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, - 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, - 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, - 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, - 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, - 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0x81, 0x09, 0x0a, - 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, - 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, - 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, - 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, - 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, - 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, - 0x73, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, - 0x0c, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4d, 0x0a, 0x0e, 0x44, 0x65, - 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x1b, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, - 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, - 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, - 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, - 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, - 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, - 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, - 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, - 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, - 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, - 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, - 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, - 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, - 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, - 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x73, 0x12, 0x33, 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, + 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, + 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, + 0x6c, 0x61, 0x79, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0xf0, 0x01, 0x0a, 0x0e, 0x4c, + 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, + 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, + 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, + 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, + 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, + 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, + 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, + 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, + 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, + 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, + 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, + 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, + 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x18, 0x07, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x22, 0x53, 0x0a, + 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, + 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, + 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, + 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, + 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, + 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, + 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, + 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, + 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, + 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, + 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, + 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, + 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, + 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, + 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, + 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, + 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, + 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, + 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, + 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, + 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, + 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, + 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, + 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, + 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, + 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x22, 0x15, 0x0a, 0x13, 0x4c, 0x69, 0x73, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, + 0x3f, 0x0a, 0x14, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x27, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x22, 0x61, 0x0a, 0x15, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x6e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x6e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, + 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, + 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, + 0x61, 0x6c, 0x6c, 0x22, 0x18, 0x0a, 0x16, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, + 0x06, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x07, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x14, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x73, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x12, 0x42, 0x0a, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, + 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, + 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, + 0x65, 0x64, 0x49, 0x50, 0x73, 0x1a, 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, + 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, + 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, + 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, + 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, + 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, + 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, + 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, + 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, + 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, + 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, + 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, + 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, + 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, + 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, + 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, + 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, + 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, + 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, + 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, + 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, + 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, + 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, + 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, + 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, + 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, + 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, + 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, + 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, - 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, + 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, + 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, + 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, + 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, + 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, + 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, + 0x10, 0x07, 0x32, 0x93, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, + 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, + 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, + 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, + 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, + 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, + 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, + 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, 0x6c, + 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, 0x10, + 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, + 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, + 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, + 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, + 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, + 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, + 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, + 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, + 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, + 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, + 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, + 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, + 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, + 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, + 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -2974,12 +2976,12 @@ var file_daemon_proto_goTypes = []interface{}{ (*RelayState)(nil), // 17: daemon.RelayState (*NSGroupState)(nil), // 18: daemon.NSGroupState (*FullStatus)(nil), // 19: daemon.FullStatus - (*ListRoutesRequest)(nil), // 20: daemon.ListRoutesRequest - (*ListRoutesResponse)(nil), // 21: daemon.ListRoutesResponse - (*SelectRoutesRequest)(nil), // 22: daemon.SelectRoutesRequest - (*SelectRoutesResponse)(nil), // 23: daemon.SelectRoutesResponse + (*ListNetworksRequest)(nil), // 20: daemon.ListNetworksRequest + (*ListNetworksResponse)(nil), // 21: daemon.ListNetworksResponse + (*SelectNetworksRequest)(nil), // 22: daemon.SelectNetworksRequest + (*SelectNetworksResponse)(nil), // 23: daemon.SelectNetworksResponse (*IPList)(nil), // 24: daemon.IPList - (*Route)(nil), // 25: daemon.Route + (*Network)(nil), // 25: daemon.Network (*DebugBundleRequest)(nil), // 26: daemon.DebugBundleRequest (*DebugBundleResponse)(nil), // 27: daemon.DebugBundleResponse (*GetLogLevelRequest)(nil), // 28: daemon.GetLogLevelRequest @@ -2995,7 +2997,7 @@ var file_daemon_proto_goTypes = []interface{}{ (*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse (*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest (*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse - nil, // 41: daemon.Route.ResolvedIPsEntry + nil, // 41: daemon.Network.ResolvedIPsEntry (*durationpb.Duration)(nil), // 42: google.protobuf.Duration (*timestamppb.Timestamp)(nil), // 43: google.protobuf.Timestamp } @@ -3011,21 +3013,21 @@ var file_daemon_proto_depIdxs = []int32{ 13, // 8: daemon.FullStatus.peers:type_name -> daemon.PeerState 17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState 18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState - 25, // 11: daemon.ListRoutesResponse.routes:type_name -> daemon.Route - 41, // 12: daemon.Route.resolvedIPs:type_name -> daemon.Route.ResolvedIPsEntry + 25, // 11: daemon.ListNetworksResponse.routes:type_name -> daemon.Network + 41, // 12: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry 0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel 32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State - 24, // 16: daemon.Route.ResolvedIPsEntry.value:type_name -> daemon.IPList + 24, // 16: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 1, // 17: daemon.DaemonService.Login:input_type -> daemon.LoginRequest 3, // 18: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest 5, // 19: daemon.DaemonService.Up:input_type -> daemon.UpRequest 7, // 20: daemon.DaemonService.Status:input_type -> daemon.StatusRequest 9, // 21: daemon.DaemonService.Down:input_type -> daemon.DownRequest 11, // 22: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 20, // 23: daemon.DaemonService.ListRoutes:input_type -> daemon.ListRoutesRequest - 22, // 24: daemon.DaemonService.SelectRoutes:input_type -> daemon.SelectRoutesRequest - 22, // 25: daemon.DaemonService.DeselectRoutes:input_type -> daemon.SelectRoutesRequest + 20, // 23: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 22, // 24: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 22, // 25: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest 26, // 26: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest 28, // 27: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest 30, // 28: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest @@ -3039,9 +3041,9 @@ var file_daemon_proto_depIdxs = []int32{ 8, // 36: daemon.DaemonService.Status:output_type -> daemon.StatusResponse 10, // 37: daemon.DaemonService.Down:output_type -> daemon.DownResponse 12, // 38: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 21, // 39: daemon.DaemonService.ListRoutes:output_type -> daemon.ListRoutesResponse - 23, // 40: daemon.DaemonService.SelectRoutes:output_type -> daemon.SelectRoutesResponse - 23, // 41: daemon.DaemonService.DeselectRoutes:output_type -> daemon.SelectRoutesResponse + 21, // 39: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 23, // 40: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 23, // 41: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse 27, // 42: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse 29, // 43: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse 31, // 44: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse @@ -3291,7 +3293,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[19].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListRoutesRequest); i { + switch v := v.(*ListNetworksRequest); i { case 0: return &v.state case 1: @@ -3303,7 +3305,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[20].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListRoutesResponse); i { + switch v := v.(*ListNetworksResponse); i { case 0: return &v.state case 1: @@ -3315,7 +3317,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[21].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SelectRoutesRequest); i { + switch v := v.(*SelectNetworksRequest); i { case 0: return &v.state case 1: @@ -3327,7 +3329,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[22].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SelectRoutesResponse); i { + switch v := v.(*SelectNetworksResponse); i { case 0: return &v.state case 1: @@ -3351,7 +3353,7 @@ func file_daemon_proto_init() { } } file_daemon_proto_msgTypes[24].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { + switch v := v.(*Network); i { case 0: return &v.state case 1: diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 96ade5b4e..cddf78242 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -28,14 +28,14 @@ service DaemonService { // GetConfig of the daemon. rpc GetConfig(GetConfigRequest) returns (GetConfigResponse) {} - // List available network routes - rpc ListRoutes(ListRoutesRequest) returns (ListRoutesResponse) {} + // List available networks + rpc ListNetworks(ListNetworksRequest) returns (ListNetworksResponse) {} // Select specific routes - rpc SelectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {} + rpc SelectNetworks(SelectNetworksRequest) returns (SelectNetworksResponse) {} // Deselect specific routes - rpc DeselectRoutes(SelectRoutesRequest) returns (SelectRoutesResponse) {} + rpc DeselectNetworks(SelectNetworksRequest) returns (SelectNetworksResponse) {} // DebugBundle creates a debug bundle rpc DebugBundle(DebugBundleRequest) returns (DebugBundleResponse) {} @@ -190,7 +190,7 @@ message PeerState { int64 bytesRx = 13; int64 bytesTx = 14; bool rosenpassEnabled = 15; - repeated string routes = 16; + repeated string networks = 16; google.protobuf.Duration latency = 17; string relayAddress = 18; } @@ -203,7 +203,7 @@ message LocalPeerState { string fqdn = 4; bool rosenpassEnabled = 5; bool rosenpassPermissive = 6; - repeated string routes = 7; + repeated string networks = 7; } // SignalState contains the latest state of a signal connection @@ -244,20 +244,20 @@ message FullStatus { repeated NSGroupState dns_servers = 6; } -message ListRoutesRequest { +message ListNetworksRequest { } -message ListRoutesResponse { - repeated Route routes = 1; +message ListNetworksResponse { + repeated Network routes = 1; } -message SelectRoutesRequest { - repeated string routeIDs = 1; +message SelectNetworksRequest { + repeated string networkIDs = 1; bool append = 2; bool all = 3; } -message SelectRoutesResponse { +message SelectNetworksResponse { } message IPList { @@ -265,9 +265,9 @@ message IPList { } -message Route { +message Network { string ID = 1; - string network = 2; + string range = 2; bool selected = 3; repeated string domains = 4; map resolvedIPs = 5; diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 2e063604a..39424aee9 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -31,12 +31,12 @@ type DaemonServiceClient interface { Down(ctx context.Context, in *DownRequest, opts ...grpc.CallOption) (*DownResponse, error) // GetConfig of the daemon. GetConfig(ctx context.Context, in *GetConfigRequest, opts ...grpc.CallOption) (*GetConfigResponse, error) - // List available network routes - ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) + // List available networks + ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) // Select specific routes - SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) + SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) // Deselect specific routes - DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) + DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) // DebugBundle creates a debug bundle DebugBundle(ctx context.Context, in *DebugBundleRequest, opts ...grpc.CallOption) (*DebugBundleResponse, error) // GetLogLevel gets the log level of the daemon @@ -115,27 +115,27 @@ func (c *daemonServiceClient) GetConfig(ctx context.Context, in *GetConfigReques return out, nil } -func (c *daemonServiceClient) ListRoutes(ctx context.Context, in *ListRoutesRequest, opts ...grpc.CallOption) (*ListRoutesResponse, error) { - out := new(ListRoutesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListRoutes", in, out, opts...) +func (c *daemonServiceClient) ListNetworks(ctx context.Context, in *ListNetworksRequest, opts ...grpc.CallOption) (*ListNetworksResponse, error) { + out := new(ListNetworksResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/ListNetworks", in, out, opts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) SelectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) { - out := new(SelectRoutesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectRoutes", in, out, opts...) +func (c *daemonServiceClient) SelectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { + out := new(SelectNetworksResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/SelectNetworks", in, out, opts...) if err != nil { return nil, err } return out, nil } -func (c *daemonServiceClient) DeselectRoutes(ctx context.Context, in *SelectRoutesRequest, opts ...grpc.CallOption) (*SelectRoutesResponse, error) { - out := new(SelectRoutesResponse) - err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectRoutes", in, out, opts...) +func (c *daemonServiceClient) DeselectNetworks(ctx context.Context, in *SelectNetworksRequest, opts ...grpc.CallOption) (*SelectNetworksResponse, error) { + out := new(SelectNetworksResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/DeselectNetworks", in, out, opts...) if err != nil { return nil, err } @@ -222,12 +222,12 @@ type DaemonServiceServer interface { Down(context.Context, *DownRequest) (*DownResponse, error) // GetConfig of the daemon. GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) - // List available network routes - ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) + // List available networks + ListNetworks(context.Context, *ListNetworksRequest) (*ListNetworksResponse, error) // Select specific routes - SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) + SelectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) // Deselect specific routes - DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) + DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) // DebugBundle creates a debug bundle DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) // GetLogLevel gets the log level of the daemon @@ -267,14 +267,14 @@ func (UnimplementedDaemonServiceServer) Down(context.Context, *DownRequest) (*Do func (UnimplementedDaemonServiceServer) GetConfig(context.Context, *GetConfigRequest) (*GetConfigResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetConfig not implemented") } -func (UnimplementedDaemonServiceServer) ListRoutes(context.Context, *ListRoutesRequest) (*ListRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListRoutes not implemented") +func (UnimplementedDaemonServiceServer) ListNetworks(context.Context, *ListNetworksRequest) (*ListNetworksResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method ListNetworks not implemented") } -func (UnimplementedDaemonServiceServer) SelectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SelectRoutes not implemented") +func (UnimplementedDaemonServiceServer) SelectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method SelectNetworks not implemented") } -func (UnimplementedDaemonServiceServer) DeselectRoutes(context.Context, *SelectRoutesRequest) (*SelectRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DeselectRoutes not implemented") +func (UnimplementedDaemonServiceServer) DeselectNetworks(context.Context, *SelectNetworksRequest) (*SelectNetworksResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method DeselectNetworks not implemented") } func (UnimplementedDaemonServiceServer) DebugBundle(context.Context, *DebugBundleRequest) (*DebugBundleResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method DebugBundle not implemented") @@ -418,56 +418,56 @@ func _DaemonService_GetConfig_Handler(srv interface{}, ctx context.Context, dec return interceptor(ctx, in, info, handler) } -func _DaemonService_ListRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(ListRoutesRequest) +func _DaemonService_ListNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(ListNetworksRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(DaemonServiceServer).ListRoutes(ctx, in) + return srv.(DaemonServiceServer).ListNetworks(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/ListRoutes", + FullMethod: "/daemon.DaemonService/ListNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).ListRoutes(ctx, req.(*ListRoutesRequest)) + return srv.(DaemonServiceServer).ListNetworks(ctx, req.(*ListNetworksRequest)) } return interceptor(ctx, in, info, handler) } -func _DaemonService_SelectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(SelectRoutesRequest) +func _DaemonService_SelectNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SelectNetworksRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(DaemonServiceServer).SelectRoutes(ctx, in) + return srv.(DaemonServiceServer).SelectNetworks(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/SelectRoutes", + FullMethod: "/daemon.DaemonService/SelectNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).SelectRoutes(ctx, req.(*SelectRoutesRequest)) + return srv.(DaemonServiceServer).SelectNetworks(ctx, req.(*SelectNetworksRequest)) } return interceptor(ctx, in, info, handler) } -func _DaemonService_DeselectRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(SelectRoutesRequest) +func _DaemonService_DeselectNetworks_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SelectNetworksRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(DaemonServiceServer).DeselectRoutes(ctx, in) + return srv.(DaemonServiceServer).DeselectNetworks(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/daemon.DaemonService/DeselectRoutes", + FullMethod: "/daemon.DaemonService/DeselectNetworks", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(DaemonServiceServer).DeselectRoutes(ctx, req.(*SelectRoutesRequest)) + return srv.(DaemonServiceServer).DeselectNetworks(ctx, req.(*SelectNetworksRequest)) } return interceptor(ctx, in, info, handler) } @@ -630,16 +630,16 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ Handler: _DaemonService_GetConfig_Handler, }, { - MethodName: "ListRoutes", - Handler: _DaemonService_ListRoutes_Handler, + MethodName: "ListNetworks", + Handler: _DaemonService_ListNetworks_Handler, }, { - MethodName: "SelectRoutes", - Handler: _DaemonService_SelectRoutes_Handler, + MethodName: "SelectNetworks", + Handler: _DaemonService_SelectNetworks_Handler, }, { - MethodName: "DeselectRoutes", - Handler: _DaemonService_DeselectRoutes_Handler, + MethodName: "DeselectNetworks", + Handler: _DaemonService_DeselectNetworks_Handler, }, { MethodName: "DebugBundle", diff --git a/client/server/route.go b/client/server/network.go similarity index 58% rename from client/server/route.go rename to client/server/network.go index d70e0dca3..aaf361524 100644 --- a/client/server/route.go +++ b/client/server/network.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/netip" + "slices" "sort" "golang.org/x/exp/maps" @@ -20,8 +21,8 @@ type selectRoute struct { Selected bool } -// ListRoutes returns a list of all available routes. -func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.ListRoutesResponse, error) { +// ListNetworks returns a list of all available networks. +func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*proto.ListNetworksResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -34,7 +35,7 @@ func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.L return nil, fmt.Errorf("not connected") } - routesMap := engine.GetClientRoutesWithNetID() + routesMap := engine.GetRouteManager().GetClientRoutesWithNetID() routeSelector := engine.GetRouteManager().GetRouteSelector() var routes []*selectRoute @@ -67,37 +68,47 @@ func (s *Server) ListRoutes(context.Context, *proto.ListRoutesRequest) (*proto.L }) resolvedDomains := s.statusRecorder.GetResolvedDomainsStates() - var pbRoutes []*proto.Route + var pbRoutes []*proto.Network for _, route := range routes { - pbRoute := &proto.Route{ + pbRoute := &proto.Network{ ID: string(route.NetID), - Network: route.Network.String(), + Range: route.Network.String(), Domains: route.Domains.ToSafeStringList(), ResolvedIPs: map[string]*proto.IPList{}, Selected: route.Selected, } - for _, domain := range route.Domains { - if prefixes, exists := resolvedDomains[domain]; exists { - var ipStrings []string - for _, prefix := range prefixes { - ipStrings = append(ipStrings, prefix.Addr().String()) - } - pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{ - Ips: ipStrings, + // Group resolved IPs by their parent domain + domainMap := map[domain.Domain][]string{} + + for resolvedDomain, info := range resolvedDomains { + // Check if this resolved domain's parent is in our route's domains + if slices.Contains(route.Domains, info.ParentDomain) { + ips := make([]string, 0, len(info.Prefixes)) + for _, prefix := range info.Prefixes { + ips = append(ips, prefix.Addr().String()) } + domainMap[resolvedDomain] = ips } } + + // Convert to proto format + for domain, ips := range domainMap { + pbRoute.ResolvedIPs[string(domain)] = &proto.IPList{ + Ips: ips, + } + } + pbRoutes = append(pbRoutes, pbRoute) } - return &proto.ListRoutesResponse{ + return &proto.ListNetworksResponse{ Routes: pbRoutes, }, nil } -// SelectRoutes selects specific routes based on the client request. -func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) { +// SelectNetworks selects specific networks based on the client request. +func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequest) (*proto.SelectNetworksResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -115,18 +126,19 @@ func (s *Server) SelectRoutes(_ context.Context, req *proto.SelectRoutesRequest) if req.GetAll() { routeSelector.SelectAllRoutes() } else { - routes := toNetIDs(req.GetRouteIDs()) - if err := routeSelector.SelectRoutes(routes, req.GetAppend(), maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + routes := toNetIDs(req.GetNetworkIDs()) + netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + if err := routeSelector.SelectRoutes(routes, req.GetAppend(), netIdRoutes); err != nil { return nil, fmt.Errorf("select routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) - return &proto.SelectRoutesResponse{}, nil + return &proto.SelectNetworksResponse{}, nil } -// DeselectRoutes deselects specific routes based on the client request. -func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesRequest) (*proto.SelectRoutesResponse, error) { +// DeselectNetworks deselects specific networks based on the client request. +func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRequest) (*proto.SelectNetworksResponse, error) { s.mutex.Lock() defer s.mutex.Unlock() @@ -144,14 +156,15 @@ func (s *Server) DeselectRoutes(_ context.Context, req *proto.SelectRoutesReques if req.GetAll() { routeSelector.DeselectAllRoutes() } else { - routes := toNetIDs(req.GetRouteIDs()) - if err := routeSelector.DeselectRoutes(routes, maps.Keys(engine.GetClientRoutesWithNetID())); err != nil { + routes := toNetIDs(req.GetNetworkIDs()) + netIdRoutes := maps.Keys(routeManager.GetClientRoutesWithNetID()) + if err := routeSelector.DeselectRoutes(routes, netIdRoutes); err != nil { return nil, fmt.Errorf("deselect routes: %w", err) } } - routeManager.TriggerSelection(engine.GetClientRoutes()) + routeManager.TriggerSelection(routeManager.GetClientRoutes()) - return &proto.SelectRoutesResponse{}, nil + return &proto.SelectNetworksResponse{}, nil } func toNetIDs(routes []string) []route.NetID { diff --git a/client/server/server.go b/client/server/server.go index 71eb58a66..5640ffa39 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -772,7 +772,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { pbFullStatus.LocalPeerState.Fqdn = fullStatus.LocalPeerState.FQDN pbFullStatus.LocalPeerState.RosenpassPermissive = fullStatus.RosenpassState.Permissive pbFullStatus.LocalPeerState.RosenpassEnabled = fullStatus.RosenpassState.Enabled - pbFullStatus.LocalPeerState.Routes = maps.Keys(fullStatus.LocalPeerState.Routes) + pbFullStatus.LocalPeerState.Networks = maps.Keys(fullStatus.LocalPeerState.Routes) for _, peerState := range fullStatus.Peers { pbPeerState := &proto.PeerState{ @@ -791,7 +791,7 @@ func toProtoFullStatus(fullStatus peer.FullStatus) *proto.FullStatus { BytesRx: peerState.BytesRx, BytesTx: peerState.BytesTx, RosenpassEnabled: peerState.RosenpassEnabled, - Routes: maps.Keys(peerState.GetRoutes()), + Networks: maps.Keys(peerState.GetRoutes()), Latency: durationpb.New(peerState.Latency), } pbFullStatus.Peers = append(pbFullStatus.Peers, pbPeerState) diff --git a/client/server/server_test.go b/client/server/server_test.go index 61bdaf660..128de8e02 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -20,6 +20,8 @@ import ( mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/signal/proto" signalServer "github.com/netbirdio/netbird/signal/server" @@ -110,7 +112,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve return nil, "", err } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanUp, err := server.NewTestStoreFromSQL(context.Background(), "", config.Datadir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", config.Datadir) if err != nil { return nil, "", err } @@ -132,7 +134,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { return nil, "", err } diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 8ca0db73f..49b0f53cf 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -58,7 +58,7 @@ func main() { var showSettings bool flag.BoolVar(&showSettings, "settings", false, "run settings windows") var showRoutes bool - flag.BoolVar(&showRoutes, "routes", false, "run routes windows") + flag.BoolVar(&showRoutes, "networks", false, "run networks windows") var errorMSG string flag.StringVar(&errorMSG, "error-msg", "", "displays a error message window") @@ -233,7 +233,7 @@ func newServiceClient(addr string, a fyne.App, showSettings bool, showRoutes boo s.showSettingsUI() return s } else if showRoutes { - s.showRoutesUI() + s.showNetworksUI() } return s @@ -549,7 +549,7 @@ func (s *serviceClient) onTrayReady() { s.mAdvancedSettings = s.mSettings.AddSubMenuItem("Advanced Settings", "Advanced settings of the application") s.loadSettings() - s.mRoutes = systray.AddMenuItem("Network Routes", "Open the routes management window") + s.mRoutes = systray.AddMenuItem("Networks", "Open the networks management window") s.mRoutes.Disable() systray.AddSeparator() @@ -657,7 +657,7 @@ func (s *serviceClient) onTrayReady() { s.mRoutes.Disable() go func() { defer s.mRoutes.Enable() - s.runSelfCommand("routes", "true") + s.runSelfCommand("networks", "true") }() } if err != nil { diff --git a/client/ui/route.go b/client/ui/network.go similarity index 56% rename from client/ui/route.go rename to client/ui/network.go index 5b6b8fee0..e6f027f0e 100644 --- a/client/ui/route.go +++ b/client/ui/network.go @@ -19,32 +19,32 @@ import ( ) const ( - allRoutesText = "All routes" - overlappingRoutesText = "Overlapping routes" - exitNodeRoutesText = "Exit-node routes" - allRoutes filter = "all" - overlappingRoutes filter = "overlapping" - exitNodeRoutes filter = "exit-node" - getClientFMT = "get client: %v" + allNetworksText = "All networks" + overlappingNetworksText = "Overlapping networks" + exitNodeNetworksText = "Exit-node networks" + allNetworks filter = "all" + overlappingNetworks filter = "overlapping" + exitNodeNetworks filter = "exit-node" + getClientFMT = "get client: %v" ) type filter string -func (s *serviceClient) showRoutesUI() { - s.wRoutes = s.app.NewWindow("NetBird Routes") +func (s *serviceClient) showNetworksUI() { + s.wRoutes = s.app.NewWindow("Networks") allGrid := container.New(layout.NewGridLayout(3)) - go s.updateRoutes(allGrid, allRoutes) + go s.updateNetworks(allGrid, allNetworks) overlappingGrid := container.New(layout.NewGridLayout(3)) exitNodeGrid := container.New(layout.NewGridLayout(3)) routeCheckContainer := container.NewVBox() tabs := container.NewAppTabs( - container.NewTabItem(allRoutesText, allGrid), - container.NewTabItem(overlappingRoutesText, overlappingGrid), - container.NewTabItem(exitNodeRoutesText, exitNodeGrid), + container.NewTabItem(allNetworksText, allGrid), + container.NewTabItem(overlappingNetworksText, overlappingGrid), + container.NewTabItem(exitNodeNetworksText, exitNodeGrid), ) tabs.OnSelected = func(item *container.TabItem) { - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) } tabs.OnUnselected = func(item *container.TabItem) { grid, _ := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) @@ -58,17 +58,17 @@ func (s *serviceClient) showRoutesUI() { buttonBox := container.NewHBox( layout.NewSpacer(), widget.NewButton("Refresh", func() { - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), widget.NewButton("Select all", func() { _, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) - s.selectAllFilteredRoutes(f) - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.selectAllFilteredNetworks(f) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), widget.NewButton("Deselect All", func() { _, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodeGrid) - s.deselectAllFilteredRoutes(f) - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) + s.deselectAllFilteredNetworks(f) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodeGrid) }), layout.NewSpacer(), ) @@ -81,36 +81,36 @@ func (s *serviceClient) showRoutesUI() { s.startAutoRefresh(10*time.Second, tabs, allGrid, overlappingGrid, exitNodeGrid) } -func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) { +func (s *serviceClient) updateNetworks(grid *fyne.Container, f filter) { grid.Objects = nil grid.Refresh() idHeader := widget.NewLabelWithStyle(" ID", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) - networkHeader := widget.NewLabelWithStyle("Network/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) + networkHeader := widget.NewLabelWithStyle("Range/Domains", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) resolvedIPsHeader := widget.NewLabelWithStyle("Resolved IPs", fyne.TextAlignLeading, fyne.TextStyle{Bold: true}) grid.Add(idHeader) grid.Add(networkHeader) grid.Add(resolvedIPsHeader) - filteredRoutes, err := s.getFilteredRoutes(f) + filteredRoutes, err := s.getFilteredNetworks(f) if err != nil { return } - sortRoutesByIDs(filteredRoutes) + sortNetworksByIDs(filteredRoutes) for _, route := range filteredRoutes { r := route checkBox := widget.NewCheck(r.GetID(), func(checked bool) { - s.selectRoute(r.ID, checked) + s.selectNetwork(r.ID, checked) }) checkBox.Checked = route.Selected checkBox.Resize(fyne.NewSize(20, 20)) checkBox.Refresh() grid.Add(checkBox) - network := r.GetNetwork() + network := r.GetRange() domains := r.GetDomains() if len(domains) == 0 { @@ -129,10 +129,8 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) { grid.Add(domainsSelector) var resolvedIPsList []string - for _, domain := range domains { - if ipList, exists := r.GetResolvedIPs()[domain]; exists { - resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", "))) - } + for domain, ipList := range r.GetResolvedIPs() { + resolvedIPsList = append(resolvedIPsList, fmt.Sprintf("%s: %s", domain, strings.Join(ipList.GetIps(), ", "))) } if len(resolvedIPsList) == 0 { @@ -151,35 +149,35 @@ func (s *serviceClient) updateRoutes(grid *fyne.Container, f filter) { grid.Refresh() } -func (s *serviceClient) getFilteredRoutes(f filter) ([]*proto.Route, error) { - routes, err := s.fetchRoutes() +func (s *serviceClient) getFilteredNetworks(f filter) ([]*proto.Network, error) { + routes, err := s.fetchNetworks() if err != nil { log.Errorf(getClientFMT, err) s.showError(fmt.Errorf(getClientFMT, err)) return nil, err } switch f { - case overlappingRoutes: - return getOverlappingRoutes(routes), nil - case exitNodeRoutes: - return getExitNodeRoutes(routes), nil + case overlappingNetworks: + return getOverlappingNetworks(routes), nil + case exitNodeNetworks: + return getExitNodeNetworks(routes), nil default: } return routes, nil } -func getOverlappingRoutes(routes []*proto.Route) []*proto.Route { - var filteredRoutes []*proto.Route - existingRange := make(map[string][]*proto.Route) +func getOverlappingNetworks(routes []*proto.Network) []*proto.Network { + var filteredRoutes []*proto.Network + existingRange := make(map[string][]*proto.Network) for _, route := range routes { if len(route.Domains) > 0 { continue } - if r, exists := existingRange[route.GetNetwork()]; exists { + if r, exists := existingRange[route.GetRange()]; exists { r = append(r, route) - existingRange[route.GetNetwork()] = r + existingRange[route.GetRange()] = r } else { - existingRange[route.GetNetwork()] = []*proto.Route{route} + existingRange[route.GetRange()] = []*proto.Network{route} } } for _, r := range existingRange { @@ -190,29 +188,29 @@ func getOverlappingRoutes(routes []*proto.Route) []*proto.Route { return filteredRoutes } -func getExitNodeRoutes(routes []*proto.Route) []*proto.Route { - var filteredRoutes []*proto.Route +func getExitNodeNetworks(routes []*proto.Network) []*proto.Network { + var filteredRoutes []*proto.Network for _, route := range routes { - if route.Network == "0.0.0.0/0" { + if route.Range == "0.0.0.0/0" { filteredRoutes = append(filteredRoutes, route) } } return filteredRoutes } -func sortRoutesByIDs(routes []*proto.Route) { +func sortNetworksByIDs(routes []*proto.Network) { sort.Slice(routes, func(i, j int) bool { return strings.ToLower(routes[i].GetID()) < strings.ToLower(routes[j].GetID()) }) } -func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) { +func (s *serviceClient) fetchNetworks() ([]*proto.Network, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { return nil, fmt.Errorf(getClientFMT, err) } - resp, err := conn.ListRoutes(s.ctx, &proto.ListRoutesRequest{}) + resp, err := conn.ListNetworks(s.ctx, &proto.ListNetworksRequest{}) if err != nil { return nil, fmt.Errorf("failed to list routes: %v", err) } @@ -220,7 +218,7 @@ func (s *serviceClient) fetchRoutes() ([]*proto.Route, error) { return resp.Routes, nil } -func (s *serviceClient) selectRoute(id string, checked bool) { +func (s *serviceClient) selectNetwork(id string, checked bool) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf(getClientFMT, err) @@ -228,73 +226,73 @@ func (s *serviceClient) selectRoute(id string, checked bool) { return } - req := &proto.SelectRoutesRequest{ - RouteIDs: []string{id}, - Append: checked, + req := &proto.SelectNetworksRequest{ + NetworkIDs: []string{id}, + Append: checked, } if checked { - if _, err := conn.SelectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to select route: %v", err) - s.showError(fmt.Errorf("failed to select route: %v", err)) + if _, err := conn.SelectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to select network: %v", err) + s.showError(fmt.Errorf("failed to select network: %v", err)) return } log.Infof("Route %s selected", id) } else { - if _, err := conn.DeselectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to deselect route: %v", err) - s.showError(fmt.Errorf("failed to deselect route: %v", err)) + if _, err := conn.DeselectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to deselect network: %v", err) + s.showError(fmt.Errorf("failed to deselect network: %v", err)) return } - log.Infof("Route %s deselected", id) + log.Infof("Network %s deselected", id) } } -func (s *serviceClient) selectAllFilteredRoutes(f filter) { +func (s *serviceClient) selectAllFilteredNetworks(f filter) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf(getClientFMT, err) return } - req := s.getRoutesRequest(f, true) - if _, err := conn.SelectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to select all routes: %v", err) - s.showError(fmt.Errorf("failed to select all routes: %v", err)) + req := s.getNetworksRequest(f, true) + if _, err := conn.SelectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to select all networks: %v", err) + s.showError(fmt.Errorf("failed to select all networks: %v", err)) return } - log.Debug("All routes selected") + log.Debug("All networks selected") } -func (s *serviceClient) deselectAllFilteredRoutes(f filter) { +func (s *serviceClient) deselectAllFilteredNetworks(f filter) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { log.Errorf(getClientFMT, err) return } - req := s.getRoutesRequest(f, false) - if _, err := conn.DeselectRoutes(s.ctx, req); err != nil { - log.Errorf("failed to deselect all routes: %v", err) - s.showError(fmt.Errorf("failed to deselect all routes: %v", err)) + req := s.getNetworksRequest(f, false) + if _, err := conn.DeselectNetworks(s.ctx, req); err != nil { + log.Errorf("failed to deselect all networks: %v", err) + s.showError(fmt.Errorf("failed to deselect all networks: %v", err)) return } - log.Debug("All routes deselected") + log.Debug("All networks deselected") } -func (s *serviceClient) getRoutesRequest(f filter, appendRoute bool) *proto.SelectRoutesRequest { - req := &proto.SelectRoutesRequest{} - if f == allRoutes { +func (s *serviceClient) getNetworksRequest(f filter, appendRoute bool) *proto.SelectNetworksRequest { + req := &proto.SelectNetworksRequest{} + if f == allNetworks { req.All = true } else { - routes, err := s.getFilteredRoutes(f) + routes, err := s.getFilteredNetworks(f) if err != nil { return nil } for _, route := range routes { - req.RouteIDs = append(req.RouteIDs, route.GetID()) + req.NetworkIDs = append(req.NetworkIDs, route.GetID()) } req.Append = appendRoute } @@ -311,7 +309,7 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container ticker := time.NewTicker(interval) go func() { for range ticker.C { - s.updateRoutesBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid) + s.updateNetworksBasedOnDisplayTab(tabs, allGrid, overlappingGrid, exitNodesGrid) } }() @@ -320,20 +318,20 @@ func (s *serviceClient) startAutoRefresh(interval time.Duration, tabs *container }) } -func (s *serviceClient) updateRoutesBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { +func (s *serviceClient) updateNetworksBasedOnDisplayTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) { grid, f := getGridAndFilterFromTab(tabs, allGrid, overlappingGrid, exitNodesGrid) s.wRoutes.Content().Refresh() - s.updateRoutes(grid, f) + s.updateNetworks(grid, f) } func getGridAndFilterFromTab(tabs *container.AppTabs, allGrid, overlappingGrid, exitNodesGrid *fyne.Container) (*fyne.Container, filter) { switch tabs.Selected().Text { - case overlappingRoutesText: - return overlappingGrid, overlappingRoutes - case exitNodeRoutesText: - return exitNodesGrid, exitNodeRoutes + case overlappingNetworksText: + return overlappingGrid, overlappingNetworks + case exitNodeNetworksText: + return exitNodesGrid, exitNodeNetworks default: - return allGrid, allRoutes + return allGrid, allNetworks } } diff --git a/dns/dns.go b/dns/dns.go index 18528c743..8dfdf8526 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -108,3 +108,9 @@ func GetParsedDomainLabel(name string) (string, error) { return validHost, nil } + +// NormalizeZone returns a normalized domain name without the wildcard prefix +func NormalizeZone(domain string) string { + d, _ := strings.CutPrefix(domain, "*.") + return d +} diff --git a/go.mod b/go.mod index c504925d2..d48280df0 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 + github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 @@ -207,6 +207,7 @@ require ( github.com/spf13/cast v1.5.0 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect github.com/srwiley/rasterx v0.0.0-20220730225603-2ab79fcdd4ef // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/go-sysconf v0.3.14 // indirect github.com/tklauser/numcpus v0.8.0 // indirect github.com/vishvananda/netns v0.0.4 // indirect diff --git a/go.sum b/go.sum index 04d2bc59a..540cbf20b 100644 --- a/go.sum +++ b/go.sum @@ -521,8 +521,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254 h1:L8mNd3tBxMdnQNxMNJ+/EiwHwizNOMy8/nHLVGNfjpg= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241106153857-de8e2beb5254/go.mod h1:nykwWZnxb+sJz2Z//CEq45CMRWSHllH8pODKRB8eY7Y= +github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 h1:M+UPn/o+plVE7ZehgL6/1dftptsO1tyTPssgImgi+28= +github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480/go.mod h1:RC0PnyATSBPrRWKQgb+7KcC1tMta9eYyzuA414RG9wQ= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= @@ -662,6 +662,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v0.0.0-20151208002404-e3a8ff8ce365/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= diff --git a/management/client/client_test.go b/management/client/client_test.go index 100b3fcaa..8bd8af8d2 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -11,6 +11,8 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/client/system" @@ -57,7 +59,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { t.Fatal(err) } s := grpc.NewServer() - store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir()) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../server/testdata/store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -76,7 +78,7 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { } secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) if err != nil { t.Fatal(err) } diff --git a/management/cmd/management.go b/management/cmd/management.go index bfa158c5b..4f34009b7 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -41,12 +41,20 @@ import ( "github.com/netbirdio/netbird/management/server" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/groups" httpapi "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/metrics" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/version" ) @@ -150,7 +158,7 @@ var ( if err != nil { return err } - store, err := server.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics) + store, err := store.NewStore(ctx, config.StoreConfig.Engine, config.Datadir, appMetrics) if err != nil { return fmt.Errorf("failed creating Store: %s: %v", config.Datadir, err) } @@ -265,7 +273,15 @@ var ( KeysLocation: config.HttpConfig.AuthKeysLocation, } - httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) + userManager := users.NewManager(store) + settingsManager := settings.NewManager(store) + permissionsManager := permissions.NewManager(userManager, settingsManager) + groupsManager := groups.NewManager(store, permissionsManager, accountManager) + resourcesManager := resources.NewManager(store, permissionsManager, groupsManager, accountManager) + routersManager := routers.NewManager(store, permissionsManager, accountManager) + networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager) + + httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } @@ -274,7 +290,7 @@ var ( ephemeralManager.LoadInitialPeers(ctx) gRPCAPIHandler := grpc.NewServer(gRPCOpts...) - srv, err := server.NewServer(ctx, config, accountManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager) + srv, err := server.NewServer(ctx, config, accountManager, settingsManager, peersUpdateManager, secretsManager, appMetrics, ephemeralManager) if err != nil { return fmt.Errorf("failed creating gRPC API handler: %v", err) } @@ -400,7 +416,7 @@ func notifyStop(ctx context.Context, msg string) { } } -func getInstallationID(ctx context.Context, store server.Store) (string, error) { +func getInstallationID(ctx context.Context, store store.Store) (string, error) { installationID := store.GetInstallationID() if installationID != "" { return installationID, nil diff --git a/management/cmd/migration_up.go b/management/cmd/migration_up.go index 7aa11f0c9..183fc554d 100644 --- a/management/cmd/migration_up.go +++ b/management/cmd/migration_up.go @@ -9,7 +9,7 @@ import ( "github.com/spf13/cobra" "github.com/netbirdio/netbird/formatter" - "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/util" ) @@ -32,7 +32,7 @@ var upCmd = &cobra.Command{ //nolint ctx := context.WithValue(cmd.Context(), formatter.ExecutionContextKey, formatter.SystemSource) - if err := server.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil { + if err := store.MigrateFileStoreToSqlite(ctx, mgmtDataDir); err != nil { return err } log.WithContext(ctx).Info("Migration finished successfully") diff --git a/management/proto/management.pb.go b/management/proto/management.pb.go index 672b2a102..b4ff16e6d 100644 --- a/management/proto/management.pb.go +++ b/management/proto/management.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v4.23.4 +// protoc v4.24.3 // source: management.proto package proto @@ -29,6 +29,7 @@ const ( RuleProtocol_TCP RuleProtocol = 2 RuleProtocol_UDP RuleProtocol = 3 RuleProtocol_ICMP RuleProtocol = 4 + RuleProtocol_CUSTOM RuleProtocol = 5 ) // Enum value maps for RuleProtocol. @@ -39,6 +40,7 @@ var ( 2: "TCP", 3: "UDP", 4: "ICMP", + 5: "CUSTOM", } RuleProtocol_value = map[string]int32{ "UNKNOWN": 0, @@ -46,6 +48,7 @@ var ( "TCP": 2, "UDP": 3, "ICMP": 4, + "CUSTOM": 5, } ) @@ -1393,7 +1396,8 @@ type PeerConfig struct { // SSHConfig of the peer. SshConfig *SSHConfig `protobuf:"bytes,3,opt,name=sshConfig,proto3" json:"sshConfig,omitempty"` // Peer fully qualified domain name - Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` + Fqdn string `protobuf:"bytes,4,opt,name=fqdn,proto3" json:"fqdn,omitempty"` + RoutingPeerDnsResolutionEnabled bool `protobuf:"varint,5,opt,name=RoutingPeerDnsResolutionEnabled,proto3" json:"RoutingPeerDnsResolutionEnabled,omitempty"` } func (x *PeerConfig) Reset() { @@ -1456,6 +1460,13 @@ func (x *PeerConfig) GetFqdn() string { return "" } +func (x *PeerConfig) GetRoutingPeerDnsResolutionEnabled() bool { + if x != nil { + return x.RoutingPeerDnsResolutionEnabled + } + return false +} + // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections type NetworkMap struct { state protoimpl.MessageState @@ -2780,6 +2791,10 @@ type RouteFirewallRule struct { PortInfo *PortInfo `protobuf:"bytes,5,opt,name=portInfo,proto3" json:"portInfo,omitempty"` // IsDynamic indicates if the route is a DNS route. IsDynamic bool `protobuf:"varint,6,opt,name=isDynamic,proto3" json:"isDynamic,omitempty"` + // Domains is a list of domains for which the rule is applicable. + Domains []string `protobuf:"bytes,7,rep,name=domains,proto3" json:"domains,omitempty"` + // CustomProtocol is a custom protocol ID. + CustomProtocol uint32 `protobuf:"varint,8,opt,name=customProtocol,proto3" json:"customProtocol,omitempty"` } func (x *RouteFirewallRule) Reset() { @@ -2856,6 +2871,20 @@ func (x *RouteFirewallRule) GetIsDynamic() bool { return false } +func (x *RouteFirewallRule) GetDomains() []string { + if x != nil { + return x.Domains + } + return nil +} + +func (x *RouteFirewallRule) GetCustomProtocol() uint32 { + if x != nil { + return x.CustomProtocol + } + return 0 +} + type PortInfo_Range struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -3075,7 +3104,7 @@ var file_management_proto_rawDesc = []byte{ 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, - 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x81, 0x01, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, + 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0xcb, 0x01, 0x0a, 0x0a, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x64, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x64, @@ -3083,250 +3112,260 @@ var file_management_proto_rawDesc = []byte{ 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0xf3, 0x04, 0x0a, 0x0a, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, - 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, - 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, - 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, - 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, - 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, - 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, - 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, - 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, - 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, - 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, - 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, - 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, - 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, - 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, - 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, - 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, - 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, - 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, - 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, + 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x12, 0x48, 0x0a, 0x1f, 0x52, + 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, + 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x1f, 0x52, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x50, 0x65, 0x65, + 0x72, 0x44, 0x6e, 0x73, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x75, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, + 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0xf3, 0x04, 0x0a, 0x0a, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x4d, 0x61, 0x70, 0x12, 0x16, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x53, 0x65, 0x72, 0x69, 0x61, 0x6c, 0x12, 0x36, 0x0a, 0x0a, + 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x65, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0a, 0x70, 0x65, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x3e, 0x0a, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, + 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0b, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x73, 0x12, 0x2e, 0x0a, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x12, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x49, 0x73, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x12, 0x29, 0x0a, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x05, + 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, + 0x33, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x06, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x40, 0x0a, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, 0x65, 0x50, + 0x65, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, + 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0c, 0x6f, 0x66, 0x66, 0x6c, 0x69, 0x6e, + 0x65, 0x50, 0x65, 0x65, 0x72, 0x73, 0x12, 0x3e, 0x0a, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x0d, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, + 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x09, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, + 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x12, 0x4f, 0x0a, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, - 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, - 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, - 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, - 0x79, 0x22, 0x97, 0x01, 0x0a, 0x10, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, - 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, - 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, - 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, - 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, - 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, - 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, - 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, - 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, - 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, - 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, - 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, - 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, - 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, - 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, - 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, - 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, - 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, - 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, - 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, - 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, - 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, - 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, - 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, - 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, - 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, - 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, - 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, - 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, - 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, - 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, - 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, - 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, - 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, - 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, - 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, - 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, - 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, - 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, - 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, - 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, - 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, - 0x55, 0x52, 0x4c, 0x73, 0x22, 0xed, 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, - 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, - 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, - 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, - 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, - 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, - 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, - 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, - 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, - 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, - 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x22, 0xb4, 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, - 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, - 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, - 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, - 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, - 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, - 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, - 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, - 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, - 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, - 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, - 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, - 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, - 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, - 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, - 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, - 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, - 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, - 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, - 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, - 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, - 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, - 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xd9, 0x01, 0x0a, 0x0c, - 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, - 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, - 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, - 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, - 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, - 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, - 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, - 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, - 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, - 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, - 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, - 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, - 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, - 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, - 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, - 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, - 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, - 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, - 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, - 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x8f, 0x02, 0x0a, 0x11, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, - 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, - 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, - 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, - 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, - 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, - 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, - 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, - 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x2a, 0x40, 0x0a, 0x0c, - 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, - 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, - 0x10, 0x01, 0x12, 0x07, 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, - 0x44, 0x50, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x2a, 0x20, - 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, - 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, - 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, - 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, - 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, - 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, - 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, - 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, - 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x52, 0x13, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x73, 0x12, 0x3e, 0x0a, 0x1a, 0x72, + 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, + 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x1a, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, + 0x75, 0x6c, 0x65, 0x73, 0x49, 0x73, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x97, 0x01, 0x0a, 0x10, + 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x50, 0x65, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, + 0x12, 0x1a, 0x0a, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x08, 0x77, 0x67, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x1e, 0x0a, 0x0a, + 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0a, 0x61, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x49, 0x70, 0x73, 0x12, 0x33, 0x0a, 0x09, + 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x15, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x53, 0x48, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x09, 0x73, 0x73, 0x68, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x04, 0x66, 0x71, 0x64, 0x6e, 0x22, 0x49, 0x0a, 0x09, 0x53, 0x53, 0x48, 0x43, 0x6f, 0x6e, 0x66, + 0x69, 0x67, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x73, 0x68, 0x45, 0x6e, 0x61, 0x62, 0x6c, + 0x65, 0x64, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x73, 0x68, 0x50, 0x75, 0x62, 0x4b, 0x65, 0x79, + 0x22, 0x20, 0x0a, 0x1e, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, + 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x22, 0xbf, 0x01, 0x0a, 0x17, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, + 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x48, + 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, + 0x32, 0x2c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x44, 0x65, + 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x2e, 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x52, 0x08, + 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x42, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, + 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x0e, 0x50, 0x72, + 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x22, 0x16, 0x0a, 0x08, + 0x70, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x12, 0x0a, 0x0a, 0x06, 0x48, 0x4f, 0x53, 0x54, + 0x45, 0x44, 0x10, 0x00, 0x22, 0x1e, 0x0a, 0x1c, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x22, 0x5b, 0x0a, 0x15, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x42, 0x0a, + 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x52, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, 0x6e, 0x66, 0x69, + 0x67, 0x22, 0xea, 0x02, 0x0a, 0x0e, 0x50, 0x72, 0x6f, 0x76, 0x69, 0x64, 0x65, 0x72, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x12, 0x1a, 0x0a, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x49, 0x44, + 0x12, 0x22, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x63, 0x72, 0x65, 0x74, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x65, + 0x63, 0x72, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, + 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x41, 0x75, 0x64, 0x69, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x2e, 0x0a, 0x12, 0x44, 0x65, 0x76, 0x69, + 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x12, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, + 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x54, 0x6f, 0x6b, 0x65, + 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x0d, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x14, + 0x0a, 0x05, 0x53, 0x63, 0x6f, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x53, + 0x63, 0x6f, 0x70, 0x65, 0x12, 0x1e, 0x0a, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, 0x6f, 0x6b, + 0x65, 0x6e, 0x18, 0x08, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x55, 0x73, 0x65, 0x49, 0x44, 0x54, + 0x6f, 0x6b, 0x65, 0x6e, 0x12, 0x34, 0x0a, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x15, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x12, 0x22, 0x0a, 0x0c, 0x52, 0x65, + 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0c, 0x52, 0x65, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x55, 0x52, 0x4c, 0x73, 0x22, 0xed, + 0x01, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x12, 0x20, 0x0a, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x54, 0x79, 0x70, + 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0b, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x54, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x65, 0x65, 0x72, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x04, 0x50, 0x65, 0x65, 0x72, 0x12, 0x16, 0x0a, 0x06, 0x4d, 0x65, 0x74, 0x72, + 0x69, 0x63, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4d, 0x65, 0x74, 0x72, 0x69, 0x63, + 0x12, 0x1e, 0x0a, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, 0x18, 0x06, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x4d, 0x61, 0x73, 0x71, 0x75, 0x65, 0x72, 0x61, 0x64, 0x65, + 0x12, 0x14, 0x0a, 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x05, 0x4e, 0x65, 0x74, 0x49, 0x44, 0x12, 0x18, 0x0a, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x18, 0x08, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, + 0x12, 0x1c, 0x0a, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x09, 0x6b, 0x65, 0x65, 0x70, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x22, 0xb4, + 0x01, 0x0a, 0x09, 0x44, 0x4e, 0x53, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x24, 0x0a, 0x0d, + 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0d, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x45, 0x6e, 0x61, 0x62, + 0x6c, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1b, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x52, 0x10, 0x4e, 0x61, 0x6d, 0x65, 0x53, + 0x65, 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x73, 0x12, 0x38, 0x0a, 0x0b, 0x43, + 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x0b, + 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x43, 0x75, + 0x73, 0x74, 0x6f, 0x6d, 0x5a, 0x6f, 0x6e, 0x65, 0x52, 0x0b, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, + 0x5a, 0x6f, 0x6e, 0x65, 0x73, 0x22, 0x58, 0x0a, 0x0a, 0x43, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x5a, + 0x6f, 0x6e, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x06, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x12, 0x32, 0x0a, 0x07, 0x52, + 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, + 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x07, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x73, 0x22, + 0x74, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x12, + 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x4e, + 0x61, 0x6d, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x54, 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x04, 0x54, 0x79, 0x70, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x43, 0x6c, 0x61, 0x73, 0x73, 0x12, 0x10, 0x0a, + 0x03, 0x54, 0x54, 0x4c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x03, 0x54, 0x54, 0x4c, 0x12, + 0x14, 0x0a, 0x05, 0x52, 0x44, 0x61, 0x74, 0x61, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, + 0x52, 0x44, 0x61, 0x74, 0x61, 0x22, 0xb3, 0x01, 0x0a, 0x0f, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, + 0x72, 0x76, 0x65, 0x72, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x38, 0x0a, 0x0b, 0x4e, 0x61, 0x6d, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x16, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x4e, 0x61, 0x6d, 0x65, + 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x4e, 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x50, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x18, 0x0a, + 0x07, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x03, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, + 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x12, 0x32, 0x0a, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, + 0x68, 0x44, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, + 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x14, 0x53, 0x65, 0x61, 0x72, 0x63, 0x68, 0x44, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x48, 0x0a, 0x0a, 0x4e, + 0x61, 0x6d, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x4e, 0x53, 0x54, + 0x79, 0x70, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x4e, 0x53, 0x54, 0x79, 0x70, + 0x65, 0x12, 0x12, 0x0a, 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x04, 0x50, 0x6f, 0x72, 0x74, 0x22, 0xd9, 0x01, 0x0a, 0x0c, 0x46, 0x69, 0x72, 0x65, 0x77, 0x61, + 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x50, 0x65, 0x65, 0x72, 0x49, 0x50, 0x12, 0x37, + 0x0a, 0x09, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0e, 0x32, 0x19, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, + 0x75, 0x6c, 0x65, 0x44, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x09, 0x44, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x2e, 0x0a, 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, + 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, + 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x06, 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x34, 0x0a, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, + 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x12, 0x0a, + 0x04, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x50, 0x6f, 0x72, + 0x74, 0x22, 0x38, 0x0a, 0x0e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x41, 0x64, 0x64, 0x72, + 0x65, 0x73, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x05, 0x6e, 0x65, 0x74, 0x49, 0x50, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x61, 0x63, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x61, 0x63, 0x22, 0x1e, 0x0a, 0x06, 0x43, + 0x68, 0x65, 0x63, 0x6b, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x18, 0x01, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x05, 0x46, 0x69, 0x6c, 0x65, 0x73, 0x22, 0x96, 0x01, 0x0a, 0x08, + 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x14, 0x0a, 0x04, 0x70, 0x6f, 0x72, 0x74, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x00, 0x52, 0x04, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x32, + 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, + 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, + 0x6e, 0x66, 0x6f, 0x2e, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x48, 0x00, 0x52, 0x05, 0x72, 0x61, 0x6e, + 0x67, 0x65, 0x1a, 0x2f, 0x0a, 0x05, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x73, + 0x74, 0x61, 0x72, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x73, 0x74, 0x61, 0x72, + 0x74, 0x12, 0x10, 0x0a, 0x03, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x03, + 0x65, 0x6e, 0x64, 0x42, 0x0f, 0x0a, 0x0d, 0x70, 0x6f, 0x72, 0x74, 0x53, 0x65, 0x6c, 0x65, 0x63, + 0x74, 0x69, 0x6f, 0x6e, 0x22, 0xd1, 0x02, 0x0a, 0x11, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x46, 0x69, + 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x52, 0x75, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x73, 0x6f, + 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, + 0x52, 0x0c, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x52, 0x61, 0x6e, 0x67, 0x65, 0x73, 0x12, 0x2e, + 0x0a, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x16, + 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x52, 0x75, 0x6c, 0x65, + 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x06, 0x61, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x20, + 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, + 0x12, 0x34, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x0e, 0x32, 0x18, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x52, 0x75, 0x6c, 0x65, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x52, 0x08, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x30, 0x0a, 0x08, 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, + 0x66, 0x6f, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x50, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x08, + 0x70, 0x6f, 0x72, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x1c, 0x0a, 0x09, 0x69, 0x73, 0x44, 0x79, + 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x44, + 0x79, 0x6e, 0x61, 0x6d, 0x69, 0x63, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, + 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, 0x73, + 0x12, 0x26, 0x0a, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x08, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0e, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2a, 0x4c, 0x0a, 0x0c, 0x52, 0x75, 0x6c, 0x65, + 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, + 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x07, 0x0a, 0x03, 0x41, 0x4c, 0x4c, 0x10, 0x01, 0x12, 0x07, + 0x0a, 0x03, 0x54, 0x43, 0x50, 0x10, 0x02, 0x12, 0x07, 0x0a, 0x03, 0x55, 0x44, 0x50, 0x10, 0x03, + 0x12, 0x08, 0x0a, 0x04, 0x49, 0x43, 0x4d, 0x50, 0x10, 0x04, 0x12, 0x0a, 0x0a, 0x06, 0x43, 0x55, + 0x53, 0x54, 0x4f, 0x4d, 0x10, 0x05, 0x2a, 0x20, 0x0a, 0x0d, 0x52, 0x75, 0x6c, 0x65, 0x44, 0x69, + 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x06, 0x0a, 0x02, 0x49, 0x4e, 0x10, 0x00, 0x12, + 0x07, 0x0a, 0x03, 0x4f, 0x55, 0x54, 0x10, 0x01, 0x2a, 0x22, 0x0a, 0x0a, 0x52, 0x75, 0x6c, 0x65, + 0x41, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x0a, 0x0a, 0x06, 0x41, 0x43, 0x43, 0x45, 0x50, 0x54, + 0x10, 0x00, 0x12, 0x08, 0x0a, 0x04, 0x44, 0x52, 0x4f, 0x50, 0x10, 0x01, 0x32, 0x90, 0x04, 0x0a, + 0x11, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, + 0x63, 0x65, 0x12, 0x45, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1c, 0x2e, 0x6d, 0x61, + 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, + 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, - 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, - 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, - 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, - 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, - 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, - 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, - 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, - 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, - 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, - 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, - 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, - 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, - 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, - 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, - 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, - 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, - 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, - 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, + 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x46, 0x0a, 0x04, 0x53, 0x79, 0x6e, + 0x63, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, + 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, + 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x30, + 0x01, 0x12, 0x42, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, + 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, + 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x1d, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x2e, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x09, 0x69, 0x73, 0x48, 0x65, 0x61, 0x6c, 0x74, + 0x68, 0x79, 0x12, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6d, 0x70, 0x74, 0x79, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x12, 0x5a, 0x0a, 0x1a, 0x47, 0x65, + 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, - 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, - 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, - 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, + 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, + 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, 0x12, 0x58, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x50, 0x4b, 0x43, + 0x45, 0x41, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x46, 0x6c, + 0x6f, 0x77, 0x12, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, + 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x1a, 0x1c, 0x2e, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, + 0x63, 0x72, 0x79, 0x70, 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x22, 0x00, + 0x12, 0x3d, 0x0a, 0x08, 0x53, 0x79, 0x6e, 0x63, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x1c, 0x2e, 0x6d, + 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6e, 0x63, 0x72, 0x79, 0x70, + 0x74, 0x65, 0x64, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x1a, 0x11, 0x2e, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x2e, 0x45, 0x6d, 0x70, 0x74, 0x79, 0x22, 0x00, 0x42, + 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x33, } var ( diff --git a/management/proto/management.proto b/management/proto/management.proto index fe6a828b1..5f4e0df46 100644 --- a/management/proto/management.proto +++ b/management/proto/management.proto @@ -222,6 +222,8 @@ message PeerConfig { SSHConfig sshConfig = 3; // Peer fully qualified domain name string fqdn = 4; + + bool RoutingPeerDnsResolutionEnabled = 5; } // NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections @@ -396,6 +398,7 @@ enum RuleProtocol { TCP = 2; UDP = 3; ICMP = 4; + CUSTOM = 5; } enum RuleDirection { @@ -459,5 +462,11 @@ message RouteFirewallRule { // IsDynamic indicates if the route is a DNS route. bool isDynamic = 6; + + // Domains is a list of domains for which the rule is applicable. + repeated string domains = 7; + + // CustomProtocol is a custom protocol ID. + uint32 customProtocol = 8; } diff --git a/management/server/account.go b/management/server/account.go index fbe6fcc1a..e60b41b4e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -19,8 +19,6 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" - "github.com/hashicorp/go-multierror" - "github.com/miekg/dns" gocache "github.com/patrickmn/go-cache" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -29,31 +27,26 @@ import ( "github.com/netbirdio/netbird/base62" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/geolocation" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrated_validator" - "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" ) const ( - PublicCategory = "public" - PrivateCategory = "private" - UnknownCategory = "unknown" - CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days - CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days - DefaultPeerLoginExpiration = 24 * time.Hour - DefaultPeerInactivityExpiration = 10 * time.Minute - emptyUserID = "empty user ID in claims" - errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" + CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days + CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days + emptyUserID = "empty user ID in claims" + errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool @@ -66,56 +59,56 @@ func cacheEntryExpiration() time.Duration { } type AccountManager interface { - GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) - GetAccount(ctx context.Context, accountID string) (*Account, error) - CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, - autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) - SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error) - CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error) + GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*types.Account, error) + GetAccount(ctx context.Context, accountID string) (*types.Account, error) + CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, + autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) + SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) + CreateUser(ctx context.Context, accountID, initiatorUserID string, key *types.UserInfo) (*types.UserInfo, error) DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error - ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) - SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) - SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) - SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) - GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) - GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) + ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) + SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) + SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) + SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) + GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) + GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error - GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) + GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) DeleteAccount(ctx context.Context, accountID, userID string) error MarkPATUsed(ctx context.Context, tokenID string) error - GetUserByID(ctx context.Context, id string) (*User, error) - GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) - ListUsers(ctx context.Context, accountID string) ([]*User, error) + GetUserByID(ctx context.Context, id string) (*types.User, error) + GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) + ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *Account) error + MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *types.Account) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) - GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) - AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) - CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) + GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) + GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) + AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error - GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) - GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) - GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) - GetGroup(ctx context.Context, accountId, groupID, userID string) (*nbgroup.Group, error) - GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) - GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) - SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error - SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error + GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) + GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) + GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) + GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) + GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) + SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error + SaveGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error DeleteGroup(ctx context.Context, accountId, userId, groupID string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error - GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) - SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) + GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) + SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error - ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) + ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) CreateRoute(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroupIDs []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) SaveRoute(ctx context.Context, accountID, userID string, route *route.Route) error @@ -129,12 +122,12 @@ type AccountManager interface { GetDNSDomain() string StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) - GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) - SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error + GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) + SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) - LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API - SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API + UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) + LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API + SyncPeer(ctx context.Context, sync PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager @@ -145,18 +138,19 @@ type AccountManager interface { GetIdpManager() idp.Manager UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(account *Account) (map[string]struct{}, error) - SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) + GetValidatedPeers(account *types.Account) (map[string]struct{}, error) + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) - GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) + GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error + UpdateAccountPeers(ctx context.Context, accountID string) } type DefaultAccountManager struct { - Store Store + Store store.Store // cacheMux and cacheLoading helps to make sure that only a single cache reload runs at a time per accountID cacheMux sync.Mutex // cacheLoading keeps the accountIDs that are currently reloading. The accountID has to be removed once cache has been reloaded @@ -191,763 +185,40 @@ type DefaultAccountManager struct { metrics telemetry.AppMetrics } -// Settings represents Account settings structure that can be modified via API and Dashboard -type Settings struct { - // PeerLoginExpirationEnabled globally enables or disables peer login expiration - PeerLoginExpirationEnabled bool - - // PeerLoginExpiration is a setting that indicates when peer login expires. - // Applies to all peers that have Peer.LoginExpirationEnabled set to true. - PeerLoginExpiration time.Duration - - // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration - PeerInactivityExpirationEnabled bool - - // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. - // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. - PeerInactivityExpiration time.Duration - - // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements - RegularUsersViewBlocked bool - - // GroupsPropagationEnabled allows to propagate auto groups from the user to the peer - GroupsPropagationEnabled bool - - // JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName - // and add it to account groups. - JWTGroupsEnabled bool - - // JWTGroupsClaimName from which we extract groups name to add it to account groups - JWTGroupsClaimName string - - // JWTAllowGroups list of groups to which users are allowed access - JWTAllowGroups []string `gorm:"serializer:json"` - - // Extra is a dictionary of Account settings - Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` -} - -// Copy copies the Settings struct -func (s *Settings) Copy() *Settings { - settings := &Settings{ - PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, - PeerLoginExpiration: s.PeerLoginExpiration, - JWTGroupsEnabled: s.JWTGroupsEnabled, - JWTGroupsClaimName: s.JWTGroupsClaimName, - GroupsPropagationEnabled: s.GroupsPropagationEnabled, - JWTAllowGroups: s.JWTAllowGroups, - RegularUsersViewBlocked: s.RegularUsersViewBlocked, - - PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, - PeerInactivityExpiration: s.PeerInactivityExpiration, - } - if s.Extra != nil { - settings.Extra = s.Extra.Copy() - } - return settings -} - -// Account represents a unique account of the system -type Account struct { - // we have to name column to aid as it collides with Network.Id when work with associations - Id string `gorm:"primaryKey"` - - // User.Id it was created by - CreatedBy string - CreatedAt time.Time - Domain string `gorm:"index"` - DomainCategory string - IsDomainPrimaryAccount bool - SetupKeys map[string]*SetupKey `gorm:"-"` - SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` - Network *Network `gorm:"embedded;embeddedPrefix:network_"` - Peers map[string]*nbpeer.Peer `gorm:"-"` - PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` - Users map[string]*User `gorm:"-"` - UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` - Groups map[string]*nbgroup.Group `gorm:"-"` - GroupsG []nbgroup.Group `json:"-" gorm:"foreignKey:AccountID;references:id"` - Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` - Routes map[route.ID]*route.Route `gorm:"-"` - RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` - NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` - NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` - DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` - PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` - // Settings is a dictionary of Account settings - Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` -} - -// Subclass used in gorm to only load settings and not whole account -type AccountSettings struct { - Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` -} - -// Subclass used in gorm to only load network and not whole account -type AccountNetwork struct { - Network *Network `gorm:"embedded;embeddedPrefix:network_"` -} - -// AccountDNSSettings used in gorm to only load dns settings and not whole account -type AccountDNSSettings struct { - DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` -} - -type UserPermissions struct { - DashboardView string `json:"dashboard_view"` -} - -type UserInfo struct { - ID string `json:"id"` - Email string `json:"email"` - Name string `json:"name"` - Role string `json:"role"` - AutoGroups []string `json:"auto_groups"` - Status string `json:"-"` - IsServiceUser bool `json:"is_service_user"` - IsBlocked bool `json:"is_blocked"` - NonDeletable bool `json:"non_deletable"` - LastLogin time.Time `json:"last_login"` - Issued string `json:"issued"` - IntegrationReference integration_reference.IntegrationReference `json:"-"` - Permissions UserPermissions `json:"permissions"` -} - -// getRoutesToSync returns the enabled routes for the peer ID and the routes -// from the ACL peers that have distribution groups associated with the peer ID. -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -func (a *Account) getRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route { - routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) - peerRoutesMembership := make(lookupMap) - for _, r := range append(routes, peerDisabledRoutes...) { - peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} - } - - groupListMap := a.getPeerGroups(peerID) - for _, peer := range aclPeers { - activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) - groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) - filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) - routes = append(routes, filteredRoutes...) - } - - return routes -} - -// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership -func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships lookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - _, found := peerMemberships[string(r.GetHAUniqueID())] - if !found { - filteredRoutes = append(filteredRoutes, r) - } - } - return filteredRoutes -} - -// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map -func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap lookupMap) []*route.Route { - var filteredRoutes []*route.Route - for _, r := range routes { - for _, groupID := range r.Groups { - _, found := groupListMap[groupID] - if found { - filteredRoutes = append(filteredRoutes, r) - break - } - } - } - return filteredRoutes -} - -// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves -// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -// If the given is not a routing peer, then the lists are empty. -func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { - - peer := a.GetPeer(peerID) - if peer == nil { - log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) - return enabledRoutes, disabledRoutes - } - - // currently we support only linux routing peers - if peer.Meta.GoOS != "linux" { - return enabledRoutes, disabledRoutes - } - - seenRoute := make(map[route.ID]struct{}) - - takeRoute := func(r *route.Route, id string) { - if _, ok := seenRoute[r.ID]; ok { - return - } - seenRoute[r.ID] = struct{}{} - - if r.Enabled { - r.Peer = peer.Key - enabledRoutes = append(enabledRoutes, r) - return - } - disabledRoutes = append(disabledRoutes, r) - } - - for _, r := range a.Routes { - for _, groupID := range r.PeerGroups { - group := a.GetGroup(groupID) - if group == nil { - log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) - continue - } - for _, id := range group.Peers { - if id != peerID { - continue - } - - newPeerRoute := r.Copy() - newPeerRoute.Peer = id - newPeerRoute.PeerGroups = nil - newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map - takeRoute(newPeerRoute, id) - break - } - } - if r.Peer == peerID { - takeRoute(r.Copy(), peerID) - } - } - - return enabledRoutes, disabledRoutes -} - -// GetRoutesByPrefixOrDomains return list of routes by account and route prefix -func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { - var routes []*route.Route - for _, r := range a.Routes { - dynamic := r.IsDynamic() - if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || - !dynamic && r.Network.String() == prefix.String() { - routes = append(routes, r) - } - } - - return routes -} - -// GetGroup returns a group by ID if exists, nil otherwise -func (a *Account) GetGroup(groupID string) *nbgroup.Group { - return a.Groups[groupID] -} - -// GetPeerNetworkMap returns the networkmap for the given peer ID. -func (a *Account) GetPeerNetworkMap( - ctx context.Context, - peerID string, - peersCustomZone nbdns.CustomZone, - validatedPeersMap map[string]struct{}, - metrics *telemetry.AccountManagerMetrics, -) *NetworkMap { - start := time.Now() - - peer := a.Peers[peerID] - if peer == nil { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - if _, ok := validatedPeersMap[peerID]; !ok { - return &NetworkMap{ - Network: a.Network.Copy(), - } - } - - aclPeers, firewallRules := a.getPeerConnectionResources(ctx, peerID, validatedPeersMap) - // exclude expired peers - var peersToConnect []*nbpeer.Peer - var expiredPeers []*nbpeer.Peer - for _, p := range aclPeers { - expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) - if a.Settings.PeerLoginExpirationEnabled && expired { - expiredPeers = append(expiredPeers, p) - continue - } - peersToConnect = append(peersToConnect, p) - } - - routesUpdate := a.getRoutesToSync(ctx, peerID, peersToConnect) - routesFirewallRules := a.getPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) - - dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) - dnsUpdate := nbdns.Config{ - ServiceEnable: dnsManagementStatus, - } - - if dnsManagementStatus { - var zones []nbdns.CustomZone - - if peersCustomZone.Domain != "" { - zones = append(zones, peersCustomZone) - } - dnsUpdate.CustomZones = zones - dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) - } - - nm := &NetworkMap{ - Peers: peersToConnect, - Network: a.Network.Copy(), - Routes: routesUpdate, - DNSConfig: dnsUpdate, - OfflinePeers: expiredPeers, - FirewallRules: firewallRules, - RoutesFirewallRules: routesFirewallRules, - } - - if metrics != nil { - objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules)) - metrics.CountNetworkMapObjects(objectCount) - metrics.CountGetPeerNetworkMapDuration(time.Since(start)) - - if objectCount > 5000 { - log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ - "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d", - a.Id, objectCount, len(peersToConnect), len(expiredPeers), len(routesUpdate), len(firewallRules)) - } - } - - return nm -} - -func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone { - var merr *multierror.Error - - if dnsDomain == "" { - log.WithContext(ctx).Error("no dns domain is set, returning empty zone") - return nbdns.CustomZone{} - } - - customZone := nbdns.CustomZone{ - Domain: dns.Fqdn(dnsDomain), - Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)), - } - - domainSuffix := "." + dnsDomain - - var sb strings.Builder - for _, peer := range a.Peers { - if peer.DNSLabel == "" { - merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name)) - continue - } - - sb.Grow(len(peer.DNSLabel) + len(domainSuffix)) - sb.WriteString(peer.DNSLabel) - sb.WriteString(domainSuffix) - - customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ - Name: sb.String(), - Type: int(dns.TypeA), - Class: nbdns.DefaultClass, - TTL: defaultTTL, - RData: peer.IP.String(), - }) - - sb.Reset() - } - - go func() { - if merr != nil { - log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr) - } - }() - - return customZone -} - -// GetExpiredPeers returns peers that have been expired -func (a *Account) GetExpiredPeers() []*nbpeer.Peer { - var peers []*nbpeer.Peer - for _, peer := range a.GetPeersWithExpiration() { - expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) - if expired { - peers = append(peers, peer) - } - } - - return peers -} - -// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. -// If there is no peer that expires this function returns false and a duration of 0. -// This function only considers peers that haven't been expired yet and that are connected. -func (a *Account) GetNextPeerExpiration() (time.Duration, bool) { - peersWithExpiry := a.GetPeersWithExpiration() - if len(peersWithExpiry) == 0 { - return 0, false - } - var nextExpiry *time.Duration - for _, peer := range peersWithExpiry { - // consider only connected peers because others will require login on connecting to the management server - if peer.Status.LoginExpired || !peer.Status.Connected { - continue - } - _, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration) - if nextExpiry == nil || duration < *nextExpiry { - // if expiration is below 1s return 1s duration - // this avoids issues with ticker that can't be set to < 0 - if duration < time.Second { - return time.Second, true - } - nextExpiry = &duration - } - } - - if nextExpiry == nil { - return 0, false - } - - return *nextExpiry, true -} - -// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user -func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { - peers := make([]*nbpeer.Peer, 0) - for _, peer := range a.Peers { - if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() { - peers = append(peers, peer) - } - } - return peers -} - -// GetInactivePeers returns peers that have been expired by inactivity -func (a *Account) GetInactivePeers() []*nbpeer.Peer { - var peers []*nbpeer.Peer - for _, inactivePeer := range a.GetPeersWithInactivity() { - inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) - if inactive { - peers = append(peers, inactivePeer) - } - } - return peers -} - -// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. -// If there is no peer that expires this function returns false and a duration of 0. -// This function only considers peers that haven't been expired yet and that are not connected. -func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { - peersWithExpiry := a.GetPeersWithInactivity() - if len(peersWithExpiry) == 0 { - return 0, false - } - var nextExpiry *time.Duration - for _, peer := range peersWithExpiry { - if peer.Status.LoginExpired || peer.Status.Connected { - continue - } - _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) - if nextExpiry == nil || duration < *nextExpiry { - // if expiration is below 1s return 1s duration - // this avoids issues with ticker that can't be set to < 0 - if duration < time.Second { - return time.Second, true - } - nextExpiry = &duration - } - } - - if nextExpiry == nil { - return 0, false - } - - return *nextExpiry, true -} - -// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user -func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { - peers := make([]*nbpeer.Peer, 0) - for _, peer := range a.Peers { - if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { - peers = append(peers, peer) - } - } - return peers -} - -// GetPeers returns a list of all Account peers -func (a *Account) GetPeers() []*nbpeer.Peer { - var peers []*nbpeer.Peer - for _, peer := range a.Peers { - peers = append(peers, peer) - } - return peers -} - -// UpdateSettings saves new account settings -func (a *Account) UpdateSettings(update *Settings) *Account { - a.Settings = update.Copy() - return a -} - -// UpdatePeer saves new or replaces existing peer -func (a *Account) UpdatePeer(update *nbpeer.Peer) { - a.Peers[update.ID] = update -} - -// DeletePeer deletes peer from the account cleaning up all the references -func (a *Account) DeletePeer(peerID string) { - // delete peer from groups - for _, g := range a.Groups { - for i, pk := range g.Peers { - if pk == peerID { - g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) - break - } - } - } - - for _, r := range a.Routes { - if r.Peer == peerID { - r.Enabled = false - r.Peer = "" - } - } - - delete(a.Peers, peerID) - a.Network.IncSerial() -} - -// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found. -// It will return an object copy of the peer. -func (a *Account) FindPeerByPubKey(peerPubKey string) (*nbpeer.Peer, error) { - for _, peer := range a.Peers { - if peer.Key == peerPubKey { - return peer.Copy(), nil - } - } - - return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey) -} - -// FindUserPeers returns a list of peers that user owns (created) -func (a *Account) FindUserPeers(userID string) ([]*nbpeer.Peer, error) { - peers := make([]*nbpeer.Peer, 0) - for _, peer := range a.Peers { - if peer.UserID == userID { - peers = append(peers, peer) - } - } - - return peers, nil -} - -// FindUser looks for a given user in the Account or returns error if user wasn't found. -func (a *Account) FindUser(userID string) (*User, error) { - user := a.Users[userID] - if user == nil { - return nil, status.Errorf(status.NotFound, "user %s not found", userID) - } - - return user, nil -} - -// FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. -func (a *Account) FindGroupByName(groupName string) (*nbgroup.Group, error) { - for _, group := range a.Groups { - if group.Name == groupName { - return group, nil - } - } - return nil, status.Errorf(status.NotFound, "group %s not found", groupName) -} - -// FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found. -func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { - key := a.SetupKeys[setupKey] - if key == nil { - return nil, status.Errorf(status.NotFound, "setup key not found") - } - - return key, nil -} - -// GetPeerGroupsList return with the list of groups ID. -func (a *Account) GetPeerGroupsList(peerID string) []string { - var grps []string - for groupID, group := range a.Groups { - for _, id := range group.Peers { - if id == peerID { - grps = append(grps, groupID) - break - } - } - } - return grps -} - -func (a *Account) getPeerDNSManagementStatus(peerID string) bool { - peerGroups := a.getPeerGroups(peerID) - enabled := true - for _, groupID := range a.DNSSettings.DisabledManagementGroups { - _, found := peerGroups[groupID] - if found { - enabled = false - break - } - } - return enabled -} - -func (a *Account) getPeerGroups(peerID string) lookupMap { - groupList := make(lookupMap) - for groupID, group := range a.Groups { - for _, id := range group.Peers { - if id == peerID { - groupList[groupID] = struct{}{} - break - } - } - } - return groupList -} - -func (a *Account) getTakenIPs() []net.IP { - var takenIps []net.IP - for _, existingPeer := range a.Peers { - takenIps = append(takenIps, existingPeer.IP) - } - - return takenIps -} - -func (a *Account) getPeerDNSLabels() lookupMap { - existingLabels := make(lookupMap) - for _, peer := range a.Peers { - if peer.DNSLabel != "" { - existingLabels[peer.DNSLabel] = struct{}{} - } - } - return existingLabels -} - -func (a *Account) Copy() *Account { - peers := map[string]*nbpeer.Peer{} - for id, peer := range a.Peers { - peers[id] = peer.Copy() - } - - users := map[string]*User{} - for id, user := range a.Users { - users[id] = user.Copy() - } - - setupKeys := map[string]*SetupKey{} - for id, key := range a.SetupKeys { - setupKeys[id] = key.Copy() - } - - groups := map[string]*nbgroup.Group{} - for id, group := range a.Groups { - groups[id] = group.Copy() - } - - policies := []*Policy{} - for _, policy := range a.Policies { - policies = append(policies, policy.Copy()) - } - - routes := map[route.ID]*route.Route{} - for id, r := range a.Routes { - routes[id] = r.Copy() - } - - nsGroups := map[string]*nbdns.NameServerGroup{} - for id, nsGroup := range a.NameServerGroups { - nsGroups[id] = nsGroup.Copy() - } - - dnsSettings := a.DNSSettings.Copy() - - var settings *Settings - if a.Settings != nil { - settings = a.Settings.Copy() - } - - postureChecks := []*posture.Checks{} - for _, postureCheck := range a.PostureChecks { - postureChecks = append(postureChecks, postureCheck.Copy()) - } - - return &Account{ - Id: a.Id, - CreatedBy: a.CreatedBy, - CreatedAt: a.CreatedAt, - Domain: a.Domain, - DomainCategory: a.DomainCategory, - IsDomainPrimaryAccount: a.IsDomainPrimaryAccount, - SetupKeys: setupKeys, - Network: a.Network.Copy(), - Peers: peers, - Users: users, - Groups: groups, - Policies: policies, - Routes: routes, - NameServerGroups: nsGroups, - DNSSettings: dnsSettings, - PostureChecks: postureChecks, - Settings: settings, - } -} - -func (a *Account) GetGroupAll() (*nbgroup.Group, error) { - for _, g := range a.Groups { - if g.Name == "All" { - return g, nil - } - } - return nil, fmt.Errorf("no group ALL found") -} - -// GetPeer looks up a Peer by ID -func (a *Account) GetPeer(peerID string) *nbpeer.Peer { - return a.Peers[peerID] -} - // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. // Returns a bool indicating if there are changes in the JWT group membership, the updated user AutoGroups, // newly groups to create and an error if any occurred. -func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgroup.Group, groupNames []string) (bool, []string, []*nbgroup.Group, error) { - existedGroupsByName := make(map[string]*nbgroup.Group) +func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups []*types.Group, groupNames []string) (bool, []string, []*types.Group, error) { + existedGroupsByName := make(map[string]*types.Group) for _, group := range groups { existedGroupsByName[group.Name] = group } newUserAutoGroups, jwtGroupsMap := separateGroups(user.AutoGroups, groups) - groupsToAdd := difference(groupNames, maps.Keys(jwtGroupsMap)) - groupsToRemove := difference(maps.Keys(jwtGroupsMap), groupNames) + groupsToAdd := util.Difference(groupNames, maps.Keys(jwtGroupsMap)) + groupsToRemove := util.Difference(maps.Keys(jwtGroupsMap), groupNames) // If no groups are added or removed, we should not sync account if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { return false, nil, nil, nil } - newGroupsToCreate := make([]*nbgroup.Group, 0) + newGroupsToCreate := make([]*types.Group, 0) var modified bool for _, name := range groupsToAdd { group, exists := existedGroupsByName[name] if !exists { - group = &nbgroup.Group{ + group = &types.Group{ ID: xid.New().String(), AccountID: user.AccountID, Name: name, - Issued: nbgroup.GroupIssuedJWT, + Issued: types.GroupIssuedJWT, } newGroupsToCreate = append(newGroupsToCreate, group) } - if group.Issued == nbgroup.GroupIssuedJWT { + if group.Issued == types.GroupIssuedJWT { newUserAutoGroups = append(newUserAutoGroups, group.ID) modified = true } @@ -964,78 +235,10 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *User, groups []*nbgro return modified, newUserAutoGroups, newGroupsToCreate, nil } -// UserGroupsAddToPeers adds groups to all peers of user -func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string { - groupUpdates := make(map[string][]string) - - userPeers := make(map[string]struct{}) - for pid, peer := range a.Peers { - if peer.UserID == userID { - userPeers[pid] = struct{}{} - } - } - - for _, gid := range groups { - group, ok := a.Groups[gid] - if !ok { - continue - } - - oldPeers := group.Peers - - groupPeers := make(map[string]struct{}) - for _, pid := range group.Peers { - groupPeers[pid] = struct{}{} - } - - for pid := range userPeers { - groupPeers[pid] = struct{}{} - } - - group.Peers = group.Peers[:0] - for pid := range groupPeers { - group.Peers = append(group.Peers, pid) - } - - groupUpdates[gid] = difference(group.Peers, oldPeers) - } - - return groupUpdates -} - -// UserGroupsRemoveFromPeers removes groups from all peers of user -func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string { - groupUpdates := make(map[string][]string) - - for _, gid := range groups { - group, ok := a.Groups[gid] - if !ok || group.Name == "All" { - continue - } - - oldPeers := group.Peers - - update := make([]string, 0, len(group.Peers)) - for _, pid := range group.Peers { - peer, ok := a.Peers[pid] - if !ok { - continue - } - if peer.UserID != userID { - update = append(update, pid) - } - } - group.Peers = update - groupUpdates[gid] = difference(oldPeers, group.Peers) - } - - return groupUpdates -} - // BuildManager creates a new DefaultAccountManager with a provided Store func BuildManager( ctx context.Context, - store Store, + store store.Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager, singleAccountModeDomain string, @@ -1137,7 +340,7 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. // Returns an updated Account -func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) { +func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -1186,6 +389,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.checkAndSchedulePeerLoginExpiration(ctx, account) } + updateAccountPeers := false + if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled { + if newSettings.RoutingPeerDNSResolutionEnabled { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionEnabled, nil) + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionDisabled, nil) + } + updateAccountPeers = true + account.Network.Serial++ + } + err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID) if err != nil { return nil, err @@ -1203,10 +417,14 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, err } + if updateAccountPeers { + go am.UpdateAccountPeers(ctx, accountID) + } + return updatedAccount, nil } -func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled { if newSettings.GroupsPropagationEnabled { am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil) @@ -1219,7 +437,7 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con return nil } -func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *types.Account, oldSettings, newSettings *types.Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { @@ -1272,7 +490,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc } } -func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) { +func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *types.Account) { am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) if nextRun, ok := account.GetNextPeerExpiration(); ok { go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) @@ -1309,7 +527,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context } // checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions -func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) { +func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *types.Account) { am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) @@ -1318,7 +536,7 @@ func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx co // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error -func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { +func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*types.Account, error) { for i := 0; i < 2; i++ { accountId := xid.New().String() @@ -1398,7 +616,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u return status.Errorf(status.PermissionDenied, "user is not allowed to delete account") } - if user.Role != UserRoleOwner { + if user.Role != types.UserRoleOwner { return status.Errorf(status.PermissionDenied, "user is not allowed to delete account. Only account owner can delete account") } for _, otherUser := range account.Users { @@ -1436,7 +654,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u // AccountExists checks if an account exists. func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { - return am.Store.AccountExists(ctx, LockingStrengthShare, accountID) + return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID) } // GetAccountIDByUserID retrieves the account ID based on the userID provided. @@ -1473,13 +691,13 @@ func isNil(i idp.Manager) bool { // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { - accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } - cachedAccount := &Account{ + cachedAccount := &types.Account{ Id: accountID, - Users: make(map[string]*User), + Users: make(map[string]*types.User), } for _, user := range accountUsers { cachedAccount.Users[user.Id] = user @@ -1562,14 +780,14 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e } // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil -func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *Account) (*idp.UserData, error) { +func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *types.Account) (*idp.UserData, error) { users := make(map[string]userLoggedInOnce, len(account.Users)) // ignore service users and users provisioned by integrations than are never logged in for _, user := range account.Users { if user.IsServiceUser { continue } - if user.Issued == UserIssuedIntegration { + if user.Issued == types.UserIssuedIntegration { continue } users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) @@ -1739,7 +957,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlockAccount() - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return err @@ -1749,7 +967,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx return nil } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting user: %v", err) return err @@ -1834,8 +1052,8 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID) defer unlockAccount() - usersMap := make(map[string]*User) - usersMap[claims.UserId] = NewRegularUser(claims.UserId) + usersMap := make(map[string]*types.User) + usersMap[claims.UserId] = types.NewRegularUser(claims.UserId) err := am.Store.SaveUsers(domainAccountID, usersMap) if err != nil { return "", err @@ -1923,22 +1141,22 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string } // GetAccount returns an account associated with this account ID. -func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) { +func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { return am.Store.GetAccount(ctx, accountID) } // GetAccountFromPAT returns Account and User associated with a personal access token -func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) { - if len(token) != PATLength { +func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) { + if len(token) != types.PATLength { return nil, nil, nil, fmt.Errorf("token has wrong length") } - prefix := token[:len(PATPrefix)] - if prefix != PATPrefix { + prefix := token[:len(types.PATPrefix)] + if prefix != types.PATPrefix { return nil, nil, nil, fmt.Errorf("token has wrong prefix") } - secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength] - encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength] + secret := token[len(types.PATPrefix) : len(types.PATPrefix)+types.PATSecretLength] + encodedChecksum := token[len(types.PATPrefix)+types.PATSecretLength : len(types.PATPrefix)+types.PATSecretLength+types.PATChecksumLength] verificationChecksum, err := base62.Decode(encodedChecksum) if err != nil { @@ -1976,8 +1194,8 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st } // GetAccountByID returns an account associated with this account ID. -func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*Account, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -1998,7 +1216,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // This section is mostly related to self-hosted installations. // We override incoming domain claims to group users under a single account. claims.Domain = am.singleAccountModeDomain - claims.DomainCategory = PrivateCategory + claims.DomainCategory = types.PrivateCategory log.WithContext(ctx).Debugf("overriding JWT Domain and DomainCategory claims since single account mode is enabled") } @@ -2007,7 +1225,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai return "", "", err } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { // this is not really possible because we got an account by user ID return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) @@ -2034,7 +1252,7 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -2060,14 +1278,14 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st var addNewGroups []string var removeOldGroups []string var hasChanges bool - var user *User - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - user, err = transaction.GetUserByUserID(ctx, LockingStrengthShare, claims.UserId) + var user *types.User + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { return fmt.Errorf("error getting user: %w", err) } - groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } @@ -2083,31 +1301,31 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return nil } - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, newGroupsToCreate); err != nil { + if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, newGroupsToCreate); err != nil { return fmt.Errorf("error saving groups: %w", err) } - addNewGroups = difference(updatedAutoGroups, user.AutoGroups) - removeOldGroups = difference(user.AutoGroups, updatedAutoGroups) + addNewGroups = util.Difference(updatedAutoGroups, user.AutoGroups) + removeOldGroups = util.Difference(user.AutoGroups, updatedAutoGroups) user.AutoGroups = updatedAutoGroups - if err = transaction.SaveUser(ctx, LockingStrengthUpdate, user); err != nil { + if err = transaction.SaveUser(ctx, store.LockingStrengthUpdate, user); err != nil { return fmt.Errorf("error saving user: %w", err) } // Propagate changes to peers if group propagation is enabled if settings.GroupsPropagationEnabled { - groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("error getting account groups: %w", err) } - groupsMap := make(map[string]*nbgroup.Group, len(groups)) + groupsMap := make(map[string]*types.Group, len(groups)) for _, group := range groups { groupsMap[group.ID] = group } - peers, err := transaction.GetUserPeers(ctx, LockingStrengthShare, accountID, claims.UserId) + peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, claims.UserId) if err != nil { return fmt.Errorf("error getting user peers: %w", err) } @@ -2117,11 +1335,11 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error modifying user peers in groups: %w", err) } - if err = transaction.SaveGroups(ctx, LockingStrengthUpdate, updatedGroups); err != nil { + if err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, updatedGroups); err != nil { return fmt.Errorf("error saving groups: %w", err) } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } @@ -2139,7 +1357,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2152,7 +1370,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) + group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2177,7 +1395,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st if removedGroupAffectsPeers || newGroupsAffectsPeers { log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } } @@ -2210,7 +1428,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", errors.New(emptyUserID) } - if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { + if claims.DomainCategory != types.PrivateCategory || !isDomainValid(claims.Domain) { return am.GetAccountIDByUserID(ctx, claims.UserId, claims.Domain) } @@ -2248,7 +1466,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return am.addNewPrivateAccount(ctx, domainAccountID, claims) } func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) @@ -2263,7 +1481,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont cancel := am.Store.AcquireGlobalLock(ctx) // check again if the domain has a primary account because of simultaneous requests - domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) + domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) if handleNotFound(err) != nil { cancel() log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) @@ -2284,7 +1502,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context return "", fmt.Errorf("user %s is not part of the account id %s", claims.UserId, claims.AccountId) } - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, claims.AccountId) + accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, claims.AccountId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) return "", err @@ -2295,7 +1513,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context } // We checked if the domain has a primary account already - domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, claims.Domain) + domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, claims.Domain) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) return "", err @@ -2322,10 +1540,10 @@ func handleNotFound(err error) error { } func domainIsUpToDate(domain string, domainCategory string, claims jwtclaims.AuthorizationClaims) bool { - return domainCategory == PrivateCategory || claims.DomainCategory != PrivateCategory || domain != claims.Domain + return domainCategory == types.PrivateCategory || claims.DomainCategory != types.PrivateCategory || domain != claims.Domain } -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { accountUnlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer accountUnlock() peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) @@ -2422,7 +1640,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, return err } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -2444,7 +1662,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { @@ -2455,8 +1673,8 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey) } -func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) +func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *types.Settings) (bool, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) if err != nil { return false, err } @@ -2477,14 +1695,14 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee return false, nil } -func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) { - existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID) +func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) { + existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) if err != nil { return "", fmt.Errorf("failed to get peer dns labels: %w", err) } labelMap := ConvertSliceToMap(existingLabels) - newLabel, err := getPeerHostLabel(peerHostName, labelMap) + newLabel, err := types.GetPeerHostLabel(peerHostName, labelMap) if err != nil { return "", fmt.Errorf("failed to get new host label: %w", err) } @@ -2496,8 +1714,8 @@ func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Stor return newLabel, nil } -func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*Settings, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -2506,70 +1724,70 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") } - return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) } // addAllGroup to account object if it doesn't exist -func addAllGroup(account *Account) error { +func addAllGroup(account *types.Account) error { if len(account.Groups) == 0 { - allGroup := &nbgroup.Group{ + allGroup := &types.Group{ ID: xid.New().String(), Name: "All", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, } for _, peer := range account.Peers { allGroup.Peers = append(allGroup.Peers, peer.ID) } - account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup} + account.Groups = map[string]*types.Group{allGroup.ID: allGroup} id := xid.New().String() - defaultPolicy := &Policy{ + defaultPolicy := &types.Policy{ ID: id, - Name: DefaultRuleName, - Description: DefaultRuleDescription, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: id, - Name: DefaultRuleName, - Description: DefaultRuleDescription, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, Enabled: true, Sources: []string{allGroup.ID}, Destinations: []string{allGroup.ID}, Bidirectional: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, }, }, } - account.Policies = []*Policy{defaultPolicy} + account.Policies = []*types.Policy{defaultPolicy} } return nil } // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id -func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Account { +func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account { log.WithContext(ctx).Debugf("creating new account") - network := NewNetwork() + network := types.NewNetwork() peers := make(map[string]*nbpeer.Peer) - users := make(map[string]*User) + users := make(map[string]*types.User) routes := make(map[route.ID]*route.Route) - setupKeys := map[string]*SetupKey{} + setupKeys := map[string]*types.SetupKey{} nameServersGroups := make(map[string]*nbdns.NameServerGroup) - owner := NewOwnerUser(userID) + owner := types.NewOwnerUser(userID) owner.AccountID = accountID users[userID] = owner - dnsSettings := DNSSettings{ + dnsSettings := types.DNSSettings{ DisabledManagementGroups: make([]string, 0), } log.WithContext(ctx).Debugf("created new account %s", accountID) - acc := &Account{ + acc := &types.Account{ Id: accountID, CreatedAt: time.Now().UTC(), SetupKeys: setupKeys, @@ -2581,14 +1799,15 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Ac Routes: routes, NameServerGroups: nameServersGroups, DNSSettings: dnsSettings, - Settings: &Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: true, - PeerLoginExpiration: DefaultPeerLoginExpiration, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, GroupsPropagationEnabled: true, RegularUsersViewBlocked: true, PeerInactivityExpirationEnabled: false, - PeerInactivityExpiration: DefaultPeerInactivityExpiration, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + RoutingPeerDNSResolutionEnabled: true, }, } @@ -2634,18 +1853,18 @@ func userHasAllowedGroup(allowedGroups []string, userGroups []string) bool { // separateGroups separates user's auto groups into non-JWT and JWT groups. // Returns the list of standard auto groups and a map of JWT auto groups, // where the keys are the group names and the values are the group IDs. -func separateGroups(autoGroups []string, allGroups []*nbgroup.Group) ([]string, map[string]string) { +func separateGroups(autoGroups []string, allGroups []*types.Group) ([]string, map[string]string) { newAutoGroups := make([]string, 0) jwtAutoGroups := make(map[string]string) // map of group name to group ID - allGroupsMap := make(map[string]*nbgroup.Group, len(allGroups)) + allGroupsMap := make(map[string]*types.Group, len(allGroups)) for _, group := range allGroups { allGroupsMap[group.ID] = group } for _, id := range autoGroups { if group, ok := allGroupsMap[id]; ok { - if group.Issued == nbgroup.GroupIssuedJWT { + if group.Issued == types.GroupIssuedJWT { jwtAutoGroups[group.Name] = id } else { newAutoGroups = append(newAutoGroups, id) diff --git a/management/server/account_request_buffer.go b/management/server/account_request_buffer.go index 5f4897e6a..fa6c45856 100644 --- a/management/server/account_request_buffer.go +++ b/management/server/account_request_buffer.go @@ -7,6 +7,9 @@ import ( "time" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) // AccountRequest holds the result channel to return the requested account. @@ -17,19 +20,19 @@ type AccountRequest struct { // AccountResult holds the account data or an error. type AccountResult struct { - Account *Account + Account *types.Account Err error } type AccountRequestBuffer struct { - store Store + store store.Store getAccountRequests map[string][]*AccountRequest mu sync.Mutex getAccountRequestCh chan *AccountRequest bufferInterval time.Duration } -func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBuffer { +func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountRequestBuffer { bufferIntervalStr := os.Getenv("NB_GET_ACCOUNT_BUFFER_INTERVAL") bufferInterval, err := time.ParseDuration(bufferIntervalStr) if err != nil { @@ -52,7 +55,7 @@ func NewAccountRequestBuffer(ctx context.Context, store Store) *AccountRequestBu return &ac } -func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*Account, error) { +func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) { req := &AccountRequest{ AccountID: accountID, ResultChan: make(chan *AccountResult, 1), diff --git a/management/server/account_test.go b/management/server/account_test.go index d952e118a..d83eab6d1 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -16,6 +16,11 @@ import ( "time" "github.com/golang-jwt/jwt" + + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -24,11 +29,12 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -46,7 +52,7 @@ func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.P } return update, false, nil } -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for _, peer := range peers { validatedPeers[peer.ID] = struct{}{} @@ -73,7 +79,7 @@ func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string) func (MocIntegratedValidator) Stop(_ context.Context) { } -func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Account, userID string) { +func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) { t.Helper() peer := &nbpeer.Peer{ Key: "BhRPtynAAYRDy08+q4HTMsos8fs4plTP4NOSh7C1ry8=", @@ -101,7 +107,7 @@ func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *Ac } } -func verifyNewAccountHasDefaultFields(t *testing.T, account *Account, createdBy string, domain string, expectedUsers []string) { +func verifyNewAccountHasDefaultFields(t *testing.T, account *types.Account, createdBy string, domain string, expectedUsers []string) { t.Helper() if len(account.Peers) != 0 { t.Errorf("expected account to have len(Peers) = %v, got %v", 0, len(account.Peers)) @@ -156,7 +162,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { // peerID3 := "peer-3" tt := []struct { name string - accountSettings Settings + accountSettings types.Settings peerID string expectedPeers []string expectedOfflinePeers []string @@ -164,7 +170,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { }{ { name: "Should return ALL peers when global peer login expiration disabled", - accountSettings: Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour}, + accountSettings: types.Settings{PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour}, peerID: peerID1, expectedPeers: []string{peerID2}, expectedOfflinePeers: []string{}, @@ -202,7 +208,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { }, { name: "Should return no peers when global peer login expiration enabled and peers expired", - accountSettings: Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour}, + accountSettings: types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour}, peerID: peerID1, expectedPeers: []string{}, expectedOfflinePeers: []string{peerID2}, @@ -396,12 +402,12 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { netIP := net.IP{100, 64, 0, 0} netMask := net.IPMask{255, 255, 0, 0} - network := &Network{ + network := &types.Network{ Identifier: "network", Net: net.IPNet{IP: netIP, Mask: netMask}, Dns: "netbird.selfhosted", Serial: 0, - mu: sync.Mutex{}, + Mu: sync.Mutex{}, } for _, testCase := range tt { @@ -420,7 +426,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil) + networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -485,12 +491,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { } initUnknown := defaultInitAccount - initUnknown.DomainCategory = UnknownCategory + initUnknown.DomainCategory = types.UnknownCategory initUnknown.Domain = unknownDomain privateInitAccount := defaultInitAccount privateInitAccount.Domain = privateDomain - privateInitAccount.DomainCategory = PrivateCategory + privateInitAccount.DomainCategory = types.PrivateCategory testCases := []struct { name string @@ -500,7 +506,7 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputUpdateClaimAccount bool testingFunc require.ComparisonAssertionFunc expectedMSG string - expectedUserRole UserRole + expectedUserRole types.UserRole expectedDomainCategory string expectedDomain string expectedPrimaryDomainStatus bool @@ -512,12 +518,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: publicDomain, UserId: "pub-domain-user", - DomainCategory: PublicCategory, + DomainCategory: types.PublicCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomainCategory: "", expectedDomain: publicDomain, expectedPrimaryDomainStatus: false, @@ -529,12 +535,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: unknownDomain, UserId: "unknown-domain-user", - DomainCategory: UnknownCategory, + DomainCategory: types.UnknownCategory, }, inputInitUserParams: initUnknown, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: unknownDomain, expectedDomainCategory: "", expectedPrimaryDomainStatus: false, @@ -546,14 +552,14 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: privateDomain, UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: "pvt-domain-user", expectedUsers: []string{"pvt-domain-user"}, @@ -563,15 +569,15 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: privateDomain, UserId: "new-pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputUpdateAttrs: true, inputInitUserParams: privateInitAccount, testingFunc: require.Equal, expectedMSG: "account IDs should match", - expectedUserRole: UserRoleUser, + expectedUserRole: types.UserRoleUser, expectedDomain: privateDomain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: defaultInitAccount.UserId, expectedUsers: []string{defaultInitAccount.UserId, "new-pvt-domain-user"}, @@ -581,14 +587,14 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.Equal, expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: defaultInitAccount.UserId, expectedUsers: []string{defaultInitAccount.UserId}, @@ -598,15 +604,15 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: defaultInitAccount.Domain, UserId: defaultInitAccount.UserId, - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputUpdateClaimAccount: true, inputInitUserParams: defaultInitAccount, testingFunc: require.Equal, expectedMSG: "account IDs should match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: defaultInitAccount.Domain, - expectedDomainCategory: PrivateCategory, + expectedDomainCategory: types.PrivateCategory, expectedPrimaryDomainStatus: true, expectedCreatedBy: defaultInitAccount.UserId, expectedUsers: []string{defaultInitAccount.UserId}, @@ -616,12 +622,12 @@ func TestDefaultAccountManager_GetAccountIDFromToken(t *testing.T) { inputClaims: jwtclaims.AuthorizationClaims{ Domain: "", UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, }, inputInitUserParams: defaultInitAccount, testingFunc: require.NotEqual, expectedMSG: "account IDs shouldn't match", - expectedUserRole: UserRoleOwner, + expectedUserRole: types.UserRoleOwner, expectedDomain: "", expectedDomainCategory: "", expectedPrimaryDomainStatus: false, @@ -733,7 +739,7 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { require.Len(t, account.Groups, 3, "groups should be added to the account") - groupsByNames := map[string]*group.Group{} + groupsByNames := map[string]*types.Group{} for _, g := range account.Groups { groupsByNames[g.Name] = g } @@ -741,32 +747,36 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { g1, ok := groupsByNames["group1"] require.True(t, ok, "group1 should be added to the account") require.Equal(t, g1.Name, "group1", "group1 name should match") - require.Equal(t, g1.Issued, group.GroupIssuedJWT, "group1 issued should match") + require.Equal(t, g1.Issued, types.GroupIssuedJWT, "group1 issued should match") g2, ok := groupsByNames["group2"] require.True(t, ok, "group2 should be added to the account") require.Equal(t, g2.Name, "group2", "group2 name should match") - require.Equal(t, g2.Issued, group.GroupIssuedJWT, "group2 issued should match") + require.Equal(t, g2.Issued, types.GroupIssuedJWT, "group2 issued should match") }) } func TestAccountManager_GetAccountFromPAT(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) account := newAccountWithId(context.Background(), "account_id", "testuser", "") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ + account.Users["someUser"] = &types.User{ Id: "someUser", - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "tokenId": { ID: "tokenId", HashedToken: encodedHashedToken, }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -786,15 +796,20 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { } func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), "account_id", "testuser", "") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ + account.Users["someUser"] = &types.User{ Id: "someUser", - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "tokenId": { ID: "tokenId", HashedToken: encodedHashedToken, @@ -802,7 +817,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -904,7 +919,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { return } - exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID) + exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID) assert.NoError(t, err) assert.True(t, exists, "expected to get existing account after creation using userid") @@ -914,7 +929,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { } } -func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { +func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*types.Account, error) { account := newAccountWithId(context.Background(), accountID, userID, domain) err := am.Store.SaveAccount(context.Background(), account) if err != nil { @@ -990,13 +1005,13 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { claims := jwtclaims.AuthorizationClaims{ Domain: "example.com", UserId: "pvt-domain-user", - DomainCategory: PrivateCategory, + DomainCategory: types.PrivateCategory, } publicClaims := jwtclaims.AuthorizationClaims{ Domain: "test.com", UserId: "public-domain-user", - DomainCategory: PublicCategory, + DomainCategory: types.PublicCategory, } am, err := createManager(b) @@ -1074,13 +1089,13 @@ func BenchmarkTest_GetAccountWithclaims(b *testing.B) { } -func genUsers(p string, n int) map[string]*User { - users := map[string]*User{} +func genUsers(p string, n int) map[string]*types.User { + users := map[string]*types.User{} now := time.Now() for i := 0; i < n; i++ { - users[fmt.Sprintf("%s-%d", p, i)] = &User{ + users[fmt.Sprintf("%s-%d", p, i)] = &types.User{ Id: fmt.Sprintf("%s-%d", p, i), - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, LastLogin: now, CreatedAt: now, Issued: "api", @@ -1105,7 +1120,7 @@ func TestAccountManager_AddPeer(t *testing.T) { serial := account.Network.CurrentSerial() // should be 0 - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -1232,7 +1247,7 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{}, @@ -1242,15 +1257,15 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { return } - _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1309,7 +1324,7 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { manager, account, peer1, peer2, _ := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, @@ -1334,15 +1349,15 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } }() - _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1357,7 +1372,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { manager, account, peer1, _, peer3 := setupNetworkMapTest(t) - group := group.Group{ + group := types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer3.ID}, @@ -1367,15 +1382,15 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { return } - _, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1413,7 +1428,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) - err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -1426,15 +1441,15 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } - policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policy, err := manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -1482,7 +1497,7 @@ func TestAccountManager_DeletePeer(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") return @@ -1557,7 +1572,7 @@ func TestGetUsersFromAccount(t *testing.T) { t.Fatal(err) } - users := map[string]*User{"1": {Id: "1", Role: UserRoleOwner}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}} + users := map[string]*types.User{"1": {Id: "1", Role: types.UserRoleOwner}, "2": {Id: "2", Role: "user"}, "3": {Id: "3", Role: "user"}} accountId := "test_account_id" account, err := createAccount(manager, accountId, users["1"].Id, "") @@ -1589,7 +1604,7 @@ func TestFileStore_GetRoutesByPrefix(t *testing.T) { if err != nil { t.Fatal(err) } - account := &Account{ + account := &types.Account{ Routes: map[route.ID]*route.Route{ "route-1": { ID: "route-1", @@ -1636,11 +1651,11 @@ func TestAccount_GetRoutesToSync(t *testing.T) { if err != nil { t.Fatal(err) } - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peer-1": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-2": {Key: "peer-2", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, "peer-3": {Key: "peer-1", Meta: nbpeer.PeerSystemMeta{GoOS: "linux"}}, }, - Groups: map[string]*group.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, + Groups: map[string]*types.Group{"group1": {ID: "group1", Peers: []string{"peer-1", "peer-2"}}}, Routes: map[route.ID]*route.Route{ "route-1": { ID: "route-1", @@ -1681,7 +1696,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { }, } - routes := account.getRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) + routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) assert.Len(t, routes, 2) routeIDs := make(map[route.ID]struct{}, 2) @@ -1691,26 +1706,26 @@ func TestAccount_GetRoutesToSync(t *testing.T) { assert.Contains(t, routeIDs, route.ID("route-2")) assert.Contains(t, routeIDs, route.ID("route-3")) - emptyRoutes := account.getRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) + emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) assert.Len(t, emptyRoutes, 0) } func TestAccount_Copy(t *testing.T) { - account := &Account{ + account := &types.Account{ Id: "account1", CreatedBy: "tester", CreatedAt: time.Now().UTC(), Domain: "test.com", DomainCategory: "public", IsDomainPrimaryAccount: true, - SetupKeys: map[string]*SetupKey{ + SetupKeys: map[string]*types.SetupKey{ "setup1": { Id: "setup1", AutoGroups: []string{"group1"}, }, }, - Network: &Network{ + Network: &types.Network{ Identifier: "net1", }, Peers: map[string]*nbpeer.Peer{ @@ -1723,12 +1738,12 @@ func TestAccount_Copy(t *testing.T) { }, }, }, - Users: map[string]*User{ + Users: map[string]*types.User{ "user1": { Id: "user1", - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, AutoGroups: []string{"group1"}, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "pat1": { ID: "pat1", Name: "First PAT", @@ -1741,17 +1756,18 @@ func TestAccount_Copy(t *testing.T) { }, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "group1": { - ID: "group1", - Peers: []string{"peer1"}, + ID: "group1", + Peers: []string{"peer1"}, + Resources: []types.Resource{}, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "policy1", Enabled: true, - Rules: make([]*PolicyRule, 0), + Rules: make([]*types.PolicyRule, 0), SourcePostureChecks: make([]string, 0), }, }, @@ -1771,13 +1787,36 @@ func TestAccount_Copy(t *testing.T) { NameServers: []nbdns.NameServer{}, }, }, - DNSSettings: DNSSettings{DisabledManagementGroups: []string{}}, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{}}, PostureChecks: []*posture.Checks{ { ID: "posture Checks1", }, }, - Settings: &Settings{}, + Settings: &types.Settings{}, + Networks: []*networkTypes.Network{ + { + ID: "network1", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: "router1", + NetworkID: "network1", + PeerGroups: []string{"group1"}, + Masquerade: false, + Metric: 0, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "resource1", + NetworkID: "network1", + Name: "resource", + Type: "Subnet", + Address: "172.12.6.1/24", + }, + }, } err := hasNilField(account) if err != nil { @@ -1830,7 +1869,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") assert.NotNil(t, settings) @@ -1863,7 +1902,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, account) require.NoError(t, err, "unable to mark peer connected") - account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + account, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1911,7 +1950,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1980,7 +2019,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }, } // enabling PeerLoginExpirationEnabled should trigger the expiration job - account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, }) @@ -1993,7 +2032,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test wg.Add(1) // disabling PeerLoginExpirationEnabled should trigger cancel - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -2011,7 +2050,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") - updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, }) @@ -2019,19 +2058,19 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") assert.False(t, settings.PeerLoginExpirationEnabled) assert.Equal(t, settings.PeerLoginExpiration, time.Hour) - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Second, PeerLoginExpirationEnabled: false, }) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpirationEnabled: false, }) @@ -2104,9 +2143,9 @@ func TestAccount_GetExpiredPeers(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -2188,9 +2227,9 @@ func TestAccount_GetInactivePeers(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{ + Settings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Second, }, @@ -2255,7 +2294,7 @@ func TestAccount_GetPeersWithExpiration(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, } @@ -2324,7 +2363,7 @@ func TestAccount_GetPeersWithInactivity(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, } @@ -2488,9 +2527,9 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled}, + Settings: &types.Settings{PeerLoginExpiration: testCase.expiration, PeerLoginExpirationEnabled: testCase.expirationEnabled}, } expiration, ok := account.GetNextPeerExpiration() @@ -2648,9 +2687,9 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { } for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: testCase.peers, - Settings: &Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled}, + Settings: &types.Settings{PeerInactivityExpiration: testCase.expiration, PeerInactivityExpirationEnabled: testCase.expirationEnabled}, } expiration, ok := account.GetNextInactivePeerExpiration() @@ -2669,7 +2708,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { require.NoError(t, err, "unable to create account manager") // create a new account - account := &Account{ + account := &types.Account{ Id: "accountID", Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, @@ -2678,11 +2717,11 @@ func TestAccount_SetJWTGroups(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, }, - Settings: &Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, - Users: map[string]*User{ + Settings: &types.Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, + Users: map[string]*types.User{ "user1": {Id: "user1", AccountID: "accountID"}, "user2": {Id: "user2", AccountID: "accountID"}, }, @@ -2698,7 +2737,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err := manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Empty(t, user.AutoGroups, "auto groups must be empty") }) @@ -2711,18 +2750,18 @@ func TestAccount_SetJWTGroups(t *testing.T) { err := manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 0) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") + group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") - assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") + assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") }) t.Run("jwt match existing api group in user auto groups", func(t *testing.T) { account.Users["user1"].AutoGroups = []string{"group1"} - assert.NoError(t, manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, account.Users["user1"])) + assert.NoError(t, manager.Store.SaveUser(context.Background(), store.LockingStrengthUpdate, account.Users["user1"])) claims := jwtclaims.AuthorizationClaims{ UserId: "user1", @@ -2731,13 +2770,13 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") + group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") - assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") + assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") }) t.Run("add jwt group", func(t *testing.T) { @@ -2748,7 +2787,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) @@ -2761,7 +2800,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 2, "groups count should not be change") }) @@ -2774,11 +2813,11 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - groups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, "accountID") + groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID") assert.NoError(t, err) assert.Len(t, groups, 3, "new group3 should be added") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user2") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "new group should be added") }) @@ -2791,7 +2830,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { err = manager.syncJWTGroups(context.Background(), "accountID", claims) assert.NoError(t, err, "unable to sync jwt groups") - user, err := manager.Store.GetUserByUserID(context.Background(), LockingStrengthShare, "user1") + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") @@ -2799,7 +2838,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { } func TestAccount_UserGroupsAddToPeers(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, @@ -2807,12 +2846,12 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{}}, - "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{}}, - "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{}}, + "group2": {ID: "group2", Name: "group2", Issued: types.GroupIssuedAPI, Peers: []string{}}, + "group3": {ID: "group3", Name: "group3", Issued: types.GroupIssuedAPI, Peers: []string{}}, }, - Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, + Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } t.Run("add groups", func(t *testing.T) { @@ -2835,7 +2874,7 @@ func TestAccount_UserGroupsAddToPeers(t *testing.T) { } func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peer1": {ID: "peer1", Key: "key1", UserID: "user1"}, "peer2": {ID: "peer2", Key: "key2", UserID: "user1"}, @@ -2843,12 +2882,12 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) { "peer4": {ID: "peer4", Key: "key4", UserID: "user2"}, "peer5": {ID: "peer5", Key: "key5", UserID: "user2"}, }, - Groups: map[string]*group.Group{ - "group1": {ID: "group1", Name: "group1", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, - "group2": {ID: "group2", Name: "group2", Issued: group.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, - "group3": {ID: "group3", Name: "group3", Issued: group.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, + Groups: map[string]*types.Group{ + "group1": {ID: "group1", Name: "group1", Issued: types.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3"}}, + "group2": {ID: "group2", Name: "group2", Issued: types.GroupIssuedAPI, Peers: []string{"peer1", "peer2", "peer3", "peer4", "peer5"}}, + "group3": {ID: "group3", Name: "group3", Issued: types.GroupIssuedAPI, Peers: []string{"peer4", "peer5"}}, }, - Users: map[string]*User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, + Users: map[string]*types.User{"user1": {Id: "user1"}, "user2": {Id: "user2"}}, } t.Run("remove groups", func(t *testing.T) { @@ -2891,10 +2930,10 @@ func createManager(t TB) (*DefaultAccountManager, error) { return manager, nil } -func createStore(t TB) (Store, error) { +func createStore(t TB) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -2917,7 +2956,7 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { } } -func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { +func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *types.Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { t.Helper() manager, err := createManager(t) @@ -2930,12 +2969,12 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpee t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) if err != nil { t.Fatal("error creating setup key") } - getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer { + getPeer := func(manager *DefaultAccountManager, setupKey *types.SetupKey) *nbpeer.Peer { key, err := wgtypes.GeneratePrivateKey() if err != nil { t.Fatal(err) @@ -2999,11 +3038,11 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { maxMsPerOpCICD float64 }{ {"Small", 50, 5, 1, 3, 3, 10}, - {"Medium", 500, 100, 7, 13, 10, 60}, - {"Large", 5000, 200, 65, 80, 60, 170}, - {"Small single", 50, 10, 1, 3, 3, 60}, + {"Medium", 500, 100, 7, 13, 10, 70}, + {"Large", 5000, 200, 65, 80, 60, 200}, + {"Small single", 50, 10, 1, 3, 3, 70}, {"Medium single", 500, 10, 7, 13, 10, 26}, - {"Large 5", 5000, 15, 65, 80, 60, 170}, + {"Large 5", 5000, 15, 65, 80, 60, 200}, } log.SetOutput(io.Discard) @@ -3047,7 +3086,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) } - if msPerOp > maxExpected { + if msPerOp > (maxExpected * 1.1) { b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) @@ -3067,7 +3106,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { }{ {"Small", 50, 5, 102, 110, 102, 120}, {"Medium", 500, 100, 105, 140, 105, 170}, - {"Large", 5000, 200, 160, 200, 160, 270}, + {"Large", 5000, 200, 160, 200, 160, 300}, {"Small single", 50, 10, 102, 110, 102, 120}, {"Medium single", 500, 10, 105, 140, 105, 170}, {"Large 5", 5000, 15, 160, 200, 160, 270}, @@ -3121,7 +3160,7 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) } - if msPerOp > maxExpected { + if msPerOp > (maxExpected * 1.1) { b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) @@ -3139,10 +3178,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 107, 120, 107, 140}, - {"Medium", 500, 100, 105, 140, 105, 170}, - {"Large", 5000, 200, 180, 220, 180, 340}, - {"Small single", 50, 10, 107, 120, 105, 140}, + {"Small", 50, 5, 107, 120, 107, 160}, + {"Medium", 500, 100, 105, 140, 105, 190}, + {"Large", 5000, 200, 180, 220, 180, 350}, + {"Small single", 50, 10, 107, 120, 105, 160}, {"Medium single", 500, 10, 105, 140, 105, 170}, {"Large 5", 5000, 15, 180, 220, 180, 340}, } @@ -3195,7 +3234,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) } - if msPerOp > maxExpected { + if msPerOp > (maxExpected * 1.1) { b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 4c57d65fb..5379a8dd8 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -151,6 +151,24 @@ const ( UserGroupPropagationEnabled Activity = 69 UserGroupPropagationDisabled Activity = 70 + + AccountRoutingPeerDNSResolutionEnabled Activity = 71 + AccountRoutingPeerDNSResolutionDisabled Activity = 72 + + NetworkCreated Activity = 73 + NetworkUpdated Activity = 74 + NetworkDeleted Activity = 75 + + NetworkResourceCreated Activity = 76 + NetworkResourceUpdated Activity = 77 + NetworkResourceDeleted Activity = 78 + + NetworkRouterCreated Activity = 79 + NetworkRouterUpdated Activity = 80 + NetworkRouterDeleted Activity = 81 + + ResourceAddedToGroup Activity = 82 + ResourceRemovedFromGroup Activity = 83 ) var activityMap = map[Activity]Code{ @@ -228,6 +246,24 @@ var activityMap = map[Activity]Code{ UserGroupPropagationEnabled: {"User group propagation enabled", "account.setting.group.propagation.enable"}, UserGroupPropagationDisabled: {"User group propagation disabled", "account.setting.group.propagation.disable"}, + + AccountRoutingPeerDNSResolutionEnabled: {"Account routing peer DNS resolution enabled", "account.setting.routing.peer.dns.resolution.enable"}, + AccountRoutingPeerDNSResolutionDisabled: {"Account routing peer DNS resolution disabled", "account.setting.routing.peer.dns.resolution.disable"}, + + NetworkCreated: {"Network created", "network.create"}, + NetworkUpdated: {"Network updated", "network.update"}, + NetworkDeleted: {"Network deleted", "network.delete"}, + + NetworkResourceCreated: {"Network resource created", "network.resource.create"}, + NetworkResourceUpdated: {"Network resource updated", "network.resource.update"}, + NetworkResourceDeleted: {"Network resource deleted", "network.resource.delete"}, + + NetworkRouterCreated: {"Network router created", "network.router.create"}, + NetworkRouterUpdated: {"Network router updated", "network.router.update"}, + NetworkRouterDeleted: {"Network router deleted", "network.router.delete"}, + + ResourceAddedToGroup: {"Resource added to group", "resource.group.add"}, + ResourceRemovedFromGroup: {"Resource removed from group", "resource.group.delete"}, } // StringCode returns a string code of the activity diff --git a/management/server/config.go b/management/server/config.go index 2f7e49766..f3555b92b 100644 --- a/management/server/config.go +++ b/management/server/config.go @@ -5,6 +5,7 @@ import ( "net/url" "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/util" ) @@ -156,7 +157,7 @@ type ProviderConfig struct { // StoreConfig contains Store configuration type StoreConfig struct { - Engine StoreEngine + Engine store.Engine } // ReverseProxy contains reverse proxy configuration in front of management. diff --git a/management/server/dns.go b/management/server/dns.go index 8df211b0b..39dc11eb2 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -2,9 +2,7 @@ package server import ( "context" - "fmt" "slices" - "strconv" "sync" log "github.com/sirupsen/logrus" @@ -12,12 +10,12 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) -const defaultTTL = 300 - // DNSConfigCache is a thread-safe cache for DNS configuration components type DNSConfigCache struct { CustomZones sync.Map @@ -62,26 +60,9 @@ func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerG c.NameServerGroups.Store(key, value) } -type lookupMap map[string]struct{} - -// DNSSettings defines dns settings at the account level -type DNSSettings struct { - // DisabledManagementGroups groups whose DNS management is disabled - DisabledManagementGroups []string `gorm:"serializer:json"` -} - -// Copy returns a copy of the DNS settings -func (d DNSSettings) Copy() DNSSettings { - settings := DNSSettings{ - DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)), - } - copy(settings.DisabledManagementGroups, d.DisabledManagementGroups) - return settings -} - // GetDNSSettings validates a user role and returns the DNS settings for the provided account ID -func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -94,16 +75,16 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) } // SaveDNSSettings validates a user role and updates the account's DNS settings -func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error { +func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error { if dnsSettingsToSave == nil { return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -119,18 +100,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID var updateAccountPeers bool var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil { return err } - oldSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthUpdate, accountID) + oldSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return err } - addedGroups := difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) - removedGroups := difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) + addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups) + removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups) updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups) if err != nil { @@ -140,11 +121,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups) eventsToStore = append(eventsToStore, events...) - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveDNSSettings(ctx, LockingStrengthUpdate, accountID, dnsSettingsToSave) + return transaction.SaveDNSSettings(ctx, store.LockingStrengthUpdate, accountID, dnsSettingsToSave) }) if err != nil { return err @@ -155,18 +136,18 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } // prepareDNSSettingsEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string) []func() { +func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, transaction store.Store, accountID, userID string, addedGroups, removedGroups []string) []func() { var eventsToStore []func() modifiedGroups := slices.Concat(addedGroups, removedGroups) - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) if err != nil { log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) return nil @@ -203,8 +184,8 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t } // areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers. -func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, accountID string, addedGroups, removedGroups []string) (bool, error) { - hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, addedGroups) +func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) { + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups) if err != nil { return false, err } @@ -213,16 +194,16 @@ func areDNSSettingChangesAffectPeers(ctx context.Context, transaction Store, acc return true, nil } - return anyGroupHasPeers(ctx, transaction, accountID, removedGroups) + return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups) } // validateDNSSettings validates the DNS settings. -func validateDNSSettings(ctx context.Context, transaction Store, accountID string, settings *DNSSettings) error { +func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error { if len(settings.DisabledManagementGroups) == 0 { return nil } - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, settings.DisabledManagementGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups) if err != nil { return err } @@ -298,81 +279,3 @@ func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameSe } return protoGroup } - -func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { - groupList := account.getPeerGroups(peerID) - - var peerNSGroups []*nbdns.NameServerGroup - - for _, nsGroup := range account.NameServerGroups { - if !nsGroup.Enabled { - continue - } - for _, gID := range nsGroup.Groups { - _, found := groupList[gID] - if found { - if !peerIsNameserver(account.GetPeer(peerID), nsGroup) { - peerNSGroups = append(peerNSGroups, nsGroup.Copy()) - break - } - } - } - } - - return peerNSGroups -} - -// peerIsNameserver returns true if the peer is a nameserver for a nsGroup -func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { - for _, ns := range nsGroup.NameServers { - if peer.IP.Equal(ns.IP.AsSlice()) { - return true - } - } - return false -} - -func addPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels lookupMap) { - for _, peer := range account.Peers { - label, err := getPeerHostLabel(peer.Name, peerLabels) - if err != nil { - log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) - label, err = getPeerHostLabel(peer.Meta.Hostname, peerLabels) - if err != nil { - log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) - continue - } - } - peer.DNSLabel = label - peerLabels[label] = struct{}{} - } -} - -func getPeerHostLabel(name string, peerLabels lookupMap) (string, error) { - label, err := nbdns.GetParsedDomainLabel(name) - if err != nil { - return "", err - } - - uniqueLabel := getUniqueHostLabel(label, peerLabels) - if uniqueLabel == "" { - return "", fmt.Errorf("couldn't find a unique valid label for %s, parsed label %s", name, label) - } - return uniqueLabel, nil -} - -// getUniqueHostLabel look for a unique host label, and if doesn't find add a suffix up to 999 -func getUniqueHostLabel(name string, peerLabels lookupMap) string { - _, found := peerLabels[name] - if !found { - return name - } - for i := 1; i < 1000; i++ { - nameWithSuffix := name + "-" + strconv.Itoa(i) - _, found = peerLabels[nameWithSuffix] - if !found { - return nameWithSuffix - } - } - return "" -} diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 8a66da96c..6fb9f6a29 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -11,13 +11,14 @@ import ( "github.com/stretchr/testify/assert" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" ) @@ -53,7 +54,7 @@ func TestGetDNSSettings(t *testing.T) { t.Fatal("DNS settings for new accounts shouldn't return nil") } - account.DNSSettings = DNSSettings{ + account.DNSSettings = types.DNSSettings{ DisabledManagementGroups: []string{group1ID}, } @@ -86,20 +87,20 @@ func TestSaveDNSSettings(t *testing.T) { testCases := []struct { name string userID string - inputSettings *DNSSettings + inputSettings *types.DNSSettings shouldFail bool }{ { name: "Saving As Admin Should Be OK", userID: dnsAdminUserID, - inputSettings: &DNSSettings{ + inputSettings: &types.DNSSettings{ DisabledManagementGroups: []string{dnsGroup1ID}, }, }, { name: "Should Not Update Settings As Regular User", userID: dnsRegularUserID, - inputSettings: &DNSSettings{ + inputSettings: &types.DNSSettings{ DisabledManagementGroups: []string{dnsGroup1ID}, }, shouldFail: true, @@ -113,7 +114,7 @@ func TestSaveDNSSettings(t *testing.T) { { name: "Should Not Update Settings If Group Is Invalid", userID: dnsAdminUserID, - inputSettings: &DNSSettings{ + inputSettings: &types.DNSSettings{ DisabledManagementGroups: []string{"non-existing-group"}, }, shouldFail: true, @@ -210,10 +211,10 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics) } -func createDNSStore(t *testing.T) (Store, error) { +func createDNSStore(t *testing.T) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -222,7 +223,7 @@ func createDNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) { t.Helper() peer1 := &nbpeer.Peer{ Key: dnsPeer1Key, @@ -259,9 +260,9 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain) - account.Users[dnsRegularUserID] = &User{ + account.Users[dnsRegularUserID] = &types.User{ Id: dnsRegularUserID, - Role: UserRoleUser, + Role: types.UserRoleUser, } err := am.Store.SaveAccount(context.Background(), account) @@ -293,13 +294,13 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro return nil, err } - newGroup1 := &group.Group{ + newGroup1 := &types.Group{ ID: dnsGroup1ID, Peers: []string{peer1.ID}, Name: dnsGroup1ID, } - newGroup2 := &group.Group{ + newGroup2 := &types.Group{ ID: dnsGroup2ID, Name: dnsGroup2ID, } @@ -483,7 +484,7 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { func TestDNSAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -510,7 +511,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupA"}, }) assert.NoError(t, err) @@ -550,7 +551,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { // Creating DNS settings with groups that have peers should update account peers and send peer update t.Run("creating dns setting with used groups", func(t *testing.T) { - err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -589,7 +590,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupA", "groupB"}, }) assert.NoError(t, err) @@ -609,7 +610,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupA"}, }) assert.NoError(t, err) @@ -629,7 +630,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{}, }) assert.NoError(t, err) diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 590b1d708..3c629a0db 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -9,6 +9,8 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -21,7 +23,7 @@ var ( type ephemeralPeer struct { id string - account *Account + account *types.Account deadline time.Time next *ephemeralPeer } @@ -32,7 +34,7 @@ type ephemeralPeer struct { // EphemeralManager keep a list of ephemeral peers. After ephemeralLifeTime inactivity the peer will be deleted // automatically. Inactivity means the peer disconnected from the Management server. type EphemeralManager struct { - store Store + store store.Store accountManager AccountManager headPeer *ephemeralPeer @@ -42,7 +44,7 @@ type EphemeralManager struct { } // NewEphemeralManager instantiate new EphemeralManager -func NewEphemeralManager(store Store, accountManager AccountManager) *EphemeralManager { +func NewEphemeralManager(store store.Store, accountManager AccountManager) *EphemeralManager { return &EphemeralManager{ store: store, accountManager: accountManager, @@ -177,7 +179,7 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { } } -func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) { +func (e *EphemeralManager) addPeer(id string, account *types.Account, deadline time.Time) { ep := &ephemeralPeer{ id: id, account: account, diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 1390352a5..ac8372440 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -8,18 +8,20 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) type MockStore struct { - Store - account *Account + store.Store + account *types.Account } -func (s *MockStore) GetAllAccounts(_ context.Context) []*Account { - return []*Account{s.account} +func (s *MockStore) GetAllAccounts(_ context.Context) []*types.Account { + return []*types.Account{s.account} } -func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Account, error) { +func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*types.Account, error) { _, ok := s.account.Peers[peerId] if ok { return s.account, nil diff --git a/management/server/group.go b/management/server/group.go index 7b307cf1a..d433a3485 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -10,10 +10,12 @@ import ( log "github.com/sirupsen/logrus" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" ) @@ -28,7 +30,7 @@ func (e *GroupLinkError) Error() string { // CheckGroupPermissions validates if a user has the necessary permissions to view groups func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, accountID, userID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -45,38 +47,38 @@ func (am *DefaultAccountManager) CheckGroupPermissions(ctx context.Context, acco } // GetGroup returns a specific group by groupID in an account -func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*nbgroup.Group, error) { +func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) + return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) } // GetAllGroups returns all groups in an account -func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { +func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { return nil, err } - return am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) } // GetGroupByName filters all groups in an account by name and returns the one with the most peers -func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName) +func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { + return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName) } // SaveGroup object of the peers -func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *nbgroup.Group) error { +func (am *DefaultAccountManager) SaveGroup(ctx context.Context, accountID, userID string, newGroup *types.Group) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - return am.SaveGroups(ctx, accountID, userID, []*nbgroup.Group{newGroup}) + return am.SaveGroups(ctx, accountID, userID, []*types.Group{newGroup}) } // SaveGroups adds new groups to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*nbgroup.Group) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -90,10 +92,10 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } var eventsToStore []func() - var groupsToSave []*nbgroup.Group + var groupsToSave []*types.Group var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { groupIDs := make([]string, 0, len(groups)) for _, newGroup := range groups { if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil { @@ -113,11 +115,11 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveGroups(ctx, LockingStrengthUpdate, groupsToSave) + return transaction.SaveGroups(ctx, store.LockingStrengthUpdate, groupsToSave) }) if err != nil { return err @@ -128,23 +130,23 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } // prepareGroupEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction Store, accountID, userID string, newGroup *nbgroup.Group) []func() { +func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transaction store.Store, accountID, userID string, newGroup *types.Group) []func() { var eventsToStore []func() addedPeers := make([]string, 0) removedPeers := make([]string, 0) - oldGroup, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) + oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) if err == nil && oldGroup != nil { - addedPeers = difference(newGroup.Peers, oldGroup.Peers) - removedPeers = difference(oldGroup.Peers, newGroup.Peers) + addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) + removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) } else { addedPeers = append(addedPeers, newGroup.Peers...) eventsToStore = append(eventsToStore, func() { @@ -153,7 +155,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac } modifiedPeers := slices.Concat(addedPeers, removedPeers) - peers, err := transaction.GetPeersByIDs(ctx, LockingStrengthShare, accountID, modifiedPeers) + peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers) if err != nil { log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) return nil @@ -194,21 +196,6 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac return eventsToStore } -// difference returns the elements in `a` that aren't in `b`. -func difference(a, b []string) []string { - mb := make(map[string]struct{}, len(b)) - for _, x := range b { - mb[x] = struct{}{} - } - var diff []string - for _, x := range a { - if _, found := mb[x]; !found { - diff = append(diff, x) - } - } - return diff -} - // DeleteGroup object of the peers. func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, userID, groupID string) error { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -223,7 +210,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use // If an error occurs while deleting a group, the function skips it and continues deleting other groups. // Errors are collected and returned at the end. func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, userID string, groupIDs []string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -238,11 +225,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us var allErrors error var groupIDsToDelete []string - var deletedGroups []*nbgroup.Group + var deletedGroups []*types.Group - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { for _, groupID := range groupIDs { - group, err := transaction.GetGroupByID(ctx, LockingStrengthUpdate, accountID, groupID) + group, err := transaction.GetGroupByID(ctx, store.LockingStrengthUpdate, accountID, groupID) if err != nil { allErrors = errors.Join(allErrors, err) continue @@ -257,11 +244,11 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us deletedGroups = append(deletedGroups, group) } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete) + return transaction.DeleteGroups(ctx, store.LockingStrengthUpdate, accountID, groupIDsToDelete) }) if err != nil { return err @@ -279,12 +266,12 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - var group *nbgroup.Group + var group *types.Group var updateAccountPeers bool var err error - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -298,18 +285,59 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) }) if err != nil { return err } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// GroupAddResource appends resource to the group +func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var group *types.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.AddResource(resource); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + }) + if err != nil { + return err + } + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -320,12 +348,12 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - var group *nbgroup.Group + var group *types.Group var updateAccountPeers bool var err error - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, accountID, groupID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) if err != nil { return err } @@ -339,31 +367,72 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveGroup(ctx, LockingStrengthUpdate, group) + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) }) if err != nil { return err } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) + } + + return nil +} + +// GroupDeleteResource removes resource from the group +func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var group *types.Group + var updateAccountPeers bool + var err error + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + group, err = transaction.GetGroupByID(context.Background(), store.LockingStrengthUpdate, accountID, groupID) + if err != nil { + return err + } + + if updated := group.RemoveResource(resource); !updated { + return nil + } + + updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID}) + if err != nil { + return err + } + + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + + return transaction.SaveGroup(ctx, store.LockingStrengthUpdate, group) + }) + if err != nil { + return err + } + + if updateAccountPeers { + am.UpdateAccountPeers(ctx, accountID) } return nil } // validateNewGroup validates the new group for existence and required fields. -func validateNewGroup(ctx context.Context, transaction Store, accountID string, newGroup *nbgroup.Group) error { - if newGroup.ID == "" && newGroup.Issued != nbgroup.GroupIssuedAPI { +func validateNewGroup(ctx context.Context, transaction store.Store, accountID string, newGroup *types.Group) error { + if newGroup.ID == "" && newGroup.Issued != types.GroupIssuedAPI { return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued) } - if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := transaction.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) + if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI { + existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name) if err != nil { if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { return err @@ -380,7 +449,7 @@ func validateNewGroup(ctx context.Context, transaction Store, accountID string, } for _, peerID := range newGroup.Peers { - _, err := transaction.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) + _, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) if err != nil { return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) } @@ -389,14 +458,14 @@ func validateNewGroup(ctx context.Context, transaction Store, accountID string, return nil } -func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup.Group, userID string) error { +func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { // disable a deleting integration group if the initiator is not an admin service user - if group.Issued == nbgroup.GroupIssuedIntegration { - executingUser, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) + if group.Issued == types.GroupIssuedIntegration { + executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } - if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser { + if executingUser.Role != types.UserRoleAdmin || !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") } } @@ -429,8 +498,8 @@ func validateDeleteGroup(ctx context.Context, transaction Store, group *nbgroup. } // checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. -func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *nbgroup.Group) error { - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, group.AccountID) +func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error { + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -439,7 +508,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *n return &GroupLinkError{"disabled DNS management groups", group.Name} } - settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, group.AccountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID) if err != nil { return err } @@ -452,8 +521,8 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction Store, group *n } // isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *route.Route) { - routes, err := transaction.GetAccountRoutes(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) { + routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) return false, nil @@ -469,8 +538,8 @@ func isGroupLinkedToRoute(ctx context.Context, transaction Store, accountID stri } // isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *Policy) { - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) return false, nil @@ -487,8 +556,8 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction Store, accountID str } // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { - nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { + nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) return false, nil @@ -506,8 +575,8 @@ func isGroupLinkedToDns(ctx context.Context, transaction Store, accountID string } // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *SetupKey) { - setupKeys, err := transaction.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) { + setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) return false, nil @@ -522,8 +591,8 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction Store, accountID s } // isGroupLinkedToUser checks if a group is linked to any user in the account. -func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID string, groupID string) (bool, *User) { - users, err := transaction.GetAccountUsers(ctx, LockingStrengthShare, accountID) +func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) { + users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) return false, nil @@ -538,12 +607,12 @@ func isGroupLinkedToUser(ctx context.Context, transaction Store, accountID strin } // areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. -func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { +func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { if len(groupIDs) == 0 { return false, nil } - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, LockingStrengthShare, accountID) + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return false, err } @@ -566,7 +635,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction Store, accountI return false, nil } -func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []string) bool { +func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool { for _, groupID := range groupIDs { if group, exists := account.Groups[groupID]; exists && group.HasPeers() { return true @@ -575,15 +644,15 @@ func (am *DefaultAccountManager) anyGroupHasPeers(account *Account, groupIDs []s return false } -// anyGroupHasPeers checks if any of the given groups in the account have peers. -func anyGroupHasPeers(ctx context.Context, transaction Store, accountID string, groupIDs []string) (bool, error) { - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, groupIDs) +// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. +func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) if err != nil { return false, err } for _, group := range groups { - if group.HasPeers() { + if group.HasPeers() || group.HasResources() { return true, nil } } diff --git a/management/server/group_test.go b/management/server/group_test.go index ec017fc57..834388d1e 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -12,8 +12,8 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -32,22 +32,22 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) { t.Error("failed to init testing account") } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedIntegration + group.Issued = types.GroupIssuedIntegration err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedIntegration) + t.Errorf("should allow to create %s groups", types.GroupIssuedIntegration) } } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedJWT + group.Issued = types.GroupIssuedJWT err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err != nil { - t.Errorf("should allow to create %s groups", nbgroup.GroupIssuedJWT) + t.Errorf("should allow to create %s groups", types.GroupIssuedJWT) } } for _, group := range account.Groups { - group.Issued = nbgroup.GroupIssuedAPI + group.Issued = types.GroupIssuedAPI group.ID = "" err = am.SaveGroup(context.Background(), account.Id, groupAdminUserID, group) if err == nil { @@ -145,13 +145,13 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { manager, account, err := initTestGroupAccount(am) assert.NoError(t, err, "Failed to init testing account") - groups := make([]*nbgroup.Group, 10) + groups := make([]*types.Group, 10) for i := 0; i < 10; i++ { - groups[i] = &nbgroup.Group{ + groups[i] = &types.Group{ ID: fmt.Sprintf("group-%d", i+1), AccountID: account.Id, Name: fmt.Sprintf("group-%d", i+1), - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, } } @@ -267,63 +267,63 @@ func TestDefaultAccountManager_DeleteGroups(t *testing.T) { } } -func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *Account, error) { +func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *types.Account, error) { accountID := "testingAcc" domain := "example.com" - groupForRoute := &nbgroup.Group{ + groupForRoute := &types.Group{ ID: "grp-for-route", AccountID: "account-id", Name: "Group for route", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForRoute2 := &nbgroup.Group{ + groupForRoute2 := &types.Group{ ID: "grp-for-route2", AccountID: "account-id", Name: "Group for route", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForNameServerGroups := &nbgroup.Group{ + groupForNameServerGroups := &types.Group{ ID: "grp-for-name-server-grp", AccountID: "account-id", Name: "Group for name server groups", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForPolicies := &nbgroup.Group{ + groupForPolicies := &types.Group{ ID: "grp-for-policies", AccountID: "account-id", Name: "Group for policies", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForSetupKeys := &nbgroup.Group{ + groupForSetupKeys := &types.Group{ ID: "grp-for-keys", AccountID: "account-id", Name: "Group for setup keys", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForUsers := &nbgroup.Group{ + groupForUsers := &types.Group{ ID: "grp-for-users", AccountID: "account-id", Name: "Group for users", - Issued: nbgroup.GroupIssuedAPI, + Issued: types.GroupIssuedAPI, Peers: make([]string, 0), } - groupForIntegration := &nbgroup.Group{ + groupForIntegration := &types.Group{ ID: "grp-for-integration", AccountID: "account-id", Name: "Group for users integration", - Issued: nbgroup.GroupIssuedIntegration, + Issued: types.GroupIssuedIntegration, Peers: make([]string, 0), } @@ -342,9 +342,9 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A Groups: []string{groupForNameServerGroups.ID}, } - policy := &Policy{ + policy := &types.Policy{ ID: "example policy", - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "example policy rule", Destinations: []string{groupForPolicies.ID}, @@ -352,12 +352,12 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A }, } - setupKey := &SetupKey{ + setupKey := &types.SetupKey{ Id: "example setup key", AutoGroups: []string{groupForSetupKeys.ID}, } - user := &User{ + user := &types.User{ Id: "example user", AutoGroups: []string{groupForUsers.ID}, } @@ -392,7 +392,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A func TestGroupAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -429,7 +429,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1.ID, peer2.ID}, @@ -500,15 +500,15 @@ func TestGroupAccountPeersUpdate(t *testing.T) { }) // adding a group to policy - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -522,7 +522,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID}, @@ -591,7 +591,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1.ID, peer3.ID}, @@ -632,7 +632,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, @@ -648,7 +648,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { // Saving a group linked to dns settings should update account peers and send peer update t.Run("saving group linked to dns settings", func(t *testing.T) { - err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &types.DNSSettings{ DisabledManagementGroups: []string{"groupD"}, }) assert.NoError(t, err) @@ -659,7 +659,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupD", Name: "GroupD", Peers: []string{peer1.ID}, diff --git a/management/server/groups/manager.go b/management/server/groups/manager.go new file mode 100644 index 000000000..f5abb212e --- /dev/null +++ b/management/server/groups/manager.go @@ -0,0 +1,196 @@ +package groups + +import ( + "context" + "fmt" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Manager interface { + GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) + GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) + AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error + AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error) + RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error) +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + accountManager s.AccountManager +} + +type mockManager struct { +} + +func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + accountManager: accountManager, + } +} + +func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Read) + if err != nil { + return nil, err + } + if !ok { + return nil, err + } + + groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("error getting account groups: %w", err) + } + + groupsMap := make(map[string]*types.Group) + for _, group := range groups { + groupsMap[group.ID] = group + } + + return groupsMap, nil +} + +func (m *managerImpl) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resource *types.Resource) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Write) + if err != nil { + return err + } + if !ok { + return err + } + + event, err := m.AddResourceToGroupInTransaction(ctx, m.store, accountID, userID, groupID, resource) + if err != nil { + return fmt.Errorf("error adding resource to group: %w", err) + } + + event() + + return nil +} + +func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resource *types.Resource) (func(), error) { + err := transaction.AddResourceToGroup(ctx, accountID, groupID, resource) + if err != nil { + return nil, fmt.Errorf("error adding resource to group: %w", err) + } + + group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) + if err != nil { + return nil, fmt.Errorf("error getting group: %w", err) + } + + // TODO: at some point, this will need to become a switch statement + networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID) + if err != nil { + return nil, fmt.Errorf("error getting network resource: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, userID, groupID, accountID, activity.ResourceAddedToGroup, group.EventMetaResource(networkResource)) + } + + return event, nil +} + +func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error) { + err := transaction.RemoveResourceFromGroup(ctx, accountID, groupID, resourceID) + if err != nil { + return nil, fmt.Errorf("error removing resource from group: %w", err) + } + + group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) + if err != nil { + return nil, fmt.Errorf("error getting group: %w", err) + } + + // TODO: at some point, this will need to become a switch statement + networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("error getting network resource: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, userID, groupID, accountID, activity.ResourceRemovedFromGroup, group.EventMetaResource(networkResource)) + } + + return event, nil +} + +func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) { + return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID) +} + +func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum { + groupsInfo := []api.GroupMinimum{} + groupsChecked := make(map[string]struct{}) + for _, group := range groups { + _, ok := groupsChecked[group.ID] + if ok { + continue + } + groupsChecked[group.ID] = struct{}{} + for _, pk := range group.Peers { + if pk == id { + info := api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + ResourcesCount: len(group.Resources), + } + groupsInfo = append(groupsInfo, info) + break + } + } + for _, rk := range group.Resources { + if rk.ID == id { + info := api.GroupMinimum{ + Id: group.ID, + Name: group.Name, + PeersCount: len(group.Peers), + ResourcesCount: len(group.Resources), + } + groupsInfo = append(groupsInfo, info) + break + } + } + } + return groupsInfo +} + +func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { + return map[string]*types.Group{}, nil +} + +func (m *mockManager) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) { + return []*types.Group{}, nil +} + +func (m *mockManager) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error { + return nil +} + +func (m *mockManager) AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error) { + return func() { + // noop + }, nil +} + +func (m *mockManager) RemoveResourceFromGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID, resourceID string) (func(), error) { + return func() { + // noop + }, nil +} + +func NewManagerMock() Manager { + return &mockManager{} +} diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 9c12336f8..2635ac11b 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -23,14 +23,17 @@ import ( "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/settings" internalStatus "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" ) // GRPCServer an instance of a Management gRPC API server type GRPCServer struct { - accountManager AccountManager - wgKey wgtypes.Key + accountManager AccountManager + settingsManager settings.Manager + wgKey wgtypes.Key proto.UnimplementedManagementServiceServer peersUpdateManager *PeersUpdateManager config *Config @@ -47,6 +50,7 @@ func NewServer( ctx context.Context, config *Config, accountManager AccountManager, + settingsManager settings.Manager, peersUpdateManager *PeersUpdateManager, secretsManager SecretsManager, appMetrics telemetry.AppMetrics, @@ -99,6 +103,7 @@ func NewServer( // peerKey -> event channel peersUpdateManager: peersUpdateManager, accountManager: accountManager, + settingsManager: settingsManager, config: config, secretsManager: secretsManager, jwtValidator: jwtValidator, @@ -483,7 +488,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ WiretrusteeConfig: toWiretrusteeConfig(s.config, nil, relayToken), - PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain()), + PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(), false), Checks: toProtocolChecks(ctx, postureChecks), } encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) @@ -599,20 +604,21 @@ func toWiretrusteeConfig(config *Config, turnCredentials *Token, relayToken *Tok } } -func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.PeerConfig { +func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, dnsResolutionOnRoutingPeerEnabled bool) *proto.PeerConfig { netmask, _ := network.Net.Mask.Size() fqdn := peer.FQDN(dnsName) return &proto.PeerConfig{ - Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network - SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, - Fqdn: fqdn, + Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network + SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, + Fqdn: fqdn, + RoutingPeerDnsResolutionEnabled: dnsResolutionOnRoutingPeerEnabled, } } -func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse { +func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, dnsResolutionOnRoutingPeerEnbled bool) *proto.SyncResponse { response := &proto.SyncResponse{ WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials, relayCredentials), - PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName), + PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, dnsResolutionOnRoutingPeerEnbled), NetworkMap: &proto.NetworkMap{ Serial: networkMap.Network.CurrentSerial(), Routes: toProtocolRoutes(networkMap.Routes), @@ -661,7 +667,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em } // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization -func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer *nbpeer.Peer, networkMap *types.NetworkMap, postureChecks []*posture.Checks, srv proto.ManagementService_SyncServer) error { var err error var turnToken *Token @@ -680,7 +686,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p } } - plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil) + settings, err := s.settingsManager.GetSettings(ctx, peer.AccountID, peer.UserID) + if err != nil { + return status.Errorf(codes.Internal, "error handling request") + } + + plainResp := toSyncResponse(ctx, s.config, peer, turnToken, relayToken, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil, settings.RoutingPeerDNSResolutionEnabled) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 2e084f6e4..351976baf 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -84,6 +84,10 @@ components: items: type: string example: Administrators + routing_peer_dns_resolution_enabled: + description: Enables or disables DNS resolution on the routing peers + type: boolean + example: true extra: $ref: '#/components/schemas/AccountExtraSettings' required: @@ -668,6 +672,10 @@ components: description: Count of peers associated to the group type: integer example: 2 + resources_count: + description: Count of resources associated to the group + type: integer + example: 5 issued: description: How the group was issued (api, integration, jwt) type: string @@ -677,6 +685,7 @@ components: - id - name - peers_count + - resources_count GroupRequest: type: object properties: @@ -690,6 +699,10 @@ components: items: type: string example: "ch8i4ug6lnn4g9hqv7m1" + resources: + type: array + items: + $ref: '#/components/schemas/Resource' required: - name Group: @@ -702,8 +715,13 @@ components: type: array items: $ref: '#/components/schemas/PeerMinimum' + resources: + type: array + items: + $ref: '#/components/schemas/Resource' required: - peers + - resources PolicyRuleMinimum: type: object properties: @@ -782,15 +800,18 @@ components: items: type: string example: "ch8i4ug6lnn4g9hqv797" + sourceResource: + description: Policy rule source resource that the rule is applied to + $ref: '#/components/schemas/Resource' destinations: description: Policy rule destination group IDs type: array items: type: string example: "ch8i4ug6lnn4g9h7v7m0" - required: - - sources - - destinations + destinationResource: + description: Policy rule destination resource that the rule is applied to + $ref: '#/components/schemas/Resource' PolicyRule: allOf: - $ref: '#/components/schemas/PolicyRuleMinimum' @@ -801,14 +822,17 @@ components: type: array items: $ref: '#/components/schemas/GroupMinimum' + sourceResource: + description: Policy rule source resource that the rule is applied to + $ref: '#/components/schemas/Resource' destinations: description: Policy rule destination group IDs type: array items: $ref: '#/components/schemas/GroupMinimum' - required: - - sources - - destinations + destinationResource: + description: Policy rule destination resource that the rule is applied to + $ref: '#/components/schemas/Resource' PolicyMinimum: type: object properties: @@ -1176,6 +1200,171 @@ components: - id - network_type - $ref: '#/components/schemas/RouteRequest' + Resource: + type: object + properties: + id: + description: ID of the resource + type: string + example: chacdk86lnnboviihd7g + type: + description: Type of the resource + $ref: '#/components/schemas/ResourceType' + required: + - id + - type + ResourceType: + allOf: + - $ref: '#/components/schemas/NetworkResourceType' + - type: string + example: host + NetworkRequest: + type: object + properties: + name: + description: Network name + type: string + example: Remote Network 1 + description: + description: Network description + type: string + example: A remote network that needs to be accessed + required: + - name + Network: + allOf: + - type: object + properties: + id: + description: Network ID + type: string + example: chacdk86lnnboviihd7g + routers: + description: List of router IDs associated with the network + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m0 + routing_peers_count: + description: Count of routing peers associated with the network + type: integer + example: 2 + resources: + description: List of network resource IDs associated with the network + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m1 + policies: + description: List of policy IDs associated with the network + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m2 + required: + - id + - routers + - resources + - routing_peers_count + - policies + - $ref: '#/components/schemas/NetworkRequest' + NetworkResourceMinimum: + type: object + properties: + name: + description: Network resource name + type: string + example: Remote Resource 1 + description: + description: Network resource description + type: string + example: A remote resource inside network 1 + address: + description: Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) + type: string + example: "1.1.1.1" + required: + - name + - address + NetworkResourceRequest: + allOf: + - $ref: '#/components/schemas/NetworkResourceMinimum' + - type: object + properties: + groups: + description: Group IDs containing the resource + type: array + items: + type: string + example: "chacdk86lnnboviihd70" + required: + - groups + - address + NetworkResource: + allOf: + - type: object + properties: + id: + description: Network Resource ID + type: string + example: chacdk86lnnboviihd7g + type: + $ref: '#/components/schemas/NetworkResourceType' + groups: + description: Groups that the resource belongs to + type: array + items: + $ref: '#/components/schemas/GroupMinimum' + required: + - id + - type + - groups + - $ref: '#/components/schemas/NetworkResourceMinimum' + NetworkResourceType: + description: Network resource type based of the address + type: string + enum: [ "host", "subnet", "domain" ] + example: host + NetworkRouterRequest: + type: object + properties: + peer: + description: Peer Identifier associated with route. This property can not be set together with `peer_groups` + type: string + example: chacbco6lnnbn6cg5s91 + peer_groups: + description: Peers Group Identifier associated with route. This property can not be set together with `peer` + type: array + items: + type: string + example: chacbco6lnnbn6cg5s91 + metric: + description: Route metric number. Lowest number has higher priority + type: integer + maximum: 9999 + minimum: 1 + example: 9999 + masquerade: + description: Indicate if peer should masquerade traffic to this route's prefix + type: boolean + example: true + required: + # Only one property has to be set + #- peer + #- peer_groups + - metric + - masquerade + NetworkRouter: + allOf: + - type: object + properties: + id: + description: Network Router Id + type: string + example: chacdk86lnnboviihd7g + required: + - id + - $ref: '#/components/schemas/NetworkRouterRequest' Nameserver: type: object properties: @@ -2460,6 +2649,502 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/networks: + get: + summary: List all Networks + description: Returns a list of all networks + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of Networks + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Network + description: Creates a Network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + description: New Network request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkRequest' + responses: + '200': + description: A Network Object + content: + application/json: + schema: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}: + get: + summary: Retrieve a Network + description: Get information about a Network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: A Network object + content: + application/json: + schema: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a Network + description: Update/Replace a Network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + requestBody: + description: Update Network request + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRequest' + responses: + '200': + description: A Network object + content: + application/json: + schema: + $ref: '#/components/schemas/Network' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Network + description: Delete a network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/resources: + get: + summary: List all Network Resources + description: Returns a list of all resources in a network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: A JSON Array of Resources + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Network Resource + description: Creates a Network Resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + requestBody: + description: New Network Resource request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkResourceRequest' + responses: + '200': + description: A Network Resource Object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/resources/{resourceId}: + get: + summary: Retrieve a Network Resource + description: Get information about a Network Resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: resourceId + required: true + schema: + type: string + description: The unique identifier of a network resource + responses: + '200': + description: A Network Resource object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a Network Resource + description: Update a Network Resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: resourceId + required: true + schema: + type: string + description: The unique identifier of a resource + requestBody: + description: Update Network Resource request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkResourceRequest' + responses: + '200': + description: A Network Resource object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkResource' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Network Resource + description: Delete a network resource + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: resourceId + required: true + schema: + type: string + description: The unique identifier of a network resource + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/routers: + get: + summary: List all Network Routers + description: Returns a list of all routers in a network + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + responses: + '200': + description: A JSON Array of Routers + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a Network Router + description: Creates a Network Router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + requestBody: + description: New Network Router request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkRouterRequest' + responses: + '200': + description: A Router Object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/networks/{networkId}/routers/{routerId}: + get: + summary: Retrieve a Network Router + description: Get information about a Network Router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: routerId + required: true + schema: + type: string + description: The unique identifier of a router + responses: + '200': + description: A Router object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a Network Router + description: Update a Network Router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: routerId + required: true + schema: + type: string + description: The unique identifier of a router + requestBody: + description: Update Network Router request + content: + 'application/json': + schema: + $ref: '#/components/schemas/NetworkRouterRequest' + responses: + '200': + description: A Router object + content: + application/json: + schema: + $ref: '#/components/schemas/NetworkRouter' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a Network Router + description: Delete a network router + tags: [ Networks ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: networkId + required: true + schema: + type: string + description: The unique identifier of a network + - in: path + name: routerId + required: true + schema: + type: string + description: The unique identifier of a router + responses: + '200': + description: Delete status code + content: { } + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" /api/dns/nameservers: get: summary: List all Nameserver Groups diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 321395d25..40574d6f1 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -88,6 +88,13 @@ const ( NameserverNsTypeUdp NameserverNsType = "udp" ) +// Defines values for NetworkResourceType. +const ( + NetworkResourceTypeDomain NetworkResourceType = "domain" + NetworkResourceTypeHost NetworkResourceType = "host" + NetworkResourceTypeSubnet NetworkResourceType = "subnet" +) + // Defines values for PeerNetworkRangeCheckAction. const ( PeerNetworkRangeCheckActionAllow PeerNetworkRangeCheckAction = "allow" @@ -136,6 +143,13 @@ const ( PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" ) +// Defines values for ResourceType. +const ( + ResourceTypeDomain ResourceType = "domain" + ResourceTypeHost ResourceType = "host" + ResourceTypeSubnet ResourceType = "subnet" +) + // Defines values for UserStatus. const ( UserStatusActive UserStatus = "active" @@ -234,6 +248,9 @@ type AccountSettings struct { // RegularUsersViewBlocked Allows blocking regular users from viewing parts of the system. RegularUsersViewBlocked bool `json:"regular_users_view_blocked"` + + // RoutingPeerDnsResolutionEnabled Enables or disables DNS resolution on the routing peers + RoutingPeerDnsResolutionEnabled *bool `json:"routing_peer_dns_resolution_enabled,omitempty"` } // Checks List of objects that perform the actual checks @@ -365,7 +382,11 @@ type Group struct { Peers []PeerMinimum `json:"peers"` // PeersCount Count of peers associated to the group - PeersCount int `json:"peers_count"` + PeersCount int `json:"peers_count"` + Resources []Resource `json:"resources"` + + // ResourcesCount Count of resources associated to the group + ResourcesCount int `json:"resources_count"` } // GroupIssued How the group was issued (api, integration, jwt) @@ -384,6 +405,9 @@ type GroupMinimum struct { // PeersCount Count of peers associated to the group PeersCount int `json:"peers_count"` + + // ResourcesCount Count of resources associated to the group + ResourcesCount int `json:"resources_count"` } // GroupMinimumIssued How the group was issued (api, integration, jwt) @@ -395,7 +419,8 @@ type GroupRequest struct { Name string `json:"name"` // Peers List of peers ids - Peers *[]string `json:"peers,omitempty"` + Peers *[]string `json:"peers,omitempty"` + Resources *[]Resource `json:"resources,omitempty"` } // Location Describe geographical location information @@ -494,6 +519,123 @@ type NameserverGroupRequest struct { SearchDomainsEnabled bool `json:"search_domains_enabled"` } +// Network defines model for Network. +type Network struct { + // Description Network description + Description *string `json:"description,omitempty"` + + // Id Network ID + Id string `json:"id"` + + // Name Network name + Name string `json:"name"` + + // Policies List of policy IDs associated with the network + Policies []string `json:"policies"` + + // Resources List of network resource IDs associated with the network + Resources []string `json:"resources"` + + // Routers List of router IDs associated with the network + Routers []string `json:"routers"` + + // RoutingPeersCount Count of routing peers associated with the network + RoutingPeersCount int `json:"routing_peers_count"` +} + +// NetworkRequest defines model for NetworkRequest. +type NetworkRequest struct { + // Description Network description + Description *string `json:"description,omitempty"` + + // Name Network name + Name string `json:"name"` +} + +// NetworkResource defines model for NetworkResource. +type NetworkResource struct { + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) + Address string `json:"address"` + + // Description Network resource description + Description *string `json:"description,omitempty"` + + // Groups Groups that the resource belongs to + Groups []GroupMinimum `json:"groups"` + + // Id Network Resource ID + Id string `json:"id"` + + // Name Network resource name + Name string `json:"name"` + + // Type Network resource type based of the address + Type NetworkResourceType `json:"type"` +} + +// NetworkResourceMinimum defines model for NetworkResourceMinimum. +type NetworkResourceMinimum struct { + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) + Address string `json:"address"` + + // Description Network resource description + Description *string `json:"description,omitempty"` + + // Name Network resource name + Name string `json:"name"` +} + +// NetworkResourceRequest defines model for NetworkResourceRequest. +type NetworkResourceRequest struct { + // Address Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) + Address string `json:"address"` + + // Description Network resource description + Description *string `json:"description,omitempty"` + + // Groups Group IDs containing the resource + Groups []string `json:"groups"` + + // Name Network resource name + Name string `json:"name"` +} + +// NetworkResourceType Network resource type based of the address +type NetworkResourceType string + +// NetworkRouter defines model for NetworkRouter. +type NetworkRouter struct { + // Id Network Router Id + Id string `json:"id"` + + // Masquerade Indicate if peer should masquerade traffic to this route's prefix + Masquerade bool `json:"masquerade"` + + // Metric Route metric number. Lowest number has higher priority + Metric int `json:"metric"` + + // Peer Peer Identifier associated with route. This property can not be set together with `peer_groups` + Peer *string `json:"peer,omitempty"` + + // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` + PeerGroups *[]string `json:"peer_groups,omitempty"` +} + +// NetworkRouterRequest defines model for NetworkRouterRequest. +type NetworkRouterRequest struct { + // Masquerade Indicate if peer should masquerade traffic to this route's prefix + Masquerade bool `json:"masquerade"` + + // Metric Route metric number. Lowest number has higher priority + Metric int `json:"metric"` + + // Peer Peer Identifier associated with route. This property can not be set together with `peer_groups` + Peer *string `json:"peer,omitempty"` + + // PeerGroups Peers Group Identifier associated with route. This property can not be set together with `peer` + PeerGroups *[]string `json:"peer_groups,omitempty"` +} + // OSVersionCheck Posture check for the version of operating system type OSVersionCheck struct { // Android Posture check for the version of operating system @@ -779,10 +921,11 @@ type PolicyRule struct { Bidirectional bool `json:"bidirectional"` // Description Policy rule friendly description - Description *string `json:"description,omitempty"` + Description *string `json:"description,omitempty"` + DestinationResource *Resource `json:"destinationResource,omitempty"` // Destinations Policy rule destination group IDs - Destinations []GroupMinimum `json:"destinations"` + Destinations *[]GroupMinimum `json:"destinations,omitempty"` // Enabled Policy rule status Enabled bool `json:"enabled"` @@ -800,10 +943,11 @@ type PolicyRule struct { Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic - Protocol PolicyRuleProtocol `json:"protocol"` + Protocol PolicyRuleProtocol `json:"protocol"` + SourceResource *Resource `json:"sourceResource,omitempty"` // Sources Policy rule source group IDs - Sources []GroupMinimum `json:"sources"` + Sources *[]GroupMinimum `json:"sources,omitempty"` } // PolicyRuleAction Policy rule accept or drops packets @@ -857,10 +1001,11 @@ type PolicyRuleUpdate struct { Bidirectional bool `json:"bidirectional"` // Description Policy rule friendly description - Description *string `json:"description,omitempty"` + Description *string `json:"description,omitempty"` + DestinationResource *Resource `json:"destinationResource,omitempty"` // Destinations Policy rule destination group IDs - Destinations []string `json:"destinations"` + Destinations *[]string `json:"destinations,omitempty"` // Enabled Policy rule status Enabled bool `json:"enabled"` @@ -878,10 +1023,11 @@ type PolicyRuleUpdate struct { Ports *[]string `json:"ports,omitempty"` // Protocol Policy rule type of the traffic - Protocol PolicyRuleUpdateProtocol `json:"protocol"` + Protocol PolicyRuleUpdateProtocol `json:"protocol"` + SourceResource *Resource `json:"sourceResource,omitempty"` // Sources Policy rule source group IDs - Sources []string `json:"sources"` + Sources *[]string `json:"sources,omitempty"` } // PolicyRuleUpdateAction Policy rule accept or drops packets @@ -955,6 +1101,16 @@ type ProcessCheck struct { Processes []Process `json:"processes"` } +// Resource defines model for Resource. +type Resource struct { + // Id ID of the resource + Id string `json:"id"` + Type ResourceType `json:"type"` +} + +// ResourceType defines model for ResourceType. +type ResourceType string + // Route defines model for Route. type Route struct { // AccessControlGroups Access control group identifier associated with route. @@ -1292,6 +1448,24 @@ type PostApiGroupsJSONRequestBody = GroupRequest // PutApiGroupsGroupIdJSONRequestBody defines body for PutApiGroupsGroupId for application/json ContentType. type PutApiGroupsGroupIdJSONRequestBody = GroupRequest +// PostApiNetworksJSONRequestBody defines body for PostApiNetworks for application/json ContentType. +type PostApiNetworksJSONRequestBody = NetworkRequest + +// PutApiNetworksNetworkIdJSONRequestBody defines body for PutApiNetworksNetworkId for application/json ContentType. +type PutApiNetworksNetworkIdJSONRequestBody = NetworkRequest + +// PostApiNetworksNetworkIdResourcesJSONRequestBody defines body for PostApiNetworksNetworkIdResources for application/json ContentType. +type PostApiNetworksNetworkIdResourcesJSONRequestBody = NetworkResourceRequest + +// PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody defines body for PutApiNetworksNetworkIdResourcesResourceId for application/json ContentType. +type PutApiNetworksNetworkIdResourcesResourceIdJSONRequestBody = NetworkResourceRequest + +// PostApiNetworksNetworkIdRoutersJSONRequestBody defines body for PostApiNetworksNetworkIdRouters for application/json ContentType. +type PostApiNetworksNetworkIdRoutersJSONRequestBody = NetworkRouterRequest + +// PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody defines body for PutApiNetworksNetworkIdRoutersRouterId for application/json ContentType. +type PutApiNetworksNetworkIdRoutersRouterIdJSONRequestBody = NetworkRouterRequest + // PutApiPeersPeerIdJSONRequestBody defines body for PutApiPeersPeerId for application/json ContentType. type PutApiPeersPeerIdJSONRequestBody = PeerRequest diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 373aa4dd7..7db7ab5b8 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -12,11 +12,13 @@ import ( s "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" + nbgroups "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/handlers/accounts" "github.com/netbirdio/netbird/management/server/http/handlers/dns" "github.com/netbirdio/netbird/management/server/http/handlers/events" "github.com/netbirdio/netbird/management/server/http/handlers/groups" + "github.com/netbirdio/netbird/management/server/http/handlers/networks" "github.com/netbirdio/netbird/management/server/http/handlers/peers" "github.com/netbirdio/netbird/management/server/http/handlers/policies" "github.com/netbirdio/netbird/management/server/http/handlers/routes" @@ -25,6 +27,9 @@ import ( "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/integrated_validator" "github.com/netbirdio/netbird/management/server/jwtclaims" + nbnetworks "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -38,7 +43,7 @@ type apiHandler struct { } // APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { +func APIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -93,6 +98,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa routes.AddEndpoints(api.AccountManager, authCfg, router) dns.AddEndpoints(api.AccountManager, authCfg, router) events.AddEndpoints(api.AccountManager, authCfg, router) + networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager, api.AccountManager.GetAccountIDFromToken, authCfg, router) return rootRouter, nil } diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index c95207777..a23628cdc 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that handles the server.Account HTTP endpoints @@ -82,7 +83,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - settings := &server.Settings{ + settings := &types.Settings{ PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled, PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)), RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked, @@ -107,6 +108,9 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { if req.Settings.JwtAllowGroups != nil { settings.JWTAllowGroups = *req.Settings.JwtAllowGroups } + if req.Settings.RoutingPeerDnsResolutionEnabled != nil { + settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled + } updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { @@ -138,7 +142,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toAccountResponse(accountID string, settings *server.Settings) *api.Account { +func toAccountResponse(accountID string, settings *types.Settings) *api.Account { jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} @@ -154,6 +158,7 @@ func toAccountResponse(accountID string, settings *server.Settings) *api.Account JwtGroupsClaimName: &settings.JWTGroupsClaimName, JwtAllowGroups: &jwtAllowGroups, RegularUsersViewBlocked: settings.RegularUsersViewBlocked, + RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled, } if settings.Extra != nil { diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 9d7e8a84d..e8a599863 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -13,23 +13,23 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) -func initAccountsTestData(account *server.Account, admin *server.User) *handler { +func initAccountsTestData(account *types.Account, admin *types.User) *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return account.Id, admin.Id, nil }, - GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) { + GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { return account.Settings, nil }, - UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { + UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -58,19 +58,19 @@ func initAccountsTestData(account *server.Account, admin *server.User) *handler func TestAccounts_AccountsHandler(t *testing.T) { accountID := "test_account" - adminUser := server.NewAdminUser("test_user") + adminUser := types.NewAdminUser("test_user") sr := func(v string) *string { return &v } br := func(v bool) *bool { return &v } - handler := initAccountsTestData(&server.Account{ + handler := initAccountsTestData(&types.Account{ Id: accountID, Domain: "hotmail.com", - Network: server.NewNetwork(), - Users: map[string]*server.User{ + Network: types.NewNetwork(), + Users: map[string]*types.User{ adminUser.Id: adminUser, }, - Settings: &server.Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: false, PeerLoginExpiration: time.Hour, RegularUsersViewBlocked: true, @@ -95,13 +95,14 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestPath: "/api/accounts", expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ - PeerLoginExpiration: int(time.Hour.Seconds()), - PeerLoginExpirationEnabled: false, - GroupsPropagationEnabled: br(false), - JwtGroupsClaimName: sr(""), - JwtGroupsEnabled: br(false), - JwtAllowGroups: &[]string{}, - RegularUsersViewBlocked: true, + PeerLoginExpiration: int(time.Hour.Seconds()), + PeerLoginExpirationEnabled: false, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr(""), + JwtGroupsEnabled: br(false), + JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: true, + RoutingPeerDnsResolutionEnabled: br(false), }, expectedArray: true, expectedID: accountID, @@ -114,13 +115,14 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ - PeerLoginExpiration: 15552000, - PeerLoginExpirationEnabled: true, - GroupsPropagationEnabled: br(false), - JwtGroupsClaimName: sr(""), - JwtGroupsEnabled: br(false), - JwtAllowGroups: &[]string{}, - RegularUsersViewBlocked: false, + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: true, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr(""), + JwtGroupsEnabled: br(false), + JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: false, + RoutingPeerDnsResolutionEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -133,13 +135,14 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ - PeerLoginExpiration: 15552000, - PeerLoginExpirationEnabled: false, - GroupsPropagationEnabled: br(false), - JwtGroupsClaimName: sr("roles"), - JwtGroupsEnabled: br(true), - JwtAllowGroups: &[]string{"test"}, - RegularUsersViewBlocked: true, + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: false, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr("roles"), + JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{"test"}, + RegularUsersViewBlocked: true, + RoutingPeerDnsResolutionEnabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -152,13 +155,14 @@ func TestAccounts_AccountsHandler(t *testing.T) { requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ - PeerLoginExpiration: 554400, - PeerLoginExpirationEnabled: true, - GroupsPropagationEnabled: br(true), - JwtGroupsClaimName: sr("groups"), - JwtGroupsEnabled: br(true), - JwtAllowGroups: &[]string{}, - RegularUsersViewBlocked: true, + PeerLoginExpiration: 554400, + PeerLoginExpirationEnabled: true, + GroupsPropagationEnabled: br(true), + JwtGroupsClaimName: sr("groups"), + JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{}, + RegularUsersViewBlocked: true, + RoutingPeerDnsResolutionEnabled: br(false), }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/handlers/dns/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go index 7dd8c1fc1..112eee179 100644 --- a/management/server/http/handlers/dns/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/types" ) // dnsSettingsHandler is a handler that returns the DNS settings of the account @@ -81,7 +82,7 @@ func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Re return } - updateDNSSettings := &server.DNSSettings{ + updateDNSSettings := &types.DNSSettings{ DisabledManagementGroups: req.DisabledManagementGroups, } diff --git a/management/server/http/handlers/dns/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go index a64e3fd83..9ca1dc032 100644 --- a/management/server/http/handlers/dns/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -13,10 +13,10 @@ import ( "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/gorilla/mux" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -27,15 +27,15 @@ const ( testDNSSettingsUserID = "test_user" ) -var baseExistingDNSSettings = server.DNSSettings{ +var baseExistingDNSSettings = types.DNSSettings{ DisabledManagementGroups: []string{testDNSSettingsExistingGroup}, } -var testingDNSSettingsAccount = &server.Account{ +var testingDNSSettingsAccount = &types.Account{ Id: testDNSSettingsAccountID, Domain: "hotmail.com", - Users: map[string]*server.User{ - testDNSSettingsUserID: server.NewAdminUser("test_user"), + Users: map[string]*types.User{ + testDNSSettingsUserID: types.NewAdminUser("test_user"), }, DNSSettings: baseExistingDNSSettings, } @@ -43,10 +43,10 @@ var testingDNSSettingsAccount = &server.Account{ func initDNSSettingsTestData() *dnsSettingsHandler { return &dnsSettingsHandler{ accountManager: &mock_server.MockAccountManager{ - GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { + GetDNSSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { return &testingDNSSettingsAccount.DNSSettings, nil }, - SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { + SaveDNSSettingsFunc: func(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error { if dnsSettingsToSave != nil { return nil } diff --git a/management/server/http/handlers/events/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go index 6af2e5346..17478aba3 100644 --- a/management/server/http/handlers/events/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -13,11 +13,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/types" ) func initEventsTestData(account string, events ...*activity.Event) *handler { @@ -32,8 +32,8 @@ func initEventsTestData(account string, events ...*activity.Event) *handler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { - return make([]*server.UserInfo, 0), nil + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) { + return make([]*types.UserInfo, 0), nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -191,7 +191,7 @@ func TestEvents_GetEvents(t *testing.T) { }, } accountID := "test_account" - adminUser := server.NewAdminUser("test_user") + adminUser := types.NewAdminUser("test_user") events := generateEvents(accountID, adminUser.Id) handler := initEventsTestData(accountID, events...) diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index e60529cec..0ecea7ec2 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -9,9 +9,9 @@ import ( "github.com/netbirdio/netbird/management/server/http/configs" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -129,10 +129,21 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := nbgroup.Group{ + + resources := make([]types.Resource, 0) + if req.Resources != nil { + for _, res := range *req.Resources { + resource := types.Resource{} + resource.FromAPIRequest(&res) + resources = append(resources, resource) + } + } + + group := types.Group{ ID: groupID, Name: req.Name, Peers: peers, + Resources: resources, Issued: existingGroup.Issued, IntegrationReference: existingGroup.IntegrationReference, } @@ -179,10 +190,21 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { } else { peers = *req.Peers } - group := nbgroup.Group{ - Name: req.Name, - Peers: peers, - Issued: nbgroup.GroupIssuedAPI, + + resources := make([]types.Resource, 0) + if req.Resources != nil { + for _, res := range *req.Resources { + resource := types.Resource{} + resource.FromAPIRequest(&res) + resources = append(resources, resource) + } + } + + group := types.Group{ + Name: req.Name, + Peers: peers, + Resources: resources, + Issued: types.GroupIssuedAPI, } err = h.accountManager.SaveGroup(r.Context(), accountID, userID, &group) @@ -259,13 +281,13 @@ func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) { } -func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { +func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group { peersMap := make(map[string]*nbpeer.Peer, len(peers)) for _, peer := range peers { peersMap[peer.ID] = peer } - cache := make(map[string]api.PeerMinimum) + peerCache := make(map[string]api.PeerMinimum) gr := api.Group{ Id: group.ID, Name: group.Name, @@ -273,7 +295,7 @@ func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { } for _, pid := range group.Peers { - _, ok := cache[pid] + _, ok := peerCache[pid] if !ok { peer, ok := peersMap[pid] if !ok { @@ -283,12 +305,19 @@ func toGroupResponse(peers []*nbpeer.Peer, group *nbgroup.Group) *api.Group { Id: peer.ID, Name: peer.Name, } - cache[pid] = peerResp + peerCache[pid] = peerResp gr.Peers = append(gr.Peers, peerResp) } } gr.PeersCount = len(gr.Peers) + for _, res := range group.Resources { + resResp := res.ToAPIResponse() + gr.Resources = append(gr.Resources, *resResp) + } + + gr.ResourcesCount = len(gr.Resources) + return &gr } diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index 089c1a40f..49805ca9b 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -17,13 +17,13 @@ import ( "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) var TestPeers = map[string]*nbpeer.Peer{ @@ -31,20 +31,20 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(initGroups ...*nbgroup.Group) *handler { +func initGroupTestData(initGroups ...*types.Group) *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ - SaveGroupFunc: func(_ context.Context, accountID, userID string, group *nbgroup.Group) error { + SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group) error { if !strings.HasPrefix(group.ID, "id-") { group.ID = "id-was-set" } return nil }, - GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*nbgroup.Group, error) { - groups := map[string]*nbgroup.Group{ - "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, - "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, - "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, + GetGroupFunc: func(_ context.Context, _, groupID, _ string) (*types.Group, error) { + groups := map[string]*types.Group{ + "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: types.GroupIssuedJWT}, + "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: types.GroupIssuedAPI}, + "id-all": {ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, } for _, group := range initGroups { @@ -61,9 +61,9 @@ func initGroupTestData(initGroups ...*nbgroup.Group) *handler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*nbgroup.Group, error) { + GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) { if groupName == "All" { - return &nbgroup.Group{ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, nil + return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil } return nil, fmt.Errorf("unknown group name") @@ -120,7 +120,7 @@ func TestGetGroup(t *testing.T) { }, } - group := &nbgroup.Group{ + group := &types.Group{ ID: "idofthegroup", Name: "Group", } @@ -154,7 +154,7 @@ func TestGetGroup(t *testing.T) { t.Fatalf("I don't know what I expected; %v", err) } - got := &nbgroup.Group{} + got := &types.Group{} if err = json.Unmarshal(content, &got); err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) } diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go new file mode 100644 index 000000000..6b36a8fce --- /dev/null +++ b/management/server/http/handlers/networks/handler.go @@ -0,0 +1,321 @@ +package networks + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/gorilla/mux" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/status" + nbtypes "github.com/netbirdio/netbird/management/server/types" +) + +// handler is a handler that returns networks of the account +type handler struct { + networksManager networks.Manager + resourceManager resources.Manager + routerManager routers.Manager + accountManager s.AccountManager + + groupsManager groups.Manager + extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) + claimsExtractor *jwtclaims.ClaimsExtractor +} + +func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { + addRouterEndpoints(routerManager, extractFromToken, authCfg, router) + addResourceEndpoints(resourceManager, groupsManager, extractFromToken, authCfg, router) + + networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager, extractFromToken, authCfg) + router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS") + router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS") +} + +func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager s.AccountManager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *handler { + return &handler{ + networksManager: networksManager, + resourceManager: resourceManager, + routerManager: routerManager, + groupsManager: groupsManager, + accountManager: accountManager, + extractFromToken: extractFromToken, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groups, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account)) +} + +func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + network := &types.Network{} + network.FromAPIRequest(&req) + + network.AccountID = accountID + network, err = h.networksManager.CreateNetwork(r.Context(), userID, network) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyIDs := account.GetPoliciesAppliedInNetwork(network.ID) + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs)) +} + +func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + networkID := vars["networkId"] + if len(networkID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w) + return + } + + network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyIDs := account.GetPoliciesAppliedInNetwork(networkID) + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) +} + +func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + networkID := vars["networkId"] + if len(networkID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w) + return + } + + var req api.NetworkRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + network := &types.Network{} + network.FromAPIRequest(&req) + + network.ID = networkID + network.AccountID = accountID + network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + account, err := h.accountManager.GetAccount(r.Context(), accountID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + policyIDs := account.GetPoliciesAppliedInNetwork(networkID) + + util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) +} + +func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + vars := mux.Vars(r) + networkID := vars["networkId"] + if len(networkID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid network ID"), w) + return + } + + err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} + +func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, networkID string) ([]string, []string, int, error) { + resources, err := h.resourceManager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) + if err != nil { + return nil, nil, 0, fmt.Errorf("failed to get resources in network: %w", err) + } + + var resourceIDs []string + for _, resource := range resources { + resourceIDs = append(resourceIDs, resource.ID) + } + + routers, err := h.routerManager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) + if err != nil { + return nil, nil, 0, fmt.Errorf("failed to get routers in network: %w", err) + } + + groups, err := h.groupsManager.GetAllGroups(ctx, accountID, userID) + if err != nil { + return nil, nil, 0, fmt.Errorf("failed to get groups: %w", err) + } + + peerCounter := 0 + var routerIDs []string + for _, router := range routers { + routerIDs = append(routerIDs, router.ID) + if router.Peer != "" { + peerCounter++ + } + if len(router.PeerGroups) > 0 { + for _, groupID := range router.PeerGroups { + peerCounter += len(groups[groupID].Peers) + } + } + } + + return routerIDs, resourceIDs, peerCounter, nil +} + +func (h *handler) generateNetworkResponse(networks []*types.Network, routers map[string][]*routerTypes.NetworkRouter, resourceIDs map[string][]string, groups map[string]*nbtypes.Group, account *nbtypes.Account) []*api.Network { + var networkResponse []*api.Network + for _, network := range networks { + routerIDs, peerCounter := getRouterIDs(network, routers, groups) + policyIDs := account.GetPoliciesAppliedInNetwork(network.ID) + networkResponse = append(networkResponse, network.ToAPIResponse(routerIDs, resourceIDs[network.ID], peerCounter, policyIDs)) + } + return networkResponse +} + +func getRouterIDs(network *types.Network, routers map[string][]*routerTypes.NetworkRouter, groups map[string]*nbtypes.Group) ([]string, int) { + routerIDs := []string{} + peerCounter := 0 + for _, router := range routers[network.ID] { + routerIDs = append(routerIDs, router.ID) + if router.Peer != "" { + peerCounter++ + } + if len(router.PeerGroups) > 0 { + for _, groupID := range router.PeerGroups { + group, ok := groups[groupID] + if !ok { + continue + } + peerCounter += len(group.Peers) + } + } + } + return routerIDs, peerCounter +} diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go new file mode 100644 index 000000000..a0dc9a10d --- /dev/null +++ b/management/server/http/handlers/networks/resources_handler.go @@ -0,0 +1,222 @@ +package networks + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/resources/types" +) + +type resourceHandler struct { + resourceManager resources.Manager + groupsManager groups.Manager + extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) + claimsExtractor *jwtclaims.ClaimsExtractor +} + +func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { + resourceHandler := newResourceHandler(resourcesManager, groupsManager, extractFromToken, authCfg) + router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS") +} + +func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *resourceHandler { + return &resourceHandler{ + resourceManager: resourceManager, + groupsManager: groupsManager, + extractFromToken: extractFromToken, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var resourcesResponse []*api.NetworkResource + for _, resource := range resources { + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + resourcesResponse = append(resourcesResponse, resource.ToAPIResponse(groupMinimumInfo)) + } + + util.WriteJSONObject(r.Context(), w, resourcesResponse) +} +func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var resourcesResponse []*api.NetworkResource + for _, resource := range resources { + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + resourcesResponse = append(resourcesResponse, resource.ToAPIResponse(groupMinimumInfo)) + } + + util.WriteJSONObject(r.Context(), w, resourcesResponse) +} + +func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkResourceRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + resource := &types.NetworkResource{} + resource.FromAPIRequest(&req) + + resource.NetworkID = mux.Vars(r)["networkId"] + resource.AccountID = accountID + resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(groupMinimumInfo)) +} + +func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + resourceID := mux.Vars(r)["resourceId"] + resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(groupMinimumInfo)) +} + +func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkResourceRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + resource := &types.NetworkResource{} + resource.FromAPIRequest(&req) + + resource.ID = mux.Vars(r)["resourceId"] + resource.NetworkID = mux.Vars(r)["networkId"] + resource.AccountID = accountID + resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + groupMinimumInfo := groups.ToGroupsInfo(grps, resource.ID) + util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(groupMinimumInfo)) +} + +func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + resourceID := mux.Vars(r)["resourceId"] + err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go new file mode 100644 index 000000000..2cf39a132 --- /dev/null +++ b/management/server/http/handlers/networks/routers_handler.go @@ -0,0 +1,165 @@ +package networks + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/http/util" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/networks/routers/types" +) + +type routersHandler struct { + routersManager routers.Manager + extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) + claimsExtractor *jwtclaims.ClaimsExtractor +} + +func addRouterEndpoints(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg, router *mux.Router) { + routersHandler := newRoutersHandler(routersManager, extractFromToken, authCfg) + router.HandleFunc("/networks/{networkId}/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS") +} + +func newRoutersHandler(routersManager routers.Manager, extractFromToken func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error), authCfg configs.AuthCfg) *routersHandler { + return &routersHandler{ + routersManager: routersManager, + extractFromToken: extractFromToken, + claimsExtractor: jwtclaims.NewClaimsExtractor( + jwtclaims.WithAudience(authCfg.Audience), + jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), + ), + } +} + +func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var routersResponse []*api.NetworkRouter + for _, router := range routers { + routersResponse = append(routersResponse, router.ToAPIResponse()) + } + + util.WriteJSONObject(r.Context(), w, routersResponse) +} + +func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + networkID := mux.Vars(r)["networkId"] + var req api.NetworkRouterRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + router := &types.NetworkRouter{} + router.FromAPIRequest(&req) + + router.NetworkID = networkID + router.AccountID = accountID + + router, err = h.routersManager.CreateRouter(r.Context(), userID, router) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) +} + +func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerID := mux.Vars(r)["routerId"] + networkID := mux.Vars(r)["networkId"] + router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) +} + +func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.NetworkRouterRequest + err = json.NewDecoder(r.Body).Decode(&req) + if err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + router := &types.NetworkRouter{} + router.FromAPIRequest(&req) + + router.NetworkID = mux.Vars(r)["networkId"] + router.ID = mux.Vars(r)["routerId"] + router.AccountID = accountID + + router, err = h.routersManager.UpdateRouter(r.Context(), userID, router) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) +} + +func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) { + claims := h.claimsExtractor.FromRequestContext(r) + accountID, userID, err := h.extractFromToken(r.Context(), claims) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + routerID := mux.Vars(r)["routerId"] + networkID := mux.Vars(r)["networkId"] + err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, struct{}{}) +} diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index c53cbc038..5bc616e47 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -10,13 +10,14 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // Handler is a handler that returns peers of the account @@ -57,7 +58,7 @@ func (h *Handler) checkPeerStatus(peer *nbpeer.Peer) (*nbpeer.Peer, error) { return peerToReturn, nil } -func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID, userID string, w http.ResponseWriter) { +func (h *Handler) getPeer(ctx context.Context, account *types.Account, peerID, userID string, w http.ResponseWriter) { peer, err := h.accountManager.GetPeer(ctx, account.Id, peerID, userID) if err != nil { util.WriteError(ctx, err, w) @@ -71,7 +72,7 @@ func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID, } dnsDomain := h.accountManager.GetDNSDomain() - groupsInfo := toGroupsInfo(account.Groups, peer.ID) + groupsInfo := groups.ToGroupsInfo(account.Groups, peer.ID) validPeers, err := h.accountManager.GetValidatedPeers(account) if err != nil { @@ -84,7 +85,7 @@ func (h *Handler) getPeer(ctx context.Context, account *server.Account, peerID, util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *Handler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *Handler) updatePeer(ctx context.Context, account *types.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -115,7 +116,7 @@ func (h *Handler) updatePeer(ctx context.Context, account *server.Account, userI } dnsDomain := h.accountManager.GetDNSDomain() - groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID) + groupMinimumInfo := groups.ToGroupsInfo(account.Groups, peer.ID) validPeers, err := h.accountManager.GetValidatedPeers(account) if err != nil { @@ -199,9 +200,9 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { return } - groupsMap := map[string]*nbgroup.Group{} - groups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) - for _, group := range groups { + groupsMap := map[string]*types.Group{} + grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + for _, group := range grps { groupsMap[group.ID] = group } @@ -212,7 +213,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := toGroupsInfo(groupsMap, peer.ID) + groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } @@ -290,12 +291,12 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } customZone := account.GetPeersCustomZone(r.Context(), h.accountManager.GetDNSDomain()) - netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, nil) + netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } -func toAccessiblePeers(netMap *server.NetworkMap, dnsDomain string) []api.AccessiblePeer { +func toAccessiblePeers(netMap *types.NetworkMap, dnsDomain string) []api.AccessiblePeer { accessiblePeers := make([]api.AccessiblePeer, 0, len(netMap.Peers)+len(netMap.OfflinePeers)) for _, p := range netMap.Peers { accessiblePeers = append(accessiblePeers, peerToAccessiblePeer(p, dnsDomain)) @@ -324,30 +325,6 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toGroupsInfo(groups map[string]*nbgroup.Group, peerID string) []api.GroupMinimum { - groupsInfo := []api.GroupMinimum{} - groupsChecked := make(map[string]struct{}) - for _, group := range groups { - _, ok := groupsChecked[group.ID] - if ok { - continue - } - groupsChecked[group.ID] = struct{}{} - for _, pk := range group.Peers { - if pk == peerID { - info := api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - } - groupsInfo = append(groupsInfo, info) - break - } - } - } - return groupsInfo -} - func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 3e3e39deb..83abc1c40 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -15,11 +15,10 @@ import ( "github.com/gorilla/mux" "golang.org/x/exp/maps" - "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/stretchr/testify/assert" @@ -73,18 +72,18 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { peersMap := make(map[string]*nbpeer.Peer) for _, peer := range peers { peersMap[peer.ID] = peer.Copy() } - policy := &server.Policy{ + policy := &types.Policy{ ID: "policy", AccountID: accountID, Name: "policy", Enabled: true, - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "rule", Name: "rule", @@ -99,19 +98,19 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { }, } - srvUser := server.NewRegularUser(serviceUser) + srvUser := types.NewRegularUser(serviceUser) srvUser.IsServiceUser = true - account := &server.Account{ + account := &types.Account{ Id: accountID, Domain: "hotmail.com", Peers: peersMap, - Users: map[string]*server.User{ - adminUser: server.NewAdminUser(adminUser), - regularUser: server.NewRegularUser(regularUser), + Users: map[string]*types.User{ + adminUser: types.NewAdminUser(adminUser), + regularUser: types.NewRegularUser(regularUser), serviceUser: srvUser, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "group1": { ID: "group1", AccountID: accountID, @@ -120,12 +119,12 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { Peers: maps.Keys(peersMap), }, }, - Settings: &server.Settings{ + Settings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, - Policies: []*server.Policy{policy}, - Network: &server.Network{ + Policies: []*types.Policy{policy}, + Network: &types.Network{ Identifier: "ciclqisab2ss43jdn8q0", Net: net.IPNet{ IP: net.ParseIP("100.67.0.0"), diff --git a/management/server/http/handlers/policies/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go index 002b914ef..fc5839baa 100644 --- a/management/server/http/handlers/policies/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -13,11 +13,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -46,8 +46,8 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { - return server.NewAdminUser(id), nil + GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { + return types.NewAdminUser(id), nil }, }, geolocationManager: geo, diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index a47f2e620..d538d07db 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -9,12 +9,12 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns policy of the account @@ -133,7 +133,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s return } - policy := &server.Policy{ + policy := &types.Policy{ ID: policyID, AccountID: accountID, Name: req.Name, @@ -146,15 +146,56 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s ruleID = *rule.Id } - pr := server.PolicyRule{ + hasSources := rule.Sources != nil + hasSourceResource := rule.SourceResource != nil + + hasDestinations := rule.Destinations != nil + hasDestinationResource := rule.DestinationResource != nil + + if hasSources && hasSourceResource { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either sources or source resources, not both"), w) + return + } + + if hasDestinations && hasDestinationResource { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either destinations or destination resources, not both"), w) + return + } + + if !(hasSources || hasSourceResource) || !(hasDestinations || hasDestinationResource) { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "specify either sources or source resources and destinations or destination resources"), w) + return + } + + pr := types.PolicyRule{ ID: ruleID, PolicyID: policyID, Name: rule.Name, - Destinations: rule.Destinations, - Sources: rule.Sources, Bidirectional: rule.Bidirectional, } + if hasSources { + pr.Sources = *rule.Sources + } + + if hasSourceResource { + // TODO: validate the resource id and type + sourceResource := &types.Resource{} + sourceResource.FromAPIRequest(rule.SourceResource) + pr.SourceResource = *sourceResource + } + + if hasDestinations { + pr.Destinations = *rule.Destinations + } + + if hasDestinationResource { + // TODO: validate the resource id and type + destinationResource := &types.Resource{} + destinationResource.FromAPIRequest(rule.DestinationResource) + pr.DestinationResource = *destinationResource + } + pr.Enabled = rule.Enabled if rule.Description != nil { pr.Description = *rule.Description @@ -162,9 +203,9 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s switch rule.Action { case api.PolicyRuleUpdateActionAccept: - pr.Action = server.PolicyTrafficActionAccept + pr.Action = types.PolicyTrafficActionAccept case api.PolicyRuleUpdateActionDrop: - pr.Action = server.PolicyTrafficActionDrop + pr.Action = types.PolicyTrafficActionDrop default: util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown action type"), w) return @@ -172,13 +213,13 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s switch rule.Protocol { case api.PolicyRuleUpdateProtocolAll: - pr.Protocol = server.PolicyRuleProtocolALL + pr.Protocol = types.PolicyRuleProtocolALL case api.PolicyRuleUpdateProtocolTcp: - pr.Protocol = server.PolicyRuleProtocolTCP + pr.Protocol = types.PolicyRuleProtocolTCP case api.PolicyRuleUpdateProtocolUdp: - pr.Protocol = server.PolicyRuleProtocolUDP + pr.Protocol = types.PolicyRuleProtocolUDP case api.PolicyRuleUpdateProtocolIcmp: - pr.Protocol = server.PolicyRuleProtocolICMP + pr.Protocol = types.PolicyRuleProtocolICMP default: util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown protocol type: %v", rule.Protocol), w) return @@ -205,7 +246,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "valid port value is in 1..65535 range"), w) return } - pr.PortRanges = append(pr.PortRanges, server.RulePortRange{ + pr.PortRanges = append(pr.PortRanges, types.RulePortRange{ Start: uint16(portRange.Start), End: uint16(portRange.End), }) @@ -214,7 +255,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s // validate policy object switch pr.Protocol { - case server.PolicyRuleProtocolALL, server.PolicyRuleProtocolICMP: + case types.PolicyRuleProtocolALL, types.PolicyRuleProtocolICMP: if len(pr.Ports) != 0 || len(pr.PortRanges) != 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol ports is not allowed"), w) return @@ -223,7 +264,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return } - case server.PolicyRuleProtocolTCP, server.PolicyRuleProtocolUDP: + case types.PolicyRuleProtocolTCP, types.PolicyRuleProtocolUDP: if !pr.Bidirectional && (len(pr.Ports) == 0 || len(pr.PortRanges) != 0) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "for ALL or ICMP protocol type flow can be only bi-directional"), w) return @@ -319,8 +360,8 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, resp) } -func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Policy { - groupsMap := make(map[string]*nbgroup.Group) +func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy { + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group } @@ -337,13 +378,15 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic rID := r.ID rDescription := r.Description rule := api.PolicyRule{ - Id: &rID, - Name: r.Name, - Enabled: r.Enabled, - Description: &rDescription, - Bidirectional: r.Bidirectional, - Protocol: api.PolicyRuleProtocol(r.Protocol), - Action: api.PolicyRuleAction(r.Action), + Id: &rID, + Name: r.Name, + Enabled: r.Enabled, + Description: &rDescription, + Bidirectional: r.Bidirectional, + Protocol: api.PolicyRuleProtocol(r.Protocol), + Action: api.PolicyRuleAction(r.Action), + SourceResource: r.SourceResource.ToAPIResponse(), + DestinationResource: r.DestinationResource.ToAPIResponse(), } if len(r.Ports) != 0 { @@ -362,26 +405,30 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic rule.PortRanges = &portRanges } + var sources []api.GroupMinimum for _, gid := range r.Sources { _, ok := cache[gid] if ok { continue } + if group, ok := groupsMap[gid]; ok { minimum := api.GroupMinimum{ Id: group.ID, Name: group.Name, PeersCount: len(group.Peers), } - rule.Sources = append(rule.Sources, minimum) + sources = append(sources, minimum) cache[gid] = minimum } } + rule.Sources = &sources + var destinations []api.GroupMinimum for _, gid := range r.Destinations { cachedMinimum, ok := cache[gid] if ok { - rule.Destinations = append(rule.Destinations, cachedMinimum) + destinations = append(destinations, cachedMinimum) continue } if group, ok := groupsMap[gid]; ok { @@ -390,10 +437,12 @@ func toPolicyResponse(groups []*nbgroup.Group, policy *server.Policy) *api.Polic Name: group.Name, PeersCount: len(group.Peers), } - rule.Destinations = append(rule.Destinations, minimum) + destinations = append(destinations, minimum) cache[gid] = minimum } } + rule.Destinations = &destinations + ap.Rules = append(ap.Rules, rule) } return ap diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index 4b465a85a..956d0b7cd 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -10,9 +10,9 @@ import ( "strings" "testing" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/gorilla/mux" @@ -20,50 +20,49 @@ import ( "github.com/magiconair/properties/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/mock_server" ) -func initPoliciesTestData(policies ...*server.Policy) *handler { - testPolicies := make(map[string]*server.Policy, len(policies)) +func initPoliciesTestData(policies ...*types.Policy) *handler { + testPolicies := make(map[string]*types.Policy, len(policies)) for _, policy := range policies { testPolicies[policy.ID] = policy } return &handler{ accountManager: &mock_server.MockAccountManager{ - GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*server.Policy, error) { + GetPolicyFunc: func(_ context.Context, _, policyID, _ string) (*types.Policy, error) { policy, ok := testPolicies[policyID] if !ok { return nil, status.Errorf(status.NotFound, "policy not found") } return policy, nil }, - SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) { + SavePolicyFunc: func(_ context.Context, _, _ string, policy *types.Policy) (*types.Policy, error) { if !strings.HasPrefix(policy.ID, "id-") { policy.ID = "id-was-set" policy.Rules[0].ID = "id-was-set" } return policy, nil }, - GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) { - return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil + GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + return []*types.Group{{ID: "F"}, {ID: "G"}}, nil }, GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { - user := server.NewAdminUser(userID) - return &server.Account{ + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { + user := types.NewAdminUser(userID) + return &types.Account{ Id: accountID, Domain: "hotmail.com", - Policies: []*server.Policy{ + Policies: []*types.Policy{ {ID: "id-existed"}, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "F": {ID: "F"}, "G": {ID: "G"}, }, - Users: map[string]*server.User{ + Users: map[string]*types.User{ "test_user": user, }, }, nil @@ -105,10 +104,10 @@ func TestPoliciesGetPolicy(t *testing.T) { }, } - policy := &server.Policy{ + policy := &types.Policy{ ID: "idofthepolicy", Name: "Rule", - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ {ID: "idoftherule", Name: "Rule"}, }, } @@ -177,7 +176,9 @@ func TestPoliciesWritePolicy(t *testing.T) { "Description": "Description", "Protocol": "tcp", "Action": "accept", - "Bidirectional":true + "Bidirectional":true, + "Sources": ["F"], + "Destinations": ["G"] } ]}`)), expectedStatus: http.StatusOK, @@ -193,6 +194,8 @@ func TestPoliciesWritePolicy(t *testing.T) { Protocol: "tcp", Action: "accept", Bidirectional: true, + Sources: &[]api.GroupMinimum{{Id: "F"}}, + Destinations: &[]api.GroupMinimum{{Id: "G"}}, }, }, }, @@ -221,7 +224,9 @@ func TestPoliciesWritePolicy(t *testing.T) { "Description": "Description", "Protocol": "tcp", "Action": "accept", - "Bidirectional":true + "Bidirectional":true, + "Sources": ["F"], + "Destinations": ["F"] } ]}`)), expectedStatus: http.StatusOK, @@ -237,6 +242,8 @@ func TestPoliciesWritePolicy(t *testing.T) { Protocol: "tcp", Action: "accept", Bidirectional: true, + Sources: &[]api.GroupMinimum{{Id: "F"}}, + Destinations: &[]api.GroupMinimum{{Id: "F"}}, }, }, }, @@ -251,10 +258,10 @@ func TestPoliciesWritePolicy(t *testing.T) { }, } - p := initPoliciesTestData(&server.Policy{ + p := initPoliciesTestData(&types.Policy{ ID: "id-existed", Name: "Default POSTed Rule", - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "id-existed", Name: "Default POSTed Rule", diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index 9d420066c..a29ba4562 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -360,7 +360,7 @@ func validateDomains(domains []string) (domain.List, error) { return nil, fmt.Errorf("domains list exceeds maximum allowed domains: %d", maxDomains) } - domainRegex := regexp.MustCompile(`^(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) + domainRegex := regexp.MustCompile(`^(?:\*\.)?(?:(?:xn--)?[a-zA-Z0-9_](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?\.)*(?:xn--)?[a-zA-Z0-9](?:[a-zA-Z0-9-_]{0,61}[a-zA-Z0-9])?$`) var domainList domain.List diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index a25c899c9..4cee3ee30 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -16,13 +16,13 @@ import ( "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" "github.com/gorilla/mux" "github.com/magiconair/properties/assert" "github.com/netbirdio/netbird/management/domain" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -61,7 +61,7 @@ var baseExistingRoute = &route.Route{ Groups: []string{existingGroupID}, } -var testingAccount = &server.Account{ +var testingAccount = &types.Account{ Id: testAccountID, Domain: "hotmail.com", Peers: map[string]*nbpeer.Peer{ @@ -82,8 +82,8 @@ var testingAccount = &server.Account{ }, }, }, - Users: map[string]*server.User{ - "test_user": server.NewAdminUser("test_user"), + Users: map[string]*types.User{ + "test_user": types.NewAdminUser("test_user"), }, } @@ -330,6 +330,14 @@ func TestRoutesHandlers(t *testing.T) { expectedStatus: http.StatusUnprocessableEntity, expectedBody: false, }, + { + name: "POST Wildcard Domain", + requestType: http.MethodPost, + requestPath: "/api/routes", + requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["*.example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID)), + expectedStatus: http.StatusOK, + expectedBody: false, + }, { name: "POST UnprocessableEntity when both network and domains are provided", requestType: http.MethodPost, @@ -609,6 +617,30 @@ func TestValidateDomains(t *testing.T) { expected: domain.List{"google.com"}, wantErr: true, }, + { + name: "Valid wildcard domain", + domains: []string{"*.example.com"}, + expected: domain.List{"*.example.com"}, + wantErr: false, + }, + { + name: "Wildcard with dot domain", + domains: []string{".*.example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Wildcard with dot domain", + domains: []string{".*.example.com"}, + expected: nil, + wantErr: true, + }, + { + name: "Invalid wildcard domain", + domains: []string{"a.*.example.com"}, + expected: nil, + wantErr: true, + }, } for _, tt := range tests { diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index 9432d5549..89696a165 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -14,6 +14,7 @@ import ( "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // handler is a handler that returns a list of setup keys of the account @@ -63,8 +64,8 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { return } - if !(server.SetupKeyType(req.Type) == server.SetupKeyReusable || - server.SetupKeyType(req.Type) == server.SetupKeyOneOff) { + if !(types.SetupKeyType(req.Type) == types.SetupKeyReusable || + types.SetupKeyType(req.Type) == types.SetupKeyOneOff) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown setup key type %s", req.Type), w) return } @@ -85,7 +86,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { ephemeral = *req.Ephemeral } - setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn, + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn, req.AutoGroups, req.UsageLimit, userID, ephemeral) if err != nil { util.WriteError(r.Context(), err, w) @@ -152,7 +153,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { return } - newKey := &server.SetupKey{} + newKey := &types.SetupKey{} newKey.AutoGroups = req.AutoGroups newKey.Revoked = req.Revoked newKey.Id = keyID @@ -212,7 +213,7 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupKey) { +func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) err := json.NewEncoder(w).Encode(toResponseBody(key)) @@ -222,7 +223,7 @@ func writeSuccess(ctx context.Context, w http.ResponseWriter, key *server.SetupK } } -func toResponseBody(key *server.SetupKey) *api.SetupKey { +func toResponseBody(key *types.SetupKey) *api.SetupKey { var state string switch { case key.IsExpired(): diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 516a2ab8b..4ecb1e9ed 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -14,11 +14,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -29,17 +29,17 @@ const ( testAccountID = "test_id" ) -func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey, - user *server.User, +func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey, + user *types.User, ) *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, + CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ types.SetupKeyType, _ time.Duration, _ []string, _ int, _ string, ephemeral bool, - ) (*server.SetupKey, error) { + ) (*types.SetupKey, error) { if keyName == newKey.Name || typ != newKey.Type { nk := newKey.Copy() nk.Ephemeral = ephemeral @@ -47,7 +47,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } return nil, fmt.Errorf("failed creating setup key") }, - GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*server.SetupKey, error) { + GetSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { switch keyID { case defaultKey.Id: return defaultKey, nil @@ -58,15 +58,15 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } }, - SaveSetupKeyFunc: func(_ context.Context, accountID string, key *server.SetupKey, _ string) (*server.SetupKey, error) { + SaveSetupKeyFunc: func(_ context.Context, accountID string, key *types.SetupKey, _ string) (*types.SetupKey, error) { if key.Id == updatedSetupKey.Id { return updatedSetupKey, nil } return nil, status.Errorf(status.NotFound, "key %s not found", key.Id) }, - ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*server.SetupKey, error) { - return []*server.SetupKey{defaultKey}, nil + ListSetupKeysFunc: func(_ context.Context, accountID, userID string) ([]*types.SetupKey, error) { + return []*types.SetupKey{defaultKey}, nil }, DeleteSetupKeyFunc: func(_ context.Context, accountID, userID, keyID string) error { @@ -89,13 +89,13 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } func TestSetupKeysHandlers(t *testing.T) { - defaultSetupKey, _ := server.GenerateDefaultSetupKey() + defaultSetupKey, _ := types.GenerateDefaultSetupKey() defaultSetupKey.Id = existingSetupKeyID - adminUser := server.NewAdminUser("test_user") + adminUser := types.NewAdminUser("test_user") - newSetupKey, plainKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}, - server.SetupKeyUnlimitedUsage, true) + newSetupKey, plainKey := types.GenerateSetupKey(newSetupKeyName, types.SetupKeyReusable, 0, []string{"group-1"}, + types.SetupKeyUnlimitedUsage, true) newSetupKey.Key = plainKey updatedDefaultSetupKey := defaultSetupKey.Copy() updatedDefaultSetupKey.AutoGroups = []string{"group-1"} diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go index 2caf98ad8..197785b34 100644 --- a/management/server/http/handlers/users/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -13,6 +13,7 @@ import ( "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // patHandler is the nameserver group handler of the account @@ -164,7 +165,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { +func toPATResponse(pat *types.PersonalAccessToken) *api.PersonalAccessToken { var lastUsed *time.Time if !pat.LastUsed.IsZero() { lastUsed = &pat.LastUsed @@ -179,7 +180,7 @@ func toPATResponse(pat *server.PersonalAccessToken) *api.PersonalAccessToken { } } -func toPATGeneratedResponse(pat *server.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated { +func toPATGeneratedResponse(pat *types.PersonalAccessTokenGenerated) *api.PersonalAccessTokenGenerated { return &api.PersonalAccessTokenGenerated{ PlainToken: pat.PlainToken, PersonalAccessToken: *toPATResponse(&pat.PersonalAccessToken), diff --git a/management/server/http/handlers/users/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go index ef6fb973e..21bdc461e 100644 --- a/management/server/http/handlers/users/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -14,11 +14,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -31,13 +31,13 @@ const ( testDomain = "hotmail.com" ) -var testAccount = &server.Account{ +var testAccount = &types.Account{ Id: existingAccountID, Domain: testDomain, - Users: map[string]*server.User{ + Users: map[string]*types.User{ existingUserID: { Id: existingUserID, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ existingTokenID: { ID: existingTokenID, Name: "My first token", @@ -64,16 +64,16 @@ var testAccount = &server.Account{ func initPATTestData() *patHandler { return &patHandler{ accountManager: &mock_server.MockAccountManager{ - CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { + CreatePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } if targetUserID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) } - return &server.PersonalAccessTokenGenerated{ + return &types.PersonalAccessTokenGenerated{ PlainToken: "nbp_z1pvsg2wP3EzmEou4S679KyTNhov632eyrXe", - PersonalAccessToken: server.PersonalAccessToken{}, + PersonalAccessToken: types.PersonalAccessToken{}, }, nil }, @@ -92,7 +92,7 @@ func initPATTestData() *patHandler { } return nil }, - GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { + GetPATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } @@ -104,14 +104,14 @@ func initPATTestData() *patHandler { } return testAccount.Users[existingUserID].PATs[existingTokenID], nil }, - GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { + GetAllPATsFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) { if accountID != existingAccountID { return nil, status.Errorf(status.NotFound, "account with ID %s not found", accountID) } if targetUserID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s not found", targetUserID) } - return []*server.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil + return []*types.PersonalAccessToken{testAccount.Users[existingUserID].PATs[existingTokenID], testAccount.Users[existingUserID].PATs["token2"]}, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( @@ -217,7 +217,7 @@ func TestTokenHandlers(t *testing.T) { t.Fatalf("Sent content is not in correct json format; %v", err) } assert.NotEmpty(t, got.PlainToken) - assert.Equal(t, server.PATLength, len(got.PlainToken)) + assert.Equal(t, types.PATLength, len(got.PlainToken)) case "Get All Tokens": expectedTokens := []api.PersonalAccessToken{ toTokenResponse(*testAccount.Users[existingUserID].PATs[existingTokenID]), @@ -243,7 +243,7 @@ func TestTokenHandlers(t *testing.T) { } } -func toTokenResponse(serverToken server.PersonalAccessToken) api.PersonalAccessToken { +func toTokenResponse(serverToken types.PersonalAccessToken) api.PersonalAccessToken { return api.PersonalAccessToken{ Id: serverToken.ID, Name: serverToken.Name, diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index c843bc52b..7380dd97e 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -12,6 +12,7 @@ import ( "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -83,13 +84,13 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { return } - userRole := server.StrRoleToUserRole(req.Role) - if userRole == server.UserRoleUnknown { + userRole := types.StrRoleToUserRole(req.Role) + if userRole == types.UserRoleUnknown { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user role"), w) return } - newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &server.User{ + newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{ Id: targetUserID, Role: userRole, AutoGroups: req.AutoGroups, @@ -156,7 +157,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { return } - if server.StrRoleToUserRole(req.Role) == server.UserRoleUnknown { + if types.StrRoleToUserRole(req.Role) == types.UserRoleUnknown { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "unknown user role %s", req.Role), w) return } @@ -171,13 +172,13 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { name = *req.Name } - newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{ + newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{ Email: email, Name: name, Role: req.Role, AutoGroups: req.AutoGroups, IsServiceUser: req.IsServiceUser, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }) if err != nil { util.WriteError(r.Context(), err, w) @@ -264,7 +265,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toUserResponse(user *server.UserInfo, currenUserID string) *api.User { +func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { autoGroups := user.AutoGroups if autoGroups == nil { autoGroups = []string{} diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index 6f6a91236..90081830a 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -13,11 +13,11 @@ import ( "github.com/gorilla/mux" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -26,37 +26,37 @@ const ( regularUserID = "regularUserID" ) -var usersTestAccount = &server.Account{ +var usersTestAccount = &types.Account{ Id: existingAccountID, Domain: testDomain, - Users: map[string]*server.User{ + Users: map[string]*types.User{ existingUserID: { Id: existingUserID, Role: "admin", IsServiceUser: false, AutoGroups: []string{"group_1"}, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }, regularUserID: { Id: regularUserID, Role: "user", IsServiceUser: false, AutoGroups: []string{"group_1"}, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }, serviceUserID: { Id: serviceUserID, Role: "user", IsServiceUser: true, AutoGroups: []string{"group_1"}, - Issued: server.UserIssuedAPI, + Issued: types.UserIssuedAPI, }, nonDeletableServiceUserID: { Id: serviceUserID, Role: "admin", IsServiceUser: true, NonDeletable: true, - Issued: server.UserIssuedIntegration, + Issued: types.UserIssuedIntegration, }, }, } @@ -67,13 +67,13 @@ func initUsersTestData() *handler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return usersTestAccount.Id, claims.UserId, nil }, - GetUserByIDFunc: func(ctx context.Context, id string) (*server.User, error) { + GetUserByIDFunc: func(ctx context.Context, id string) (*types.User, error) { return usersTestAccount.Users[id], nil }, - GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { - users := make([]*server.UserInfo, 0) + GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*types.UserInfo, error) { + users := make([]*types.UserInfo, 0) for _, v := range usersTestAccount.Users { - users = append(users, &server.UserInfo{ + users = append(users, &types.UserInfo{ ID: v.Id, Role: string(v.Role), Name: "", @@ -85,7 +85,7 @@ func initUsersTestData() *handler { } return users, nil }, - CreateUserFunc: func(_ context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) { + CreateUserFunc: func(_ context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) { if userID != existingUserID { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } @@ -100,7 +100,7 @@ func initUsersTestData() *handler { } return nil }, - SaveUserFunc: func(_ context.Context, accountID, userID string, update *server.User) (*server.UserInfo, error) { + SaveUserFunc: func(_ context.Context, accountID, userID string, update *types.User) (*types.UserInfo, error) { if update.Id == notFoundUserID { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", update.Id) } @@ -109,7 +109,7 @@ func initUsersTestData() *handler { return nil, status.Errorf(status.NotFound, "user with ID %s does not exists", userID) } - info, err := update.Copy().ToUserInfo(nil, &server.Settings{RegularUsersViewBlocked: false}) + info, err := update.Copy().ToUserInfo(nil, &types.Settings{RegularUsersViewBlocked: false}) if err != nil { return nil, err } @@ -175,7 +175,7 @@ func TestGetUsers(t *testing.T) { return } - respBody := []*server.UserInfo{} + respBody := []*types.UserInfo{} err = json.Unmarshal(content, &respBody) if err != nil { t.Fatalf("Sent content is not in correct json format; %v", err) @@ -342,7 +342,7 @@ func TestCreateUser(t *testing.T) { requestType string requestPath string requestBody io.Reader - expectedResult []*server.User + expectedResult []*types.User }{ {name: "CreateServiceUser", requestType: http.MethodPost, requestPath: "/api/users", expectedStatus: http.StatusOK, requestBody: bytes.NewBuffer(serviceUserString)}, // right now creation is blocked in AC middleware, will be refactored in the future diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index 0ad250f43..c5bdf5fe7 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -7,16 +7,16 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/jwtclaims" ) // GetUser function defines a function to fetch user from Account by jwtclaims.AuthorizationClaims -type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) +type GetUser func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only type AccessControl struct { diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index b25aad99c..0d3459712 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -11,16 +11,16 @@ import ( "github.com/golang-jwt/jwt" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) // GetAccountFromPATFunc function -type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) +type GetAccountFromPATFunc func(ctx context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) // ValidateAndParseTokenFunc function type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error) diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index fdfb0ea24..b0d970c5d 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -10,9 +10,9 @@ import ( "github.com/golang-jwt/jwt" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -28,13 +28,13 @@ const ( wrongToken = "wrongToken" ) -var testAccount = &server.Account{ +var testAccount = &types.Account{ Id: accountID, Domain: domain, - Users: map[string]*server.User{ + Users: map[string]*types.User{ userID: { Id: userID, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ tokenID: { ID: tokenID, Name: "My first token", @@ -49,7 +49,7 @@ var testAccount = &server.Account{ }, } -func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { +func mockGetAccountFromPAT(_ context.Context, token string) (*types.Account, *types.User, *types.PersonalAccessToken, error) { if token == PAT { return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil } diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 0c70b702a..47c4ca6ae 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -7,6 +7,8 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) // UpdateIntegratedValidatorGroups updates the integrated validator groups for a specified account. @@ -57,9 +59,9 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } - err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { for _, groupID := range groupIDs { - _, err := transaction.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + _, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID) if err != nil { return err } @@ -73,6 +75,6 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { +func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) { return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) } diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go index 03be9d039..22b8026aa 100644 --- a/management/server/integrated_validator/interface.go +++ b/management/server/integrated_validator/interface.go @@ -4,8 +4,8 @@ import ( "context" "github.com/netbirdio/netbird/management/server/account" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" ) // IntegratedValidator interface exists to avoid the circle dependencies @@ -14,7 +14,7 @@ type IntegratedValidator interface { ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) - GetValidatedPeers(accountID string, groups map[string]*nbgroup.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) + GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) PeerDeleted(ctx context.Context, accountID, peerID string) error SetPeerInvalidationListener(fn func(accountID string)) Stop(ctx context.Context) diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index dc8765e19..c66423736 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -23,7 +23,10 @@ import ( "github.com/netbirdio/netbird/formatter" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -413,7 +416,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp } s := grpc.NewServer(grpc.KeepaliveEnforcementPolicy(kaep), grpc.KeepaliveParams(kasp)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir()) if err != nil { t.Fatal(err) } @@ -437,7 +440,7 @@ func startManagementForTest(t *testing.T, testFile string, config *Config) (*grp secretsManager := NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) ephemeralMgr := NewEphemeralManager(store, accountManager) - mgmtServer, err := NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, ephemeralMgr) + mgmtServer, err := NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, ephemeralMgr) if err != nil { return nil, nil, "", cleanup, err } @@ -618,7 +621,7 @@ func testSyncStatusRace(t *testing.T) { } time.Sleep(10 * time.Millisecond) - peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) + peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) if err != nil { t.Fatal(err) return @@ -705,7 +708,7 @@ func Test_LoginPerformance(t *testing.T) { return } - setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false) + setupKey, err := am.CreateSetupKey(context.Background(), account.Id, fmt.Sprintf("key-%d", j), types.SetupKeyReusable, time.Hour, nil, 0, fmt.Sprintf("user-%d", j), false) if err != nil { t.Logf("error creating setup key: %v", err) return diff --git a/management/server/management_test.go b/management/server/management_test.go index 5361da53f..40514ae14 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -23,9 +23,11 @@ import ( "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -457,7 +459,7 @@ func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.P return update, false, nil } -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*group.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for p := range peers { validatedPeers[p] = struct{}{} @@ -532,7 +534,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. Expect(err).NotTo(HaveOccurred()) s := grpc.NewServer() - store, _, err := server.NewTestStoreFromSQL(context.Background(), testFile, dataDir) + store, _, err := store.NewTestStoreFromSQL(context.Background(), testFile, dataDir) if err != nil { log.Fatalf("failed creating a store: %s: %v", config.Datadir, err) } @@ -551,7 +553,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. } secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) - mgmtServer, err := server.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil) + mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil) Expect(err).NotTo(HaveOccurred()) mgmtProto.RegisterManagementServiceServer(s, mgmtServer) go func() { diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 843fa575e..82b34393f 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -15,7 +15,8 @@ import ( "github.com/hashicorp/go-version" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" nbversion "github.com/netbirdio/netbird/version" ) @@ -47,8 +48,8 @@ type properties map[string]interface{} // DataSource metric data source type DataSource interface { - GetAllAccounts(ctx context.Context) []*server.Account - GetStoreEngine() server.StoreEngine + GetAllAccounts(ctx context.Context) []*types.Account + GetStoreEngine() store.Engine } // ConnManager peer connection manager that holds state for current active connections diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index 2ac2d68a0..1d356387f 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -5,10 +5,10 @@ import ( "testing" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -22,19 +22,19 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} { } // GetAllAccounts returns a list of *server.Account for use in tests with predefined information -func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { - return []*server.Account{ +func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { + return []*types.Account{ { Id: "1", - Settings: &server.Settings{PeerLoginExpirationEnabled: true}, - SetupKeys: map[string]*server.SetupKey{ + Settings: &types.Settings{PeerLoginExpirationEnabled: true}, + SetupKeys: map[string]*types.SetupKey{ "1": { Id: "1", Ephemeral: true, UsedTimes: 1, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "1": {}, "2": {}, }, @@ -49,20 +49,20 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, }, }, - Policies: []*server.Policy{ + Policies: []*types.Policy{ { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: true, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, }, { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: false, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, SourcePostureChecks: []string{"1"}, @@ -94,16 +94,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { }, }, }, - Users: map[string]*server.User{ + Users: map[string]*types.User{ "1": { IsServiceUser: true, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, "2": { IsServiceUser: false, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, @@ -111,15 +111,15 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { }, { Id: "2", - Settings: &server.Settings{PeerLoginExpirationEnabled: true}, - SetupKeys: map[string]*server.SetupKey{ + Settings: &types.Settings{PeerLoginExpirationEnabled: true}, + SetupKeys: map[string]*types.SetupKey{ "1": { Id: "1", Ephemeral: true, UsedTimes: 1, }, }, - Groups: map[string]*group.Group{ + Groups: map[string]*types.Group{ "1": {}, "2": {}, }, @@ -134,20 +134,20 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, }, }, - Policies: []*server.Policy{ + Policies: []*types.Policy{ { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: true, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, }, { - Rules: []*server.PolicyRule{ + Rules: []*types.PolicyRule{ { Bidirectional: false, - Protocol: server.PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, }, }, }, @@ -158,16 +158,16 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { PeerGroups: make([]string, 1), }, }, - Users: map[string]*server.User{ + Users: map[string]*types.User{ "1": { IsServiceUser: true, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, "2": { IsServiceUser: false, - PATs: map[string]*server.PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "1": {}, }, }, @@ -177,8 +177,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*server.Account { } // GetStoreEngine returns FileStoreEngine -func (mockDatasource) GetStoreEngine() server.StoreEngine { - return server.FileStoreEngine +func (mockDatasource) GetStoreEngine() store.Engine { + return store.FileStoreEngine } // TestGenerateProperties tests and validate the properties generation by using the mockDatasource for the Worker.generateProperties @@ -267,7 +267,7 @@ func TestGenerateProperties(t *testing.T) { t.Errorf("expected 2 user_peers, got %d", properties["user_peers"]) } - if properties["store_engine"] != server.FileStoreEngine { + if properties["store_engine"] != store.FileStoreEngine { t.Errorf("expected JsonFile, got %s", properties["store_engine"]) } diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index 51358c7ad..a645ae325 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -12,9 +12,9 @@ import ( "gorm.io/driver/sqlite" "gorm.io/gorm" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/migration" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -31,64 +31,64 @@ func setupDatabase(t *testing.T) *gorm.DB { func TestMigrateFieldFromGobToJSON_EmptyDB(t *testing.T) { db := setupDatabase(t) - err := migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") + err := migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail for an empty database") } func TestMigrateFieldFromGobToJSON_WithGobData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &route.Route{}) + err := db.AutoMigrate(&types.Account{}, &route.Route{}) require.NoError(t, err, "Failed to auto-migrate tables") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") require.NoError(t, err, "Failed to parse CIDR") type network struct { - server.Network + types.Network Net net.IPNet `gorm:"serializer:gob"` } type account struct { - server.Account + types.Account Network *network `gorm:"embedded;embeddedPrefix:network_"` } - err = db.Save(&account{Account: server.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error + err = db.Save(&account{Account: types.Account{Id: "123"}, Network: &network{Net: *ipnet}}).Error require.NoError(t, err, "Failed to insert Gob data") var gobStr string - err = db.Model(&server.Account{}).Select("network_net").First(&gobStr).Error + err = db.Model(&types.Account{}).Select("network_net").First(&gobStr).Error assert.NoError(t, err, "Failed to fetch Gob data") err = gob.NewDecoder(strings.NewReader(gobStr)).Decode(&ipnet) require.NoError(t, err, "Failed to decode Gob data") - err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") + err = migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail with Gob data") var jsonStr string - db.Model(&server.Account{}).Select("network_net").First(&jsonStr) + db.Model(&types.Account{}).Select("network_net").First(&jsonStr) assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be migrated") } func TestMigrateFieldFromGobToJSON_WithJSONData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &route.Route{}) + err := db.AutoMigrate(&types.Account{}, &route.Route{}) require.NoError(t, err, "Failed to auto-migrate tables") _, ipnet, err := net.ParseCIDR("10.0.0.0/24") require.NoError(t, err, "Failed to parse CIDR") - err = db.Save(&server.Account{Network: &server.Network{Net: *ipnet}}).Error + err = db.Save(&types.Account{Network: &types.Network{Net: *ipnet}}).Error require.NoError(t, err, "Failed to insert JSON data") - err = migration.MigrateFieldFromGobToJSON[server.Account, net.IPNet](context.Background(), db, "network_net") + err = migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](context.Background(), db, "network_net") require.NoError(t, err, "Migration should not fail with JSON data") var jsonStr string - db.Model(&server.Account{}).Select("network_net").First(&jsonStr) + db.Model(&types.Account{}).Select("network_net").First(&jsonStr) assert.JSONEq(t, `{"IP":"10.0.0.0","Mask":"////AA=="}`, jsonStr, "Data should be unchanged") } @@ -101,7 +101,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_EmptyDB(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{}) + err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) require.NoError(t, err, "Failed to auto-migrate tables") type location struct { @@ -115,12 +115,12 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { } type account struct { - server.Account + types.Account Peers []peer `gorm:"foreignKey:AccountID;references:id"` } err = db.Save(&account{ - Account: server.Account{Id: "123"}, + Account: types.Account{Id: "123"}, Peers: []peer{ {Location: location{ConnectionIP: net.IP{10, 0, 0, 1}}}, }}, @@ -142,10 +142,10 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithBlobData(t *testing.T) { func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.Account{}, &nbpeer.Peer{}) + err := db.AutoMigrate(&types.Account{}, &nbpeer.Peer{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.Account{ + err = db.Save(&types.Account{ Id: "1234", PeersG: []nbpeer.Peer{ {Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}}, @@ -164,20 +164,20 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.SetupKey{}) + err := db.AutoMigrate(&types.SetupKey{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.SetupKey{ + err = db.Save(&types.SetupKey{ Id: "1", Key: "EEFDAB47-C1A5-4472-8C05-71DE9A1E8382", }).Error require.NoError(t, err, "Failed to insert setup key") - err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db) require.NoError(t, err, "Migration should not fail to migrate setup key") - var key server.SetupKey - err = db.Model(&server.SetupKey{}).First(&key).Error + var key types.SetupKey + err = db.Model(&types.SetupKey{}).First(&key).Error assert.NoError(t, err, "Failed to fetch setup key") assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") @@ -187,21 +187,21 @@ func TestMigrateSetupKeyToHashedSetupKey_ForPlainKey(t *testing.T) { func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.SetupKey{}) + err := db.AutoMigrate(&types.SetupKey{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.SetupKey{ + err = db.Save(&types.SetupKey{ Id: "1", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", KeySecret: "EEFDA****", }).Error require.NoError(t, err, "Failed to insert setup key") - err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db) require.NoError(t, err, "Migration should not fail to migrate setup key") - var key server.SetupKey - err = db.Model(&server.SetupKey{}).First(&key).Error + var key types.SetupKey + err = db.Model(&types.SetupKey{}).First(&key).Error assert.NoError(t, err, "Failed to fetch setup key") assert.Equal(t, "EEFDA****", key.KeySecret, "Key should be secret") @@ -211,20 +211,20 @@ func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case1(t *testing. func TestMigrateSetupKeyToHashedSetupKey_ForAlreadyMigratedKey_Case2(t *testing.T) { db := setupDatabase(t) - err := db.AutoMigrate(&server.SetupKey{}) + err := db.AutoMigrate(&types.SetupKey{}) require.NoError(t, err, "Failed to auto-migrate tables") - err = db.Save(&server.SetupKey{ + err = db.Save(&types.SetupKey{ Id: "1", Key: "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", }).Error require.NoError(t, err, "Failed to insert setup key") - err = migration.MigrateSetupKeyToHashedSetupKey[server.SetupKey](context.Background(), db) + err = migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](context.Background(), db) require.NoError(t, err, "Migration should not fail to migrate setup key") - var key server.SetupKey - err = db.Model(&server.SetupKey{}).First(&key).Error + var key types.SetupKey + err = db.Model(&types.SetupKey{}).First(&key).Error assert.NoError(t, err, "Failed to fetch setup key") assert.Equal(t, "9+FQcmNd2GCxIK+SvHmtp6PPGV4MKEicDS+xuSQmvlE=", key.Key, "Key should be hashed") diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 46a4fbc1f..042137b1b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -13,47 +13,47 @@ import ( "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" - "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) type MockAccountManager struct { - GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) - GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) - CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, - expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) - GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) + GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*types.Account, error) + GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error) + CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType, + expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) + GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) - GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) - ListUsersFunc func(ctx context.Context, accountID string) ([]*server.User, error) + GetUserFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) + ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error - GetNetworkMapFunc func(ctx context.Context, peerKey string) (*server.NetworkMap, error) - GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*server.Network, error) - AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*group.Group, error) - GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*group.Group, error) - GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*group.Group, error) - SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error - SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error + GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) + GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) + AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) + GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) + SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group) error + SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group) error DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error - GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) + GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error - ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*server.Policy, error) - GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*server.UserInfo, error) - GetAccountFromPATFunc func(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) + ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) + GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) + GetAccountFromPATFunc func(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) MarkPATUsedFunc func(ctx context.Context, pat string) error UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) @@ -62,35 +62,35 @@ type MockAccountManager struct { SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error) - SaveSetupKeyFunc func(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) - ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) - SaveUserFunc func(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) - SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) - SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) + SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) + ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) + SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) + SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) + SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error - CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) + CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error - GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error) - GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*server.PersonalAccessToken, error) + GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error) + GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error) GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) + CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) GetAccountIDFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func() string StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) - GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error) - SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error + GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) + SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) - LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) + UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) + LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error GetAllConnectedPeersFunc func() (map[string]struct{}, error) HasConnectedChannelFunc func(peerID string) bool @@ -105,12 +105,16 @@ type MockAccountManager struct { SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) - GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*server.Account, error) - GetUserByIDFunc func(ctx context.Context, id string) (*server.User, error) - GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*server.Settings, error) + GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error) + GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error) + GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error } +func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { + // do nothing +} + func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { if am.DeleteSetupKeyFunc != nil { return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID) @@ -118,7 +122,7 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.SyncAndMarkPeerFunc != nil { return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) } @@ -130,7 +134,7 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) { approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} @@ -139,7 +143,7 @@ func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[st } // GetGroup mock implementation of GetGroup from server.AccountManager interface -func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*group.Group, error) { +func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupFunc(ctx, accountId, groupID, userID) } @@ -147,7 +151,7 @@ func (am *MockAccountManager) GetGroup(ctx context.Context, accountId, groupID, } // GetAllGroups mock implementation of GetAllGroups from server.AccountManager interface -func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*group.Group, error) { +func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { if am.GetAllGroupsFunc != nil { return am.GetAllGroupsFunc(ctx, accountID, userID) } @@ -155,7 +159,7 @@ func (am *MockAccountManager) GetAllGroups(ctx context.Context, accountID, userI } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface -func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*server.UserInfo, error) { +func (am *MockAccountManager) GetUsersFromAccount(ctx context.Context, accountID string, userID string) ([]*types.UserInfo, error) { if am.GetUsersFromAccountFunc != nil { return am.GetUsersFromAccountFunc(ctx, accountID, userID) } @@ -173,7 +177,7 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID, // GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface func (am *MockAccountManager) GetOrCreateAccountByUser( ctx context.Context, userId, domain string, -) (*server.Account, error) { +) (*types.Account, error) { if am.GetOrCreateAccountByUserFunc != nil { return am.GetOrCreateAccountByUserFunc(ctx, userId, domain) } @@ -188,13 +192,13 @@ func (am *MockAccountManager) CreateSetupKey( ctx context.Context, accountID string, keyName string, - keyType server.SetupKeyType, + keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, -) (*server.SetupKey, error) { +) (*types.SetupKey, error) { if am.CreateSetupKeyFunc != nil { return am.CreateSetupKeyFunc(ctx, accountID, keyName, keyType, expiresIn, autoGroups, usageLimit, userID, ephemeral) } @@ -221,7 +225,7 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId, } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *types.Account) error { if am.MarkPeerConnectedFunc != nil { return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) } @@ -229,7 +233,7 @@ func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey str } // GetAccountFromPAT mock implementation of GetAccountFromPAT from server.AccountManager interface -func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*server.Account, *server.User, *server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetAccountFromPAT(ctx context.Context, pat string) (*types.Account, *types.User, *types.PersonalAccessToken, error) { if am.GetAccountFromPATFunc != nil { return am.GetAccountFromPATFunc(ctx, pat) } @@ -253,7 +257,7 @@ func (am *MockAccountManager) MarkPATUsed(ctx context.Context, pat string) error } // CreatePAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*server.PersonalAccessTokenGenerated, error) { +func (am *MockAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { if am.CreatePATFunc != nil { return am.CreatePATFunc(ctx, accountID, initiatorUserID, targetUserID, name, expiresIn) } @@ -269,7 +273,7 @@ func (am *MockAccountManager) DeletePAT(ctx context.Context, accountID string, i } // GetPAT mock implementation of GetPAT from server.AccountManager interface -func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) { if am.GetPATFunc != nil { return am.GetPATFunc(ctx, accountID, initiatorUserID, targetUserID, tokenID) } @@ -277,7 +281,7 @@ func (am *MockAccountManager) GetPAT(ctx context.Context, accountID string, init } // GetAllPATs mock implementation of GetAllPATs from server.AccountManager interface -func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*server.PersonalAccessToken, error) { +func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) { if am.GetAllPATsFunc != nil { return am.GetAllPATsFunc(ctx, accountID, initiatorUserID, targetUserID) } @@ -285,7 +289,7 @@ func (am *MockAccountManager) GetAllPATs(ctx context.Context, accountID string, } // GetNetworkMap mock implementation of GetNetworkMap from server.AccountManager interface -func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*server.NetworkMap, error) { +func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) (*types.NetworkMap, error) { if am.GetNetworkMapFunc != nil { return am.GetNetworkMapFunc(ctx, peerKey) } @@ -293,7 +297,7 @@ func (am *MockAccountManager) GetNetworkMap(ctx context.Context, peerKey string) } // GetPeerNetwork mock implementation of GetPeerNetwork from server.AccountManager interface -func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*server.Network, error) { +func (am *MockAccountManager) GetPeerNetwork(ctx context.Context, peerKey string) (*types.Network, error) { if am.GetPeerNetworkFunc != nil { return am.GetPeerNetworkFunc(ctx, peerKey) } @@ -306,7 +310,7 @@ func (am *MockAccountManager) AddPeer( setupKey string, userId string, peer *nbpeer.Peer, -) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.AddPeerFunc != nil { return am.AddPeerFunc(ctx, setupKey, userId, peer) } @@ -314,7 +318,7 @@ func (am *MockAccountManager) AddPeer( } // GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface -func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*group.Group, error) { +func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) { if am.GetGroupFunc != nil { return am.GetGroupByNameFunc(ctx, accountID, groupName) } @@ -322,7 +326,7 @@ func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, gro } // SaveGroup mock implementation of SaveGroup from server.AccountManager interface -func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *group.Group) error { +func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID string, group *types.Group) error { if am.SaveGroupFunc != nil { return am.SaveGroupFunc(ctx, accountID, userID, group) } @@ -330,7 +334,7 @@ func (am *MockAccountManager) SaveGroup(ctx context.Context, accountID, userID s } // SaveGroups mock implementation of SaveGroups from server.AccountManager interface -func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*group.Group) error { +func (am *MockAccountManager) SaveGroups(ctx context.Context, accountID, userID string, groups []*types.Group) error { if am.SaveGroupsFunc != nil { return am.SaveGroupsFunc(ctx, accountID, userID, groups) } @@ -378,7 +382,7 @@ func (am *MockAccountManager) DeleteRule(ctx context.Context, accountID, ruleID, } // GetPolicy mock implementation of GetPolicy from server.AccountManager interface -func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) { +func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) { if am.GetPolicyFunc != nil { return am.GetPolicyFunc(ctx, accountID, policyID, userID) } @@ -386,7 +390,7 @@ func (am *MockAccountManager) GetPolicy(ctx context.Context, accountID, policyID } // SavePolicy mock implementation of SavePolicy from server.AccountManager interface -func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) { +func (am *MockAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) { if am.SavePolicyFunc != nil { return am.SavePolicyFunc(ctx, accountID, userID, policy) } @@ -402,7 +406,7 @@ func (am *MockAccountManager) DeletePolicy(ctx context.Context, accountID, polic } // ListPolicies mock implementation of ListPolicies from server.AccountManager interface -func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*server.Policy, error) { +func (am *MockAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) { if am.ListPoliciesFunc != nil { return am.ListPoliciesFunc(ctx, accountID, userID) } @@ -418,14 +422,14 @@ func (am *MockAccountManager) UpdatePeerMeta(ctx context.Context, peerID string, } // GetUser mock implementation of GetUser from server.AccountManager interface -func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.User, error) { +func (am *MockAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) { if am.GetUserFunc != nil { return am.GetUserFunc(ctx, claims) } return nil, status.Errorf(codes.Unimplemented, "method GetUser is not implemented") } -func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*server.User, error) { +func (am *MockAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { if am.ListUsersFunc != nil { return am.ListUsersFunc(ctx, accountID) } @@ -481,7 +485,7 @@ func (am *MockAccountManager) ListRoutes(ctx context.Context, accountID, userID } // SaveSetupKey mocks SaveSetupKey of the AccountManager interface -func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error) { +func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) { if am.SaveSetupKeyFunc != nil { return am.SaveSetupKeyFunc(ctx, accountID, key, userID) } @@ -490,7 +494,7 @@ func (am *MockAccountManager) SaveSetupKey(ctx context.Context, accountID string } // GetSetupKey mocks GetSetupKey of the AccountManager interface -func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) { +func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { if am.GetSetupKeyFunc != nil { return am.GetSetupKeyFunc(ctx, accountID, userID, keyID) } @@ -499,7 +503,7 @@ func (am *MockAccountManager) GetSetupKey(ctx context.Context, accountID, userID } // ListSetupKeys mocks ListSetupKeys of the AccountManager interface -func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*server.SetupKey, error) { +func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) { if am.ListSetupKeysFunc != nil { return am.ListSetupKeysFunc(ctx, accountID, userID) } @@ -508,7 +512,7 @@ func (am *MockAccountManager) ListSetupKeys(ctx context.Context, accountID, user } // SaveUser mocks SaveUser of the AccountManager interface -func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *server.User) (*server.UserInfo, error) { +func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) { if am.SaveUserFunc != nil { return am.SaveUserFunc(ctx, accountID, userID, user) } @@ -516,7 +520,7 @@ func (am *MockAccountManager) SaveUser(ctx context.Context, accountID, userID st } // SaveOrAddUser mocks SaveOrAddUser of the AccountManager interface -func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error) { +func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) { if am.SaveOrAddUserFunc != nil { return am.SaveOrAddUserFunc(ctx, accountID, userID, user, addIfNotExists) } @@ -524,7 +528,7 @@ func (am *MockAccountManager) SaveOrAddUser(ctx context.Context, accountID, user } // SaveOrAddUsers mocks SaveOrAddUsers of the AccountManager interface -func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*server.User, addIfNotExists bool) ([]*server.UserInfo, error) { +func (am *MockAccountManager) SaveOrAddUsers(ctx context.Context, accountID, userID string, users []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) { if am.SaveOrAddUsersFunc != nil { return am.SaveOrAddUsersFunc(ctx, accountID, userID, users, addIfNotExists) } @@ -595,7 +599,7 @@ func (am *MockAccountManager) ListNameServerGroups(ctx context.Context, accountI } // CreateUser mocks CreateUser of the AccountManager interface -func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *server.UserInfo) (*server.UserInfo, error) { +func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) { if am.CreateUserFunc != nil { return am.CreateUserFunc(ctx, accountID, userID, invite) } @@ -642,7 +646,7 @@ func (am *MockAccountManager) GetEvents(ctx context.Context, accountID, userID s } // GetDNSSettings mocks GetDNSSettings of the AccountManager interface -func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*server.DNSSettings, error) { +func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) { if am.GetDNSSettingsFunc != nil { return am.GetDNSSettingsFunc(ctx, accountID, userID) } @@ -650,7 +654,7 @@ func (am *MockAccountManager) GetDNSSettings(ctx context.Context, accountID stri } // SaveDNSSettings mocks SaveDNSSettings of the AccountManager interface -func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *server.DNSSettings) error { +func (am *MockAccountManager) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error { if am.SaveDNSSettingsFunc != nil { return am.SaveDNSSettingsFunc(ctx, accountID, userID, dnsSettingsToSave) } @@ -666,7 +670,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us } // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface -func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { +func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { if am.UpdateAccountSettingsFunc != nil { return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings) } @@ -674,7 +678,7 @@ func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, account } // LoginPeer mocks LoginPeer of the AccountManager interface -func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.LoginPeerFunc != nil { return am.LoginPeerFunc(ctx, login) } @@ -682,7 +686,7 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo } // SyncPeer mocks SyncPeer of the AccountManager interface -func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { +func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if am.SyncPeerFunc != nil { return am.SyncPeerFunc(ctx, sync, account) } @@ -803,7 +807,7 @@ func (am *MockAccountManager) GetAccountIDForPeerKey(ctx context.Context, peerKe } // GetAccountByID mocks GetAccountByID of the AccountManager interface -func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*server.Account, error) { +func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) { if am.GetAccountByIDFunc != nil { return am.GetAccountByIDFunc(ctx, accountID, userID) } @@ -811,21 +815,21 @@ func (am *MockAccountManager) GetAccountByID(ctx context.Context, accountID stri } // GetUserByID mocks GetUserByID of the AccountManager interface -func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*server.User, error) { +func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { if am.GetUserByIDFunc != nil { return am.GetUserByIDFunc(ctx, id) } return nil, status.Errorf(codes.Unimplemented, "method GetUserByID is not implemented") } -func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*server.Settings, error) { +func (am *MockAccountManager) GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { if am.GetAccountSettingsFunc != nil { return am.GetAccountSettingsFunc(ctx, accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetAccountSettings is not implemented") } -func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*server.Account, error) { +func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { if am.GetAccountFunc != nil { return am.GetAccountFunc(ctx, accountID) } diff --git a/management/server/nameserver.go b/management/server/nameserver.go index e7a5387a1..1a01c7a89 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -11,15 +11,16 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$` // GetNameServerGroup gets a nameserver group object from account and nameserver group IDs func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -32,7 +33,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account return nil, status.NewAdminPermissionError() } - return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID) + return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID) } // CreateNameServerGroup creates and saves a new nameserver group @@ -40,7 +41,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -64,21 +65,21 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil { return err } - updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, newNSGroup.Groups) + updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, newNSGroup.Groups) if err != nil { return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, newNSGroup) + return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, newNSGroup) }) if err != nil { return nil, err @@ -87,7 +88,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return newNSGroup.Copy(), nil @@ -102,7 +103,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return status.Errorf(status.InvalidArgument, "nameserver group provided is nil") } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -113,8 +114,8 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID) if err != nil { return err } @@ -129,11 +130,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.SaveNameServerGroup(ctx, LockingStrengthUpdate, nsGroupToSave) + return transaction.SaveNameServerGroup(ctx, store.LockingStrengthUpdate, nsGroupToSave) }) if err != nil { return err @@ -142,7 +143,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -153,7 +154,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -165,22 +166,22 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco var nsGroup *nbdns.NameServerGroup var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - nsGroup, err = transaction.GetNameServerGroupByID(ctx, LockingStrengthUpdate, accountID, nsGroupID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID) if err != nil { return err } - updateAccountPeers, err = anyGroupHasPeers(ctx, transaction, accountID, nsGroup.Groups) + updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, nsGroup.Groups) if err != nil { return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID) + return transaction.DeleteNameServerGroup(ctx, store.LockingStrengthUpdate, accountID, nsGroupID) }) if err != nil { return err @@ -189,7 +190,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -197,7 +198,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco // ListNameServerGroups returns a list of nameserver groups from account func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -210,10 +211,10 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) } -func validateNameServerGroup(ctx context.Context, transaction Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { +func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { err := validateDomainInput(nameserverGroup.Primary, nameserverGroup.Domains, nameserverGroup.SearchDomainsEnabled) if err != nil { return err @@ -224,7 +225,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s return err } - nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, LockingStrengthShare, accountID) + nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -234,7 +235,7 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s return err } - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, nameserverGroup.Groups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups) if err != nil { return err } @@ -243,12 +244,12 @@ func validateNameServerGroup(ctx context.Context, transaction Store, accountID s } // areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers. -func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { +func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) { if !newNSGroup.Enabled && !oldNSGroup.Enabled { return false, nil } - hasPeers, err := anyGroupHasPeers(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups) + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups) if err != nil { return false, err } @@ -257,7 +258,7 @@ func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction Store return true, nil } - return anyGroupHasPeers(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) + return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups) } func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error { @@ -305,7 +306,7 @@ func validateNSList(list []nbdns.NameServer) error { return nil } -func validateGroups(list []string, groups map[string]*nbgroup.Group) error { +func validateGroups(list []string, groups map[string]*types.Group) error { if len(list) == 0 { return status.Errorf(status.InvalidArgument, "the list of group IDs should not be empty") } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 846dbf023..0743db513 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -11,9 +11,10 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -772,10 +773,10 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) } -func createNSStore(t *testing.T) (Store, error) { +func createNSStore(t *testing.T) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -784,7 +785,7 @@ func createNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) { t.Helper() peer1 := &nbpeer.Peer{ Key: nsGroupPeer1Key, @@ -842,12 +843,12 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup - newGroup1 := &nbgroup.Group{ + newGroup1 := &types.Group{ ID: group1ID, Name: group1ID, } - newGroup2 := &nbgroup.Group{ + newGroup2 := &types.Group{ ID: group2ID, Name: group2ID, } @@ -944,7 +945,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { var newNameServerGroupA *nbdns.NameServerGroup var newNameServerGroupB *nbdns.NameServerGroup - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go new file mode 100644 index 000000000..4a7b3db77 --- /dev/null +++ b/management/server/networks/manager.go @@ -0,0 +1,187 @@ +package networks + +import ( + "context" + "fmt" + + "github.com/rs/xid" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +type Manager interface { + GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) + CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) + GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) + UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) + DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error +} + +type managerImpl struct { + store store.Store + accountManager s.AccountManager + permissionsManager permissions.Manager + resourcesManager resources.Manager + routersManager routers.Manager +} + +func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager s.AccountManager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + resourcesManager: resourceManager, + routersManager: routersManager, + accountManager: accountManager, + } +} + +func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID) +} + +func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + network.ID = xid.New().String() + + unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID) + defer unlock() + + err = m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) + if err != nil { + return nil, fmt.Errorf("failed to save network: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkCreated, network.EventMeta()) + + return network, nil +} + +func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) +} + +func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, network.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, network.AccountID) + defer unlock() + + _, err = m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, network.AccountID, network.ID) + if err != nil { + return nil, fmt.Errorf("failed to get network: %w", err) + } + + m.accountManager.StoreEvent(ctx, userID, network.ID, network.AccountID, activity.NetworkUpdated, network.EventMeta()) + + return network, m.store.SaveNetwork(ctx, store.LockingStrengthUpdate, network) +} + +func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + network, err := m.store.GetNetworkByID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var eventsToStore []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to get resources in network: %w", err) + } + + for _, resource := range resources { + event, err := m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resource.ID) + if err != nil { + return fmt.Errorf("failed to delete resource: %w", err) + } + eventsToStore = append(eventsToStore, event...) + } + + routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to get routers in network: %w", err) + } + + for _, router := range routers { + event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, router.ID) + if err != nil { + return fmt.Errorf("failed to delete router: %w", err) + } + eventsToStore = append(eventsToStore, event) + } + + err = transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to delete network: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, networkID, accountID, activity.NetworkDeleted, network.EventMeta()) + }) + + return nil + }) + if err != nil { + return fmt.Errorf("failed to delete network: %w", err) + } + + for _, event := range eventsToStore { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} diff --git a/management/server/networks/manager_test.go b/management/server/networks/manager_test.go new file mode 100644 index 000000000..edd830c25 --- /dev/null +++ b/management/server/networks/manager_test.go @@ -0,0 +1,254 @@ +package networks + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/store" +) + +func Test_GetAllNetworksReturnsNetworks(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + networks, err := manager.GetAllNetworks(ctx, accountID, userID) + require.NoError(t, err) + require.Len(t, networks, 1) + require.Equal(t, "testNetworkId", networks[0].ID) +} + +func Test_GetAllNetworksReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + networks, err := manager.GetAllNetworks(ctx, accountID, userID) + require.Error(t, err) + require.Nil(t, networks) +} + +func Test_GetNetworkReturnsNetwork(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + networks, err := manager.GetNetwork(ctx, accountID, userID, networkID) + require.NoError(t, err) + require.Equal(t, "testNetworkId", networks.ID) +} + +func Test_GetNetworkReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + network, err := manager.GetNetwork(ctx, accountID, userID, networkID) + require.Error(t, err) + require.Nil(t, network) +} + +func Test_CreateNetworkSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + network := &types.Network{ + AccountID: "testAccountId", + Name: "new-network", + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + createdNetwork, err := manager.CreateNetwork(ctx, userID, network) + require.NoError(t, err) + require.Equal(t, network.Name, createdNetwork.Name) +} + +func Test_CreateNetworkFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + network := &types.Network{ + AccountID: "testAccountId", + Name: "new-network", + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + createdNetwork, err := manager.CreateNetwork(ctx, userID, network) + require.Error(t, err) + require.Nil(t, createdNetwork) +} + +func Test_DeleteNetworkSuccessfully(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + err = manager.DeleteNetwork(ctx, accountID, userID, networkID) + require.NoError(t, err) +} + +func Test_DeleteNetworkFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + err = manager.DeleteNetwork(ctx, accountID, userID, networkID) + require.Error(t, err) +} + +func Test_UpdateNetworkSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + network := &types.Network{ + AccountID: "testAccountId", + ID: "testNetworkId", + Name: "new-network", + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) + require.NoError(t, err) + require.Equal(t, network.Name, updatedNetwork.Name) +} + +func Test_UpdateNetworkFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + network := &types.Network{ + AccountID: "testAccountId", + ID: "testNetworkId", + Name: "new-network", + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + + am := mock_server.MockAccountManager{} + permissionsManager := permissions.NewManagerMock() + groupsManager := groups.NewManagerMock() + routerManager := routers.NewManagerMock() + resourcesManager := resources.NewManager(s, permissionsManager, groupsManager, &am) + manager := NewManager(s, permissionsManager, resourcesManager, routerManager, &am) + + updatedNetwork, err := manager.UpdateNetwork(ctx, userID, network) + require.Error(t, err) + require.Nil(t, updatedNetwork) +} diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go new file mode 100644 index 000000000..0fff5bcf8 --- /dev/null +++ b/management/server/networks/resources/manager.go @@ -0,0 +1,383 @@ +package resources + +import ( + "context" + "errors" + "fmt" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" +) + +type Manager interface { + GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) + GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) + GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) + CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) + GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) + UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) + DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error + DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + groupsManager groups.Manager + accountManager s.AccountManager +} + +func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager s.AccountManager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + groupsManager: groupsManager, + accountManager: accountManager, + } +} + +func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID) +} + +func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) +} + +func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get network resources: %w", err) + } + + resourceMap := make(map[string][]string) + for _, resource := range resources { + resourceMap[resource.NetworkID] = append(resourceMap[resource.NetworkID], resource.ID) + } + + return resourceMap, nil +} + +func (m *managerImpl) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resource, err = types.NewNetworkResource(resource.AccountID, resource.NetworkID, resource.Name, resource.Description, resource.Address, resource.GroupIDs) + if err != nil { + return nil, fmt.Errorf("failed to create new network resource: %w", err) + } + + unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID) + defer unlock() + + var eventsToStore []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + _, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) + if err == nil { + return errors.New("resource already exists") + } + + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + if err != nil { + return fmt.Errorf("failed to save network resource: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network)) + } + eventsToStore = append(eventsToStore, event) + + res := nbtypes.Resource{ + ID: resource.ID, + Type: resource.Type.String(), + } + for _, groupID := range resource.GroupIDs { + event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res) + if err != nil { + return fmt.Errorf("failed to add resource to group: %w", err) + } + eventsToStore = append(eventsToStore, event) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to create network resource: %w", err) + } + + for _, event := range eventsToStore { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) + + return resource, nil +} + +func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to get network resource: %w", err) + } + + if resource.NetworkID != networkID { + return nil, errors.New("resource not part of network") + } + + return resource, nil +} + +func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, resource.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + resourceType, domain, prefix, err := types.GetResourceType(resource.Address) + if err != nil { + return nil, fmt.Errorf("failed to get resource type: %w", err) + } + + resource.Type = resourceType + resource.Domain = domain + resource.Prefix = prefix + + unlock := m.store.AcquireWriteLockByUID(ctx, resource.AccountID) + defer unlock() + + var eventsToStore []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + if network.ID != resource.NetworkID { + return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID) + } + + _, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) + if err != nil { + return fmt.Errorf("failed to get network resource: %w", err) + } + + oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) + if err == nil && oldResource.ID != resource.ID { + return errors.New("new resource name already exists") + } + + oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) + if err != nil { + return fmt.Errorf("failed to get network resource: %w", err) + } + + err = transaction.SaveNetworkResource(ctx, store.LockingStrengthUpdate, resource) + if err != nil { + return fmt.Errorf("failed to save network resource: %w", err) + } + + events, err := m.updateResourceGroups(ctx, transaction, userID, resource, oldResource) + if err != nil { + return fmt.Errorf("failed to update resource groups: %w", err) + } + + eventsToStore = append(eventsToStore, events...) + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network)) + }) + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, resource.AccountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, fmt.Errorf("failed to update network resource: %w", err) + } + + for _, event := range eventsToStore { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) + + return resource, nil +} + +func (m *managerImpl) updateResourceGroups(ctx context.Context, transaction store.Store, userID string, newResource, oldResource *types.NetworkResource) ([]func(), error) { + res := nbtypes.Resource{ + ID: newResource.ID, + Type: newResource.Type.String(), + } + + oldResourceGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, oldResource.AccountID, oldResource.ID) + if err != nil { + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + oldGroupsIds := make([]string, 0) + for _, group := range oldResourceGroups { + oldGroupsIds = append(oldGroupsIds, group.ID) + } + + var eventsToStore []func() + groupsToAdd := util.Difference(newResource.GroupIDs, oldGroupsIds) + for _, groupID := range groupsToAdd { + events, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, newResource.AccountID, userID, groupID, &res) + if err != nil { + return nil, fmt.Errorf("failed to add resource to group: %w", err) + } + eventsToStore = append(eventsToStore, events) + } + + groupsToRemove := util.Difference(oldGroupsIds, newResource.GroupIDs) + for _, groupID := range groupsToRemove { + events, err := m.groupsManager.RemoveResourceFromGroupInTransaction(ctx, transaction, newResource.AccountID, userID, groupID, res.ID) + if err != nil { + return nil, fmt.Errorf("failed to add resource to group: %w", err) + } + eventsToStore = append(eventsToStore, events) + } + + return eventsToStore, nil +} + +func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var events []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID) + if err != nil { + return fmt.Errorf("failed to delete resource: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return fmt.Errorf("failed to delete network resource: %w", err) + } + + for _, event := range events { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) { + resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthUpdate, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to get network resource: %w", err) + } + + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return nil, fmt.Errorf("failed to get network: %w", err) + } + + if resource.NetworkID != networkID { + return nil, errors.New("resource not part of network") + } + + groups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthUpdate, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to get resource groups: %w", err) + } + + var eventsToStore []func() + + for _, group := range groups { + event, err := m.groupsManager.RemoveResourceFromGroupInTransaction(ctx, transaction, accountID, userID, group.ID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to remove resource from group: %w", err) + } + eventsToStore = append(eventsToStore, event) + } + + err = transaction.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID) + if err != nil { + return nil, fmt.Errorf("failed to delete network resource: %w", err) + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, resourceID, accountID, activity.NetworkResourceDeleted, resource.EventMeta(network)) + }) + + return eventsToStore, nil +} diff --git a/management/server/networks/resources/manager_test.go b/management/server/networks/resources/manager_test.go new file mode 100644 index 000000000..993cd65df --- /dev/null +++ b/management/server/networks/resources/manager_test.go @@ -0,0 +1,411 @@ +package resources + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/groups" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +func Test_GetAllResourcesInNetworkReturnsResources(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) + require.NoError(t, err) + require.Len(t, resources, 2) +} + +func Test_GetAllResourcesInNetworkReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetAllResourcesInNetwork(ctx, accountID, userID, networkID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, resources) +} +func Test_GetAllResourcesInAccountReturnsResources(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) + require.NoError(t, err) + require.Len(t, resources, 2) +} + +func Test_GetAllResourcesInAccountReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetAllResourcesInAccount(ctx, accountID, userID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, resources) +} + +func Test_GetResourceInNetworkReturnsResources(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resource, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) + require.NoError(t, err) + require.Equal(t, resourceID, resource.ID) +} + +func Test_GetResourceInNetworkReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + resources, err := manager.GetResource(ctx, accountID, userID, networkID, resourceID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, resources) +} + +func Test_CreateResourceSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + resource := &types.NetworkResource{ + AccountID: "testAccountId", + NetworkID: "testNetworkId", + Name: "newResourceId", + Description: "description", + Address: "192.168.1.1", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + createdResource, err := manager.CreateResource(ctx, userID, resource) + require.NoError(t, err) + require.Equal(t, resource.Name, createdResource.Name) +} + +func Test_CreateResourceFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + resource := &types.NetworkResource{ + AccountID: "testAccountId", + NetworkID: "testNetworkId", + Name: "testResourceId", + Description: "description", + Address: "192.168.1.1", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + createdResource, err := manager.CreateResource(ctx, userID, resource) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, createdResource) +} + +func Test_CreateResourceFailsWithInvalidAddress(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + resource := &types.NetworkResource{ + AccountID: "testAccountId", + NetworkID: "testNetworkId", + Name: "testResourceId", + Description: "description", + Address: "invalid-address", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + createdResource, err := manager.CreateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, createdResource) +} + +func Test_CreateResourceFailsWithUsedName(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + resource := &types.NetworkResource{ + AccountID: "testAccountId", + NetworkID: "testNetworkId", + Name: "testResourceId", + Description: "description", + Address: "invalid-address", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + createdResource, err := manager.CreateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, createdResource) +} + +func Test_UpdateResourceSuccessfully(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + resource := &types.NetworkResource{ + AccountID: accountID, + NetworkID: networkID, + Name: "someNewName", + ID: resourceID, + Description: "new-description", + Address: "1.2.3.0/24", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + updatedResource, err := manager.UpdateResource(ctx, userID, resource) + require.NoError(t, err) + require.NotNil(t, updatedResource) + require.Equal(t, "new-description", updatedResource.Description) + require.Equal(t, "1.2.3.0/24", updatedResource.Address) + require.Equal(t, types.NetworkResourceType("subnet"), updatedResource.Type) +} + +func Test_UpdateResourceFailsWithResourceNotFound(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "otherResourceId" + resource := &types.NetworkResource{ + AccountID: accountID, + NetworkID: networkID, + Name: resourceID, + Description: "new-description", + Address: "1.2.3.0/24", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + updatedResource, err := manager.UpdateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, updatedResource) +} + +func Test_UpdateResourceFailsWithNameInUse(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + resource := &types.NetworkResource{ + AccountID: accountID, + NetworkID: networkID, + ID: resourceID, + Name: "used-name", + Description: "new-description", + Address: "1.2.3.0/24", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + updatedResource, err := manager.UpdateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, updatedResource) +} + +func Test_UpdateResourceFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + resource := &types.NetworkResource{ + AccountID: accountID, + NetworkID: networkID, + Name: resourceID, + Description: "new-description", + Address: "1.2.3.0/24", + } + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + updatedResource, err := manager.UpdateResource(ctx, userID, resource) + require.Error(t, err) + require.Nil(t, updatedResource) +} + +func Test_DeleteResourceSuccessfully(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) + require.NoError(t, err) +} + +func Test_DeleteResourceFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + resourceID := "testResourceId" + + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + groupsManager := groups.NewManagerMock() + manager := NewManager(store, permissionsManager, groupsManager, &am) + + err = manager.DeleteResource(ctx, accountID, userID, networkID, resourceID) + require.Error(t, err) +} diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go new file mode 100644 index 000000000..7eecdce0f --- /dev/null +++ b/management/server/networks/resources/types/resource.go @@ -0,0 +1,169 @@ +package types + +import ( + "errors" + "fmt" + "net/netip" + "regexp" + + "github.com/rs/xid" + + nbDomain "github.com/netbirdio/netbird/management/domain" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/route" + + "github.com/netbirdio/netbird/management/server/http/api" +) + +type NetworkResourceType string + +const ( + host NetworkResourceType = "host" + subnet NetworkResourceType = "subnet" + domain NetworkResourceType = "domain" +) + +func (p NetworkResourceType) String() string { + return string(p) +} + +type NetworkResource struct { + ID string `gorm:"index"` + NetworkID string `gorm:"index"` + AccountID string `gorm:"index"` + Name string + Description string + Type NetworkResourceType + Address string `gorm:"-"` + GroupIDs []string `gorm:"-"` + Domain string + Prefix netip.Prefix `gorm:"serializer:json"` +} + +func NewNetworkResource(accountID, networkID, name, description, address string, groupIDs []string) (*NetworkResource, error) { + resourceType, domain, prefix, err := GetResourceType(address) + if err != nil { + return nil, fmt.Errorf("invalid address: %w", err) + } + + return &NetworkResource{ + ID: xid.New().String(), + AccountID: accountID, + NetworkID: networkID, + Name: name, + Description: description, + Type: resourceType, + Address: address, + Domain: domain, + Prefix: prefix, + GroupIDs: groupIDs, + }, nil +} + +func (n *NetworkResource) ToAPIResponse(groups []api.GroupMinimum) *api.NetworkResource { + addr := n.Prefix.String() + if n.Type == domain { + addr = n.Domain + } + + return &api.NetworkResource{ + Id: n.ID, + Name: n.Name, + Description: &n.Description, + Type: api.NetworkResourceType(n.Type.String()), + Address: addr, + Groups: groups, + } +} + +func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) { + n.Name = req.Name + + if req.Description != nil { + n.Description = *req.Description + } + n.Address = req.Address + n.GroupIDs = req.Groups +} + +func (n *NetworkResource) Copy() *NetworkResource { + return &NetworkResource{ + ID: n.ID, + AccountID: n.AccountID, + NetworkID: n.NetworkID, + Name: n.Name, + Description: n.Description, + Type: n.Type, + Address: n.Address, + Domain: n.Domain, + Prefix: n.Prefix, + GroupIDs: n.GroupIDs, + } +} + +func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.NetworkRouter) *route.Route { + r := &route.Route{ + ID: route.ID(fmt.Sprintf("%s:%s", n.ID, peer.ID)), + AccountID: n.AccountID, + KeepRoute: true, + NetID: route.NetID(n.Name), + Description: n.Description, + Peer: peer.Key, + PeerGroups: nil, + Masquerade: router.Masquerade, + Metric: router.Metric, + Enabled: true, + Groups: nil, + AccessControlGroups: nil, + } + + if n.Type == host || n.Type == subnet { + r.Network = n.Prefix + + r.NetworkType = route.IPv4Network + if n.Prefix.Addr().Is6() { + r.NetworkType = route.IPv6Network + } + } + + if n.Type == domain { + domainList, err := nbDomain.FromStringList([]string{n.Domain}) + if err != nil { + return nil + } + r.Domains = domainList + r.NetworkType = route.DomainNetwork + + // add default placeholder for domain network + r.Network = netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) + } + + return r +} + +func (n *NetworkResource) EventMeta(network *networkTypes.Network) map[string]any { + return map[string]any{"name": n.Name, "type": n.Type, "network_name": network.Name, "network_id": network.ID} +} + +// GetResourceType returns the type of the resource based on the address +func GetResourceType(address string) (NetworkResourceType, string, netip.Prefix, error) { + if prefix, err := netip.ParsePrefix(address); err == nil { + if prefix.Bits() == 32 || prefix.Bits() == 128 { + return host, "", prefix, nil + } + return subnet, "", prefix, nil + } + + if ip, err := netip.ParseAddr(address); err == nil { + return host, "", netip.PrefixFrom(ip, ip.BitLen()), nil + } + + domainRegex := regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`) + if domainRegex.MatchString(address) { + return domain, address, netip.Prefix{}, nil + } + + return "", "", netip.Prefix{}, errors.New("not a valid host, subnet, or domain") +} diff --git a/management/server/networks/resources/types/resource_test.go b/management/server/networks/resources/types/resource_test.go new file mode 100644 index 000000000..6af384cce --- /dev/null +++ b/management/server/networks/resources/types/resource_test.go @@ -0,0 +1,53 @@ +package types + +import ( + "net/netip" + "testing" +) + +func TestGetResourceType(t *testing.T) { + tests := []struct { + input string + expectedType NetworkResourceType + expectedErr bool + expectedDomain string + expectedPrefix netip.Prefix + }{ + // Valid host IPs + {"1.1.1.1", host, false, "", netip.MustParsePrefix("1.1.1.1/32")}, + {"1.1.1.1/32", host, false, "", netip.MustParsePrefix("1.1.1.1/32")}, + // Valid subnets + {"192.168.1.0/24", subnet, false, "", netip.MustParsePrefix("192.168.1.0/24")}, + {"10.0.0.0/16", subnet, false, "", netip.MustParsePrefix("10.0.0.0/16")}, + // Valid domains + {"example.com", domain, false, "example.com", netip.Prefix{}}, + {"*.example.com", domain, false, "*.example.com", netip.Prefix{}}, + {"sub.example.com", domain, false, "sub.example.com", netip.Prefix{}}, + // Invalid inputs + {"invalid", "", true, "", netip.Prefix{}}, + {"1.1.1.1/abc", "", true, "", netip.Prefix{}}, + {"1234", "", true, "", netip.Prefix{}}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result, domain, prefix, err := GetResourceType(tt.input) + + if result != tt.expectedType { + t.Errorf("Expected type %v, got %v", tt.expectedType, result) + } + + if tt.expectedErr && err == nil { + t.Errorf("Expected error, got nil") + } + + if prefix != tt.expectedPrefix { + t.Errorf("Expected address %v, got %v", tt.expectedPrefix, prefix) + } + + if domain != tt.expectedDomain { + t.Errorf("Expected domain %v, got %v", tt.expectedDomain, domain) + } + }) + } +} diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go new file mode 100644 index 000000000..3b32810a2 --- /dev/null +++ b/management/server/networks/routers/manager.go @@ -0,0 +1,289 @@ +package routers + +import ( + "context" + "errors" + "fmt" + + "github.com/rs/xid" + + s "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +type Manager interface { + GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) + GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) + CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) + GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) + UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) + DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error + DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) +} + +type managerImpl struct { + store store.Store + permissionsManager permissions.Manager + accountManager s.AccountManager +} + +type mockManager struct { +} + +func NewManager(store store.Store, permissionsManager permissions.Manager, accountManager s.AccountManager) Manager { + return &managerImpl{ + store: store, + permissionsManager: permissionsManager, + accountManager: accountManager, + } +} + +func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID) +} + +func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get network routers: %w", err) + } + + routersMap := make(map[string][]*types.NetworkRouter) + for _, router := range routers { + routersMap[router.NetworkID] = append(routersMap[router.NetworkID], router) + } + + return routersMap, nil +} + +func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID) + defer unlock() + + var network *networkTypes.Network + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + if network.ID != router.NetworkID { + return status.NewNetworkNotFoundError(router.NetworkID) + } + + router.ID = xid.New().String() + + err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + if err != nil { + return fmt.Errorf("failed to create network router: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, err + } + + m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) + + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + + return router, nil +} + +func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID) + if err != nil { + return nil, fmt.Errorf("failed to get network router: %w", err) + } + + if router.NetworkID != networkID { + return nil, errors.New("router not part of network") + } + + return router, nil +} + +func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, router.AccountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, router.AccountID) + defer unlock() + + var network *networkTypes.Network + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) + if err != nil { + return fmt.Errorf("failed to get network: %w", err) + } + + if network.ID != router.NetworkID { + return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID) + } + + err = transaction.SaveNetworkRouter(ctx, store.LockingStrengthUpdate, router) + if err != nil { + return fmt.Errorf("failed to update network router: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, router.AccountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, err + } + + m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) + + go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) + + return router, nil +} + +func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Networks, permissions.Write) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + var event func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID) + if err != nil { + return fmt.Errorf("failed to delete network router: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return err + } + + event() + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) { + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) + if err != nil { + return nil, fmt.Errorf("failed to get network: %w", err) + } + + router, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, accountID, routerID) + if err != nil { + return nil, fmt.Errorf("failed to get network router: %w", err) + } + + if router.NetworkID != networkID { + return nil, status.NewRouterNotPartOfNetworkError(routerID, networkID) + } + + err = transaction.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, routerID) + if err != nil { + return nil, fmt.Errorf("failed to delete network router: %w", err) + } + + event := func() { + m.accountManager.StoreEvent(ctx, userID, routerID, accountID, activity.NetworkRouterDeleted, router.EventMeta(network)) + } + + return event, nil +} + +func NewManagerMock() Manager { + return &mockManager{} +} + +func (m *mockManager) GetAllRoutersInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkRouter, error) { + return []*types.NetworkRouter{}, nil +} + +func (m *mockManager) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) { + return map[string][]*types.NetworkRouter{}, nil +} + +func (m *mockManager) CreateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + return router, nil +} + +func (m *mockManager) GetRouter(ctx context.Context, accountID, userID, networkID, routerID string) (*types.NetworkRouter, error) { + return &types.NetworkRouter{}, nil +} + +func (m *mockManager) UpdateRouter(ctx context.Context, userID string, router *types.NetworkRouter) (*types.NetworkRouter, error) { + return router, nil +} + +func (m *mockManager) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error { + return nil +} + +func (m *mockManager) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) { + return func() {}, nil +} diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go new file mode 100644 index 000000000..e650074cc --- /dev/null +++ b/management/server/networks/routers/manager_test.go @@ -0,0 +1,234 @@ +package routers + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" +) + +func Test_GetAllRoutersInNetworkReturnsRouters(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + routers, err := manager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) + require.NoError(t, err) + require.Len(t, routers, 1) + require.Equal(t, "testRouterId", routers[0].ID) +} + +func Test_GetAllRoutersInNetworkReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + routers, err := manager.GetAllRoutersInNetwork(ctx, accountID, userID, networkID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, routers) +} + +func Test_GetRouterReturnsRouter(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + resourceID := "testRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + router, err := manager.GetRouter(ctx, accountID, userID, networkID, resourceID) + require.NoError(t, err) + require.Equal(t, "testRouterId", router.ID) +} + +func Test_GetRouterReturnsPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + resourceID := "testRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + router, err := manager.GetRouter(ctx, accountID, userID, networkID, resourceID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, router) +} + +func Test_CreateRouterSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999) + if err != nil { + require.NoError(t, err) + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + createdRouter, err := manager.CreateRouter(ctx, userID, router) + require.NoError(t, err) + require.NotEqual(t, "", router.ID) + require.Equal(t, router.NetworkID, createdRouter.NetworkID) + require.Equal(t, router.Peer, createdRouter.Peer) + require.Equal(t, router.Metric, createdRouter.Metric) + require.Equal(t, router.Masquerade, createdRouter.Masquerade) +} + +func Test_CreateRouterFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999) + if err != nil { + require.NoError(t, err) + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + createdRouter, err := manager.CreateRouter(ctx, userID, router) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, createdRouter) +} + +func Test_DeleteRouterSuccessfully(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "allowedUser" + networkID := "testNetworkId" + routerID := "testRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + err = manager.DeleteRouter(ctx, accountID, userID, networkID, routerID) + require.NoError(t, err) +} + +func Test_DeleteRouterFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + accountID := "testAccountId" + userID := "invalidUser" + networkID := "testNetworkId" + routerID := "testRouterId" + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + err = manager.DeleteRouter(ctx, accountID, userID, networkID, routerID) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) +} + +func Test_UpdateRouterSuccessfully(t *testing.T) { + ctx := context.Background() + userID := "allowedUser" + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1) + if err != nil { + require.NoError(t, err) + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + updatedRouter, err := manager.UpdateRouter(ctx, userID, router) + require.NoError(t, err) + require.Equal(t, router.Metric, updatedRouter.Metric) +} + +func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) { + ctx := context.Background() + userID := "invalidUser" + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1) + if err != nil { + require.NoError(t, err) + } + + s, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "../../testdata/networks.sql", t.TempDir()) + if err != nil { + t.Fatal(err) + } + t.Cleanup(cleanUp) + permissionsManager := permissions.NewManagerMock() + am := mock_server.MockAccountManager{} + manager := NewManager(s, permissionsManager, &am) + + updatedRouter, err := manager.UpdateRouter(ctx, userID, router) + require.Error(t, err) + require.Equal(t, status.NewPermissionDeniedError(), err) + require.Nil(t, updatedRouter) +} diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go new file mode 100644 index 000000000..f37ae0861 --- /dev/null +++ b/management/server/networks/routers/types/router.go @@ -0,0 +1,75 @@ +package types + +import ( + "errors" + + "github.com/rs/xid" + + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/networks/types" +) + +type NetworkRouter struct { + ID string `gorm:"index"` + NetworkID string `gorm:"index"` + AccountID string `gorm:"index"` + Peer string + PeerGroups []string `gorm:"serializer:json"` + Masquerade bool + Metric int +} + +func NewNetworkRouter(accountID string, networkID string, peer string, peerGroups []string, masquerade bool, metric int) (*NetworkRouter, error) { + if peer != "" && len(peerGroups) > 0 { + return nil, errors.New("peer and peerGroups cannot be set at the same time") + } + + return &NetworkRouter{ + ID: xid.New().String(), + AccountID: accountID, + NetworkID: networkID, + Peer: peer, + PeerGroups: peerGroups, + Masquerade: masquerade, + Metric: metric, + }, nil +} + +func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter { + return &api.NetworkRouter{ + Id: n.ID, + Peer: &n.Peer, + PeerGroups: &n.PeerGroups, + Masquerade: n.Masquerade, + Metric: n.Metric, + } +} + +func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) { + if req.Peer != nil { + n.Peer = *req.Peer + } + + if req.PeerGroups != nil { + n.PeerGroups = *req.PeerGroups + } + + n.Masquerade = req.Masquerade + n.Metric = req.Metric +} + +func (n *NetworkRouter) Copy() *NetworkRouter { + return &NetworkRouter{ + ID: n.ID, + NetworkID: n.NetworkID, + AccountID: n.AccountID, + Peer: n.Peer, + PeerGroups: n.PeerGroups, + Masquerade: n.Masquerade, + Metric: n.Metric, + } +} + +func (n *NetworkRouter) EventMeta(network *types.Network) map[string]any { + return map[string]any{"network_name": network.Name, "network_id": network.ID, "peer": n.Peer, "peer_groups": n.PeerGroups} +} diff --git a/management/server/networks/routers/types/router_test.go b/management/server/networks/routers/types/router_test.go new file mode 100644 index 000000000..3335f7c89 --- /dev/null +++ b/management/server/networks/routers/types/router_test.go @@ -0,0 +1,100 @@ +package types + +import "testing" + +func TestNewNetworkRouter(t *testing.T) { + tests := []struct { + name string + accountID string + networkID string + peer string + peerGroups []string + masquerade bool + metric int + expectedError bool + }{ + // Valid cases + { + name: "Valid with peer only", + networkID: "network-1", + accountID: "account-1", + peer: "peer-1", + peerGroups: nil, + masquerade: true, + metric: 100, + expectedError: false, + }, + { + name: "Valid with peerGroups only", + networkID: "network-2", + accountID: "account-2", + peer: "", + peerGroups: []string{"group-1", "group-2"}, + masquerade: false, + metric: 200, + expectedError: false, + }, + { + name: "Valid with no peer or peerGroups", + networkID: "network-3", + accountID: "account-3", + peer: "", + peerGroups: nil, + masquerade: true, + metric: 300, + expectedError: false, + }, + + // Invalid cases + { + name: "Invalid with both peer and peerGroups", + networkID: "network-4", + accountID: "account-4", + peer: "peer-2", + peerGroups: []string{"group-3"}, + masquerade: false, + metric: 400, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + router, err := NewNetworkRouter(tt.accountID, tt.networkID, tt.peer, tt.peerGroups, tt.masquerade, tt.metric) + + if tt.expectedError && err == nil { + t.Fatalf("Expected an error, got nil") + } + + if tt.expectedError == false { + if router == nil { + t.Fatalf("Expected a NetworkRouter object, got nil") + } + + if router.AccountID != tt.accountID { + t.Errorf("Expected AccountID %s, got %s", tt.accountID, router.AccountID) + } + + if router.NetworkID != tt.networkID { + t.Errorf("Expected NetworkID %s, got %s", tt.networkID, router.NetworkID) + } + + if router.Peer != tt.peer { + t.Errorf("Expected Peer %s, got %s", tt.peer, router.Peer) + } + + if len(router.PeerGroups) != len(tt.peerGroups) { + t.Errorf("Expected PeerGroups %v, got %v", tt.peerGroups, router.PeerGroups) + } + + if router.Masquerade != tt.masquerade { + t.Errorf("Expected Masquerade %v, got %v", tt.masquerade, router.Masquerade) + } + + if router.Metric != tt.metric { + t.Errorf("Expected Metric %d, got %d", tt.metric, router.Metric) + } + } + }) + } +} diff --git a/management/server/networks/types/network.go b/management/server/networks/types/network.go new file mode 100644 index 000000000..a4ba7b821 --- /dev/null +++ b/management/server/networks/types/network.go @@ -0,0 +1,56 @@ +package types + +import ( + "github.com/rs/xid" + + "github.com/netbirdio/netbird/management/server/http/api" +) + +type Network struct { + ID string `gorm:"index"` + AccountID string `gorm:"index"` + Name string + Description string +} + +func NewNetwork(accountId, name, description string) *Network { + return &Network{ + ID: xid.New().String(), + AccountID: accountId, + Name: name, + Description: description, + } +} + +func (n *Network) ToAPIResponse(routerIDs []string, resourceIDs []string, routingPeersCount int, policyIDs []string) *api.Network { + return &api.Network{ + Id: n.ID, + Name: n.Name, + Description: &n.Description, + Routers: routerIDs, + Resources: resourceIDs, + RoutingPeersCount: routingPeersCount, + Policies: policyIDs, + } +} + +func (n *Network) FromAPIRequest(req *api.NetworkRequest) { + n.Name = req.Name + if req.Description != nil { + n.Description = *req.Description + } +} + +// Copy returns a copy of a posture checks. +func (n *Network) Copy() *Network { + return &Network{ + ID: n.ID, + AccountID: n.AccountID, + Name: n.Name, + Description: n.Description, + } +} + +func (n *Network) EventMeta() map[string]any { + return map[string]any{"name": n.Name} +} diff --git a/management/server/peer.go b/management/server/peer.go index ba211be96..ad20d279a 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -16,6 +16,8 @@ import ( "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/activity" @@ -92,7 +94,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID // fetch all the peers that have access to the user's peers for _, peer := range peers { - aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) + aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap) for _, p := range aclPeers { peersMap[p.ID] = p } @@ -107,7 +109,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error { +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *types.Account) error { peer, err := account.FindPeerByPubKey(peerPubKey) if err != nil { return fmt.Errorf("failed to find peer by pub key: %w", err) @@ -133,13 +135,13 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } return nil } -func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) { +func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *types.Account) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -213,9 +215,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if peerLabelUpdated { peer.Name = update.Name - existingLabels := account.getPeerDNSLabels() + existingLabels := account.GetPeerDNSLabels() - newLabel, err := getPeerHostLabel(peer.Name, existingLabels) + newLabel, err := types.GetPeerHostLabel(peer.Name, existingLabels) if err != nil { return nil, err } @@ -271,14 +273,14 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if peerLabelUpdated || requiresPeerUpdates { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return peer, nil } // deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock -func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error { +func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *types.Account, peerIDs []string, userID string) error { // the first loop is needed to ensure all peers present under the account before modifying, otherwise // we might have some inconsistencies @@ -316,7 +318,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou FirewallRulesIsEmpty: true, }, }, - NetworkMap: &NetworkMap{}, + NetworkMap: &types.NetworkMap{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) @@ -351,14 +353,14 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) -func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) { +func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) { account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return nil, err @@ -379,11 +381,11 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin return nil, err } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil + return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil), nil } // GetPeerNetwork returns the Network for a given peer -func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*Network, error) { +func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID string) (*types.Network, error) { account, err := am.Store.GetAccountByPeerID(ctx, peerID) if err != nil { return nil, err @@ -399,7 +401,7 @@ func (am *DefaultAccountManager) GetPeerNetwork(ctx context.Context, peerID stri // to it. We also add the User ID to the peer metadata to identify registrant. If no userID provided, then fail with status.PermissionDenied // Each new Peer will be assigned a new next net.IP from the Account.Network and Account.Network.LastIP will be updated (IP's are not reused). // The peer property is just a placeholder for the Peer properties to pass further -func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if setupKey == "" && userID == "" { // no auth method provided => reject access return nil, nil, nil, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login") @@ -433,7 +435,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s // and the peer disconnects with a timeout and tries to register again. // We just check if this machine has been registered before and reject the second registration. // The connecting peer should be able to recover with a retry. - _, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key) + _, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key) if err == nil { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") } @@ -446,12 +448,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s var newPeer *nbpeer.Peer var groupsToAdd []string - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var setupKeyID string var setupKeyName string var ephemeral bool if addedByUser { - user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID) + user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthUpdate, userID) if err != nil { return fmt.Errorf("failed to get user groups: %w", err) } @@ -460,7 +462,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s opEvent.Activity = activity.PeerAddedByUser } else { // Validate the setup key - sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, encodedHashedKey) + sk, err := transaction.GetSetupKeyBySecret(ctx, store.LockingStrengthUpdate, encodedHashedKey) if err != nil { return fmt.Errorf("failed to get setup key: %w", err) } @@ -533,7 +535,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } - settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return fmt.Errorf("failed to get account settings: %w", err) } @@ -558,7 +560,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed to add peer to account: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -609,7 +611,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } if newGroupsAffectsPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } approvedPeersMap, err := am.GetValidatedPeers(account) @@ -623,22 +625,22 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) + networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) return newPeer, networkMap, postureChecks, nil } -func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) { - takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID) +func (am *DefaultAccountManager) getFreeIP(ctx context.Context, s store.Store, accountID string) (net.IP, error) { + takenIps, err := s.GetTakenIPs(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return nil, fmt.Errorf("failed to get taken IPs: %w", err) } - network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID) + network, err := s.GetAccountNetwork(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return nil, fmt.Errorf("failed getting network: %w", err) } - nextIp, err := AllocatePeerIP(network.Net, takenIps) + nextIp, err := types.AllocatePeerIP(network.Net, takenIps) if err != nil { return nil, fmt.Errorf("failed to allocate new peer ip: %w", err) } @@ -647,7 +649,7 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc } // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible -func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *types.Account) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) if err != nil { return nil, nil, nil, status.NewPeerNotRegisteredError() @@ -691,11 +693,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } if isStatusChanged || sync.UpdateAccountPeers || (updated && len(postureChecks) > 0) { - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } if peerNotValid { - emptyMap := &NetworkMap{ + emptyMap := &types.NetworkMap{ Network: account.Network.Copy(), } return peer, emptyMap, []*posture.Checks{}, nil @@ -707,10 +709,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil + return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()), postureChecks, nil } -func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login PeerLogin, err error) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. // Try registering it. @@ -730,7 +732,7 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. -func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, login.WireGuardPubKey) if err != nil { return am.handlePeerLoginNotFound(ctx, login, err) @@ -755,12 +757,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } }() - peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) if err != nil { return nil, nil, nil, err } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } @@ -785,7 +787,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } } - groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } @@ -837,7 +839,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } if updateRemotePeers || isStatusChanged || (updated && len(postureChecks) > 0) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) @@ -849,7 +851,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // and before starting the engine, we do the checks without an account lock to avoid piling up requests. func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey) if err != nil { return err } @@ -860,7 +862,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } - settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -872,11 +874,11 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } -func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *types.Account, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { var postureChecks []*posture.Checks if isRequiresApproval { - emptyMap := &NetworkMap{ + emptyMap := &types.NetworkMap{ Network: account.Network.Copy(), } return peer, emptyMap, nil, nil @@ -893,10 +895,10 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is } customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil + return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()), postureChecks, nil } -func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error { +func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *types.User, peer *nbpeer.Peer) error { err := checkAuth(ctx, user.Id, peer) if err != nil { return err @@ -918,7 +920,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us return nil } -func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error { +func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *types.User) error { if peer.AddedWithSSOLogin() { if user.IsBlocked() { return status.Errorf(status.PermissionDenied, "user is blocked") @@ -939,7 +941,7 @@ func checkAuth(ctx context.Context, loginUserID string, peer *nbpeer.Peer) error return nil } -func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings) bool { +func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Settings) bool { expired, expiresIn := peer.LoginExpired(settings.PeerLoginExpiration) expired = settings.PeerLoginExpirationEnabled && expired if expired || peer.Status.LoginExpired { @@ -991,7 +993,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, } for _, p := range userPeers { - aclPeers, _ := account.getPeerConnectionResources(ctx, p.ID, approvedPeersMap) + aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap) for _, aclPeer := range aclPeers { if aclPeer.ID == peerID { return peer, nil @@ -1002,9 +1004,9 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID) } -// updateAccountPeers updates all peers that belong to an account. +// UpdateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) { +func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) @@ -1031,6 +1033,8 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account dnsCache := &DNSConfigCache{} customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() for _, peer := range peers { if !am.peersUpdateManager.HasChannel(peer.ID) { @@ -1050,8 +1054,8 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account return } - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) - update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache) + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) }(peer) } @@ -1069,7 +1073,7 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} { // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) { +func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *types.Account, peerID string) (bool, error) { peerGroupIDs := make([]string, 0) for _, group := range account.Groups { if slices.Contains(group.Peers, peerID) { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index b15315f98..2ab262ff0 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -19,15 +19,20 @@ import ( "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" nbroute "github.com/netbirdio/netbird/route" ) @@ -37,13 +42,13 @@ func TestPeer_LoginExpired(t *testing.T) { expirationEnabled bool lastLogin time.Time expected bool - accountSettings *Settings + accountSettings *types.Settings }{ { name: "Peer Login Expiration Disabled. Peer Login Should Not Expire", expirationEnabled: false, lastLogin: time.Now().UTC().Add(-25 * time.Hour), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -53,7 +58,7 @@ func TestPeer_LoginExpired(t *testing.T) { name: "Peer Login Should Expire", expirationEnabled: true, lastLogin: time.Now().UTC().Add(-25 * time.Hour), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -63,7 +68,7 @@ func TestPeer_LoginExpired(t *testing.T) { name: "Peer Login Should Not Expire", expirationEnabled: true, lastLogin: time.Now().UTC(), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerLoginExpirationEnabled: true, PeerLoginExpiration: time.Hour, }, @@ -92,14 +97,14 @@ func TestPeer_SessionExpired(t *testing.T) { lastLogin time.Time connected bool expected bool - accountSettings *Settings + accountSettings *types.Settings }{ { name: "Peer Inactivity Expiration Disabled. Peer Inactivity Should Not Expire", expirationEnabled: false, connected: false, lastLogin: time.Now().UTC().Add(-1 * time.Second), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Hour, }, @@ -110,7 +115,7 @@ func TestPeer_SessionExpired(t *testing.T) { expirationEnabled: true, connected: false, lastLogin: time.Now().UTC().Add(-1 * time.Second), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Second, }, @@ -121,7 +126,7 @@ func TestPeer_SessionExpired(t *testing.T) { expirationEnabled: true, connected: true, lastLogin: time.Now().UTC(), - accountSettings: &Settings{ + accountSettings: &types.Settings{ PeerInactivityExpirationEnabled: true, PeerInactivityExpiration: time.Second, }, @@ -161,7 +166,7 @@ func TestAccountManager_GetNetworkMap(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { t.Fatal("error creating setup key") return @@ -233,9 +238,9 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { t.Fatal(err) } - var setupKey *SetupKey + var setupKey *types.SetupKey for _, key := range account.SetupKeys { - if key.Type == SetupKeyReusable { + if key.Type == types.SetupKeyReusable { setupKey = key } } @@ -281,8 +286,8 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } var ( - group1 nbgroup.Group - group2 nbgroup.Group + group1 types.Group + group2 types.Group ) group1.ID = xid.New().String() @@ -303,16 +308,16 @@ func TestAccountManager_GetNetworkMapWithPolicy(t *testing.T) { return } - policy := &Policy{ + policy := &types.Policy{ Name: "test", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{group1.ID}, Destinations: []string{group2.ID}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -410,7 +415,7 @@ func TestAccountManager_GetPeerNetwork(t *testing.T) { t.Fatal(err) } - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userId, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, userId, false) if err != nil { t.Fatal("error creating setup key") return @@ -469,9 +474,9 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { adminUser := "account_creator" someUser := "some_user" account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ + account.Users[someUser] = &types.User{ Id: someUser, - Role: UserRoleUser, + Role: types.UserRoleUser, } account.Settings.RegularUsersViewBlocked = false @@ -482,7 +487,7 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { } // two peers one added by a regular user and one with a setup key - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", types.SetupKeyReusable, time.Hour, nil, 999, adminUser, false) if err != nil { t.Fatal("error creating setup key") return @@ -567,77 +572,77 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { func TestDefaultAccountManager_GetPeers(t *testing.T) { testCases := []struct { name string - role UserRole + role types.UserRole limitedViewSettings bool isServiceUser bool expectedPeerCount int }{ { name: "Regular user, no limited view settings, not a service user", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: false, isServiceUser: false, expectedPeerCount: 1, }, { name: "Service user, no limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: false, isServiceUser: true, expectedPeerCount: 2, }, { name: "Regular user, limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 0, }, { name: "Service user, limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: true, isServiceUser: true, expectedPeerCount: 2, }, { name: "Admin, no limited view settings, not a service user", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: false, isServiceUser: false, expectedPeerCount: 2, }, { name: "Admin service user, no limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: false, isServiceUser: true, expectedPeerCount: 2, }, { name: "Admin, limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 2, }, { name: "Admin Service user, limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: true, isServiceUser: true, expectedPeerCount: 2, }, { name: "Owner, no limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 2, }, { name: "Owner, limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: true, isServiceUser: false, expectedPeerCount: 2, @@ -656,12 +661,12 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { adminUser := "account_creator" someUser := "some_user" account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ + account.Users[someUser] = &types.User{ Id: someUser, Role: testCase.role, IsServiceUser: testCase.isServiceUser, } - account.Policies = []*Policy{} + account.Policies = []*types.Policy{} account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings err = manager.Store.SaveAccount(context.Background(), account) @@ -726,9 +731,9 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou regularUser := "regular_user" account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[regularUser] = &User{ + account.Users[regularUser] = &types.User{ Id: regularUser, - Role: UserRoleUser, + Role: types.UserRoleUser, } // Create peers @@ -746,10 +751,10 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou } // Create groups and policies - account.Policies = make([]*Policy, 0, groups) + account.Policies = make([]*types.Policy, 0, groups) for i := 0; i < groups; i++ { groupID := fmt.Sprintf("group-%d", i) - group := &nbgroup.Group{ + group := &types.Group{ ID: groupID, Name: fmt.Sprintf("Group %d", i), } @@ -757,14 +762,95 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou peerIndex := i*(peers/groups) + j group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) } + + // Create network, router and resource for this group + network := &networkTypes.Network{ + ID: fmt.Sprintf("network-%d", i), + AccountID: account.Id, + Name: fmt.Sprintf("Network for Group %d", i), + } + account.Networks = append(account.Networks, network) + + ips := account.GetTakenIPs() + peerIP, err := types.AllocatePeerIP(account.Network.Net, ips) + if err != nil { + return nil, "", "", err + } + + peerKey, _ := wgtypes.GeneratePrivateKey() + peer := &nbpeer.Peer{ + ID: fmt.Sprintf("peer-nr-%d", len(account.Peers)+1), + DNSLabel: fmt.Sprintf("peer-nr-%d", len(account.Peers)+1), + Key: peerKey.PublicKey().String(), + IP: peerIP, + Status: &nbpeer.PeerStatus{}, + UserID: regularUser, + Meta: nbpeer.PeerSystemMeta{ + Hostname: fmt.Sprintf("peer-nr-%d", len(account.Peers)+1), + GoOS: "linux", + Kernel: "Linux", + Core: "21.04", + Platform: "x86_64", + OS: "Ubuntu", + WtVersion: "development", + UIVersion: "development", + }, + } + account.Peers[peer.ID] = peer + + group.Peers = append(group.Peers, peer.ID) account.Groups[groupID] = group + router := &routerTypes.NetworkRouter{ + ID: fmt.Sprintf("network-router-%d", i), + NetworkID: network.ID, + AccountID: account.Id, + Peer: peer.ID, + PeerGroups: []string{}, + Masquerade: false, + Metric: 9999, + } + account.NetworkRouters = append(account.NetworkRouters, router) + + resource := &resourceTypes.NetworkResource{ + ID: fmt.Sprintf("network-resource-%d", i), + NetworkID: network.ID, + AccountID: account.Id, + Name: fmt.Sprintf("Network resource for Group %d", i), + Type: "host", + Address: "192.0.2.0/32", + } + account.NetworkResources = append(account.NetworkResources, resource) + + // Create a policy for this network resource + nrPolicy := &types.Policy{ + ID: fmt.Sprintf("policy-nr-%d", i), + Name: fmt.Sprintf("Policy for network resource %d", i), + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: fmt.Sprintf("rule-nr-%d", i), + Name: fmt.Sprintf("Rule for network resource %d", i), + Enabled: true, + Sources: []string{groupID}, + Destinations: []string{}, + DestinationResource: types.Resource{ + ID: resource.ID, + }, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + }, + }, + } + account.Policies = append(account.Policies, nrPolicy) + // Create a policy for this group - policy := &Policy{ + policy := &types.Policy{ ID: fmt.Sprintf("policy-%d", i), Name: fmt.Sprintf("Policy for Group %d", i), Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: fmt.Sprintf("rule-%d", i), Name: fmt.Sprintf("Rule for Group %d", i), @@ -772,8 +858,8 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou Sources: []string{groupID}, Destinations: []string{groupID}, Bidirectional: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -845,11 +931,12 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { maxMsPerOpCICD float64 }{ {"Small", 50, 5, 90, 120, 90, 120}, - {"Medium", 500, 100, 110, 140, 120, 200}, - {"Large", 5000, 200, 800, 1300, 2500, 3600}, + {"Medium", 500, 100, 110, 150, 120, 260}, + {"Large", 5000, 200, 800, 1390, 2500, 4600}, {"Small single", 50, 10, 90, 120, 90, 120}, {"Medium single", 500, 10, 110, 170, 120, 200}, - {"Large 5", 5000, 15, 1300, 1800, 5000, 6000}, + {"Large 5", 5000, 15, 1300, 2100, 5000, 7000}, + {"Extra Large", 2000, 2000, 1300, 2100, 4000, 6000}, } log.SetOutput(io.Discard) @@ -881,7 +968,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - manager.updateAccountPeers(ctx, account.Id) + manager.UpdateAccountPeers(ctx, account.Id) } duration := time.Since(start) @@ -899,7 +986,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, minExpected) } - if msPerOp > maxExpected { + if msPerOp > (maxExpected * 1.1) { b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, maxExpected) } }) @@ -939,8 +1026,8 @@ func TestToSyncResponse(t *testing.T) { Payload: "turn-user", Signature: "turn-pass", } - networkMap := &NetworkMap{ - Network: &Network{Net: *ipnet, Serial: 1000}, + networkMap := &types.NetworkMap{ + Network: &types.Network{Net: *ipnet, Serial: 1000}, Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}}, OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}}, Routes: []*nbroute.Route{ @@ -987,8 +1074,8 @@ func TestToSyncResponse(t *testing.T) { }, CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}}, }, - FirewallRules: []*FirewallRule{ - {PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"}, + FirewallRules: []*types.FirewallRule{ + {PeerIP: "192.168.1.2", Direction: types.FirewallRuleDirectionIN, Action: string(types.PolicyTrafficActionAccept), Protocol: string(types.PolicyRuleProtocolTCP), Port: "80"}, }, } dnsName := "example.com" @@ -1003,7 +1090,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, true) assert.NotNil(t, response) // assert peer config @@ -1088,7 +1175,7 @@ func Test_RegisterPeerByUser(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1099,13 +1186,13 @@ func Test_RegisterPeerByUser(t *testing.T) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) assert.NoError(t, err) - am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003" - _, err = store.GetAccount(context.Background(), existingAccountID) + _, err = s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) newPeer := &nbpeer.Peer{ @@ -1128,12 +1215,12 @@ func Test_RegisterPeerByUser(t *testing.T) { addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer) require.NoError(t, err) - peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key) + peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key) require.NoError(t, err) assert.Equal(t, peer.AccountID, existingAccountID) assert.Equal(t, peer.UserID, existingUserID) - account, err := store.GetAccount(context.Background(), existingAccountID) + account, err := s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) assert.Contains(t, account.Peers, addedPeer.ID) assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname) @@ -1152,7 +1239,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1163,13 +1250,13 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) assert.NoError(t, err) - am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" - _, err = store.GetAccount(context.Background(), existingAccountID) + _, err = s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) newPeer := &nbpeer.Peer{ @@ -1192,11 +1279,11 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { require.NoError(t, err) - peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key) require.NoError(t, err) assert.Equal(t, peer.AccountID, existingAccountID) - account, err := store.GetAccount(context.Background(), existingAccountID) + account, err := s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) assert.Contains(t, account.Peers, addedPeer.ID) assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID) @@ -1219,7 +1306,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { t.Skip("The SQLite store is not properly supported by Windows yet") } - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + s, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) if err != nil { t.Fatal(err) } @@ -1230,13 +1317,13 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) assert.NoError(t, err) - am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + am, err := BuildManager(context.Background(), s, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC" - _, err = store.GetAccount(context.Background(), existingAccountID) + _, err = s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) newPeer := &nbpeer.Peer{ @@ -1258,10 +1345,10 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) require.Error(t, err) - _, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key) require.Error(t, err) - account, err := store.GetAccount(context.Background(), existingAccountID) + account, err := s.GetAccount(context.Background(), existingAccountID) require.NoError(t, err) assert.NotContains(t, account.Peers, newPeer.ID) assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID) @@ -1284,7 +1371,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) require.NoError(t, err) - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -1304,26 +1391,26 @@ func TestPeerAccountPeersUpdate(t *testing.T) { require.NoError(t, err) // create a user with auto groups - _, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*User{ + _, err = manager.SaveOrAddUsers(context.Background(), account.Id, userID, []*types.User{ { Id: "regularUser1", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, AutoGroups: []string{"groupA"}, }, { Id: "regularUser2", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, AutoGroups: []string{"groupB"}, }, { Id: "regularUser3", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, AutoGroups: []string{"groupC"}, }, }, true) @@ -1464,15 +1551,15 @@ func TestPeerAccountPeersUpdate(t *testing.T) { // Adding peer to group linked with policy should update account peers and send peer update t.Run("adding peer to group linked with policy", func(t *testing.T) { - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go new file mode 100644 index 000000000..320aad027 --- /dev/null +++ b/management/server/permissions/manager.go @@ -0,0 +1,102 @@ +package permissions + +import ( + "context" + "errors" + "fmt" + + "github.com/netbirdio/netbird/management/server/settings" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/users" +) + +type Module string + +const ( + Networks Module = "networks" + Peers Module = "peers" + Groups Module = "groups" +) + +type Operation string + +const ( + Read Operation = "read" + Write Operation = "write" +) + +type Manager interface { + ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) +} + +type managerImpl struct { + userManager users.Manager + settingsManager settings.Manager +} + +type managerMock struct { +} + +func NewManager(userManager users.Manager, settingsManager settings.Manager) Manager { + return &managerImpl{ + userManager: userManager, + settingsManager: settingsManager, + } +} + +func (m *managerImpl) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { + user, err := m.userManager.GetUser(ctx, userID) + if err != nil { + return false, err + } + + if user == nil { + return false, errors.New("user not found") + } + + if user.AccountID != accountID { + return false, errors.New("user does not belong to account") + } + + switch user.Role { + case types.UserRoleAdmin, types.UserRoleOwner: + return true, nil + case types.UserRoleUser: + return m.validateRegularUserPermissions(ctx, accountID, userID, module, operation) + case types.UserRoleBillingAdmin: + return false, nil + default: + return false, errors.New("invalid role") + } +} + +func (m *managerImpl) validateRegularUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { + settings, err := m.settingsManager.GetSettings(ctx, accountID, userID) + if err != nil { + return false, fmt.Errorf("failed to get settings: %w", err) + } + if settings.RegularUsersViewBlocked { + return false, nil + } + + if operation == Write { + return false, nil + } + + if module == Peers { + return true, nil + } + + return false, nil +} + +func NewManagerMock() Manager { + return &managerMock{} +} + +func (m *managerMock) ValidateUserPermissions(ctx context.Context, accountID, userID string, module Module, operation Operation) (bool, error) { + if userID == "allowedUser" { + return true, nil + } + return false, nil +} diff --git a/management/server/policy.go b/management/server/policy.go index 2d3abc3f1..45b3e93e6 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -3,344 +3,21 @@ package server import ( "context" _ "embed" - "strconv" - "strings" + + "github.com/rs/xid" "github.com/netbirdio/netbird/management/proto" - "github.com/rs/xid" - log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" ) -// PolicyUpdateOperationType operation type -type PolicyUpdateOperationType int - -// PolicyTrafficActionType action type for the firewall -type PolicyTrafficActionType string - -// PolicyRuleProtocolType type of traffic -type PolicyRuleProtocolType string - -// PolicyRuleDirection direction of traffic -type PolicyRuleDirection string - -const ( - // PolicyTrafficActionAccept indicates that the traffic is accepted - PolicyTrafficActionAccept = PolicyTrafficActionType("accept") - // PolicyTrafficActionDrop indicates that the traffic is dropped - PolicyTrafficActionDrop = PolicyTrafficActionType("drop") -) - -const ( - // PolicyRuleProtocolALL type of traffic - PolicyRuleProtocolALL = PolicyRuleProtocolType("all") - // PolicyRuleProtocolTCP type of traffic - PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp") - // PolicyRuleProtocolUDP type of traffic - PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp") - // PolicyRuleProtocolICMP type of traffic - PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp") -) - -const ( - // PolicyRuleFlowDirect allows traffic from source to destination - PolicyRuleFlowDirect = PolicyRuleDirection("direct") - // PolicyRuleFlowBidirect allows traffic to both directions - PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect") -) - -const ( - // DefaultRuleName is a name for the Default rule that is created for every account - DefaultRuleName = "Default" - // DefaultRuleDescription is a description for the Default rule that is created for every account - DefaultRuleDescription = "This is a default rule that allows connections between all the resources" - // DefaultPolicyName is a name for the Default policy that is created for every account - DefaultPolicyName = "Default" - // DefaultPolicyDescription is a description for the Default policy that is created for every account - DefaultPolicyDescription = "This is a default policy that allows connections between all the resources" -) - -const ( - firewallRuleDirectionIN = 0 - firewallRuleDirectionOUT = 1 -) - -// PolicyUpdateOperation operation object with type and values to be applied -type PolicyUpdateOperation struct { - Type PolicyUpdateOperationType - Values []string -} - -// RulePortRange represents a range of ports for a firewall rule. -type RulePortRange struct { - Start uint16 - End uint16 -} - -// PolicyRule is the metadata of the policy -type PolicyRule struct { - // ID of the policy rule - ID string `gorm:"primaryKey"` - - // PolicyID is a reference to Policy that this object belongs - PolicyID string `json:"-" gorm:"index"` - - // Name of the rule visible in the UI - Name string - - // Description of the rule visible in the UI - Description string - - // Enabled status of rule in the system - Enabled bool - - // Action policy accept or drops packets - Action PolicyTrafficActionType - - // Destinations policy destination groups - Destinations []string `gorm:"serializer:json"` - - // Sources policy source groups - Sources []string `gorm:"serializer:json"` - - // Bidirectional define if the rule is applicable in both directions, sources, and destinations - Bidirectional bool - - // Protocol type of the traffic - Protocol PolicyRuleProtocolType - - // Ports or it ranges list - Ports []string `gorm:"serializer:json"` - - // PortRanges a list of port ranges. - PortRanges []RulePortRange `gorm:"serializer:json"` -} - -// Copy returns a copy of a policy rule -func (pm *PolicyRule) Copy() *PolicyRule { - rule := &PolicyRule{ - ID: pm.ID, - PolicyID: pm.PolicyID, - Name: pm.Name, - Description: pm.Description, - Enabled: pm.Enabled, - Action: pm.Action, - Destinations: make([]string, len(pm.Destinations)), - Sources: make([]string, len(pm.Sources)), - Bidirectional: pm.Bidirectional, - Protocol: pm.Protocol, - Ports: make([]string, len(pm.Ports)), - PortRanges: make([]RulePortRange, len(pm.PortRanges)), - } - copy(rule.Destinations, pm.Destinations) - copy(rule.Sources, pm.Sources) - copy(rule.Ports, pm.Ports) - copy(rule.PortRanges, pm.PortRanges) - return rule -} - -// Policy of the Rego query -type Policy struct { - // ID of the policy' - ID string `gorm:"primaryKey"` - - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - - // Name of the Policy - Name string - - // Description of the policy visible in the UI - Description string - - // Enabled status of the policy - Enabled bool - - // Rules of the policy - Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"` - - // SourcePostureChecks are ID references to Posture checks for policy source groups - SourcePostureChecks []string `gorm:"serializer:json"` -} - -// Copy returns a copy of the policy. -func (p *Policy) Copy() *Policy { - c := &Policy{ - ID: p.ID, - AccountID: p.AccountID, - Name: p.Name, - Description: p.Description, - Enabled: p.Enabled, - Rules: make([]*PolicyRule, len(p.Rules)), - SourcePostureChecks: make([]string, len(p.SourcePostureChecks)), - } - for i, r := range p.Rules { - c.Rules[i] = r.Copy() - } - copy(c.SourcePostureChecks, p.SourcePostureChecks) - return c -} - -// EventMeta returns activity event meta related to this policy -func (p *Policy) EventMeta() map[string]any { - return map[string]any{"name": p.Name} -} - -// UpgradeAndFix different version of policies to latest version -func (p *Policy) UpgradeAndFix() { - for _, r := range p.Rules { - // start migrate from version v0.20.3 - if r.Protocol == "" { - r.Protocol = PolicyRuleProtocolALL - } - if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional { - r.Bidirectional = true - } - // -- v0.20.4 - } -} - -// ruleGroups returns a list of all groups referenced in the policy's rules, -// including sources and destinations. -func (p *Policy) ruleGroups() []string { - groups := make([]string, 0) - for _, rule := range p.Rules { - groups = append(groups, rule.Sources...) - groups = append(groups, rule.Destinations...) - } - - return groups -} - -// FirewallRule is a rule of the firewall. -type FirewallRule struct { - // PeerIP of the peer - PeerIP string - - // Direction of the traffic - Direction int - - // Action of the traffic - Action string - - // Protocol of the traffic - Protocol string - - // Port of the traffic - Port string -} - -// getPeerConnectionResources for a given peer -// -// This function returns the list of peers and firewall rules that are applicable to a given peer. -func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { - generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) - for _, policy := range a.Policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) - destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) - - if rule.Bidirectional { - if peerInSources { - generateResources(rule, destinationPeers, firewallRuleDirectionIN) - } - if peerInDestinations { - generateResources(rule, sourcePeers, firewallRuleDirectionOUT) - } - } - - if peerInSources { - generateResources(rule, destinationPeers, firewallRuleDirectionOUT) - } - - if peerInDestinations { - generateResources(rule, sourcePeers, firewallRuleDirectionIN) - } - } - } - - return getAccumulatedResources() -} - -// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls -// -// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. -// It safe to call the generator function multiple times for same peer and different rules no duplicates will be -// generated. The accumulator function returns the result of all the generator calls. -func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { - rulesExists := make(map[string]struct{}) - peersExists := make(map[string]struct{}) - rules := make([]*FirewallRule, 0) - peers := make([]*nbpeer.Peer, 0) - - all, err := a.GetGroupAll() - if err != nil { - log.WithContext(ctx).Errorf("failed to get group all: %v", err) - all = &nbgroup.Group{} - } - - return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { - isAll := (len(all.Peers) - 1) == len(groupPeers) - for _, peer := range groupPeers { - if peer == nil { - continue - } - - if _, ok := peersExists[peer.ID]; !ok { - peers = append(peers, peer) - peersExists[peer.ID] = struct{}{} - } - - fr := FirewallRule{ - PeerIP: peer.IP.String(), - Direction: direction, - Action: string(rule.Action), - Protocol: string(rule.Protocol), - } - - if isAll { - fr.PeerIP = "0.0.0.0" - } - - ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + - fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") - if _, ok := rulesExists[ruleID]; ok { - continue - } - rulesExists[ruleID] = struct{}{} - - if len(rule.Ports) == 0 { - rules = append(rules, &fr) - continue - } - - for _, port := range rule.Ports { - pr := fr // clone rule and add set new port - pr.Port = port - rules = append(rules, &pr) - } - } - }, func() ([]*nbpeer.Peer, []*FirewallRule) { - return peers, rules - } -} - // GetPolicy from the store -func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -353,15 +30,15 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, status.NewAdminPermissionError() } - return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) + return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID) } // SavePolicy in the store -func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) { +func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -378,7 +55,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user var updateAccountPeers bool var action = activity.PolicyAdded - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { return err } @@ -388,7 +65,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } @@ -398,7 +75,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user saveFunc = transaction.SavePolicy } - return saveFunc(ctx, LockingStrengthUpdate, policy) + return saveFunc(ctx, store.LockingStrengthUpdate, policy) }) if err != nil { return nil, err @@ -407,7 +84,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return policy, nil @@ -418,7 +95,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -431,11 +108,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return status.NewAdminPermissionError() } - var policy *Policy + var policy *types.Policy var updateAccountPeers bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - policy, err = transaction.GetPolicyByID(ctx, LockingStrengthUpdate, accountID, policyID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID) if err != nil { return err } @@ -445,11 +122,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeletePolicy(ctx, LockingStrengthUpdate, accountID, policyID) + return transaction.DeletePolicy(ctx, store.LockingStrengthUpdate, accountID, policyID) }) if err != nil { return err @@ -458,15 +135,15 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil } // ListPolicies from the store. -func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*types.Policy, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -479,13 +156,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountPolicies(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) } // arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. -func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, accountID string, policy *Policy, isUpdate bool) (bool, error) { +func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { if isUpdate { - existingPolicy, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) if err != nil { return false, err } @@ -494,7 +171,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account return false, nil } - hasPeers, err := anyGroupHasPeers(ctx, transaction, policy.AccountID, existingPolicy.ruleGroups()) + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) if err != nil { return false, err } @@ -504,13 +181,13 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction Store, account } } - return anyGroupHasPeers(ctx, transaction, policy.AccountID, policy.ruleGroups()) + return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } // validatePolicy validates the policy and its rules. -func validatePolicy(ctx context.Context, transaction Store, accountID string, policy *Policy) error { +func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { if policy.ID != "" { - _, err := transaction.GetPolicyByID(ctx, LockingStrengthShare, accountID, policy.ID) + _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) if err != nil { return err } @@ -519,12 +196,12 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po policy.AccountID = accountID } - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups()) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups()) if err != nil { return err } - postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks) + postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks) if err != nil { return err } @@ -548,84 +225,6 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po return nil } -// getAllPeersFromGroups for given peer ID and list of groups -// -// Returns a list of peers from specified groups that pass specified posture checks -// and a boolean indicating if the supplied peer ID exists within these groups. -// -// Important: Posture checks are applicable only to source group peers, -// for destination group peers, call this method with an empty list of sourcePostureChecksIDs -func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { - peerInGroups := false - filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) - for _, g := range groups { - group, ok := a.Groups[g] - if !ok { - continue - } - - for _, p := range group.Peers { - peer, ok := a.Peers[p] - if !ok || peer == nil { - continue - } - - // validate the peer based on policy posture checks applied - isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) - if !isValid { - continue - } - - if _, ok := validatedPeersMap[peer.ID]; !ok { - continue - } - - if peer.ID == peerID { - peerInGroups = true - continue - } - - filteredPeers = append(filteredPeers, peer) - } - } - return filteredPeers, peerInGroups -} - -// validatePostureChecksOnPeer validates the posture checks on a peer -func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { - peer, ok := a.Peers[peerID] - if !ok && peer == nil { - return false - } - - for _, postureChecksID := range sourcePostureChecksID { - postureChecks := a.getPostureChecks(postureChecksID) - if postureChecks == nil { - continue - } - - for _, check := range postureChecks.GetChecks() { - isValid, err := check.Check(ctx, *peer) - if err != nil { - log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) - } - if !isValid { - return false - } - } - } - return true -} - -func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { - for _, postureChecks := range a.PostureChecks { - if postureChecks.ID == postureChecksID { - return postureChecks - } - } - return nil -} - // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string { validIDs := make([]string, 0, len(postureChecksIds)) @@ -639,7 +238,7 @@ func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureCh } // getValidGroupIDs filters and returns only the valid group IDs from the provided list. -func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string { +func getValidGroupIDs(groups map[string]*types.Group, groupIDs []string) []string { validIDs := make([]string, 0, len(groupIDs)) for _, id := range groupIDs { if _, exists := groups[id]; exists { @@ -651,7 +250,7 @@ func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []str } // toProtocolFirewallRules converts the firewall rules to the protocol firewall rules. -func toProtocolFirewallRules(rules []*FirewallRule) []*proto.FirewallRule { +func toProtocolFirewallRules(rules []*types.FirewallRule) []*proto.FirewallRule { result := make([]*proto.FirewallRule, len(rules)) for i := range rules { rule := rules[i] diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 62d80f46e..fab738abe 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -10,13 +10,13 @@ import ( "github.com/stretchr/testify/assert" "golang.org/x/exp/slices" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" ) func TestAccount_getPeersByPolicy(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -59,7 +59,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -87,21 +87,21 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupAll", }, @@ -116,15 +116,15 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Name: "Swarm", Description: "No description", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleSwarm", Name: "Swarm", Description: "No description", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupSwarm", "GroupAll", @@ -145,14 +145,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) { for _, p := range account.Peers { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), p.ID, validatedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers) assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") } }) t.Run("check first peer map details", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", validatedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers) assert.Len(t, peers, 7) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) @@ -160,45 +160,45 @@ func TestAccount_getPeersByPolicy(t *testing.T) { assert.Contains(t, peers, account.Peers["peerE"]) assert.Contains(t, peers, account.Peers["peerF"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "0.0.0.0", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "0.0.0.0", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.14.88", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.14.88", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -206,14 +206,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -221,14 +221,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.32.206", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.32.206", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -236,14 +236,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.250.202", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.250.202", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -251,14 +251,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.13.186", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.13.186", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -266,14 +266,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) { { PeerIP: "100.65.29.55", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.29.55", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -289,7 +289,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { } func TestAccount_getPeersByPolicyDirect(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -307,7 +307,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -332,21 +332,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Enabled: false, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleDefault", Name: "Default", Description: "This is a default rule that allows connections between all the resources", Bidirectional: true, Enabled: false, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupAll", }, @@ -361,15 +361,15 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { Name: "Swarm", Description: "No description", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleSwarm", Name: "Swarm", Description: "No description", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupSwarm", }, @@ -388,20 +388,20 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } t.Run("check first peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", @@ -416,20 +416,20 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.80.39", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", }, { PeerIP: "100.65.80.39", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", @@ -446,13 +446,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { account.Policies[1].Rules[0].Bidirectional = false t.Run("check first peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Contains(t, peers, account.Peers["peerC"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", Port: "", @@ -467,13 +467,13 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { }) t.Run("check second peer map directional only", func(t *testing.T) { - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Contains(t, peers, account.Peers["peerB"]) - epectedFirewallRules := []*FirewallRule{ + epectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.80.39", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", Port: "", @@ -489,7 +489,7 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) { } func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -582,7 +582,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "GroupAll": { ID: "GroupAll", Name: "All", @@ -630,17 +630,17 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }, } - account.Policies = append(account.Policies, &Policy{ + account.Policies = append(account.Policies, &types.Policy{ ID: "PolicyPostureChecks", Name: "", Description: "This is the policy with posture checks applied", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleSwarm", Name: "Swarm", Enabled: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, Destinations: []string{ "GroupSwarm", }, @@ -648,7 +648,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { "GroupAll", }, Bidirectional: false, - Protocol: PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"80"}, }, }, @@ -664,7 +664,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) { // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // will establish a connection with all source peers satisfying the NB posture check. - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -674,13 +674,13 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, 1) - expectedFirewallRules := []*FirewallRule{ + expectedFirewallRules := []*types.FirewallRule{ { PeerIP: "0.0.0.0", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", @@ -690,7 +690,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -700,7 +700,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) assert.Len(t, peers, 4) assert.Len(t, firewallRules, 4) assert.Contains(t, peers, account.Peers["peerA"]) @@ -715,19 +715,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules := account.getPeerConnectionResources(context.Background(), "peerB", approvedPeers) + peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // no connection should be established to any peer of destination group - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerI", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) assert.Len(t, peers, 0) assert.Len(t, firewallRules, 0) // peerC satisfy the NB posture check, should establish connection to all destination group peer's // We expect a single permissive firewall rule which all outgoing connections - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerC", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) @@ -742,14 +742,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // all source group peers satisfying the NB posture check should establish connection - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerE", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) assert.Len(t, peers, 3) assert.Len(t, firewallRules, 3) assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerD"]) - peers, firewallRules = account.getPeerConnectionResources(context.Background(), "peerA", approvedPeers) + peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers) assert.Len(t, peers, 5) // assert peers from Group Swarm assert.Contains(t, peers, account.Peers["peerD"]) @@ -760,45 +760,45 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // assert peers from Group All assert.Contains(t, peers, account.Peers["peerC"]) - expectedFirewallRules := []*FirewallRule{ + expectedFirewallRules := []*types.FirewallRule{ { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.32.206", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.13.186", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.29.55", - Direction: firewallRuleDirectionOUT, + Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.254.139", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "tcp", Port: "80", }, { PeerIP: "100.65.62.5", - Direction: firewallRuleDirectionIN, + Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "tcp", Port: "80", @@ -809,8 +809,8 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { }) } -func sortFunc() func(a *FirewallRule, b *FirewallRule) int { - return func(a, b *FirewallRule) int { +func sortFunc() func(a *types.FirewallRule, b *types.FirewallRule) int { + return func(a, b *types.FirewallRule) int { // Concatenate PeerIP and Direction as string for comparison aStr := a.PeerIP + fmt.Sprintf("%d", a.Direction) bStr := b.PeerIP + fmt.Sprintf("%d", b.Direction) @@ -829,7 +829,7 @@ func sortFunc() func(a *FirewallRule, b *FirewallRule) int { func TestPolicyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -858,9 +858,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - var policyWithGroupRulesNoPeers *Policy - var policyWithDestinationPeersOnly *Policy - var policyWithSourceAndDestinationPeers *Policy + var policyWithGroupRulesNoPeers *types.Policy + var policyWithDestinationPeersOnly *types.Policy + var policyWithSourceAndDestinationPeers *types.Policy // Saving policy with rule groups with no peers should not update account's peers and not send peer update t.Run("saving policy with rule groups with no peers", func(t *testing.T) { @@ -870,16 +870,16 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policyWithGroupRulesNoPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -901,17 +901,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupB"}, - Protocol: PolicyRuleProtocolTCP, + Protocol: types.PolicyRuleProtocolTCP, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -933,17 +933,17 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policyWithDestinationPeersOnly, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupC"}, Destinations: []string{"groupD"}, Bidirectional: true, - Protocol: PolicyRuleProtocolTCP, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, }, }, }) @@ -965,16 +965,16 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { close(done) }() - policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + policyWithSourceAndDestinationPeers, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ AccountID: account.Id, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupD"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, }) diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 0467efedb..1690f8e33 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -12,10 +12,12 @@ import ( "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -28,7 +30,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID return nil, status.NewAdminPermissionError() } - return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) } // SavePostureChecks saves a posture check. @@ -36,7 +38,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -53,7 +55,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI var isUpdate = postureChecks.ID != "" var action = activity.PostureCheckCreated - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil { return err } @@ -64,7 +66,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } @@ -72,7 +74,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI } postureChecks.AccountID = accountID - return transaction.SavePostureChecks(ctx, LockingStrengthUpdate, postureChecks) + return transaction.SavePostureChecks(ctx, store.LockingStrengthUpdate, postureChecks) }) if err != nil { return nil, err @@ -81,7 +83,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return postureChecks, nil @@ -92,7 +94,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -107,8 +109,8 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun var postureChecks *posture.Checks - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - postureChecks, err = transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) if err != nil { return err } @@ -117,11 +119,11 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return err } - if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { return err } - return transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID) + return transaction.DeletePostureChecks(ctx, store.LockingStrengthUpdate, accountID, postureChecksID) }) if err != nil { return err @@ -134,7 +136,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun // ListPostureChecks returns a list of posture checks. func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -147,11 +149,11 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) } // getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) { +func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, peerID string) ([]*posture.Checks, error) { peerPostureChecks := make(map[string]*posture.Checks) if len(account.PostureChecks) == 0 { @@ -172,15 +174,15 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID s } // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. -func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, accountID, postureCheckID string) (bool, error) { - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { return false, err } for _, policy := range policies { if slices.Contains(policy.SourcePostureChecks, postureCheckID) { - hasPeers, err := anyGroupHasPeers(ctx, transaction, accountID, policy.ruleGroups()) + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, policy.RuleGroups()) if err != nil { return false, err } @@ -195,21 +197,21 @@ func arePostureCheckChangesAffectPeers(ctx context.Context, transaction Store, a } // validatePostureChecks validates the posture checks. -func validatePostureChecks(ctx context.Context, transaction Store, accountID string, postureChecks *posture.Checks) error { +func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error { if err := postureChecks.Validate(); err != nil { return status.Errorf(status.InvalidArgument, err.Error()) //nolint } // If the posture check already has an ID, verify its existence in the store. if postureChecks.ID != "" { - if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil { + if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil { return err } return nil } // For new posture checks, ensure no duplicates by name. - checks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } @@ -226,7 +228,7 @@ func validatePostureChecks(ctx context.Context, transaction Store, accountID str } // addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. -func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { +func addPolicyPostureChecks(account *types.Account, peerID string, policy *types.Policy, peerPostureChecks map[string]*posture.Checks) error { isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) if err != nil { return err @@ -237,7 +239,7 @@ func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, pee } for _, sourcePostureCheckID := range policy.SourcePostureChecks { - postureCheck := account.getPostureChecks(sourcePostureCheckID) + postureCheck := account.GetPostureChecks(sourcePostureCheckID) if postureCheck == nil { return errors.New("failed to add policy posture checks: posture checks not found") } @@ -248,7 +250,7 @@ func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, pee } // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) { +func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *types.Policy) (bool, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue @@ -270,8 +272,8 @@ func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) } // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. -func isPostureCheckLinkedToPolicy(ctx context.Context, transaction Store, postureChecksID, accountID string) error { - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) +func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { return err } diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 93e5741cf..bad162f05 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -8,7 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/posture" ) @@ -92,17 +93,17 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { }) } -func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { +func initTestPostureChecksAccount(am *DefaultAccountManager) (*types.Account, error) { accountID := "testingAccount" domain := "example.com" - admin := &User{ + admin := &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, } - user := &User{ + user := &types.User{ Id: regularUserID, - Role: UserRoleUser, + Role: types.UserRoleUser, } account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) @@ -120,7 +121,7 @@ func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroups(context.Background(), account.Id, userID, []*group.Group{ + err := manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -209,15 +210,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { } }) - policy := &Policy{ + policy := &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -312,15 +313,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked posture check to policy with no peers should not trigger account peers update and not send peer update t.Run("updating linked posture check to policy with no peers", func(t *testing.T) { - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupC"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -356,15 +357,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) }) - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupB"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -395,15 +396,15 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { // Updating linked client posture check to policy where source has peers but destination does not, // should trigger account peers update and send peer update t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) { - _, err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + _, err = manager.SavePolicy(context.Background(), account.Id, userID, &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupB"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, SourcePostureChecks: []string{postureCheckB.ID}, @@ -443,18 +444,18 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { account, err := initTestPostureChecksAccount(manager) require.NoError(t, err, "failed to init testing account") - groupA := &group.Group{ + groupA := &types.Group{ ID: "groupA", AccountID: account.Id, Peers: []string{"peer1"}, } - groupB := &group.Group{ + groupB := &types.Group{ ID: "groupB", AccountID: account.Id, Peers: []string{}, } - err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB}) + err = manager.Store.SaveGroups(context.Background(), store.LockingStrengthUpdate, []*types.Group{groupA, groupB}) require.NoError(t, err, "failed to save groups") postureCheckA := &posture.Checks{ @@ -477,9 +478,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB) require.NoError(t, err, "failed to save postureCheckB") - policy := &Policy{ + policy := &types.Policy{ AccountID: account.Id, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, @@ -534,7 +535,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) { groupA.Peers = []string{} - err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA) + err = manager.Store.SaveGroup(context.Background(), store.LockingStrengthUpdate, groupA) require.NoError(t, err, "failed to save groups") result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) diff --git a/management/server/resource.go b/management/server/resource.go new file mode 100644 index 000000000..77a5612b3 --- /dev/null +++ b/management/server/resource.go @@ -0,0 +1,21 @@ +package server + +type ResourceType string + +const ( + // nolint + hostType ResourceType = "Host" + //nolint + subnetType ResourceType = "Subnet" + // nolint + domainType ResourceType = "Domain" +) + +func (p ResourceType) String() string { + return string(p) +} + +type Resource struct { + Type ResourceType + ID string +} diff --git a/management/server/route.go b/management/server/route.go index 23bea87e3..1eb51aea7 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -4,15 +4,12 @@ import ( "context" "fmt" "net/netip" - "slices" - "strconv" - "strings" "unicode/utf8" "github.com/rs/xid" - log "github.com/sirupsen/logrus" - nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" @@ -21,33 +18,9 @@ import ( "github.com/netbirdio/netbird/route" ) -// RouteFirewallRule a firewall rule applicable for a routed network. -type RouteFirewallRule struct { - // SourceRanges IP ranges of the routing peers. - SourceRanges []string - - // Action of the traffic when the rule is applicable - Action string - - // Destination a network prefix for the routed traffic - Destination string - - // Protocol of the traffic - Protocol string - - // Port of the traffic - Port uint16 - - // PortRange represents the range of ports for a firewall rule - PortRange RulePortRange - - // isDynamic indicates whether the rule is for DNS routing - IsDynamic bool -} - // GetRoute gets a route object from account and route IDs func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -56,11 +29,11 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) + return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID) } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. -func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { +func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { // routes can have both peer and peer_groups routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) @@ -238,7 +211,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri } if am.isRouteChangeAffectPeers(account, &newRoute) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) @@ -324,7 +297,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) @@ -356,7 +329,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta()) if am.isRouteChangeAffectPeers(account, routy) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return nil @@ -364,7 +337,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri // ListRoutes returns a list of routes from account func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -373,7 +346,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - return am.Store.GetAccountRoutes(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) } func toProtocolRoute(route *route.Route) *proto.Route { @@ -404,244 +377,7 @@ func getPlaceholderIP() netip.Prefix { return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32) } -// getPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. -func (a *Account) getPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { - routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) - - enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) - for _, route := range enabledRoutes { - // If no access control groups are specified, accept all traffic. - if len(route.AccessControlGroups) == 0 { - defaultPermit := getDefaultPermit(route) - routesFirewallRules = append(routesFirewallRules, defaultPermit...) - continue - } - - distributionPeers := a.getDistributionGroupsPeers(route) - - for _, accessGroup := range route.AccessControlGroups { - policies := getAllRoutePoliciesFromGroups(a, []string{accessGroup}) - rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) - routesFirewallRules = append(routesFirewallRules, rules...) - } - } - - return routesFirewallRules -} - -func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { - var fwRules []*RouteFirewallRule - for _, policy := range policies { - if !policy.Enabled { - continue - } - - for _, rule := range policy.Rules { - if !rule.Enabled { - continue - } - - rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) - rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, firewallRuleDirectionIN) - fwRules = append(fwRules, rules...) - } - } - return fwRules -} - -func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { - distPeersWithPolicy := make(map[string]struct{}) - for _, id := range rule.Sources { - group := a.Groups[id] - if group == nil { - continue - } - - for _, pID := range group.Peers { - if pID == peerID { - continue - } - _, distPeer := distributionPeers[pID] - _, valid := validatedPeersMap[pID] - if distPeer && valid { - distPeersWithPolicy[pID] = struct{}{} - } - } - } - - distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) - for pID := range distPeersWithPolicy { - peer := a.Peers[pID] - if peer == nil { - continue - } - distributionGroupPeers = append(distributionGroupPeers, peer) - } - return distributionGroupPeers -} - -func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { - distPeers := make(map[string]struct{}) - for _, id := range route.Groups { - group := a.Groups[id] - if group == nil { - continue - } - - for _, pID := range group.Peers { - distPeers[pID] = struct{}{} - } - } - return distPeers -} - -func getDefaultPermit(route *route.Route) []*RouteFirewallRule { - var rules []*RouteFirewallRule - - sources := []string{"0.0.0.0/0"} - if route.Network.Addr().Is6() { - sources = []string{"::/0"} - } - rule := RouteFirewallRule{ - SourceRanges: sources, - Action: string(PolicyTrafficActionAccept), - Destination: route.Network.String(), - Protocol: string(PolicyRuleProtocolALL), - IsDynamic: route.IsDynamic(), - } - - rules = append(rules, &rule) - - // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally - if route.IsDynamic() { - ruleV6 := rule - ruleV6.SourceRanges = []string{"::/0"} - rules = append(rules, &ruleV6) - } - - return rules -} - -// getAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups -// and returns a list of policies that have rules with destinations matching the specified groups. -func getAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { - routePolicies := make([]*Policy, 0) - for _, groupID := range accessControlGroups { - group, ok := account.Groups[groupID] - if !ok { - continue - } - - for _, policy := range account.Policies { - for _, rule := range policy.Rules { - exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool { - return groupID == group.ID - }) - if exist { - routePolicies = append(routePolicies, policy) - continue - } - } - } - } - - return routePolicies -} - -// generateRouteFirewallRules generates a list of firewall rules for a given route. -func generateRouteFirewallRules(ctx context.Context, route *route.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { - rulesExists := make(map[string]struct{}) - rules := make([]*RouteFirewallRule, 0) - - sourceRanges := make([]string, 0, len(groupPeers)) - for _, peer := range groupPeers { - if peer == nil { - continue - } - sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) - } - - baseRule := RouteFirewallRule{ - SourceRanges: sourceRanges, - Action: string(rule.Action), - Destination: route.Network.String(), - Protocol: string(rule.Protocol), - IsDynamic: route.IsDynamic(), - } - - // generate rule for port range - if len(rule.Ports) == 0 { - rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) - } else { - rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) - - } - - // TODO: generate IPv6 rules for dynamic routes - - return rules -} - -// generateRuleIDBase generates the base rule ID for checking duplicates. -func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string { - return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(firewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action -} - -// generateRulesForPeer generates rules for a given peer based on ports and port ranges. -func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { - rules := make([]*RouteFirewallRule, 0) - - ruleIDBase := generateRuleIDBase(rule, baseRule) - if len(rule.Ports) == 0 { - if len(rule.PortRanges) == 0 { - if _, ok := rulesExists[ruleIDBase]; !ok { - rulesExists[ruleIDBase] = struct{}{} - rules = append(rules, &baseRule) - } - } else { - for _, portRange := range rule.PortRanges { - ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End) - if _, ok := rulesExists[ruleID]; !ok { - rulesExists[ruleID] = struct{}{} - pr := baseRule - pr.PortRange = portRange - rules = append(rules, &pr) - } - } - } - return rules - } - - return rules -} - -// generateRulesWithPorts generates rules when specific ports are provided. -func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { - rules := make([]*RouteFirewallRule, 0) - ruleIDBase := generateRuleIDBase(rule, baseRule) - - for _, port := range rule.Ports { - ruleID := ruleIDBase + port - if _, ok := rulesExists[ruleID]; ok { - continue - } - rulesExists[ruleID] = struct{}{} - - pr := baseRule - p, err := strconv.ParseUint(port, 10, 16) - if err != nil { - log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID) - continue - } - - pr.Port = uint16(p) - rules = append(rules, &pr) - } - - return rules -} - -func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFirewallRule { +func toProtocolRoutesFirewallRules(rules []*types.RouteFirewallRule) []*proto.RouteFirewallRule { result := make([]*proto.RouteFirewallRule, len(rules)) for i := range rules { rule := rules[i] @@ -660,7 +396,7 @@ func toProtocolRoutesFirewallRules(rules []*RouteFirewallRule) []*proto.RouteFir // getProtoDirection converts the direction to proto.RuleDirection. func getProtoDirection(direction int) proto.RuleDirection { - if direction == firewallRuleDirectionOUT { + if direction == types.FirewallRuleDirectionOUT { return proto.RuleDirection_OUT } return proto.RuleDirection_IN @@ -668,7 +404,7 @@ func getProtoDirection(direction int) proto.RuleDirection { // getProtoAction converts the action to proto.RuleAction. func getProtoAction(action string) proto.RuleAction { - if action == string(PolicyTrafficActionDrop) { + if action == string(types.PolicyTrafficActionDrop) { return proto.RuleAction_DROP } return proto.RuleAction_ACCEPT @@ -676,14 +412,14 @@ func getProtoAction(action string) proto.RuleAction { // getProtoProtocol converts the protocol to proto.RuleProtocol. func getProtoProtocol(protocol string) proto.RuleProtocol { - switch PolicyRuleProtocolType(protocol) { - case PolicyRuleProtocolALL: + switch types.PolicyRuleProtocolType(protocol) { + case types.PolicyRuleProtocolALL: return proto.RuleProtocol_ALL - case PolicyRuleProtocolTCP: + case types.PolicyRuleProtocolTCP: return proto.RuleProtocol_TCP - case PolicyRuleProtocolUDP: + case types.PolicyRuleProtocolUDP: return proto.RuleProtocol_UDP - case PolicyRuleProtocolICMP: + case types.PolicyRuleProtocolICMP: return proto.RuleProtocol_ICMP default: return proto.RuleProtocol_UNKNOWN @@ -691,7 +427,7 @@ func getProtoProtocol(protocol string) proto.RuleProtocol { } // getProtoPortInfo converts the port info to proto.PortInfo. -func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { +func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo { var portInfo proto.PortInfo if rule.Port != 0 { portInfo.PortSelection = &proto.PortInfo_Port{Port: uint32(rule.Port)} @@ -708,6 +444,6 @@ func getProtoPortInfo(rule *RouteFirewallRule) *proto.PortInfo { // isRouteChangeAffectPeers checks if a given route affects peers by determining // if it has a routing peer, distribution, or peer groups that include peers -func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *Account, route *route.Route) bool { +func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool { return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" } diff --git a/management/server/route_test.go b/management/server/route_test.go index 8bf9a3aeb..5390cb66b 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -13,11 +13,16 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -1092,9 +1097,9 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id) require.NoError(t, err) - var groupHA1, groupHA2 *nbgroup.Group + var groupHA1, groupHA2 *types.Group for _, group := range groups { switch group.Name { case routeGroupHA1: @@ -1202,7 +1207,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.Len(t, peer2Routes.Routes, 1, "we should receive one route") require.True(t, peer1Routes.Routes[0].IsEqual(peer2Routes.Routes[0]), "routes should be the same for peers in the same group") - newGroup := &nbgroup.Group{ + newGroup := &types.Group{ ID: xid.New().String(), Name: "peer1 group", Peers: []string{peer1ID}, @@ -1255,10 +1260,10 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) { return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) } -func createRouterStore(t *testing.T) (Store, error) { +func createRouterStore(t *testing.T) (store.Store, error) { t.Helper() dataDir := t.TempDir() - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", dataDir) + store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), "", dataDir) if err != nil { return nil, err } @@ -1267,7 +1272,7 @@ func createRouterStore(t *testing.T) (Store, error) { return store, nil } -func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Account, error) { t.Helper() accountID := "testingAcc" @@ -1279,8 +1284,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return nil, err } - ips := account.getTakenIPs() - peer1IP, err := AllocatePeerIP(account.Network.Net, ips) + ips := account.GetTakenIPs() + peer1IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1306,8 +1311,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer1.ID] = peer1 - ips = account.getTakenIPs() - peer2IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer2IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1333,8 +1338,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer2.ID] = peer2 - ips = account.getTakenIPs() - peer3IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer3IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1360,8 +1365,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer3.ID] = peer3 - ips = account.getTakenIPs() - peer4IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer4IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1387,8 +1392,8 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er } account.Peers[peer4.ID] = peer4 - ips = account.getTakenIPs() - peer5IP, err := AllocatePeerIP(account.Network.Net, ips) + ips = account.GetTakenIPs() + peer5IP, err := types.AllocatePeerIP(account.Network.Net, ips) if err != nil { return nil, err } @@ -1439,7 +1444,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er return nil, err } - newGroup := []*nbgroup.Group{ + newGroup := []*types.Group{ { ID: routeGroup1, Name: routeGroup1, @@ -1491,7 +1496,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { peerKIp = "100.65.29.66" ) - account := &Account{ + account := &types.Account{ Peers: map[string]*nbpeer.Peer{ "peerA": { ID: "peerA", @@ -1555,7 +1560,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Status: &nbpeer.PeerStatus{}, }, }, - Groups: map[string]*nbgroup.Group{ + Groups: map[string]*types.Group{ "routingPeer1": { ID: "routingPeer1", Name: "RoutingPeer1", @@ -1685,19 +1690,19 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { AccessControlGroups: []string{"route4"}, }, }, - Policies: []*Policy{ + Policies: []*types.Policy{ { ID: "RuleRoute1", Name: "Route1", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute1", Name: "ruleRoute1", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Ports: []string{"80", "320"}, Sources: []string{ "dev", @@ -1712,15 +1717,15 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { ID: "RuleRoute2", Name: "Route2", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute2", Name: "ruleRoute2", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolTCP, - Action: PolicyTrafficActionAccept, - PortRanges: []RulePortRange{ + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + PortRanges: []types.RulePortRange{ { Start: 80, End: 350, @@ -1742,14 +1747,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { ID: "RuleRoute4", Name: "RuleRoute4", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute4", Name: "RuleRoute4", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolTCP, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, Ports: []string{"80"}, Sources: []string{ "restrictQA", @@ -1764,14 +1769,14 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { ID: "RuleRoute5", Name: "RuleRoute5", Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { ID: "RuleRoute5", Name: "RuleRoute5", Bidirectional: true, Enabled: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, Sources: []string{ "unrestrictedQA", }, @@ -1791,28 +1796,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { t.Run("check applied policies for the route", func(t *testing.T) { route1 := account.Routes["route1"] - policies := getAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) + policies := types.GetAllRoutePoliciesFromGroups(account, route1.AccessControlGroups) assert.Len(t, policies, 1) route2 := account.Routes["route2"] - policies = getAllRoutePoliciesFromGroups(account, route2.AccessControlGroups) + policies = types.GetAllRoutePoliciesFromGroups(account, route2.AccessControlGroups) assert.Len(t, policies, 1) route3 := account.Routes["route3"] - policies = getAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) + policies = types.GetAllRoutePoliciesFromGroups(account, route3.AccessControlGroups) assert.Len(t, policies, 0) }) t.Run("check peer routes firewall rules", func(t *testing.T) { - routesFirewallRules := account.getPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) + routesFirewallRules := account.GetPeerRoutesFirewallRules(context.Background(), "peerA", validatedPeers) assert.Len(t, routesFirewallRules, 4) - expectedRoutesFirewallRules := []*RouteFirewallRule{ + expectedRoutesFirewallRules := []*types.RouteFirewallRule{ { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerCIp), - fmt.Sprintf(AllowedIPsFormat, peerHIp), - fmt.Sprintf(AllowedIPsFormat, peerBIp), + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), }, Action: "accept", Destination: "192.168.0.0/16", @@ -1821,9 +1826,9 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerCIp), - fmt.Sprintf(AllowedIPsFormat, peerHIp), - fmt.Sprintf(AllowedIPsFormat, peerBIp), + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), }, Action: "accept", Destination: "192.168.0.0/16", @@ -1831,10 +1836,10 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Port: 320, }, } - additionalFirewallRule := []*RouteFirewallRule{ + additionalFirewallRule := []*types.RouteFirewallRule{ { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerJIp), + fmt.Sprintf(types.AllowedIPsFormat, peerJIp), }, Action: "accept", Destination: "192.168.10.0/16", @@ -1843,7 +1848,7 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { }, { SourceRanges: []string{ - fmt.Sprintf(AllowedIPsFormat, peerKIp), + fmt.Sprintf(types.AllowedIPsFormat, peerKIp), }, Action: "accept", Destination: "192.168.10.0/16", @@ -1854,27 +1859,28 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(append(expectedRoutesFirewallRules, additionalFirewallRule...))) // peerD is also the routing peer for route1, should contain same routes firewall rules as peerA - routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) + routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerD", validatedPeers) assert.Len(t, routesFirewallRules, 2) assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerE is a single routing peer for route 2 and route 3 - routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) + routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerE", validatedPeers) assert.Len(t, routesFirewallRules, 3) - expectedRoutesFirewallRules = []*RouteFirewallRule{ + expectedRoutesFirewallRules = []*types.RouteFirewallRule{ { SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, Action: "accept", Destination: existingNetwork.String(), Protocol: "tcp", - PortRange: RulePortRange{Start: 80, End: 350}, + PortRange: types.RulePortRange{Start: 80, End: 350}, }, { SourceRanges: []string{"0.0.0.0/0"}, Action: "accept", Destination: "192.0.2.0/32", Protocol: "all", + Domains: domain.List{"example.com"}, IsDynamic: true, }, { @@ -1882,20 +1888,21 @@ func TestAccount_getPeersRoutesFirewall(t *testing.T) { Action: "accept", Destination: "192.0.2.0/32", Protocol: "all", + Domains: domain.List{"example.com"}, IsDynamic: true, }, } assert.ElementsMatch(t, orderRuleSourceRanges(routesFirewallRules), orderRuleSourceRanges(expectedRoutesFirewallRules)) // peerC is part of route1 distribution groups but should not receive the routes firewall rules - routesFirewallRules = account.getPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) + routesFirewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) assert.Len(t, routesFirewallRules, 0) }) } // orderList is a helper function to sort a list of strings -func orderRuleSourceRanges(ruleList []*RouteFirewallRule) []*RouteFirewallRule { +func orderRuleSourceRanges(ruleList []*types.RouteFirewallRule) []*types.RouteFirewallRule { for _, rule := range ruleList { sort.Strings(rule.SourceRanges) } @@ -1909,7 +1916,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { account, err := initTestRouteAccount(t, manager) require.NoError(t, err, "failed to init testing account") - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "groupA", Name: "GroupA", @@ -2105,7 +2112,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1ID}, @@ -2145,7 +2152,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1ID}, @@ -2159,3 +2166,502 @@ func TestRouteAccountPeersUpdate(t *testing.T) { } }) } + +func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { + var ( + peerBIp = "100.65.80.39" + peerCIp = "100.65.254.139" + peerHIp = "100.65.29.55" + peerJIp = "100.65.29.65" + peerKIp = "100.65.29.66" + peerMIp = "100.65.29.67" + ) + + account := &types.Account{ + Peers: map[string]*nbpeer.Peer{ + "peerA": { + ID: "peerA", + IP: net.ParseIP("100.65.14.88"), + Key: "peerA", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerB": { + ID: "peerB", + IP: net.ParseIP(peerBIp), + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{}, + }, + "peerC": { + ID: "peerC", + IP: net.ParseIP(peerCIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerD": { + ID: "peerD", + IP: net.ParseIP("100.65.62.5"), + Key: "peerD", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerE": { + ID: "peerE", + IP: net.ParseIP("100.65.32.206"), + Key: "peerE", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerF": { + ID: "peerF", + IP: net.ParseIP("100.65.250.202"), + Status: &nbpeer.PeerStatus{}, + }, + "peerG": { + ID: "peerG", + IP: net.ParseIP("100.65.13.186"), + Status: &nbpeer.PeerStatus{}, + }, + "peerH": { + ID: "peerH", + IP: net.ParseIP(peerHIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerJ": { + ID: "peerJ", + IP: net.ParseIP(peerJIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerK": { + ID: "peerK", + IP: net.ParseIP(peerKIp), + Status: &nbpeer.PeerStatus{}, + }, + "peerL": { + ID: "peerL", + IP: net.ParseIP("100.65.19.186"), + Key: "peerL", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerM": { + ID: "peerM", + IP: net.ParseIP(peerMIp), + Status: &nbpeer.PeerStatus{}, + }, + }, + Groups: map[string]*types.Group{ + "router1": { + ID: "router1", + Name: "router1", + Peers: []string{ + "peerA", + }, + }, + "router2": { + ID: "router2", + Name: "router2", + Peers: []string{ + "peerD", + }, + }, + "finance": { + ID: "finance", + Name: "Finance", + Peers: []string{ + "peerF", + "peerG", + }, + }, + "dev": { + ID: "dev", + Name: "Dev", + Peers: []string{ + "peerC", + "peerH", + "peerB", + }, + Resources: []types.Resource{ + {ID: "resource2"}, + }, + }, + "qa": { + ID: "qa", + Name: "QA", + Peers: []string{ + "peerJ", + "peerK", + }, + }, + "restrictQA": { + ID: "restrictQA", + Name: "restrictQA", + Peers: []string{ + "peerJ", + }, + Resources: []types.Resource{ + {ID: "resource4"}, + }, + }, + "unrestrictedQA": { + ID: "unrestrictedQA", + Name: "unrestrictedQA", + Peers: []string{ + "peerK", + }, + Resources: []types.Resource{ + {ID: "resource4"}, + }, + }, + "contractors": { + ID: "contractors", + Name: "Contractors", + Peers: []string{}, + }, + "pipeline": { + ID: "pipeline", + Name: "Pipeline", + Peers: []string{"peerM"}, + }, + }, + Networks: []*networkTypes.Network{ + { + ID: "network1", + Name: "Finance Network", + }, + { + ID: "network2", + Name: "Devs Network", + }, + { + ID: "network3", + Name: "Contractors Network", + }, + { + ID: "network4", + Name: "QA Network", + }, + { + ID: "network5", + Name: "Pipeline Network", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: "router1", + NetworkID: "network1", + Peer: "peerE", + PeerGroups: nil, + Masquerade: false, + Metric: 9999, + }, + { + ID: "router2", + NetworkID: "network2", + PeerGroups: []string{"router1", "router2"}, + Masquerade: false, + Metric: 9999, + }, + { + ID: "router3", + NetworkID: "network3", + Peer: "peerE", + PeerGroups: []string{}, + }, + { + ID: "router4", + NetworkID: "network4", + PeerGroups: []string{"router1"}, + Masquerade: false, + Metric: 9999, + }, + { + ID: "router5", + NetworkID: "network5", + Peer: "peerL", + Masquerade: false, + Metric: 9999, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "resource1", + NetworkID: "network1", + Name: "Resource 1", + Type: "subnet", + Prefix: netip.MustParsePrefix("10.10.10.0/24"), + }, + { + ID: "resource2", + NetworkID: "network2", + Name: "Resource 2", + Type: "subnet", + Prefix: netip.MustParsePrefix("192.168.0.0/16"), + }, + { + ID: "resource3", + NetworkID: "network3", + Name: "Resource 3", + Type: "domain", + Domain: "example.com", + }, + { + ID: "resource4", + NetworkID: "network4", + Name: "Resource 4", + Type: "domain", + Domain: "example.com", + }, + { + ID: "resource5", + NetworkID: "network5", + Name: "Resource 5", + Type: "host", + Prefix: netip.MustParsePrefix("10.12.12.1/32"), + }, + }, + Policies: []*types.Policy{ + { + ID: "policyResource1", + Name: "Policy for resource 1", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource1", + Name: "ruleResource1", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + PortRanges: []types.RulePortRange{ + { + Start: 80, + End: 350, + }, { + Start: 80, + End: 350, + }, + }, + Sources: []string{ + "finance", + }, + DestinationResource: types.Resource{ID: "resource1"}, + }, + }, + }, + { + ID: "policyResource2", + Name: "Policy for resource 2", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource2", + Name: "ruleResource2", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + Ports: []string{"80", "320"}, + Sources: []string{"dev"}, + Destinations: []string{"dev"}, + }, + }, + }, + { + ID: "policyResource3", + Name: "policyResource3", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource3", + Name: "ruleResource3", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + Ports: []string{"80"}, + Sources: []string{"restrictQA"}, + Destinations: []string{"restrictQA"}, + }, + }, + }, + { + ID: "policyResource4", + Name: "policyResource4", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource4", + Name: "ruleResource4", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + Sources: []string{"unrestrictedQA"}, + Destinations: []string{"unrestrictedQA"}, + }, + }, + }, + { + ID: "policyResource5", + Name: "policyResource5", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource5", + Name: "ruleResource5", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + Ports: []string{"8080"}, + Sources: []string{"pipeline"}, + DestinationResource: types.Resource{ID: "resource5"}, + }, + }, + }, + }, + } + + validatedPeers := make(map[string]struct{}) + for p := range account.Peers { + validatedPeers[p] = struct{}{} + } + + t.Run("validate applied policies for different network resources", func(t *testing.T) { + // Test case: Resource1 is directly applied to the policy (policyResource1) + policies := account.GetPoliciesForNetworkResource("resource1") + assert.Len(t, policies, 1, "resource1 should have exactly 1 policy applied directly") + + // Test case: Resource2 is applied to an access control group (dev), + // which is part of the destination in the policy (policyResource2) + policies = account.GetPoliciesForNetworkResource("resource2") + assert.Len(t, policies, 1, "resource2 should have exactly 1 policy applied via access control group") + + // Test case: Resource3 is not applied to any access control group or policy + policies = account.GetPoliciesForNetworkResource("resource3") + assert.Len(t, policies, 0, "resource3 should have no policies applied") + + // Test case: Resource4 is applied to the access control groups (restrictQA and unrestrictedQA), + // which is part of the destination in the policies (policyResource3 and policyResource4) + policies = account.GetPoliciesForNetworkResource("resource4") + assert.Len(t, policies, 2, "resource4 should have exactly 2 policy applied via access control groups") + }) + + t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) { + resourcePoliciesMap := account.GetResourcePoliciesMap() + resourceRoutersMap := account.GetResourceRoutersMap() + _, routes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), "peerA", resourcePoliciesMap, resourceRoutersMap) + firewallRules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerA"], validatedPeers, routes, resourcePoliciesMap) + assert.Len(t, firewallRules, 4) + assert.Len(t, sourcePeers, 5) + + expectedFirewallRules := []*types.RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 80, + }, + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerCIp), + fmt.Sprintf(types.AllowedIPsFormat, peerHIp), + fmt.Sprintf(types.AllowedIPsFormat, peerBIp), + }, + Action: "accept", + Destination: "192.168.0.0/16", + Protocol: "all", + Port: 320, + }, + } + + additionalFirewallRules := []*types.RouteFirewallRule{ + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerJIp), + }, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "tcp", + Port: 80, + Domains: domain.List{"example.com"}, + IsDynamic: true, + }, + { + SourceRanges: []string{ + fmt.Sprintf(types.AllowedIPsFormat, peerKIp), + }, + Action: "accept", + Destination: "192.0.2.0/32", + Protocol: "all", + Domains: domain.List{"example.com"}, + IsDynamic: true, + }, + } + assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(append(expectedFirewallRules, additionalFirewallRules...))) + + // peerD is also the routing peer for resource2 + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerD", resourcePoliciesMap, resourceRoutersMap) + firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerD"], validatedPeers, routes, resourcePoliciesMap) + assert.Len(t, firewallRules, 2) + assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) + assert.Len(t, sourcePeers, 3) + + // peerE is a single routing peer for resource1 and resource3 + // PeerE should only receive rules for resource1 since resource3 has no applied policy + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerE", resourcePoliciesMap, resourceRoutersMap) + firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerE"], validatedPeers, routes, resourcePoliciesMap) + assert.Len(t, firewallRules, 1) + assert.Len(t, sourcePeers, 2) + + expectedFirewallRules = []*types.RouteFirewallRule{ + { + SourceRanges: []string{"100.65.250.202/32", "100.65.13.186/32"}, + Action: "accept", + Destination: "10.10.10.0/24", + Protocol: "tcp", + PortRange: types.RulePortRange{Start: 80, End: 350}, + }, + } + assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) + + // peerC is part of distribution groups for resource2 but should not receive the firewall rules + firewallRules = account.GetPeerRoutesFirewallRules(context.Background(), "peerC", validatedPeers) + assert.Len(t, firewallRules, 0) + + // peerL is the single routing peer for resource5 + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerL", resourcePoliciesMap, resourceRoutersMap) + assert.Len(t, routes, 1) + firewallRules = account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers["peerL"], validatedPeers, routes, resourcePoliciesMap) + assert.Len(t, firewallRules, 1) + assert.Len(t, sourcePeers, 1) + + expectedFirewallRules = []*types.RouteFirewallRule{ + { + SourceRanges: []string{"100.65.29.67/32"}, + Action: "accept", + Destination: "10.12.12.1/32", + Protocol: "tcp", + Port: 8080, + }, + } + assert.ElementsMatch(t, orderRuleSourceRanges(firewallRules), orderRuleSourceRanges(expectedFirewallRules)) + + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap) + assert.Len(t, routes, 1) + assert.Len(t, sourcePeers, 0) + }) +} diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go new file mode 100644 index 000000000..37bc9f549 --- /dev/null +++ b/management/server/settings/manager.go @@ -0,0 +1,37 @@ +package settings + +import ( + "context" + + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Manager interface { + GetSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) +} + +type managerImpl struct { + store store.Store +} + +type managerMock struct { +} + +func NewManager(store store.Store) Manager { + return &managerImpl{ + store: store, + } +} + +func (m *managerImpl) GetSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + return m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) +} + +func NewManagerMock() Manager { + return &managerMock{} +} + +func (m *managerMock) GetSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) { + return &types.Settings{}, nil +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index ef431d3ad..9a4a1efb8 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -2,34 +2,16 @@ package server import ( "context" - "crypto/sha256" - b64 "encoding/base64" - "hash/fnv" "slices" - "strconv" - "strings" "time" - "unicode/utf8" - "github.com/google/uuid" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/status" -) - -const ( - // SetupKeyReusable is a multi-use key (can be used for multiple machines) - SetupKeyReusable SetupKeyType = "reusable" - // SetupKeyOneOff is a single use key (can be used only once) - SetupKeyOneOff SetupKeyType = "one-off" - - // DefaultSetupKeyDuration = 1 month - DefaultSetupKeyDuration = 24 * 30 * time.Hour - // DefaultSetupKeyName is a default name of the default setup key - DefaultSetupKeyName = "Default key" - // SetupKeyUnlimitedUsage indicates an unlimited usage of a setup key - SetupKeyUnlimitedUsage = 0 + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) const ( @@ -67,169 +49,14 @@ type SetupKeyUpdateOperation struct { Values []string } -// SetupKeyType is the type of setup key -type SetupKeyType string - -// SetupKey represents a pre-authorized key used to register machines (peers) -type SetupKey struct { - Id string - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - Key string - KeySecret string - Name string - Type SetupKeyType - CreatedAt time.Time - ExpiresAt time.Time - UpdatedAt time.Time `gorm:"autoUpdateTime:false"` - // Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes) - Revoked bool - // UsedTimes indicates how many times the key was used - UsedTimes int - // LastUsed last time the key was used for peer registration - LastUsed time.Time - // AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register - AutoGroups []string `gorm:"serializer:json"` - // UsageLimit indicates the number of times this key can be used to enroll a machine. - // The value of 0 indicates the unlimited usage. - UsageLimit int - // Ephemeral indicate if the peers will be ephemeral or not - Ephemeral bool -} - -// Copy copies SetupKey to a new object -func (key *SetupKey) Copy() *SetupKey { - autoGroups := make([]string, len(key.AutoGroups)) - copy(autoGroups, key.AutoGroups) - if key.UpdatedAt.IsZero() { - key.UpdatedAt = key.CreatedAt - } - return &SetupKey{ - Id: key.Id, - AccountID: key.AccountID, - Key: key.Key, - KeySecret: key.KeySecret, - Name: key.Name, - Type: key.Type, - CreatedAt: key.CreatedAt, - ExpiresAt: key.ExpiresAt, - UpdatedAt: key.UpdatedAt, - Revoked: key.Revoked, - UsedTimes: key.UsedTimes, - LastUsed: key.LastUsed, - AutoGroups: autoGroups, - UsageLimit: key.UsageLimit, - Ephemeral: key.Ephemeral, - } -} - -// EventMeta returns activity event meta related to the setup key -func (key *SetupKey) EventMeta() map[string]any { - return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret} -} - -// hiddenKey returns the Key value hidden with "*" and a 5 character prefix. -// E.g., "831F6*******************************" -func hiddenKey(key string, length int) string { - prefix := key[0:5] - if length > utf8.RuneCountInString(key) { - length = utf8.RuneCountInString(key) - len(prefix) - } - return prefix + strings.Repeat("*", length) -} - -// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now -func (key *SetupKey) IncrementUsage() *SetupKey { - c := key.Copy() - c.UsedTimes++ - c.LastUsed = time.Now().UTC() - return c -} - -// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to -func (key *SetupKey) IsValid() bool { - return !key.IsRevoked() && !key.IsExpired() && !key.IsOverUsed() -} - -// IsRevoked if key was revoked -func (key *SetupKey) IsRevoked() bool { - return key.Revoked -} - -// IsExpired if key was expired -func (key *SetupKey) IsExpired() bool { - if key.ExpiresAt.IsZero() { - return false - } - return time.Now().After(key.ExpiresAt) -} - -// IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage. -func (key *SetupKey) IsOverUsed() bool { - limit := key.UsageLimit - if key.Type == SetupKeyOneOff { - limit = 1 - } - return limit > 0 && key.UsedTimes >= limit -} - -// GenerateSetupKey generates a new setup key -func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string, - usageLimit int, ephemeral bool) (*SetupKey, string) { - key := strings.ToUpper(uuid.New().String()) - limit := usageLimit - if t == SetupKeyOneOff { - limit = 1 - } - - expiresAt := time.Time{} - if validFor != 0 { - expiresAt = time.Now().UTC().Add(validFor) - } - - hashedKey := sha256.Sum256([]byte(key)) - encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) - - return &SetupKey{ - Id: strconv.Itoa(int(Hash(key))), - Key: encodedHashedKey, - KeySecret: hiddenKey(key, 4), - Name: name, - Type: t, - CreatedAt: time.Now().UTC(), - ExpiresAt: expiresAt, - UpdatedAt: time.Now().UTC(), - Revoked: false, - UsedTimes: 0, - AutoGroups: autoGroups, - UsageLimit: limit, - Ephemeral: ephemeral, - }, key -} - -// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration -func GenerateDefaultSetupKey() (*SetupKey, string) { - return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, - SetupKeyUnlimitedUsage, false) -} - -func Hash(s string) uint32 { - h := fnv.New32a() - _, err := h.Write([]byte(s)) - if err != nil { - panic(err) - } - return h.Sum32() -} - // CreateSetupKey generates a new setup key with a given name, type, list of groups IDs to auto-assign to peers registered with this key, // and adds it to the specified account. A list of autoGroups IDs can be empty. -func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, - expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) { +func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType types.SetupKeyType, + expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*types.SetupKey, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -242,22 +69,22 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s return nil, status.NewAdminPermissionError() } - var setupKey *SetupKey + var setupKey *types.SetupKey var plainKey string var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil { return err } - setupKey, plainKey = GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) + setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) setupKey.AccountID = accountID events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, autoGroups, nil, setupKey) eventsToStore = append(eventsToStore, events...) - return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, setupKey) + return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, setupKey) }) if err != nil { return nil, err @@ -278,7 +105,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s // Due to the unique nature of a SetupKey certain properties must not be overwritten // (e.g. the key itself, creation date, ID, etc). // These properties are overwritten: AutoGroups, Revoked (only from false to true), and the UpdatedAt. The rest is copied from the existing key. -func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *SetupKey, userID string) (*SetupKey, error) { +func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID string, keyToSave *types.SetupKey, userID string) (*types.SetupKey, error) { if keyToSave == nil { return nil, status.Errorf(status.InvalidArgument, "provided setup key to update is nil") } @@ -286,7 +113,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -299,16 +126,16 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, status.NewAdminPermissionError() } - var oldKey *SetupKey - var newKey *SetupKey + var oldKey *types.SetupKey + var newKey *types.SetupKey var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil { return err } - oldKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id) + oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id) if err != nil { return err } @@ -323,13 +150,13 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str newKey.Revoked = keyToSave.Revoked newKey.UpdatedAt = time.Now().UTC() - addedGroups := difference(newKey.AutoGroups, oldKey.AutoGroups) - removedGroups := difference(oldKey.AutoGroups, newKey.AutoGroups) + addedGroups := util.Difference(newKey.AutoGroups, oldKey.AutoGroups) + removedGroups := util.Difference(oldKey.AutoGroups, newKey.AutoGroups) events := am.prepareSetupKeyEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups, oldKey) eventsToStore = append(eventsToStore, events...) - return transaction.SaveSetupKey(ctx, LockingStrengthUpdate, newKey) + return transaction.SaveSetupKey(ctx, store.LockingStrengthUpdate, newKey) }) if err != nil { return nil, err @@ -347,8 +174,8 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str } // ListSetupKeys returns a list of all setup keys of the account -func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -361,12 +188,12 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u return nil, status.NewAdminPermissionError() } - return am.Store.GetAccountSetupKeys(ctx, LockingStrengthShare, accountID) + return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. -func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) +func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) { + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -379,7 +206,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, status.NewAdminPermissionError() } - setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) + setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) if err != nil { return nil, err } @@ -394,7 +221,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use // DeleteSetupKey removes the setup key from the account func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return err } @@ -407,15 +234,15 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return status.NewAdminPermissionError() } - var deletedSetupKey *SetupKey + var deletedSetupKey *types.SetupKey - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) if err != nil { return err } - return transaction.DeleteSetupKey(ctx, LockingStrengthUpdate, accountID, keyID) + return transaction.DeleteSetupKey(ctx, store.LockingStrengthUpdate, accountID, keyID) }) if err != nil { return err @@ -426,8 +253,8 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID, return nil } -func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountID string, autoGroupIDs []string) error { - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, autoGroupIDs) +func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error { + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs) if err != nil { return err } @@ -447,11 +274,11 @@ func validateSetupKeyAutoGroups(ctx context.Context, transaction Store, accountI } // prepareSetupKeyEvents prepares a list of event functions to be stored. -func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction Store, accountID, userID string, addedGroups, removedGroups []string, key *SetupKey) []func() { +func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, transaction store.Store, accountID, userID string, addedGroups, removedGroups []string, key *types.SetupKey) []func() { var eventsToStore []func() modifiedGroups := slices.Concat(addedGroups, removedGroups) - groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, modifiedGroups) + groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) if err != nil { log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) return nil diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 614547c60..f728db5d4 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -15,7 +15,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/activity" - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/types" ) func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { @@ -30,7 +30,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), account.Id, userID, []*types.Group{ { ID: "group_1", Name: "group_name_1", @@ -49,15 +49,15 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { expiresIn := time.Hour keyName := "my-test-key" - key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, - SetupKeyUnlimitedUsage, userID, false) + key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, types.SetupKeyReusable, expiresIn, []string{}, + types.SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) } autoGroups := []string{"group_1", "group_2"} revoked := true - newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &types.SetupKey{ Id: key.Id, Revoked: revoked, AutoGroups: autoGroups, @@ -85,7 +85,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { // saving setup key with All group assigned to auto groups should return error autoGroups = append(autoGroups, groupAll.ID) - _, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + _, err = manager.SaveSetupKey(context.Background(), account.Id, &types.SetupKey{ Id: key.Id, Revoked: revoked, AutoGroups: autoGroups, @@ -105,7 +105,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -114,7 +114,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -167,8 +167,8 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { for _, tCase := range []testCase{testCase1, testCase2, testCase3} { t.Run(tCase.name, func(t *testing.T) { - key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, - tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) + key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, types.SetupKeyReusable, expiresIn, + tCase.expectedGroups, types.SetupKeyUnlimitedUsage, userID, false) if tCase.expectedFailure { if err == nil { @@ -182,7 +182,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } assertKey(t, key, tCase.expectedKeyName, false, tCase.expectedType, tCase.expectedUsedTimes, - tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(Hash(key.Key))), + tCase.expectedCreatedAt, tCase.expectedExpiresAt, strconv.Itoa(int(types.Hash(key.Key))), tCase.expectedUpdatedAt, tCase.expectedGroups, false) // check the corresponding events that should have been generated @@ -210,7 +210,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false) + plainKey, err := manager.CreateSetupKey(context.Background(), account.Id, "key1", types.SetupKeyReusable, time.Hour, nil, types.SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) } @@ -258,10 +258,10 @@ func TestGenerateDefaultSetupKey(t *testing.T) { expectedExpiresAt := time.Now().UTC().Add(24 * 30 * time.Hour) var expectedAutoGroups []string - key, plainKey := GenerateDefaultSetupKey() + key, plainKey := types.GenerateDefaultSetupKey() assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) + expectedExpiresAt, strconv.Itoa(int(types.Hash(plainKey))), expectedUpdatedAt, expectedAutoGroups, true) } @@ -275,48 +275,48 @@ func TestGenerateSetupKey(t *testing.T) { expectedUpdatedAt := time.Now().UTC() var expectedAutoGroups []string - key, plain := GenerateSetupKey(expectedName, SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, plain := types.GenerateSetupKey(expectedName, types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) assertKey(t, key, expectedName, expectedRevoke, expectedType, expectedUsedTimes, expectedCreatedAt, - expectedExpiresAt, strconv.Itoa(int(Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) + expectedExpiresAt, strconv.Itoa(int(types.Hash(plain))), expectedUpdatedAt, expectedAutoGroups, true) } func TestSetupKey_IsValid(t *testing.T) { - validKey, _ := GenerateSetupKey("valid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + validKey, _ := types.GenerateSetupKey("valid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) if !validKey.IsValid() { t.Errorf("expected key to be valid, got invalid %v", validKey) } // expired - expiredKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, -time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + expiredKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, -time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) if expiredKey.IsValid() { t.Errorf("expected key to be invalid due to expiration, got valid %v", expiredKey) } // revoked - revokedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + revokedKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) revokedKey.Revoked = true if revokedKey.IsValid() { t.Errorf("expected revoked key to be invalid, got valid %v", revokedKey) } // overused - overUsedKey, _ := GenerateSetupKey("invalid key", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + overUsedKey, _ := types.GenerateSetupKey("invalid key", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) overUsedKey.UsedTimes = 1 if overUsedKey.IsValid() { t.Errorf("expected overused key to be invalid, got valid %v", overUsedKey) } // overused - reusableKey, _ := GenerateSetupKey("valid key", SetupKeyReusable, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + reusableKey, _ := types.GenerateSetupKey("valid key", types.SetupKeyReusable, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) reusableKey.UsedTimes = 99 if !reusableKey.IsValid() { t.Errorf("expected reusable key to be valid when used many times, got valid %v", reusableKey) } } -func assertKey(t *testing.T, key *SetupKey, expectedName string, expectedRevoke bool, expectedType string, +func assertKey(t *testing.T, key *types.SetupKey, expectedName string, expectedRevoke bool, expectedType string, expectedUsedTimes int, expectedCreatedAt time.Time, expectedExpiresAt time.Time, expectedID string, expectedUpdatedAt time.Time, expectedAutoGroups []string, expectHashedKey bool) { t.Helper() @@ -388,7 +388,7 @@ func isValidBase64SHA256(encodedKey string) bool { func TestSetupKey_Copy(t *testing.T) { - key, _ := GenerateSetupKey("key name", SetupKeyOneOff, time.Hour, []string{}, SetupKeyUnlimitedUsage, false) + key, _ := types.GenerateSetupKey("key name", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) keyCopy := key.Copy() assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id, @@ -399,22 +399,22 @@ func TestSetupKey_Copy(t *testing.T) { func TestSetupKeyAccountPeersUpdate(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, }) assert.NoError(t, err) - policy := &Policy{ + policy := &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"group"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -426,7 +426,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) }) - var setupKey *SetupKey + var setupKey *types.SetupKey // Creating setup key should not update account peers and not send peer update t.Run("creating setup key", func(t *testing.T) { @@ -436,7 +436,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { close(done) }() - setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", SetupKeyReusable, time.Hour, nil, 999, userID, false) + setupKey, err = manager.CreateSetupKey(context.Background(), account.Id, "key1", types.SetupKeyReusable, time.Hour, nil, 999, userID, false) assert.NoError(t, err) select { @@ -477,7 +477,7 @@ func TestDefaultAccountManager_CreateSetupKey_ShouldNotAllowToUpdateRevokedKey(t t.Fatal(err) } - key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", SetupKeyReusable, time.Hour, nil, SetupKeyUnlimitedUsage, userID, false) + key, err := manager.CreateSetupKey(context.Background(), account.Id, "testName", types.SetupKeyReusable, time.Hour, nil, types.SetupKeyUnlimitedUsage, userID, false) assert.NoError(t, err) // revoke the key diff --git a/management/server/status/error.go b/management/server/status/error.go index 59f436f5b..d9cab0231 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -154,3 +154,35 @@ func NewPolicyNotFoundError(policyID string) error { func NewNameServerGroupNotFoundError(nsGroupID string) error { return Errorf(NotFound, "nameserver group: %s not found", nsGroupID) } + +// NewNetworkNotFoundError creates a new Error with NotFound type for a missing network. +func NewNetworkNotFoundError(networkID string) error { + return Errorf(NotFound, "network: %s not found", networkID) +} + +// NewNetworkRouterNotFoundError creates a new Error with NotFound type for a missing network router. +func NewNetworkRouterNotFoundError(routerID string) error { + return Errorf(NotFound, "network router: %s not found", routerID) +} + +// NewNetworkResourceNotFoundError creates a new Error with NotFound type for a missing network resource. +func NewNetworkResourceNotFoundError(resourceID string) error { + return Errorf(NotFound, "network resource: %s not found", resourceID) +} + +// NewPermissionDeniedError creates a new Error with PermissionDenied type for a permission denied error. +func NewPermissionDeniedError() error { + return Errorf(PermissionDenied, "permission denied") +} + +func NewPermissionValidationError(err error) error { + return Errorf(PermissionDenied, "failed to vlidate user permissions: %s", err) +} + +func NewResourceNotPartOfNetworkError(resourceID, networkID string) error { + return Errorf(BadRequest, "resource %s is not part of the network %s", resourceID, networkID) +} + +func NewRouterNotPartOfNetworkError(routerID, networkID string) error { + return Errorf(BadRequest, "router %s is not part of the network %s", routerID, networkID) +} diff --git a/management/server/file_store.go b/management/server/store/file_store.go similarity index 89% rename from management/server/file_store.go rename to management/server/store/file_store.go index f375fb990..f40a0392b 100644 --- a/management/server/file_store.go +++ b/management/server/store/file_store.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -11,9 +11,9 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" - nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -22,7 +22,7 @@ const storeFileName = "store.json" // FileStore represents an account storage backed by a file persisted to disk type FileStore struct { - Accounts map[string]*Account + Accounts map[string]*types.Account SetupKeyID2AccountID map[string]string `json:"-"` PeerKeyID2AccountID map[string]string `json:"-"` PeerID2AccountID map[string]string `json:"-"` @@ -55,7 +55,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { if _, err := os.Stat(file); os.IsNotExist(err) { // create a new FileStore if previously didn't exist (e.g. first run) s := &FileStore{ - Accounts: make(map[string]*Account), + Accounts: make(map[string]*types.Account), mux: sync.Mutex{}, SetupKeyID2AccountID: make(map[string]string), PeerKeyID2AccountID: make(map[string]string), @@ -92,12 +92,14 @@ func restore(ctx context.Context, file string) (*FileStore, error) { for accountID, account := range store.Accounts { if account.Settings == nil { - account.Settings = &Settings{ + account.Settings = &types.Settings{ PeerLoginExpirationEnabled: false, - PeerLoginExpiration: DefaultPeerLoginExpiration, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, PeerInactivityExpirationEnabled: false, - PeerInactivityExpiration: DefaultPeerInactivityExpiration, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + + RoutingPeerDNSResolutionEnabled: true, } } @@ -112,7 +114,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { for _, user := range account.Users { store.UserID2AccountID[user.Id] = accountID if user.Issued == "" { - user.Issued = UserIssuedAPI + user.Issued = types.UserIssuedAPI account.Users[user.Id] = user } @@ -122,7 +124,7 @@ func restore(ctx context.Context, file string) (*FileStore, error) { } } - if account.Domain != "" && account.DomainCategory == PrivateCategory && + if account.Domain != "" && account.DomainCategory == types.PrivateCategory && account.IsDomainPrimaryAccount { store.PrivateDomain2AccountID[account.Domain] = accountID } @@ -134,20 +136,20 @@ func restore(ctx context.Context, file string) (*FileStore, error) { policy.UpgradeAndFix() } if account.Policies == nil { - account.Policies = make([]*Policy, 0) + account.Policies = make([]*types.Policy, 0) } // for data migration. Can be removed once most base will be with labels - existingLabels := account.getPeerDNSLabels() + existingLabels := account.GetPeerDNSLabels() if len(existingLabels) != len(account.Peers) { - addPeerLabelsToAccount(ctx, account, existingLabels) + types.AddPeerLabelsToAccount(ctx, account, existingLabels) } // TODO: delete this block after migration // Set API as issuer for groups which has not this field for _, group := range account.Groups { if group.Issued == "" { - group.Issued = nbgroup.GroupIssuedAPI + group.Issued = types.GroupIssuedAPI } } @@ -236,7 +238,7 @@ func (s *FileStore) persist(ctx context.Context, file string) error { } // GetAllAccounts returns all accounts -func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { +func (s *FileStore) GetAllAccounts(_ context.Context) (all []*types.Account) { s.mux.Lock() defer s.mux.Unlock() for _, a := range s.Accounts { @@ -257,6 +259,6 @@ func (s *FileStore) Close(ctx context.Context) error { } // GetStoreEngine returns FileStoreEngine -func (s *FileStore) GetStoreEngine() StoreEngine { +func (s *FileStore) GetStoreEngine() Engine { return FileStoreEngine } diff --git a/management/server/sql_store.go b/management/server/store/sql_store.go similarity index 76% rename from management/server/sql_store.go rename to management/server/store/sql_store.go index 1fd8ae2aa..62b004f9c 100644 --- a/management/server/sql_store.go +++ b/management/server/store/sql_store.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -24,11 +24,14 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" - nbgroup "github.com/netbirdio/netbird/management/server/group" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) @@ -49,7 +52,7 @@ type SqlStore struct { globalAccountLock sync.Mutex metrics telemetry.AppMetrics installationPK int - storeEngine StoreEngine + storeEngine Engine } type installation struct { @@ -60,7 +63,7 @@ type installation struct { type migrationFunc func(*gorm.DB) error // NewSqlStore creates a new SqlStore instance. -func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metrics telemetry.AppMetrics) (*SqlStore, error) { +func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics telemetry.AppMetrics) (*SqlStore, error) { sql, err := db.DB() if err != nil { return nil, err @@ -86,9 +89,10 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine StoreEngine, metr return nil, fmt.Errorf("migrate: %w", err) } err = db.AutoMigrate( - &SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &nbgroup.Group{}, - &Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, + &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, + &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, + &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, ) if err != nil { return nil, fmt.Errorf("auto migrate: %w", err) @@ -151,7 +155,7 @@ func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (u return unlock } -func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error { +func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error { start := time.Now() defer func() { elapsed := time.Since(start) @@ -201,7 +205,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *Account) error { } // generateAccountSQLTypes generates the GORM compatible types for the account -func generateAccountSQLTypes(account *Account) { +func generateAccountSQLTypes(account *types.Account) { for _, key := range account.SetupKeys { account.SetupKeysG = append(account.SetupKeysG, *key) } @@ -238,7 +242,7 @@ func generateAccountSQLTypes(account *Account) { // checkAccountDomainBeforeSave temporary method to troubleshoot an issue with domains getting blank func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, newDomain string) { - var acc Account + var acc types.Account var domain string result := s.db.Model(&acc).Select("domain").Where(idQueryCondition, accountID).First(&domain) if result.Error != nil { @@ -252,7 +256,7 @@ func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, } } -func (s *SqlStore) DeleteAccount(ctx context.Context, account *Account) error { +func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) error { start := time.Now() err := s.db.Transaction(func(tx *gorm.DB) error { @@ -333,14 +337,14 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. } func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { - accountCopy := Account{ + accountCopy := types.Account{ Domain: domain, DomainCategory: category, IsDomainPrimaryAccount: isPrimaryDomain, } fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} - result := s.db.Model(&Account{}). + result := s.db.Model(&types.Account{}). Select(fieldsToUpdate). Where(idQueryCondition, accountID). Updates(&accountCopy) @@ -402,8 +406,8 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P // SaveUsers saves the given list of users to the database. // It updates existing users if a conflict occurs. -func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { - usersToSave := make([]User, 0, len(users)) +func (s *SqlStore) SaveUsers(accountID string, users map[string]*types.User) error { + usersToSave := make([]types.User, 0, len(users)) for _, user := range users { user.AccountID = accountID for id, pat := range user.PATs { @@ -423,7 +427,7 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error { } // SaveUser saves the given user to the database. -func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error { +func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user) if result.Error != nil { return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error) @@ -432,7 +436,7 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u } // SaveGroups saves the given list of groups to the database. -func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error { +func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error { if len(groups) == 0 { return nil } @@ -454,7 +458,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { return nil } -func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { +func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) { accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) if err != nil { return nil, err @@ -466,9 +470,9 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { var accountID string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id"). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("id"). Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", - strings.ToLower(domain), true, PrivateCategory, + strings.ToLower(domain), true, types.PrivateCategory, ).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -481,8 +485,8 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength return accountID, nil } -func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { - var key SetupKey +func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) { + var key types.SetupKey result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -500,7 +504,7 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (* } func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) { - var token PersonalAccessToken + var token types.PersonalAccessToken result := s.db.First(&token, "hashed_token = ?", hashedToken) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -513,8 +517,8 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri return token.ID, nil } -func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) { - var token PersonalAccessToken +func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) { + var token types.PersonalAccessToken result := s.db.First(&token, idQueryCondition, tokenID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -528,13 +532,13 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - var user User + var user types.User result = s.db.Preload("PATsG").First(&user, idQueryCondition, token.UserID) if result.Error != nil { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - user.PATs = make(map[string]*PersonalAccessToken, len(user.PATsG)) + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATsG)) for _, pat := range user.PATsG { user.PATs[pat.ID] = pat.Copy() } @@ -542,8 +546,8 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return &user, nil } -func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { - var user User +func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { + var user types.User result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Preload(clause.Associations).First(&user, idQueryCondition, userID) if result.Error != nil { @@ -556,8 +560,8 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } -func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { - var users []*User +func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) { + var users []*types.User result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -570,8 +574,8 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre return users, nil } -func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) { - var groups []*nbgroup.Group +func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) { + var groups []*types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -584,8 +588,27 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr return groups, nil } -func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { - var accounts []Account +func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) { + var groups []*types.Group + + likePattern := `%"ID":"` + resourceID + `"%` + + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("resources LIKE ?", likePattern). + Find(&groups) + + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, nil + } + return nil, result.Error + } + + return groups, nil +} + +func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) { + var accounts []types.Account result := s.db.Find(&accounts) if result.Error != nil { return all @@ -600,7 +623,7 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { return all } -func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) { +func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -609,7 +632,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } }() - var account Account + var account types.Account result := s.db.Model(&account). Preload("UsersG.PATsG"). // have to be specifies as this is nester reference Preload(clause.Associations). @@ -624,15 +647,15 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us for i, policy := range account.Policies { - var rules []*PolicyRule - err := s.db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + var rules []*types.PolicyRule + err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error if err != nil { return nil, status.Errorf(status.NotFound, "rule not found") } account.Policies[i].Rules = rules } - account.SetupKeys = make(map[string]*SetupKey, len(account.SetupKeysG)) + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for _, key := range account.SetupKeysG { account.SetupKeys[key.Key] = key.Copy() } @@ -644,9 +667,9 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } account.PeersG = nil - account.Users = make(map[string]*User, len(account.UsersG)) + account.Users = make(map[string]*types.User, len(account.UsersG)) for _, user := range account.UsersG { - user.PATs = make(map[string]*PersonalAccessToken, len(user.PATs)) + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) for _, pat := range user.PATsG { user.PATs[pat.ID] = pat.Copy() } @@ -654,7 +677,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, } account.UsersG = nil - account.Groups = make(map[string]*nbgroup.Group, len(account.GroupsG)) + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) for _, group := range account.GroupsG { account.Groups[group.ID] = group.Copy() } @@ -675,8 +698,8 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, return &account, nil } -func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { - var user User +func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { + var user types.User result := s.db.Select("account_id").First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -692,7 +715,7 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun return s.GetAccount(ctx, user.AccountID) } -func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { +func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) { var peer nbpeer.Peer result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) if result.Error != nil { @@ -709,7 +732,7 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco return s.GetAccount(ctx, peer.AccountID) } -func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { +func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) { var peer nbpeer.Peer result := s.db.Select("account_id").First(&peer, keyQueryCondition, peerKey) if result.Error != nil { @@ -742,7 +765,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { var accountID string - result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) + result := s.db.Model(&types.User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -755,7 +778,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { var accountID string - result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) + result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.NewSetupKeyNotFoundError(setupKey) @@ -815,9 +838,9 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock return labels, nil } -func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { - var accountNetwork AccountNetwork - if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { +func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { + var accountNetwork types.AccountNetwork + if err := s.db.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } @@ -839,9 +862,9 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking return &peer, nil } -func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) { - var accountSettings AccountSettings - if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { +func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { + var accountSettings types.AccountSettings + if err := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } @@ -852,7 +875,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { - var user User + var user types.User result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -890,7 +913,7 @@ func (s *SqlStore) Close(_ context.Context) error { } // GetStoreEngine returns underlying store engine -func (s *SqlStore) GetStoreEngine() StoreEngine { +func (s *SqlStore) GetStoreEngine() Engine { return s.storeEngine } @@ -982,8 +1005,8 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, return store, nil } -func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { - var setupKey SetupKey +func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { + var setupKey types.SetupKey result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&setupKey, keyQueryCondition, key) if result.Error != nil { @@ -997,7 +1020,7 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking } func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - result := s.db.Model(&SetupKey{}). + result := s.db.Model(&types.SetupKey{}). Where(idQueryCondition, setupKeyID). Updates(map[string]interface{}{ "used_times": gorm.Expr("used_times + 1"), @@ -1015,8 +1038,9 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string return nil } +// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { - var group nbgroup.Group + var group types.Group result := s.db.Where("account_id = ? AND name = ?", accountID, "All").First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1040,8 +1064,9 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer return nil } +// AddPeerToGroup adds a peer to a group. Method always needs to run in a transaction func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { - var group nbgroup.Group + var group types.Group result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1066,6 +1091,59 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } +// AddResourceToGroup adds a resource to a group. Method always needs to run n a transaction +func (s *SqlStore) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error { + var group types.Group + result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.NewGroupNotFoundError(groupID) + } + + return status.Errorf(status.Internal, "issue finding group: %s", result.Error) + } + + for _, res := range group.Resources { + if res.ID == resource.ID { + return nil + } + } + + group.Resources = append(group.Resources, *resource) + + if err := s.db.Save(&group).Error; err != nil { + return status.Errorf(status.Internal, "issue updating group: %s", err) + } + + return nil +} + +// RemoveResourceFromGroup removes a resource from a group. Method always needs to run in a transaction +func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error { + var group types.Group + result := s.db.Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.NewGroupNotFoundError(groupID) + } + + return status.Errorf(status.Internal, "issue finding group: %s", result.Error) + } + + for i, res := range group.Resources { + if res.ID == resourceID { + group.Resources = append(group.Resources[:i], group.Resources[i+1:]...) + break + } + } + + if err := s.db.Save(&group).Error; err != nil { + return status.Errorf(status.Internal, "issue updating group: %s", err) + } + + return nil +} + // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { return getRecords[*nbpeer.Peer](s.db.Where("user_id = ?", userID), lockStrength, accountID) @@ -1114,7 +1192,7 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) return status.Errorf(status.Internal, "failed to increment network serial count in store") @@ -1156,9 +1234,9 @@ func (s *SqlStore) GetDB() *gorm.DB { return s.db } -func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) { - var accountDNSSettings AccountDNSSettings - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). +func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) { + var accountDNSSettings types.AccountDNSSettings + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). First(&accountDNSSettings, idQueryCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1173,7 +1251,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki // AccountExists checks whether an account exists by the given ID. func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { var accountID string - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). Select("id").First(&accountID, idQueryCondition, id) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1187,8 +1265,8 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { - var account Account - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category"). + var account types.Account + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).Select("domain", "domain_category"). Where(idQueryCondition, accountID).First(&account) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1201,8 +1279,8 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength } // GetGroupByID retrieves a group by ID and account ID. -func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) { - var group *nbgroup.Group +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) { + var group *types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&group, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { @@ -1216,8 +1294,8 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt } // GetGroupByName retrieves a group by name and account ID. -func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) { - var group nbgroup.Group +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) { + var group types.Group // TODO: This fix is accepted for now, but if we need to handle this more frequently // we may need to reconsider changing the types. @@ -1240,15 +1318,15 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren } // GetGroupsByIDs retrieves groups by their IDs and account ID. -func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) { - var groups []*nbgroup.Group +func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) { + var groups []*types.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") } - groupsMap := make(map[string]*nbgroup.Group) + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group } @@ -1257,7 +1335,7 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren } // SaveGroup saves a group to the store. -func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error { +func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(group) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save group to store: %v", result.Error) @@ -1269,7 +1347,7 @@ func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, // DeleteGroup deletes a group from the database. func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID) + Delete(&types.Group{}, accountAndIDQueryCondition, accountID, groupID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete group from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete group from store") @@ -1285,7 +1363,7 @@ func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength // DeleteGroups deletes groups from the database. func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { result := s.db.Clauses(clause.Locking{Strength: string(strength)}). - Delete(&nbgroup.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) + Delete(&types.Group{}, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete groups from store: %v", result.Error) return status.Errorf(status.Internal, "failed to delete groups from store") @@ -1295,8 +1373,8 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a } // GetAccountPolicies retrieves policies for an account. -func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - var policies []*Policy +func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) { + var policies []*types.Policy result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) if err := result.Error; err != nil { @@ -1308,8 +1386,8 @@ func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingS } // GetPolicyByID retrieves a policy by its ID and account ID. -func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) { - var policy *Policy +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) { + var policy *types.Policy result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). First(&policy, accountAndIDQueryCondition, accountID, policyID) if err := result.Error; err != nil { @@ -1323,7 +1401,7 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng return policy, nil } -func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { +func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(policy) if result.Error != nil { log.WithContext(ctx).Errorf("failed to create policy in store: %s", result.Error) @@ -1334,7 +1412,7 @@ func (s *SqlStore) CreatePolicy(ctx context.Context, lockStrength LockingStrengt } // SavePolicy saves a policy to the database. -func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { +func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error { result := s.db.Session(&gorm.Session{FullSaveAssociations: true}). Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy) if err := result.Error; err != nil { @@ -1346,7 +1424,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID) + Delete(&types.Policy{}, accountAndIDQueryCondition, accountID, policyID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) return status.Errorf(status.Internal, "failed to delete policy from store") @@ -1442,8 +1520,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt } // GetAccountSetupKeys retrieves setup keys for an account. -func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { - var setupKeys []*SetupKey +func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) { + var setupKeys []*types.SetupKey result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). Find(&setupKeys, accountIDCondition, accountID) if err := result.Error; err != nil { @@ -1455,8 +1533,8 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking } // GetSetupKeyByID retrieves a setup key by its ID and account ID. -func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) { - var setupKey *SetupKey +func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) { + var setupKey *types.SetupKey result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID) if err := result.Error; err != nil { @@ -1471,7 +1549,7 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre } // SaveSetupKey saves a setup key to the database. -func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error { +func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save setup key to store: %s", result.Error) @@ -1483,7 +1561,7 @@ func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrengt // DeleteSetupKey deletes a setup key from the database. func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&SetupKey{}, accountAndIDQueryCondition, accountID, keyID) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&types.SetupKey{}, accountAndIDQueryCondition, accountID, keyID) if result.Error != nil { log.WithContext(ctx).Errorf("failed to delete setup key from store: %s", result.Error) return status.Errorf(status.Internal, "failed to delete setup key from store") @@ -1583,9 +1661,9 @@ func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, a } // SaveDNSSettings saves the DNS settings to the store. -func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). - Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings}) +func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). + Where(idQueryCondition, accountID).Updates(&types.AccountDNSSettings{DNSSettings: *settings}) if result.Error != nil { log.WithContext(ctx).Errorf("failed to save dns settings to store: %v", result.Error) return status.Errorf(status.Internal, "failed to save dns settings to store") @@ -1597,3 +1675,198 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre return nil } + +func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) { + var networks []*networkTypes.Network + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&networks, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get networks from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get networks from store") + } + + return networks, nil +} + +func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) { + var network *networkTypes.Network + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&network, accountAndIDQueryCondition, accountID, networkID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkNotFoundError(networkID) + } + + log.WithContext(ctx).Errorf("failed to get network from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network from store") + } + + return network, nil +} + +func (s *SqlStore) SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(network) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save network to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save network to store") + } + + return nil +} + +func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&networkTypes.Network{}, accountAndIDQueryCondition, accountID, networkID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete network from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete network from store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkNotFoundError(networkID) + } + + return nil +} + +func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) { + var netRouters []*routerTypes.NetworkRouter + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network routers from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network routers from store") + } + + return netRouters, nil +} + +func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) { + var netRouters []*routerTypes.NetworkRouter + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netRouters, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network routers from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network routers from store") + } + + return netRouters, nil +} + +func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) { + var netRouter *routerTypes.NetworkRouter + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&netRouter, accountAndIDQueryCondition, accountID, routerID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkRouterNotFoundError(routerID) + } + log.WithContext(ctx).Errorf("failed to get network router from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network router from store") + } + + return netRouter, nil +} + +func (s *SqlStore) SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(router) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save network router to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save network router to store") + } + + return nil +} + +func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&routerTypes.NetworkRouter{}, accountAndIDQueryCondition, accountID, routerID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete network router from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete network router from store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkRouterNotFoundError(routerID) + } + + return nil +} + +func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) { + var netResources []*resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network resources from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resources from store") + } + + return netResources, nil +} + +func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) { + var netResources []*resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&netResources, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get network resources from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resources from store") + } + + return netResources, nil +} + +func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) { + var netResources *resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&netResources, accountAndIDQueryCondition, accountID, resourceID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkResourceNotFoundError(resourceID) + } + log.WithContext(ctx).Errorf("failed to get network resource from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resource from store") + } + + return netResources, nil +} + +func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) { + var netResources *resourceTypes.NetworkResource + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&netResources, "account_id = ? AND name = ?", accountID, resourceName) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewNetworkResourceNotFoundError(resourceName) + } + log.WithContext(ctx).Errorf("failed to get network resource from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get network resource from store") + } + + return netResources, nil +} + +func (s *SqlStore) SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(resource) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save network resource to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save network resource to store") + } + + return nil +} + +func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&resourceTypes.NetworkResource{}, accountAndIDQueryCondition, accountID, resourceID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete network resource from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete network resource from store") + } + + if result.RowsAffected == 0 { + return status.NewNetworkResourceNotFoundError(resourceID) + } + + return nil +} diff --git a/management/server/sql_store_test.go b/management/server/store/sql_store_test.go similarity index 75% rename from management/server/sql_store_test.go rename to management/server/store/sql_store_test.go index 6064b019f..845bc8fd4 100644 --- a/management/server/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -14,17 +14,24 @@ import ( "time" "github.com/google/uuid" - nbdns "github.com/netbirdio/netbird/dns" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/netbirdio/netbird/management/server/posture" + "github.com/rs/xid" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + route2 "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/management/server/status" nbpeer "github.com/netbirdio/netbird/management/server/peer" + nbroute "github.com/netbirdio/netbird/route" ) func TestSqlite_NewStore(t *testing.T) { @@ -73,7 +80,7 @@ func runLargeTest(t *testing.T, store Store) { if err != nil { t.Fatal(err) } - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey const numPerAccount = 6000 for n := 0; n < numPerAccount; n++ { @@ -86,14 +93,14 @@ func runLargeTest(t *testing.T, store Store) { IP: netIP, Name: peerID, DNSLabel: peerID, - UserID: userID, + UserID: "testuser", Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, SSHEnabled: false, } account.Peers[peerID] = peer group, _ := account.GetGroupAll() group.Peers = append(group.Peers, peerID) - user := &User{ + user := &types.User{ Id: fmt.Sprintf("%s-user-%d", account.Id, n), AccountID: account.Id, } @@ -111,7 +118,7 @@ func runLargeTest(t *testing.T, store Store) { } account.Routes[route.ID] = route - group = &nbgroup.Group{ + group = &types.Group{ ID: fmt.Sprintf("group-id-%d", n), AccountID: account.Id, Name: fmt.Sprintf("group-id-%d", n), @@ -134,7 +141,7 @@ func runLargeTest(t *testing.T, store Store) { } account.NameServerGroups[nameserver.ID] = nameserver - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey } @@ -216,7 +223,7 @@ func TestSqlite_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -230,7 +237,7 @@ func TestSqlite_SaveAccount(t *testing.T) { require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey, _ = GenerateDefaultSetupKey() + setupKey, _ = types.GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ Key: "peerkey2", @@ -289,14 +296,14 @@ func TestSqlite_DeleteAccount(t *testing.T) { assert.NoError(t, err) testUserID := "testuser" - user := NewAdminUser(testUserID) - user.PATs = map[string]*PersonalAccessToken{"testtoken": { + user := types.NewAdminUser(testUserID) + user.PATs = map[string]*types.PersonalAccessToken{"testtoken": { ID: "testtoken", Name: "test token", }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -306,6 +313,35 @@ func TestSqlite_DeleteAccount(t *testing.T) { Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, } account.Users[testUserID] = user + account.Networks = []*networkTypes.Network{ + { + ID: "network_id", + AccountID: account.Id, + Name: "network name", + Description: "network description", + }, + } + account.NetworkRouters = []*routerTypes.NetworkRouter{ + { + ID: "router_id", + NetworkID: account.Networks[0].ID, + AccountID: account.Id, + PeerGroups: []string{"group_id"}, + Masquerade: true, + Metric: 1, + }, + } + account.NetworkResources = []*resourceTypes.NetworkResource{ + { + ID: "resource_id", + NetworkID: account.Networks[0].ID, + AccountID: account.Id, + Name: "Name", + Description: "Description", + Type: "Domain", + Address: "example.com", + }, + } err = store.SaveAccount(context.Background(), account) require.NoError(t, err) @@ -337,21 +373,30 @@ func TestSqlite_DeleteAccount(t *testing.T) { require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id") for _, policy := range account.Policies { - var rules []*PolicyRule - err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + var rules []*types.PolicyRule + err = store.(*SqlStore).db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") } for _, accountUser := range account.Users { - var pats []*PersonalAccessToken - err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + var pats []*types.PersonalAccessToken + err = store.(*SqlStore).db.Model(&types.PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") } + for _, network := range account.Networks { + routers, err := store.GetNetworkRoutersByNetID(context.Background(), LockingStrengthShare, account.Id, network.ID) + require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network routers") + require.Len(t, routers, 0, "expecting no network routers to be found after DeleteAccount") + + resources, err := store.GetNetworkResourcesByNetID(context.Background(), LockingStrengthShare, account.Id, network.ID) + require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network resources") + require.Len(t, resources, 0, "expecting no network resources to be found after DeleteAccount") + } } func TestSqlite_GetAccount(t *testing.T) { @@ -360,7 +405,7 @@ func TestSqlite_GetAccount(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -383,7 +428,7 @@ func TestSqlite_SavePeer(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -433,7 +478,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -488,7 +533,7 @@ func TestSqlite_SavePeerLocation(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -542,7 +587,7 @@ func TestSqlite_TestGetAccountByPrivateDomain(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -565,7 +610,7 @@ func TestSqlite_GetTokenIDByHashedToken(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -589,7 +634,7 @@ func TestSqlite_GetUserByTokenID(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -625,7 +670,7 @@ func TestMigrate(t *testing.T) { require.NoError(t, err, "Failed to parse CIDR") type network struct { - Network + types.Network Net net.IPNet `gorm:"serializer:gob"` } @@ -640,7 +685,7 @@ func TestMigrate(t *testing.T) { } type account struct { - Account + types.Account Network *network `gorm:"embedded;embeddedPrefix:network_"` Peers []peer `gorm:"foreignKey:AccountID;references:id"` } @@ -700,23 +745,10 @@ func TestMigrate(t *testing.T) { } -func newSqliteStore(t *testing.T) *SqlStore { - t.Helper() - - store, err := NewSqliteStore(context.Background(), t.TempDir(), nil) - t.Cleanup(func() { - store.Close(context.Background()) - }) - require.NoError(t, err) - require.NotNil(t, store) - - return store -} - func newAccount(store Store, id int) error { str := fmt.Sprintf("%s-%d", uuid.New().String(), id) account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["p"+str] = &nbpeer.Peer{ Key: "peerkey" + str, @@ -755,7 +787,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { assert.NoError(t, err) account := newAccountWithId(context.Background(), "account_id", "testuser", "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -769,7 +801,7 @@ func TestPostgresql_SaveAccount(t *testing.T) { require.NoError(t, err) account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey, _ = GenerateDefaultSetupKey() + setupKey, _ = types.GenerateDefaultSetupKey() account2.SetupKeys[setupKey.Key] = setupKey account2.Peers["testpeer2"] = &nbpeer.Peer{ Key: "peerkey2", @@ -828,14 +860,14 @@ func TestPostgresql_DeleteAccount(t *testing.T) { assert.NoError(t, err) testUserID := "testuser" - user := NewAdminUser(testUserID) - user.PATs = map[string]*PersonalAccessToken{"testtoken": { + user := types.NewAdminUser(testUserID) + user.PATs = map[string]*types.PersonalAccessToken{"testtoken": { ID: "testtoken", Name: "test token", }} account := newAccountWithId(context.Background(), "account_id", testUserID, "") - setupKey, _ := GenerateDefaultSetupKey() + setupKey, _ := types.GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ Key: "peerkey", @@ -876,16 +908,16 @@ func TestPostgresql_DeleteAccount(t *testing.T) { require.Error(t, err, "expecting error after removing DeleteAccount when getting account by id") for _, policy := range account.Policies { - var rules []*PolicyRule - err = store.(*SqlStore).db.Model(&PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + var rules []*types.PolicyRule + err = store.(*SqlStore).db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for policy rules") require.Len(t, rules, 0, "expecting no policy rules to be found after removing DeleteAccount") } for _, accountUser := range account.Users { - var pats []*PersonalAccessToken - err = store.(*SqlStore).db.Model(&PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error + var pats []*types.PersonalAccessToken + err = store.(*SqlStore).db.Model(&types.PersonalAccessToken{}).Find(&pats, "user_id = ?", accountUser.Id).Error require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for personal access token") require.Len(t, pats, 0, "expecting no personal access token to be found after removing DeleteAccount") @@ -899,7 +931,7 @@ func TestPostgresql_SavePeerStatus(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -940,7 +972,7 @@ func TestPostgresql_TestGetAccountByPrivateDomain(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -960,7 +992,7 @@ func TestPostgresql_GetTokenIDByHashedToken(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -978,7 +1010,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { } t.Setenv("NETBIRD_STORE_ENGINE", string(PostgresStoreEngine)) - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -991,7 +1023,7 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { func TestSqlite_GetTakenIPs(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) defer cleanup() if err != nil { t.Fatal(err) @@ -1036,7 +1068,7 @@ func TestSqlite_GetTakenIPs(t *testing.T) { func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) if err != nil { return } @@ -1078,7 +1110,7 @@ func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { func TestSqlite_GetAccountNetwork(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1101,7 +1133,7 @@ func TestSqlite_GetAccountNetwork(t *testing.T) { func TestSqlite_GetSetupKeyBySecret(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1119,14 +1151,14 @@ func TestSqlite_GetSetupKeyBySecret(t *testing.T) { setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, encodedHashedKey) require.NoError(t, err) assert.Equal(t, encodedHashedKey, setupKey.Key) - assert.Equal(t, hiddenKey(plainKey, 4), setupKey.KeySecret) + assert.Equal(t, types.HiddenKey(plainKey, 4), setupKey.KeySecret) assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID) assert.Equal(t, "Default key", setupKey.Name) } func TestSqlite_incrementSetupKeyUsage(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1162,13 +1194,13 @@ func TestSqlite_incrementSetupKeyUsage(t *testing.T) { func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) } - group := &nbgroup.Group{ + group := &types.Group{ ID: "group-id", AccountID: "account-id", Name: "group-name", @@ -1193,7 +1225,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { } func TestSqlite_GetAccoundUsers(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1207,7 +1239,7 @@ func TestSqlite_GetAccoundUsers(t *testing.T) { } func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1253,7 +1285,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { } func TestSqlite_GetGroupByName(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) if err != nil { t.Fatal(err) @@ -1267,7 +1299,7 @@ func TestSqlite_GetGroupByName(t *testing.T) { func Test_DeleteSetupKeySuccessfully(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1283,7 +1315,7 @@ func Test_DeleteSetupKeySuccessfully(t *testing.T) { func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1295,7 +1327,7 @@ func Test_DeleteSetupKeyFailsForNonExistingKey(t *testing.T) { } func TestSqlStore_GetGroupsByIDs(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1338,13 +1370,13 @@ func TestSqlStore_GetGroupsByIDs(t *testing.T) { } func TestSqlStore_SaveGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - group := &nbgroup.Group{ + group := &types.Group{ ID: "group-id", AccountID: accountID, Issued: "api", @@ -1359,13 +1391,13 @@ func TestSqlStore_SaveGroup(t *testing.T) { } func TestSqlStore_SaveGroups(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - groups := []*nbgroup.Group{ + groups := []*types.Group{ { ID: "group-1", AccountID: accountID, @@ -1384,7 +1416,7 @@ func TestSqlStore_SaveGroups(t *testing.T) { } func TestSqlStore_DeleteGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1432,7 +1464,7 @@ func TestSqlStore_DeleteGroup(t *testing.T) { } func TestSqlStore_DeleteGroups(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1479,7 +1511,7 @@ func TestSqlStore_DeleteGroups(t *testing.T) { } func TestSqlStore_GetPeerByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1525,7 +1557,7 @@ func TestSqlStore_GetPeerByID(t *testing.T) { } func TestSqlStore_GetPeersByIDs(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1567,7 +1599,7 @@ func TestSqlStore_GetPeersByIDs(t *testing.T) { } func TestSqlStore_GetPostureChecksByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1613,7 +1645,7 @@ func TestSqlStore_GetPostureChecksByID(t *testing.T) { } func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1656,7 +1688,7 @@ func TestSqlStore_GetPostureChecksByIDs(t *testing.T) { } func TestSqlStore_SavePostureChecks(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1697,7 +1729,7 @@ func TestSqlStore_SavePostureChecks(t *testing.T) { } func TestSqlStore_DeletePostureChecks(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1744,7 +1776,7 @@ func TestSqlStore_DeletePostureChecks(t *testing.T) { } func TestSqlStore_GetPolicyByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1790,23 +1822,23 @@ func TestSqlStore_GetPolicyByID(t *testing.T) { } func TestSqlStore_CreatePolicy(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" - policy := &Policy{ + policy := &types.Policy{ ID: "policy-id", AccountID: accountID, Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupC"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -1820,7 +1852,7 @@ func TestSqlStore_CreatePolicy(t *testing.T) { } func TestSqlStore_SavePolicy(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1843,7 +1875,7 @@ func TestSqlStore_SavePolicy(t *testing.T) { } func TestSqlStore_DeletePolicy(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1859,7 +1891,7 @@ func TestSqlStore_DeletePolicy(t *testing.T) { } func TestSqlStore_GetDNSSettings(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1903,7 +1935,7 @@ func TestSqlStore_GetDNSSettings(t *testing.T) { } func TestSqlStore_SaveDNSSettings(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1922,7 +1954,7 @@ func TestSqlStore_SaveDNSSettings(t *testing.T) { } func TestSqlStore_GetAccountNameServerGroups(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -1959,7 +1991,7 @@ func TestSqlStore_GetAccountNameServerGroups(t *testing.T) { } func TestSqlStore_GetNameServerByID(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2005,7 +2037,7 @@ func TestSqlStore_GetNameServerByID(t *testing.T) { } func TestSqlStore_SaveNameServerGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2037,7 +2069,7 @@ func TestSqlStore_SaveNameServerGroup(t *testing.T) { } func TestSqlStore_DeleteNameServerGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/extended-store.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2051,3 +2083,481 @@ func TestSqlStore_DeleteNameServerGroup(t *testing.T) { require.Error(t, err) require.Nil(t, nsGroup) } + +// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id +func newAccountWithId(ctx context.Context, accountID, userID, domain string) *types.Account { + log.WithContext(ctx).Debugf("creating new account") + + network := types.NewNetwork() + peers := make(map[string]*nbpeer.Peer) + users := make(map[string]*types.User) + routes := make(map[nbroute.ID]*nbroute.Route) + setupKeys := map[string]*types.SetupKey{} + nameServersGroups := make(map[string]*nbdns.NameServerGroup) + + owner := types.NewOwnerUser(userID) + owner.AccountID = accountID + users[userID] = owner + + dnsSettings := types.DNSSettings{ + DisabledManagementGroups: make([]string, 0), + } + log.WithContext(ctx).Debugf("created new account %s", accountID) + + acc := &types.Account{ + Id: accountID, + CreatedAt: time.Now().UTC(), + SetupKeys: setupKeys, + Network: network, + Peers: peers, + Users: users, + CreatedBy: userID, + Domain: domain, + Routes: routes, + NameServerGroups: nameServersGroups, + DNSSettings: dnsSettings, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: types.DefaultPeerLoginExpiration, + GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, + + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, + }, + } + + if err := addAllGroup(acc); err != nil { + log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) + } + return acc +} + +// addAllGroup to account object if it doesn't exist +func addAllGroup(account *types.Account) error { + if len(account.Groups) == 0 { + allGroup := &types.Group{ + ID: xid.New().String(), + Name: "All", + Issued: types.GroupIssuedAPI, + } + for _, peer := range account.Peers { + allGroup.Peers = append(allGroup.Peers, peer.ID) + } + account.Groups = map[string]*types.Group{allGroup.ID: allGroup} + + id := xid.New().String() + + defaultPolicy := &types.Policy{ + ID: id, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: id, + Name: types.DefaultRuleName, + Description: types.DefaultRuleDescription, + Enabled: true, + Sources: []string{allGroup.ID}, + Destinations: []string{allGroup.ID}, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + }, + }, + } + + account.Policies = []*types.Policy{defaultPolicy} + } + return nil +} + +func TestSqlStore_GetAccountNetworks(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectedCount int + }{ + { + name: "retrieve networks by existing account ID", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + }, + + { + name: "retrieve networks by non-existing account ID", + accountID: "non-existent", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + networks, err := store.GetAccountNetworks(context.Background(), LockingStrengthShare, tt.accountID) + require.NoError(t, err) + require.Len(t, networks, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetNetworkByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + networkID string + expectError bool + }{ + { + name: "retrieve existing network ID", + networkID: "ct286bi7qv930dsrrug0", + expectError: false, + }, + { + name: "retrieve non-existing network ID", + networkID: "non-existing", + expectError: true, + }, + { + name: "retrieve network with empty ID", + networkID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + network, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, network) + } else { + require.NoError(t, err) + require.NotNil(t, network) + require.Equal(t, tt.networkID, network.ID) + } + }) + } +} + +func TestSqlStore_SaveNetwork(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + network := &networkTypes.Network{ + ID: "net-id", + AccountID: accountID, + Name: "net", + } + + err = store.SaveNetwork(context.Background(), LockingStrengthUpdate, network) + require.NoError(t, err) + + savedNet, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, network.ID) + require.NoError(t, err) + require.Equal(t, network, savedNet) +} + +func TestSqlStore_DeleteNetwork(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + + err = store.DeleteNetwork(context.Background(), LockingStrengthUpdate, accountID, networkID) + require.NoError(t, err) + + network, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, networkID) + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) + require.Nil(t, network) +} + +func TestSqlStore_GetNetworkRoutersByNetID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + networkID string + expectedCount int + }{ + { + name: "retrieve routers by existing network ID", + networkID: "ct286bi7qv930dsrrug0", + expectedCount: 1, + }, + { + name: "retrieve routers by non-existing network ID", + networkID: "non-existent", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + routers, err := store.GetNetworkRoutersByNetID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + require.NoError(t, err) + require.Len(t, routers, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetNetworkRouterByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + networkRouterID string + expectError bool + }{ + { + name: "retrieve existing network router ID", + networkRouterID: "ctc20ji7qv9ck2sebc80", + expectError: false, + }, + { + name: "retrieve non-existing network router ID", + networkRouterID: "non-existing", + expectError: true, + }, + { + name: "retrieve network with empty router ID", + networkRouterID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + networkRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthShare, accountID, tt.networkRouterID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, networkRouter) + } else { + require.NoError(t, err) + require.NotNil(t, networkRouter) + require.Equal(t, tt.networkRouterID, networkRouter.ID) + } + }) + } +} + +func TestSqlStore_SaveNetworkRouter(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + + netRouter, err := routerTypes.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0) + require.NoError(t, err) + + err = store.SaveNetworkRouter(context.Background(), LockingStrengthUpdate, netRouter) + require.NoError(t, err) + + savedNetRouter, err := store.GetNetworkRouterByID(context.Background(), LockingStrengthShare, accountID, netRouter.ID) + require.NoError(t, err) + require.Equal(t, netRouter, savedNetRouter) +} + +func TestSqlStore_DeleteNetworkRouter(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + netRouterID := "ctc20ji7qv9ck2sebc80" + + err = store.DeleteNetworkRouter(context.Background(), LockingStrengthUpdate, accountID, netRouterID) + require.NoError(t, err) + + netRouter, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, netRouterID) + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) + require.Nil(t, netRouter) +} + +func TestSqlStore_GetNetworkResourcesByNetID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + tests := []struct { + name string + networkID string + expectedCount int + }{ + { + name: "retrieve resources by existing network ID", + networkID: "ct286bi7qv930dsrrug0", + expectedCount: 1, + }, + { + name: "retrieve resources by non-existing network ID", + networkID: "non-existent", + expectedCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + netResources, err := store.GetNetworkResourcesByNetID(context.Background(), LockingStrengthShare, accountID, tt.networkID) + require.NoError(t, err) + require.Len(t, netResources, tt.expectedCount) + }) + } +} + +func TestSqlStore_GetNetworkResourceByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + tests := []struct { + name string + netResourceID string + expectError bool + }{ + { + name: "retrieve existing network resource ID", + netResourceID: "ctc4nci7qv9061u6ilfg", + expectError: false, + }, + { + name: "retrieve non-existing network resource ID", + netResourceID: "non-existing", + expectError: true, + }, + { + name: "retrieve network with empty resource ID", + netResourceID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + netResource, err := store.GetNetworkResourceByID(context.Background(), LockingStrengthShare, accountID, tt.netResourceID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, netResource) + } else { + require.NoError(t, err) + require.NotNil(t, netResource) + require.Equal(t, tt.netResourceID, netResource.ID) + } + }) + } +} + +func TestSqlStore_SaveNetworkResource(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + networkID := "ct286bi7qv930dsrrug0" + + netResource, err := resourceTypes.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com", []string{}) + require.NoError(t, err) + + err = store.SaveNetworkResource(context.Background(), LockingStrengthUpdate, netResource) + require.NoError(t, err) + + savedNetResource, err := store.GetNetworkResourceByID(context.Background(), LockingStrengthShare, accountID, netResource.ID) + require.NoError(t, err) + require.Equal(t, netResource.ID, savedNetResource.ID) + require.Equal(t, netResource.Name, savedNetResource.Name) + require.Equal(t, netResource.NetworkID, savedNetResource.NetworkID) + require.Equal(t, netResource.Type, resourceTypes.NetworkResourceType("domain")) + require.Equal(t, netResource.Domain, "example.com") + require.Equal(t, netResource.AccountID, savedNetResource.AccountID) + require.Equal(t, netResource.Prefix, netip.Prefix{}) +} + +func TestSqlStore_DeleteNetworkResource(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + netResourceID := "ctc4nci7qv9061u6ilfg" + + err = store.DeleteNetworkResource(context.Background(), LockingStrengthUpdate, accountID, netResourceID) + require.NoError(t, err) + + netResource, err := store.GetNetworkByID(context.Background(), LockingStrengthShare, accountID, netResourceID) + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) + require.Nil(t, netResource) +} + +func TestSqlStore_AddAndRemoveResourceFromGroup(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + require.NoError(t, err) + t.Cleanup(cleanup) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + resourceId := "ctc4nci7qv9061u6ilfg" + groupID := "cs1tnh0hhcjnqoiuebeg" + + res := &types.Resource{ + ID: resourceId, + Type: "host", + } + err = store.AddResourceToGroup(context.Background(), accountID, groupID, res) + require.NoError(t, err) + + group, err := store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err) + require.Contains(t, group.Resources, *res) + + groups, err := store.GetResourceGroups(context.Background(), LockingStrengthShare, accountID, resourceId) + require.NoError(t, err) + require.Len(t, groups, 1) + + err = store.RemoveResourceFromGroup(context.Background(), accountID, groupID, res.ID) + require.NoError(t, err) + + group, err = store.GetGroupByID(context.Background(), LockingStrengthShare, accountID, groupID) + require.NoError(t, err) + require.NotContains(t, group.Resources, *res) +} diff --git a/management/server/store.go b/management/server/store/store.go similarity index 73% rename from management/server/store.go rename to management/server/store/store.go index b16ad8a1a..d9dc6b8f7 100644 --- a/management/server/store.go +++ b/management/server/store/store.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -18,13 +18,15 @@ import ( "gorm.io/gorm" "github.com/netbirdio/netbird/dns" - - nbgroup "github.com/netbirdio/netbird/management/server/group" + "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/management/server/migration" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/testutil" @@ -41,49 +43,50 @@ const ( ) type Store interface { - GetAllAccounts(ctx context.Context) []*Account - GetAccount(ctx context.Context, accountID string) (*Account, error) + GetAllAccounts(ctx context.Context) []*types.Account + GetAccount(ctx context.Context, accountID string) (*types.Account, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) - GetAccountByUser(ctx context.Context, userID string) (*Account, error) - GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) + GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) + GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByUserID(userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) - GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) - GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later - GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) + GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) + GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) // todo use key hash later + GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) - GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) - GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) - SaveAccount(ctx context.Context, account *Account) error - DeleteAccount(ctx context.Context, account *Account) error + GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) + GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) + SaveAccount(ctx context.Context, account *types.Account) error + DeleteAccount(ctx context.Context, account *types.Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error - SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error + SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error - GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) - GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) - SaveUsers(accountID string, users map[string]*User) error - SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error + GetUserByTokenID(ctx context.Context, tokenID string) (*types.User, error) + GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) + GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) + SaveUsers(accountID string, users map[string]*types.User) error + SaveUser(ctx context.Context, lockStrength LockingStrength, user *types.User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) DeleteHashedPAT2TokenIDIndex(hashedToken string) error DeleteTokenID2UserIDIndex(tokenID string) error - GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) - GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) - GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) - GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) - SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error - SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error + GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) + GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error) + GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) + SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*types.Group) error + SaveGroup(ctx context.Context, lockStrength LockingStrength, group *types.Group) error DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error - GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) - GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) - CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error - SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) + CreatePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error + SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *types.Policy) error DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) @@ -96,6 +99,8 @@ type Store interface { GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error + AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error + RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) @@ -105,11 +110,11 @@ type Store interface { SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error - GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) + GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error - GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) - GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) - SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error + GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) + GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) + SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *types.SetupKey) error DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) @@ -122,7 +127,7 @@ type Store interface { GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error - GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) + GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*types.Network, error) GetInstallationID() string SaveInstallationID(ctx context.Context, ID string) error @@ -136,30 +141,48 @@ type Store interface { // Close should close the store persisting all unsaved data. Close(ctx context.Context) error - // GetStoreEngine should return StoreEngine of the current store implementation. + // GetStoreEngine should return Engine of the current store implementation. // This is also a method of metrics.DataSource interface. - GetStoreEngine() StoreEngine + GetStoreEngine() Engine ExecuteInTransaction(ctx context.Context, f func(store Store) error) error + + GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) + GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) + SaveNetwork(ctx context.Context, lockStrength LockingStrength, network *networkTypes.Network) error + DeleteNetwork(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) error + + GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) + GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) + GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) + SaveNetworkRouter(ctx context.Context, lockStrength LockingStrength, router *routerTypes.NetworkRouter) error + DeleteNetworkRouter(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) error + + GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*resourceTypes.NetworkResource, error) + GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) + GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) + GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) + SaveNetworkResource(ctx context.Context, lockStrength LockingStrength, resource *resourceTypes.NetworkResource) error + DeleteNetworkResource(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) error } -type StoreEngine string +type Engine string const ( - FileStoreEngine StoreEngine = "jsonfile" - SqliteStoreEngine StoreEngine = "sqlite" - PostgresStoreEngine StoreEngine = "postgres" + FileStoreEngine Engine = "jsonfile" + SqliteStoreEngine Engine = "sqlite" + PostgresStoreEngine Engine = "postgres" postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN" ) -func getStoreEngineFromEnv() StoreEngine { +func getStoreEngineFromEnv() Engine { // NETBIRD_STORE_ENGINE supposed to be used in tests. Otherwise, rely on the config file. kind, ok := os.LookupEnv("NETBIRD_STORE_ENGINE") if !ok { return "" } - value := StoreEngine(strings.ToLower(kind)) + value := Engine(strings.ToLower(kind)) if value == SqliteStoreEngine || value == PostgresStoreEngine { return value } @@ -171,7 +194,7 @@ func getStoreEngineFromEnv() StoreEngine { // If no engine is specified, it attempts to retrieve it from the environment. // If still not specified, it defaults to using SQLite. // Additionally, it handles the migration from a JSON store file to SQLite if applicable. -func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) StoreEngine { +func getStoreEngine(ctx context.Context, dataDir string, kind Engine) Engine { if kind == "" { kind = getStoreEngineFromEnv() if kind == "" { @@ -197,7 +220,7 @@ func getStoreEngine(ctx context.Context, dataDir string, kind StoreEngine) Store } // NewStore creates a new store based on the provided engine type, data directory, and telemetry metrics -func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { +func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetry.AppMetrics) (Store, error) { kind = getStoreEngine(ctx, dataDir, kind) if err := checkFileStoreEngine(kind, dataDir); err != nil { @@ -216,7 +239,7 @@ func NewStore(ctx context.Context, kind StoreEngine, dataDir string, metrics tel } } -func checkFileStoreEngine(kind StoreEngine, dataDir string) error { +func checkFileStoreEngine(kind Engine, dataDir string) error { if kind == FileStoreEngine { storeFile := filepath.Join(dataDir, storeFileName) if util.FileExists(storeFile) { @@ -243,7 +266,7 @@ func migrate(ctx context.Context, db *gorm.DB) error { func getMigrations(ctx context.Context) []migrationFunc { return []migrationFunc{ func(db *gorm.DB) error { - return migration.MigrateFieldFromGobToJSON[Account, net.IPNet](ctx, db, "network_net") + return migration.MigrateFieldFromGobToJSON[types.Account, net.IPNet](ctx, db, "network_net") }, func(db *gorm.DB) error { return migration.MigrateFieldFromGobToJSON[route.Route, netip.Prefix](ctx, db, "network") @@ -258,7 +281,7 @@ func getMigrations(ctx context.Context) []migrationFunc { return migration.MigrateNetIPFieldFromBlobToJSON[nbpeer.Peer](ctx, db, "ip", "idx_peers_account_id_ip") }, func(db *gorm.DB) error { - return migration.MigrateSetupKeyToHashedSetupKey[SetupKey](ctx, db) + return migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](ctx, db) }, } } diff --git a/management/server/store_test.go b/management/server/store/store_test.go similarity index 93% rename from management/server/store_test.go rename to management/server/store/store_test.go index fc821670d..1d0026e3d 100644 --- a/management/server/store_test.go +++ b/management/server/store/store_test.go @@ -1,4 +1,4 @@ -package server +package store import ( "context" @@ -76,11 +76,3 @@ func BenchmarkTest_StoreRead(b *testing.B) { }) } } - -func newStore(t *testing.T) Store { - t.Helper() - - store := newSqliteStore(t) - - return store -} diff --git a/management/server/testdata/networks.sql b/management/server/testdata/networks.sql new file mode 100644 index 000000000..8138ce520 --- /dev/null +++ b/management/server/testdata/networks.sql @@ -0,0 +1,18 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); + +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); + +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + +CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO networks VALUES('testNetworkId','testAccountId','some-name','some-description'); + +CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_routers` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO network_routers VALUES('testRouterId','testNetworkId','testAccountId','','["csquuo4jcko732k1ag00"]',0,9999); + +CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`name` text,`description` text,`type` text,`address` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_network_resources` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +INSERT INTO network_resources VALUES('testResourceId','testNetworkId','testAccountId','some-name','some-description','host','3.3.3.3/32'); +INSERT INTO network_resources VALUES('anotherTestResourceId','testNetworkId','testAccountId','used-name','some-description','host','3.3.3.3/32'); diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 168973cad..7f0c7b5a4 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -12,6 +12,9 @@ CREATE TABLE `installations` (`id` integer,`installation_id_value` text,PRIMARY CREATE TABLE `extra_settings` (`peer_approval_enabled` numeric,`integrated_validator_groups` text); CREATE TABLE `posture_checks` (`id` text,`name` text,`description` text,`account_id` text,`checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_posture_checks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `network_addresses` (`net_ip` text,`mac` text); +CREATE TABLE `network_routers` (`id` text,`network_id` text,`account_id` text,`peer` text,`peer_groups` text,`masquerade` numeric,`metric` integer,PRIMARY KEY (`id`)); +CREATE TABLE `network_resources` (`id` text,`network_id` text,`account_id` text,`type` text,`address` text,PRIMARY KEY (`id`)); +CREATE TABLE `networks` (`id` text,`account_id` text,`name` text,`description` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_networks` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE INDEX `idx_accounts_domain` ON `accounts`(`domain`); CREATE INDEX `idx_setup_keys_account_id` ON `setup_keys`(`account_id`); CREATE INDEX `idx_peers_key` ON `peers`(`key`); @@ -24,6 +27,14 @@ CREATE INDEX `idx_policy_rules_policy_id` ON `policy_rules`(`policy_id`); CREATE INDEX `idx_routes_account_id` ON `routes`(`account_id`); CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`account_id`); CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); +CREATE INDEX `idx_network_routers_id` ON `network_routers`(`id`); +CREATE INDEX `idx_network_routers_account_id` ON `network_routers`(`account_id`); +CREATE INDEX `idx_network_routers_network_id` ON `network_routers`(`network_id`); +CREATE INDEX `idx_network_resources_account_id` ON `network_resources`(`account_id`); +CREATE INDEX `idx_network_resources_network_id` ON `network_resources`(`network_id`); +CREATE INDEX `idx_network_resources_id` ON `network_resources`(`id`); +CREATE INDEX `idx_networks_id` ON `networks`(`id`); +CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); @@ -34,3 +45,6 @@ INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003' INSERT INTO installations VALUES(1,''); INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]'); INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL); +INSERT INTO network_routers VALUES('ctc20ji7qv9ck2sebc80','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','cs1tnh0hhcjnqoiuebeg',NULL,0,0); +INSERT INTO network_resources VALUES ('ctc4nci7qv9061u6ilfg','ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Host','192.168.1.1'); +INSERT INTO networks VALUES('ct286bi7qv930dsrrug0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Test Network','Test Network'); diff --git a/management/server/types/account.go b/management/server/types/account.go new file mode 100644 index 000000000..5d0f12019 --- /dev/null +++ b/management/server/types/account.go @@ -0,0 +1,1475 @@ +package types + +import ( + "context" + "fmt" + "net" + "net/netip" + "slices" + "strconv" + "strings" + "time" + + "github.com/hashicorp/go-multierror" + "github.com/miekg/dns" + log "github.com/sirupsen/logrus" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/domain" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/route" +) + +const ( + defaultTTL = 300 + DefaultPeerLoginExpiration = 24 * time.Hour + DefaultPeerInactivityExpiration = 10 * time.Minute + + PublicCategory = "public" + PrivateCategory = "private" + UnknownCategory = "unknown" +) + +type LookupMap map[string]struct{} + +// Account represents a unique account of the system +type Account struct { + // we have to name column to aid as it collides with Network.Id when work with associations + Id string `gorm:"primaryKey"` + + // User.Id it was created by + CreatedBy string + CreatedAt time.Time + Domain string `gorm:"index"` + DomainCategory string + IsDomainPrimaryAccount bool + SetupKeys map[string]*SetupKey `gorm:"-"` + SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` + Network *Network `gorm:"embedded;embeddedPrefix:network_"` + Peers map[string]*nbpeer.Peer `gorm:"-"` + PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` + Users map[string]*User `gorm:"-"` + UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` + Groups map[string]*Group `gorm:"-"` + GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` + Routes map[route.ID]*route.Route `gorm:"-"` + RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` + NameServerGroups map[string]*nbdns.NameServerGroup `gorm:"-"` + NameServerGroupsG []nbdns.NameServerGroup `json:"-" gorm:"foreignKey:AccountID;references:id"` + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` + PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` + // Settings is a dictionary of Account settings + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` + + Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` + NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` + NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` +} + +// Subclass used in gorm to only load network and not whole account +type AccountNetwork struct { + Network *Network `gorm:"embedded;embeddedPrefix:network_"` +} + +// AccountDNSSettings used in gorm to only load dns settings and not whole account +type AccountDNSSettings struct { + DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` +} + +// Subclass used in gorm to only load settings and not whole account +type AccountSettings struct { + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` +} + +// GetRoutesToSync returns the enabled routes for the peer ID and the routes +// from the ACL peers that have distribution groups associated with the peer ID. +// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. +func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route { + routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) + peerRoutesMembership := make(LookupMap) + for _, r := range append(routes, peerDisabledRoutes...) { + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} + } + + groupListMap := a.GetPeerGroups(peerID) + for _, peer := range aclPeers { + activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) + groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) + filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) + routes = append(routes, filteredRoutes...) + } + + return routes +} + +// filterRoutesFromPeersOfSameHAGroup filters and returns a list of routes that don't share the same HA route membership +func (a *Account) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + _, found := peerMemberships[string(r.GetHAUniqueID())] + if !found { + filteredRoutes = append(filteredRoutes, r) + } + } + return filteredRoutes +} + +// filterRoutesByGroups returns a list with routes that have distribution groups in the group's map +func (a *Account) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route { + var filteredRoutes []*route.Route + for _, r := range routes { + for _, groupID := range r.Groups { + _, found := groupListMap[groupID] + if found { + filteredRoutes = append(filteredRoutes, r) + break + } + } + } + return filteredRoutes +} + +// getRoutingPeerRoutes returns the enabled and disabled lists of routes that the given routing peer serves +// Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. +// If the given is not a routing peer, then the lists are empty. +func (a *Account) getRoutingPeerRoutes(ctx context.Context, peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + + peer := a.GetPeer(peerID) + if peer == nil { + log.WithContext(ctx).Errorf("peer %s that doesn't exist under account %s", peerID, a.Id) + return enabledRoutes, disabledRoutes + } + + // currently we support only linux routing peers + if peer.Meta.GoOS != "linux" { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[route.ID]struct{}) + + takeRoute := func(r *route.Route, id string) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + if r.Enabled { + r.Peer = peer.Key + enabledRoutes = append(enabledRoutes, r) + return + } + disabledRoutes = append(disabledRoutes, r) + } + + for _, r := range a.Routes { + for _, groupID := range r.PeerGroups { + group := a.GetGroup(groupID) + if group == nil { + log.WithContext(ctx).Errorf("route %s has peers group %s that doesn't exist under account %s", r.ID, groupID, a.Id) + continue + } + for _, id := range group.Peers { + if id != peerID { + continue + } + + newPeerRoute := r.Copy() + newPeerRoute.Peer = id + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = route.ID(string(r.ID) + ":" + id) // we have to provide unique route id when distribute network map + takeRoute(newPeerRoute, id) + break + } + } + if r.Peer == peerID { + takeRoute(r.Copy(), peerID) + } + } + + return enabledRoutes, disabledRoutes +} + +// GetRoutesByPrefixOrDomains return list of routes by account and route prefix +func (a *Account) GetRoutesByPrefixOrDomains(prefix netip.Prefix, domains domain.List) []*route.Route { + var routes []*route.Route + for _, r := range a.Routes { + dynamic := r.IsDynamic() + if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || + !dynamic && r.Network.String() == prefix.String() { + routes = append(routes, r) + } + } + + return routes +} + +// GetGroup returns a group by ID if exists, nil otherwise +func (a *Account) GetGroup(groupID string) *Group { + return a.Groups[groupID] +} + +// GetPeerNetworkMap returns the networkmap for the given peer ID. +func (a *Account) GetPeerNetworkMap( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + validatedPeersMap map[string]struct{}, + resourcePolicies map[string][]*Policy, + routers map[string]map[string]*routerTypes.NetworkRouter, + metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + start := time.Now() + + peer := a.Peers[peerID] + if peer == nil { + return &NetworkMap{ + Network: a.Network.Copy(), + } + } + + if _, ok := validatedPeersMap[peerID]; !ok { + return &NetworkMap{ + Network: a.Network.Copy(), + } + } + + aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap) + // exclude expired peers + var peersToConnect []*nbpeer.Peer + var expiredPeers []*nbpeer.Peer + for _, p := range aclPeers { + expired, _ := p.LoginExpired(a.Settings.PeerLoginExpiration) + if a.Settings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, p) + continue + } + peersToConnect = append(peersToConnect, p) + } + + routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect) + routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) + isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers) + var networkResourcesFirewallRules []*RouteFirewallRule + if isRouter { + networkResourcesFirewallRules = a.GetPeerNetworkResourceFirewallRules(ctx, peer, validatedPeersMap, networkResourcesRoutes, resourcePolicies) + } + peersToConnectIncludingRouters := a.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, isRouter, sourcePeers) + + dnsManagementStatus := a.getPeerDNSManagementStatus(peerID) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + var zones []nbdns.CustomZone + + if peersCustomZone.Domain != "" { + zones = append(zones, peersCustomZone) + } + dnsUpdate.CustomZones = zones + dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) + } + + nm := &NetworkMap{ + Peers: peersToConnectIncludingRouters, + Network: a.Network.Copy(), + Routes: slices.Concat(networkResourcesRoutes, routesUpdate), + DNSConfig: dnsUpdate, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules), + } + + if metrics != nil { + objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects, "+ + "peers to connect: %d, expired peers: %d, routes: %d, firewall rules: %d", + a.Id, objectCount, len(peersToConnect), len(expiredPeers), len(routesUpdate), len(firewallRules)) + } + } + + return nm +} + +func (a *Account) addNetworksRoutingPeers(networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers []string) []*nbpeer.Peer { + missingPeers := map[string]struct{}{} + for _, r := range networkResourcesRoutes { + if r.Peer == peer.Key { + continue + } + + missing := true + for _, p := range slices.Concat(peersToConnect, expiredPeers) { + if r.Peer == p.Key { + missing = false + break + } + } + if missing { + missingPeers[r.Peer] = struct{}{} + } + } + + if isRouter { + for _, s := range sourcePeers { + if s == peer.ID { + continue + } + + missing := true + for _, p := range slices.Concat(peersToConnect, expiredPeers) { + if s == p.ID { + missing = false + break + } + } + if missing { + p, ok := a.Peers[s] + if ok { + missingPeers[p.Key] = struct{}{} + } + } + } + } + + for p := range missingPeers { + for _, p2 := range a.Peers { + if p2.Key == p { + peersToConnect = append(peersToConnect, p2) + break + } + } + } + return peersToConnect +} + +func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup { + groupList := account.GetPeerGroups(peerID) + + var peerNSGroups []*nbdns.NameServerGroup + + for _, nsGroup := range account.NameServerGroups { + if !nsGroup.Enabled { + continue + } + for _, gID := range nsGroup.Groups { + _, found := groupList[gID] + if found { + if !peerIsNameserver(account.GetPeer(peerID), nsGroup) { + peerNSGroups = append(peerNSGroups, nsGroup.Copy()) + break + } + } + } + } + + return peerNSGroups +} + +// peerIsNameserver returns true if the peer is a nameserver for a nsGroup +func peerIsNameserver(peer *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool { + for _, ns := range nsGroup.NameServers { + if peer.IP.Equal(ns.IP.AsSlice()) { + return true + } + } + return false +} + +func AddPeerLabelsToAccount(ctx context.Context, account *Account, peerLabels LookupMap) { + for _, peer := range account.Peers { + label, err := GetPeerHostLabel(peer.Name, peerLabels) + if err != nil { + log.WithContext(ctx).Errorf("got an error while generating a peer host label. Peer name %s, error: %v. Trying with the peer's meta hostname", peer.Name, err) + label, err = GetPeerHostLabel(peer.Meta.Hostname, peerLabels) + if err != nil { + log.WithContext(ctx).Errorf("got another error while generating a peer host label with hostname. Peer hostname %s, error: %v. Skipping", peer.Meta.Hostname, err) + continue + } + } + peer.DNSLabel = label + peerLabels[label] = struct{}{} + } +} + +func GetPeerHostLabel(name string, peerLabels LookupMap) (string, error) { + label, err := nbdns.GetParsedDomainLabel(name) + if err != nil { + return "", err + } + + uniqueLabel := getUniqueHostLabel(label, peerLabels) + if uniqueLabel == "" { + return "", fmt.Errorf("couldn't find a unique valid label for %s, parsed label %s", name, label) + } + return uniqueLabel, nil +} + +// getUniqueHostLabel look for a unique host label, and if doesn't find add a suffix up to 999 +func getUniqueHostLabel(name string, peerLabels LookupMap) string { + _, found := peerLabels[name] + if !found { + return name + } + for i := 1; i < 1000; i++ { + nameWithSuffix := name + "-" + strconv.Itoa(i) + _, found = peerLabels[nameWithSuffix] + if !found { + return nameWithSuffix + } + } + return "" +} + +func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone { + var merr *multierror.Error + + if dnsDomain == "" { + log.WithContext(ctx).Error("no dns domain is set, returning empty zone") + return nbdns.CustomZone{} + } + + customZone := nbdns.CustomZone{ + Domain: dns.Fqdn(dnsDomain), + Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)), + } + + domainSuffix := "." + dnsDomain + + var sb strings.Builder + for _, peer := range a.Peers { + if peer.DNSLabel == "" { + merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name)) + continue + } + + sb.Grow(len(peer.DNSLabel) + len(domainSuffix)) + sb.WriteString(peer.DNSLabel) + sb.WriteString(domainSuffix) + + customZone.Records = append(customZone.Records, nbdns.SimpleRecord{ + Name: sb.String(), + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: defaultTTL, + RData: peer.IP.String(), + }) + + sb.Reset() + } + + go func() { + if merr != nil { + log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr) + } + }() + + return customZone +} + +// GetExpiredPeers returns peers that have been expired +func (a *Account) GetExpiredPeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, peer := range a.GetPeersWithExpiration() { + expired, _ := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if expired { + peers = append(peers, peer) + } + } + + return peers +} + +// GetNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are connected. +func (a *Account) GetNextPeerExpiration() (time.Duration, bool) { + peersWithExpiry := a.GetPeersWithExpiration() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + // consider only connected peers because others will require login on connecting to the management server + if peer.Status.LoginExpired || !peer.Status.Connected { + continue + } + _, duration := peer.LoginExpired(a.Settings.PeerLoginExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithExpiration returns a list of peers that have Peer.LoginExpirationEnabled set to true and that were added by a user +func (a *Account) GetPeersWithExpiration() []*nbpeer.Peer { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.LoginExpirationEnabled && peer.AddedWithSSOLogin() { + peers = append(peers, peer) + } + } + return peers +} + +// GetInactivePeers returns peers that have been expired by inactivity +func (a *Account) GetInactivePeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, inactivePeer := range a.GetPeersWithInactivity() { + inactive, _ := inactivePeer.SessionExpired(a.Settings.PeerInactivityExpiration) + if inactive { + peers = append(peers, inactivePeer) + } + } + return peers +} + +// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are not connected. +func (a *Account) GetNextInactivePeerExpiration() (time.Duration, bool) { + peersWithExpiry := a.GetPeersWithInactivity() + if len(peersWithExpiry) == 0 { + return 0, false + } + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + if peer.Status.LoginExpired || peer.Status.Connected { + continue + } + _, duration := peer.SessionExpired(a.Settings.PeerInactivityExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetPeersWithInactivity eturns a list of peers that have Peer.InactivityExpirationEnabled set to true and that were added by a user +func (a *Account) GetPeersWithInactivity() []*nbpeer.Peer { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.InactivityExpirationEnabled && peer.AddedWithSSOLogin() { + peers = append(peers, peer) + } + } + return peers +} + +// GetPeers returns a list of all Account peers +func (a *Account) GetPeers() []*nbpeer.Peer { + var peers []*nbpeer.Peer + for _, peer := range a.Peers { + peers = append(peers, peer) + } + return peers +} + +// UpdateSettings saves new account settings +func (a *Account) UpdateSettings(update *Settings) *Account { + a.Settings = update.Copy() + return a +} + +// UpdatePeer saves new or replaces existing peer +func (a *Account) UpdatePeer(update *nbpeer.Peer) { + a.Peers[update.ID] = update +} + +// DeletePeer deletes peer from the account cleaning up all the references +func (a *Account) DeletePeer(peerID string) { + // delete peer from groups + for _, g := range a.Groups { + for i, pk := range g.Peers { + if pk == peerID { + g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) + break + } + } + } + + for _, r := range a.Routes { + if r.Peer == peerID { + r.Enabled = false + r.Peer = "" + } + } + + for i, r := range a.NetworkRouters { + if r.Peer == peerID { + a.NetworkRouters = append(a.NetworkRouters[:i], a.NetworkRouters[i+1:]...) + break + } + } + + delete(a.Peers, peerID) + a.Network.IncSerial() +} + +func (a *Account) DeleteResource(resourceID string) { + // delete resource from groups + for _, g := range a.Groups { + for i, pk := range g.Resources { + if pk.ID == resourceID { + g.Resources = append(g.Resources[:i], g.Resources[i+1:]...) + break + } + } + } +} + +// FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found. +// It will return an object copy of the peer. +func (a *Account) FindPeerByPubKey(peerPubKey string) (*nbpeer.Peer, error) { + for _, peer := range a.Peers { + if peer.Key == peerPubKey { + return peer.Copy(), nil + } + } + + return nil, status.Errorf(status.NotFound, "peer with the public key %s not found", peerPubKey) +} + +// FindUserPeers returns a list of peers that user owns (created) +func (a *Account) FindUserPeers(userID string) ([]*nbpeer.Peer, error) { + peers := make([]*nbpeer.Peer, 0) + for _, peer := range a.Peers { + if peer.UserID == userID { + peers = append(peers, peer) + } + } + + return peers, nil +} + +// FindUser looks for a given user in the Account or returns error if user wasn't found. +func (a *Account) FindUser(userID string) (*User, error) { + user := a.Users[userID] + if user == nil { + return nil, status.Errorf(status.NotFound, "user %s not found", userID) + } + + return user, nil +} + +// FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found. +func (a *Account) FindGroupByName(groupName string) (*Group, error) { + for _, group := range a.Groups { + if group.Name == groupName { + return group, nil + } + } + return nil, status.Errorf(status.NotFound, "group %s not found", groupName) +} + +// FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found. +func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) { + key := a.SetupKeys[setupKey] + if key == nil { + return nil, status.Errorf(status.NotFound, "setup key not found") + } + + return key, nil +} + +// GetPeerGroupsList return with the list of groups ID. +func (a *Account) GetPeerGroupsList(peerID string) []string { + var grps []string + for groupID, group := range a.Groups { + for _, id := range group.Peers { + if id == peerID { + grps = append(grps, groupID) + break + } + } + } + return grps +} + +func (a *Account) getPeerDNSManagementStatus(peerID string) bool { + peerGroups := a.GetPeerGroups(peerID) + enabled := true + for _, groupID := range a.DNSSettings.DisabledManagementGroups { + _, found := peerGroups[groupID] + if found { + enabled = false + break + } + } + return enabled +} + +func (a *Account) GetPeerGroups(peerID string) LookupMap { + groupList := make(LookupMap) + for groupID, group := range a.Groups { + for _, id := range group.Peers { + if id == peerID { + groupList[groupID] = struct{}{} + break + } + } + } + return groupList +} + +func (a *Account) GetTakenIPs() []net.IP { + var takenIps []net.IP + for _, existingPeer := range a.Peers { + takenIps = append(takenIps, existingPeer.IP) + } + + return takenIps +} + +func (a *Account) GetPeerDNSLabels() LookupMap { + existingLabels := make(LookupMap) + for _, peer := range a.Peers { + if peer.DNSLabel != "" { + existingLabels[peer.DNSLabel] = struct{}{} + } + } + return existingLabels +} + +func (a *Account) Copy() *Account { + peers := map[string]*nbpeer.Peer{} + for id, peer := range a.Peers { + peers[id] = peer.Copy() + } + + users := map[string]*User{} + for id, user := range a.Users { + users[id] = user.Copy() + } + + setupKeys := map[string]*SetupKey{} + for id, key := range a.SetupKeys { + setupKeys[id] = key.Copy() + } + + groups := map[string]*Group{} + for id, group := range a.Groups { + groups[id] = group.Copy() + } + + policies := []*Policy{} + for _, policy := range a.Policies { + policies = append(policies, policy.Copy()) + } + + routes := map[route.ID]*route.Route{} + for id, r := range a.Routes { + routes[id] = r.Copy() + } + + nsGroups := map[string]*nbdns.NameServerGroup{} + for id, nsGroup := range a.NameServerGroups { + nsGroups[id] = nsGroup.Copy() + } + + dnsSettings := a.DNSSettings.Copy() + + var settings *Settings + if a.Settings != nil { + settings = a.Settings.Copy() + } + + postureChecks := []*posture.Checks{} + for _, postureCheck := range a.PostureChecks { + postureChecks = append(postureChecks, postureCheck.Copy()) + } + + nets := []*networkTypes.Network{} + for _, network := range a.Networks { + nets = append(nets, network.Copy()) + } + + networkRouters := []*routerTypes.NetworkRouter{} + for _, router := range a.NetworkRouters { + networkRouters = append(networkRouters, router.Copy()) + } + + networkResources := []*resourceTypes.NetworkResource{} + for _, resource := range a.NetworkResources { + networkResources = append(networkResources, resource.Copy()) + } + + return &Account{ + Id: a.Id, + CreatedBy: a.CreatedBy, + CreatedAt: a.CreatedAt, + Domain: a.Domain, + DomainCategory: a.DomainCategory, + IsDomainPrimaryAccount: a.IsDomainPrimaryAccount, + SetupKeys: setupKeys, + Network: a.Network.Copy(), + Peers: peers, + Users: users, + Groups: groups, + Policies: policies, + Routes: routes, + NameServerGroups: nsGroups, + DNSSettings: dnsSettings, + PostureChecks: postureChecks, + Settings: settings, + Networks: nets, + NetworkRouters: networkRouters, + NetworkResources: networkResources, + } +} + +func (a *Account) GetGroupAll() (*Group, error) { + for _, g := range a.Groups { + if g.Name == "All" { + return g, nil + } + } + return nil, fmt.Errorf("no group ALL found") +} + +// GetPeer looks up a Peer by ID +func (a *Account) GetPeer(peerID string) *nbpeer.Peer { + return a.Peers[peerID] +} + +// UserGroupsAddToPeers adds groups to all peers of user +func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + + userPeers := make(map[string]struct{}) + for pid, peer := range a.Peers { + if peer.UserID == userID { + userPeers[pid] = struct{}{} + } + } + + for _, gid := range groups { + group, ok := a.Groups[gid] + if !ok { + continue + } + + oldPeers := group.Peers + + groupPeers := make(map[string]struct{}) + for _, pid := range group.Peers { + groupPeers[pid] = struct{}{} + } + + for pid := range userPeers { + groupPeers[pid] = struct{}{} + } + + group.Peers = group.Peers[:0] + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } + + groupUpdates[gid] = util.Difference(group.Peers, oldPeers) + } + + return groupUpdates +} + +// UserGroupsRemoveFromPeers removes groups from all peers of user +func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map[string][]string { + groupUpdates := make(map[string][]string) + + for _, gid := range groups { + group, ok := a.Groups[gid] + if !ok || group.Name == "All" { + continue + } + + oldPeers := group.Peers + + update := make([]string, 0, len(group.Peers)) + for _, pid := range group.Peers { + peer, ok := a.Peers[pid] + if !ok { + continue + } + if peer.UserID != userID { + update = append(update, pid) + } + } + group.Peers = update + groupUpdates[gid] = util.Difference(oldPeers, group.Peers) + } + + return groupUpdates +} + +// GetPeerConnectionResources for a given peer +// +// This function returns the list of peers and firewall rules that are applicable to a given peer. +func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { + generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) + for _, policy := range a.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) + + if rule.Bidirectional { + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionIN) + } + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionOUT) + } + } + + if peerInSources { + generateResources(rule, destinationPeers, FirewallRuleDirectionOUT) + } + + if peerInDestinations { + generateResources(rule, sourcePeers, FirewallRuleDirectionIN) + } + } + } + + return getAccumulatedResources() +} + +// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls +// +// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. +// It safe to call the generator function multiple times for same peer and different rules no duplicates will be +// generated. The accumulator function returns the result of all the generator calls. +func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + rules := make([]*FirewallRule, 0) + peers := make([]*nbpeer.Peer, 0) + + all, err := a.GetGroupAll() + if err != nil { + log.WithContext(ctx).Errorf("failed to get group all: %v", err) + all = &Group{} + } + + return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { + isAll := (len(all.Peers) - 1) == len(groupPeers) + for _, peer := range groupPeers { + if peer == nil { + continue + } + + if _, ok := peersExists[peer.ID]; !ok { + peers = append(peers, peer) + peersExists[peer.ID] = struct{}{} + } + + fr := FirewallRule{ + PeerIP: peer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + if isAll { + fr.PeerIP = "0.0.0.0" + } + + ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + + fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 { + rules = append(rules, &fr) + continue + } + + for _, port := range rule.Ports { + pr := fr // clone rule and add set new port + pr.Port = port + rules = append(rules, &pr) + } + } + }, func() ([]*nbpeer.Peer, []*FirewallRule) { + return peers, rules + } +} + +// getAllPeersFromGroups for given peer ID and list of groups +// +// Returns a list of peers from specified groups that pass specified posture checks +// and a boolean indicating if the supplied peer ID exists within these groups. +// +// Important: Posture checks are applicable only to source group peers, +// for destination group peers, call this method with an empty list of sourcePostureChecksIDs +func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { + peerInGroups := false + filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) + for _, g := range groups { + group, ok := a.Groups[g] + if !ok { + continue + } + + for _, p := range group.Peers { + peer, ok := a.Peers[p] + if !ok || peer == nil { + continue + } + + // validate the peer based on policy posture checks applied + isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid { + continue + } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeers = append(filteredPeers, peer) + } + } + return filteredPeers, peerInGroups +} + +// validatePostureChecksOnPeer validates the posture checks on a peer +func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePostureChecksID []string, peerID string) bool { + peer, ok := a.Peers[peerID] + if !ok && peer == nil { + return false + } + + for _, postureChecksID := range sourcePostureChecksID { + postureChecks := a.GetPostureChecks(postureChecksID) + if postureChecks == nil { + continue + } + + for _, check := range postureChecks.GetChecks() { + isValid, err := check.Check(ctx, *peer) + if err != nil { + log.WithContext(ctx).Debugf("an error occurred check %s: on peer: %s :%s", check.Name(), peer.ID, err.Error()) + } + if !isValid { + return false + } + } + } + return true +} + +func (a *Account) GetPostureChecks(postureChecksID string) *posture.Checks { + for _, postureChecks := range a.PostureChecks { + if postureChecks.ID == postureChecksID { + return postureChecks + } + } + return nil +} + +// GetPeerRoutesFirewallRules gets the routes firewall rules associated with a routing peer ID for the account. +func (a *Account) GetPeerRoutesFirewallRules(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0, len(a.Routes)) + + enabledRoutes, _ := a.getRoutingPeerRoutes(ctx, peerID) + for _, route := range enabledRoutes { + // If no access control groups are specified, accept all traffic. + if len(route.AccessControlGroups) == 0 { + defaultPermit := getDefaultPermit(route) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + distributionPeers := a.getDistributionGroupsPeers(route) + + for _, accessGroup := range route.AccessControlGroups { + policies := GetAllRoutePoliciesFromGroups(a, []string{accessGroup}) + rules := a.getRouteFirewallRules(ctx, peerID, policies, route, validatedPeersMap, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, distributionPeers map[string]struct{}) []*RouteFirewallRule { + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + for _, id := range rule.Sources { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := validatedPeersMap[pID] + if distPeer && valid { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peer := a.Peers[pID] + if peer == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peer) + } + return distributionGroupPeers +} + +func (a *Account) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range route.Groups { + group := a.Groups[id] + if group == nil { + continue + } + + for _, pID := range group.Peers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + +func getDefaultPermit(route *route.Route) []*RouteFirewallRule { + var rules []*RouteFirewallRule + + sources := []string{"0.0.0.0/0"} + if route.Network.Addr().Is6() { + sources = []string{"::/0"} + } + rule := RouteFirewallRule{ + SourceRanges: sources, + Action: string(PolicyTrafficActionAccept), + Destination: route.Network.String(), + Protocol: string(PolicyRuleProtocolALL), + Domains: route.Domains, + IsDynamic: route.IsDynamic(), + } + + rules = append(rules, &rule) + + // dynamic routes always contain an IPv4 placeholder as destination, hence we must add IPv6 rules additionally + if route.IsDynamic() { + ruleV6 := rule + ruleV6.SourceRanges = []string{"::/0"} + rules = append(rules, &ruleV6) + } + + return rules +} + +// GetAllRoutePoliciesFromGroups retrieves route policies associated with the specified access control groups +// and returns a list of policies that have rules with destinations matching the specified groups. +func GetAllRoutePoliciesFromGroups(account *Account, accessControlGroups []string) []*Policy { + routePolicies := make([]*Policy, 0) + for _, groupID := range accessControlGroups { + group, ok := account.Groups[groupID] + if !ok { + continue + } + + for _, policy := range account.Policies { + for _, rule := range policy.Rules { + exist := slices.ContainsFunc(rule.Destinations, func(groupID string) bool { + return groupID == group.ID + }) + if exist { + routePolicies = append(routePolicies, policy) + continue + } + } + } + } + + return routePolicies +} + +// GetPeerNetworkResourceFirewallRules gets the network resources firewall rules associated with a routing peer ID for the account. +func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, routes []*route.Route, resourcePolicies map[string][]*Policy) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + for _, route := range routes { + if route.Peer != peer.Key { + continue + } + resourceAppliedPolicies := resourcePolicies[route.GetResourceID()] + distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) + + rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) + routesFirewallRules = append(routesFirewallRules, rules...) + } + + return routesFirewallRules +} + +// getNetworkResourceGroups retrieves all groups associated with the given network resource. +func (a *Account) getNetworkResourceGroups(resourceID string) []*Group { + var networkResourceGroups []*Group + + for _, group := range a.Groups { + for _, resource := range group.Resources { + if resource.ID == resourceID { + networkResourceGroups = append(networkResourceGroups, group) + } + } + } + + return networkResourceGroups +} + +// GetResourcePoliciesMap returns a map of networks resource IDs and their associated policies. +func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { + resourcePolicies := make(map[string][]*Policy) + for _, resource := range a.NetworkResources { + resourceAppliedPolicies := a.GetPoliciesForNetworkResource(resource.ID) + resourcePolicies[resource.ID] = resourceAppliedPolicies + } + return resourcePolicies +} + +// GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers. +func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) { + var isRoutingPeer bool + var routes []*route.Route + var allSourcePeers []string + + for _, resource := range a.NetworkResources { + var addSourcePeers bool + + networkRoutingPeers, exists := routers[resource.NetworkID] + if exists { + if router, ok := networkRoutingPeers[peerID]; ok { + isRoutingPeer, addSourcePeers = true, true + routes = append(routes, a.getNetworkResourcesRoutes(resource, peerID, router, resourcePolicies)...) + } + } + + for _, policy := range resourcePolicies[resource.ID] { + for _, sourceGroup := range policy.SourceGroups() { + group := a.GetGroup(sourceGroup) + if group == nil { + log.WithContext(ctx).Warnf("policy %s has source group %s that doesn't exist under account %s, will continue map generation without it", policy.ID, sourceGroup, a.Id) + continue + } + + // routing peer should be able to connect with all source peers + if addSourcePeers { + allSourcePeers = append(allSourcePeers, group.Peers...) + } + + // add routes for the resource if the peer is in the distribution group + if slices.Contains(group.Peers, peerID) { + for peerId, router := range networkRoutingPeers { + routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) + } + } + } + } + } + + return isRoutingPeer, routes, allSourcePeers +} + +// getNetworkResources filters and returns a list of network resources associated with the given network ID. +func (a *Account) getNetworkResources(networkID string) []*resourceTypes.NetworkResource { + var resources []*resourceTypes.NetworkResource + for _, resource := range a.NetworkResources { + if resource.NetworkID == networkID { + resources = append(resources, resource) + } + } + return resources +} + +// GetPoliciesForNetworkResource retrieves the list of policies that apply to a specific network resource. +// A policy is deemed applicable if its destination groups include any of the given network resource groups +// or if its destination resource explicitly matches the provided resource. +func (a *Account) GetPoliciesForNetworkResource(resourceId string) []*Policy { + var resourceAppliedPolicies []*Policy + + networkResourceGroups := a.getNetworkResourceGroups(resourceId) + + for _, policy := range a.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + if rule.DestinationResource.ID == resourceId { + resourceAppliedPolicies = append(resourceAppliedPolicies, policy) + break + } + + for _, group := range networkResourceGroups { + if slices.Contains(rule.Destinations, group.ID) { + resourceAppliedPolicies = append(resourceAppliedPolicies, policy) + break + } + } + } + } + + return resourceAppliedPolicies +} + +func (a *Account) GetPoliciesAppliedInNetwork(networkID string) []string { + networkResources := a.getNetworkResources(networkID) + + policiesIDs := map[string]struct{}{} + for _, resource := range networkResources { + resourceAppliedPolicies := a.GetPoliciesForNetworkResource(resource.ID) + for _, policy := range resourceAppliedPolicies { + policiesIDs[policy.ID] = struct{}{} + } + } + + result := make([]string, 0, len(policiesIDs)) + for id := range policiesIDs { + result = append(result, id) + } + + return result +} + +// getNetworkResourcesRoutes convert the network resources list to routes list. +func (a *Account) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerId string, router *routerTypes.NetworkRouter, resourcePolicies map[string][]*Policy) []*route.Route { + resourceAppliedPolicies := resourcePolicies[resource.ID] + + var routes []*route.Route + // distribute the resource routes only if there is policy applied to it + if len(resourceAppliedPolicies) > 0 { + peer := a.GetPeer(peerId) + if peer != nil { + routes = append(routes, resource.ToRoute(peer, router)) + } + } + + return routes +} + +func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.NetworkRouter { + routers := make(map[string]map[string]*routerTypes.NetworkRouter) + + for _, router := range a.NetworkRouters { + if routers[router.NetworkID] == nil { + routers[router.NetworkID] = make(map[string]*routerTypes.NetworkRouter) + } + + if router.Peer != "" { + routers[router.NetworkID][router.Peer] = router + continue + } + + for _, peerGroup := range router.PeerGroups { + g := a.Groups[peerGroup] + if g != nil { + for _, peerID := range g.Peers { + routers[router.NetworkID][peerID] = router + } + } + } + } + + return routers +} + +// getPoliciesSourcePeers collects all unique peers from the source groups defined in the given policies. +func getPoliciesSourcePeers(policies []*Policy, groups map[string]*Group) map[string]struct{} { + sourcePeers := make(map[string]struct{}) + + for _, policy := range policies { + for _, rule := range policy.Rules { + for _, sourceGroup := range rule.Sources { + group := groups[sourceGroup] + if group == nil { + continue + } + + for _, peer := range group.Peers { + sourcePeers[peer] = struct{}{} + } + } + } + } + + return sourcePeers +} diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go new file mode 100644 index 000000000..c73421d16 --- /dev/null +++ b/management/server/types/account_test.go @@ -0,0 +1,375 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/require" + + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/route" +) + +func setupTestAccount() *Account { + return &Account{ + Id: "accountID", + Peers: map[string]*nbpeer.Peer{ + "peer1": { + ID: "peer1", + AccountID: "accountID", + Key: "peer1Key", + }, + "peer2": { + ID: "peer2", + AccountID: "accountID", + Key: "peer2Key", + }, + "peer3": { + ID: "peer3", + AccountID: "accountID", + Key: "peer3Key", + }, + "peer11": { + ID: "peer11", + AccountID: "accountID", + Key: "peer11Key", + }, + "peer12": { + ID: "peer12", + AccountID: "accountID", + Key: "peer12Key", + }, + "peer21": { + ID: "peer21", + AccountID: "accountID", + Key: "peer21Key", + }, + "peer31": { + ID: "peer31", + AccountID: "accountID", + Key: "peer31Key", + }, + "peer32": { + ID: "peer32", + AccountID: "accountID", + Key: "peer32Key", + }, + "peer41": { + ID: "peer41", + AccountID: "accountID", + Key: "peer41Key", + }, + "peer51": { + ID: "peer51", + AccountID: "accountID", + Key: "peer51Key", + }, + "peer61": { + ID: "peer61", + AccountID: "accountID", + Key: "peer61Key", + }, + }, + Groups: map[string]*Group{ + "group1": { + ID: "group1", + Peers: []string{"peer11", "peer12"}, + Resources: []Resource{ + { + ID: "resource1ID", + Type: "Host", + }, + }, + }, + "group2": { + ID: "group2", + Peers: []string{"peer21"}, + Resources: []Resource{ + { + ID: "resource2ID", + Type: "Domain", + }, + }, + }, + "group3": { + ID: "group3", + Peers: []string{"peer31", "peer32"}, + Resources: []Resource{ + { + ID: "resource3ID", + Type: "Subnet", + }, + }, + }, + "group4": { + ID: "group4", + Peers: []string{"peer41"}, + Resources: []Resource{ + { + ID: "resource3ID", + Type: "Subnet", + }, + }, + }, + "group5": { + ID: "group5", + Peers: []string{"peer51"}, + }, + "group6": { + ID: "group6", + Peers: []string{"peer61"}, + }, + }, + Networks: []*networkTypes.Network{ + { + ID: "network1ID", + AccountID: "accountID", + Name: "network1", + }, + { + ID: "network2ID", + AccountID: "accountID", + Name: "network2", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: "router1ID", + NetworkID: "network1ID", + AccountID: "accountID", + Peer: "peer1", + PeerGroups: []string{}, + Masquerade: false, + Metric: 100, + }, + { + ID: "router2ID", + NetworkID: "network2ID", + AccountID: "accountID", + Peer: "peer2", + PeerGroups: []string{}, + Masquerade: false, + Metric: 100, + }, + { + ID: "router3ID", + NetworkID: "network1ID", + AccountID: "accountID", + Peer: "peer3", + PeerGroups: []string{}, + Masquerade: false, + Metric: 100, + }, + { + ID: "router4ID", + NetworkID: "network1ID", + AccountID: "accountID", + Peer: "", + PeerGroups: []string{"group1"}, + Masquerade: false, + Metric: 100, + }, + { + ID: "router5ID", + NetworkID: "network1ID", + AccountID: "accountID", + Peer: "", + PeerGroups: []string{"group2", "group3"}, + Masquerade: false, + Metric: 100, + }, + { + ID: "router6ID", + NetworkID: "network2ID", + AccountID: "accountID", + Peer: "", + PeerGroups: []string{"group4"}, + Masquerade: false, + Metric: 100, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "resource1ID", + AccountID: "accountID", + NetworkID: "network1ID", + }, + { + ID: "resource2ID", + AccountID: "accountID", + NetworkID: "network2ID", + }, + { + ID: "resource3ID", + AccountID: "accountID", + NetworkID: "network1ID", + }, + { + ID: "resource4ID", + AccountID: "accountID", + NetworkID: "network1ID", + }, + }, + Policies: []*Policy{ + { + ID: "policy1ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule1ID", + Enabled: true, + Destinations: []string{"group1"}, + }, + }, + }, + { + ID: "policy2ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule2ID", + Enabled: true, + Destinations: []string{"group3"}, + }, + }, + }, + { + ID: "policy3ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule3ID", + Enabled: true, + Destinations: []string{"group2", "group4"}, + }, + }, + }, + { + ID: "policy4ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule4ID", + Enabled: true, + DestinationResource: Resource{ + ID: "resource4ID", + Type: "Host", + }, + }, + }, + }, + { + ID: "policy5ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule5ID", + Enabled: true, + }, + }, + }, + }, + } +} + +func Test_GetResourceRoutersMap(t *testing.T) { + account := setupTestAccount() + routers := account.GetResourceRoutersMap() + require.Equal(t, 2, len(routers)) + + require.Equal(t, 7, len(routers["network1ID"])) + require.NotNil(t, routers["network1ID"]["peer1"]) + require.NotNil(t, routers["network1ID"]["peer3"]) + require.NotNil(t, routers["network1ID"]["peer11"]) + require.NotNil(t, routers["network1ID"]["peer12"]) + require.NotNil(t, routers["network1ID"]["peer21"]) + require.NotNil(t, routers["network1ID"]["peer31"]) + require.NotNil(t, routers["network1ID"]["peer32"]) + + require.Equal(t, 2, len(routers["network2ID"])) + require.NotNil(t, routers["network2ID"]["peer2"]) + require.NotNil(t, routers["network2ID"]["peer41"]) +} + +func Test_GetResourcePoliciesMap(t *testing.T) { + account := setupTestAccount() + policies := account.GetResourcePoliciesMap() + require.Equal(t, 4, len(policies)) + require.Equal(t, 1, len(policies["resource1ID"])) + require.Equal(t, 1, len(policies["resource2ID"])) + require.Equal(t, 2, len(policies["resource3ID"])) + require.Equal(t, 1, len(policies["resource4ID"])) +} + +func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1"} + networkResourcesRoutes := []*route.Route{ + {Peer: "peer2Key"}, + {Peer: "peer3Key"}, + } + peersToConnect := []*nbpeer.Peer{ + {Key: "peer2Key"}, + } + expiredPeers := []*nbpeer.Peer{ + {Key: "peer4Key"}, + } + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + require.Len(t, result, 2) + require.Equal(t, "peer2Key", result[0].Key) + require.Equal(t, "peer3Key", result[1].Key) +} + +func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1"} + networkResourcesRoutes := []*route.Route{ + {Peer: "peer2Key"}, + } + peersToConnect := []*nbpeer.Peer{ + {Key: "peer2Key"}, + } + expiredPeers := []*nbpeer.Peer{} + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + require.Len(t, result, 1) + require.Equal(t, "peer2Key", result[0].Key) +} + +func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1Key"} + networkResourcesRoutes := []*route.Route{ + {Peer: "peer2Key"}, + {Peer: "peer3Key"}, + } + peersToConnect := []*nbpeer.Peer{ + {Key: "peer2Key"}, + } + expiredPeers := []*nbpeer.Peer{ + {Key: "peer3Key"}, + } + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + require.Len(t, result, 1) + require.Equal(t, "peer2Key", result[0].Key) +} + +func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1"} + networkResourcesRoutes := []*route.Route{} + peersToConnect := []*nbpeer.Peer{} + expiredPeers := []*nbpeer.Peer{} + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + require.Len(t, result, 0) +} diff --git a/management/server/types/dns_settings.go b/management/server/types/dns_settings.go new file mode 100644 index 000000000..1d33bb9fb --- /dev/null +++ b/management/server/types/dns_settings.go @@ -0,0 +1,16 @@ +package types + +// DNSSettings defines dns settings at the account level +type DNSSettings struct { + // DisabledManagementGroups groups whose DNS management is disabled + DisabledManagementGroups []string `gorm:"serializer:json"` +} + +// Copy returns a copy of the DNS settings +func (d DNSSettings) Copy() DNSSettings { + settings := DNSSettings{ + DisabledManagementGroups: make([]string, len(d.DisabledManagementGroups)), + } + copy(settings.DisabledManagementGroups, d.DisabledManagementGroups) + return settings +} diff --git a/management/server/types/firewall_rule.go b/management/server/types/firewall_rule.go new file mode 100644 index 000000000..3d1b7e225 --- /dev/null +++ b/management/server/types/firewall_rule.go @@ -0,0 +1,130 @@ +package types + +import ( + "context" + "fmt" + "strconv" + "strings" + + log "github.com/sirupsen/logrus" + + nbpeer "github.com/netbirdio/netbird/management/server/peer" + nbroute "github.com/netbirdio/netbird/route" +) + +const ( + FirewallRuleDirectionIN = 0 + FirewallRuleDirectionOUT = 1 +) + +// FirewallRule is a rule of the firewall. +type FirewallRule struct { + // PeerIP of the peer + PeerIP string + + // Direction of the traffic + Direction int + + // Action of the traffic + Action string + + // Protocol of the traffic + Protocol string + + // Port of the traffic + Port string +} + +// generateRouteFirewallRules generates a list of firewall rules for a given route. +func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { + rulesExists := make(map[string]struct{}) + rules := make([]*RouteFirewallRule, 0) + + sourceRanges := make([]string, 0, len(groupPeers)) + for _, peer := range groupPeers { + if peer == nil { + continue + } + sourceRanges = append(sourceRanges, fmt.Sprintf(AllowedIPsFormat, peer.IP)) + } + + baseRule := RouteFirewallRule{ + SourceRanges: sourceRanges, + Action: string(rule.Action), + Destination: route.Network.String(), + Protocol: string(rule.Protocol), + Domains: route.Domains, + IsDynamic: route.IsDynamic(), + } + + // generate rule for port range + if len(rule.Ports) == 0 { + rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) + } else { + rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) + + } + + // TODO: generate IPv6 rules for dynamic routes + + return rules +} + +// generateRulesForPeer generates rules for a given peer based on ports and port ranges. +func generateRulesWithPortRanges(baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + + ruleIDBase := generateRuleIDBase(rule, baseRule) + if len(rule.Ports) == 0 { + if len(rule.PortRanges) == 0 { + if _, ok := rulesExists[ruleIDBase]; !ok { + rulesExists[ruleIDBase] = struct{}{} + rules = append(rules, &baseRule) + } + } else { + for _, portRange := range rule.PortRanges { + ruleID := fmt.Sprintf("%s%d-%d", ruleIDBase, portRange.Start, portRange.End) + if _, ok := rulesExists[ruleID]; !ok { + rulesExists[ruleID] = struct{}{} + pr := baseRule + pr.PortRange = portRange + rules = append(rules, &pr) + } + } + } + return rules + } + + return rules +} + +// generateRulesWithPorts generates rules when specific ports are provided. +func generateRulesWithPorts(ctx context.Context, baseRule RouteFirewallRule, rule *PolicyRule, rulesExists map[string]struct{}) []*RouteFirewallRule { + rules := make([]*RouteFirewallRule, 0) + ruleIDBase := generateRuleIDBase(rule, baseRule) + + for _, port := range rule.Ports { + ruleID := ruleIDBase + port + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + pr := baseRule + p, err := strconv.ParseUint(port, 10, 16) + if err != nil { + log.WithContext(ctx).Errorf("failed to parse port %s for rule: %s", port, rule.ID) + continue + } + + pr.Port = uint16(p) + rules = append(rules, &pr) + } + + return rules +} + +// generateRuleIDBase generates the base rule ID for checking duplicates. +func generateRuleIDBase(rule *PolicyRule, baseRule RouteFirewallRule) string { + return rule.ID + strings.Join(baseRule.SourceRanges, ",") + strconv.Itoa(FirewallRuleDirectionIN) + baseRule.Protocol + baseRule.Action +} diff --git a/management/server/group/group.go b/management/server/types/group.go similarity index 58% rename from management/server/group/group.go rename to management/server/types/group.go index 24c60d3ce..00a28fa77 100644 --- a/management/server/group/group.go +++ b/management/server/types/group.go @@ -1,6 +1,9 @@ -package group +package types -import "github.com/netbirdio/netbird/management/server/integration_reference" +import ( + "github.com/netbirdio/netbird/management/server/integration_reference" + "github.com/netbirdio/netbird/management/server/networks/resources/types" +) const ( GroupIssuedAPI = "api" @@ -25,6 +28,9 @@ type Group struct { // Peers list of the group Peers []string `gorm:"serializer:json"` + // Resources contains a list of resources in that group + Resources []Resource `gorm:"serializer:json"` + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` } @@ -33,15 +39,21 @@ func (g *Group) EventMeta() map[string]any { return map[string]any{"name": g.Name} } +func (g *Group) EventMetaResource(resource *types.NetworkResource) map[string]any { + return map[string]any{"name": g.Name, "id": g.ID, "resource_name": resource.Name, "resource_id": resource.ID, "resource_type": resource.Type} +} + func (g *Group) Copy() *Group { group := &Group{ ID: g.ID, Name: g.Name, Issued: g.Issued, Peers: make([]string, len(g.Peers)), + Resources: make([]Resource, len(g.Resources)), IntegrationReference: g.IntegrationReference, } copy(group.Peers, g.Peers) + copy(group.Resources, g.Resources) return group } @@ -81,3 +93,31 @@ func (g *Group) RemovePeer(peerID string) bool { } return false } + +// AddResource adds resource to Resources if not present, returning true if added. +func (g *Group) AddResource(resource Resource) bool { + for _, item := range g.Resources { + if item == resource { + return false + } + } + + g.Resources = append(g.Resources, resource) + return true +} + +// RemoveResource removes resource from Resources if present, returning true if removed. +func (g *Group) RemoveResource(resource Resource) bool { + for i, item := range g.Resources { + if item == resource { + g.Resources = append(g.Resources[:i], g.Resources[i+1:]...) + return true + } + } + return false +} + +// HasResources checks if the group has any resources. +func (g *Group) HasResources() bool { + return len(g.Resources) > 0 +} diff --git a/management/server/group/group_test.go b/management/server/types/group_test.go similarity index 99% rename from management/server/group/group_test.go rename to management/server/types/group_test.go index cb002f8d9..12107c603 100644 --- a/management/server/group/group_test.go +++ b/management/server/types/group_test.go @@ -1,4 +1,4 @@ -package group +package types import ( "testing" diff --git a/management/server/network.go b/management/server/types/network.go similarity index 96% rename from management/server/network.go rename to management/server/types/network.go index a5b188b46..d1fccd149 100644 --- a/management/server/network.go +++ b/management/server/types/network.go @@ -1,4 +1,4 @@ -package server +package types import ( "math/rand" @@ -43,7 +43,7 @@ type Network struct { // Used to synchronize state to the client apps. Serial uint64 - mu sync.Mutex `json:"-" gorm:"-"` + Mu sync.Mutex `json:"-" gorm:"-"` } // NewNetwork creates a new Network initializing it with a Serial=0 @@ -66,15 +66,15 @@ func NewNetwork() *Network { // IncSerial increments Serial by 1 reflecting that the network state has been changed func (n *Network) IncSerial() { - n.mu.Lock() - defer n.mu.Unlock() + n.Mu.Lock() + defer n.Mu.Unlock() n.Serial++ } // CurrentSerial returns the Network.Serial of the network (latest state id) func (n *Network) CurrentSerial() uint64 { - n.mu.Lock() - defer n.mu.Unlock() + n.Mu.Lock() + defer n.Mu.Unlock() return n.Serial } diff --git a/management/server/network_test.go b/management/server/types/network_test.go similarity index 98% rename from management/server/network_test.go rename to management/server/types/network_test.go index b067c4991..d0b0894d4 100644 --- a/management/server/network_test.go +++ b/management/server/types/network_test.go @@ -1,4 +1,4 @@ -package server +package types import ( "net" diff --git a/management/server/personal_access_token.go b/management/server/types/personal_access_token.go similarity index 99% rename from management/server/personal_access_token.go rename to management/server/types/personal_access_token.go index f46666112..1bf225856 100644 --- a/management/server/personal_access_token.go +++ b/management/server/types/personal_access_token.go @@ -1,4 +1,4 @@ -package server +package types import ( "crypto/sha256" diff --git a/management/server/personal_access_token_test.go b/management/server/types/personal_access_token_test.go similarity index 98% rename from management/server/personal_access_token_test.go rename to management/server/types/personal_access_token_test.go index 311ffd9cf..ac3377151 100644 --- a/management/server/personal_access_token_test.go +++ b/management/server/types/personal_access_token_test.go @@ -1,4 +1,4 @@ -package server +package types import ( "crypto/sha256" diff --git a/management/server/types/policy.go b/management/server/types/policy.go new file mode 100644 index 000000000..c2b82d68a --- /dev/null +++ b/management/server/types/policy.go @@ -0,0 +1,125 @@ +package types + +const ( + // PolicyTrafficActionAccept indicates that the traffic is accepted + PolicyTrafficActionAccept = PolicyTrafficActionType("accept") + // PolicyTrafficActionDrop indicates that the traffic is dropped + PolicyTrafficActionDrop = PolicyTrafficActionType("drop") +) + +const ( + // PolicyRuleProtocolALL type of traffic + PolicyRuleProtocolALL = PolicyRuleProtocolType("all") + // PolicyRuleProtocolTCP type of traffic + PolicyRuleProtocolTCP = PolicyRuleProtocolType("tcp") + // PolicyRuleProtocolUDP type of traffic + PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp") + // PolicyRuleProtocolICMP type of traffic + PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp") +) + +const ( + // PolicyRuleFlowDirect allows traffic from source to destination + PolicyRuleFlowDirect = PolicyRuleDirection("direct") + // PolicyRuleFlowBidirect allows traffic to both directions + PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect") +) + +const ( + // DefaultRuleName is a name for the Default rule that is created for every account + DefaultRuleName = "Default" + // DefaultRuleDescription is a description for the Default rule that is created for every account + DefaultRuleDescription = "This is a default rule that allows connections between all the resources" + // DefaultPolicyName is a name for the Default policy that is created for every account + DefaultPolicyName = "Default" + // DefaultPolicyDescription is a description for the Default policy that is created for every account + DefaultPolicyDescription = "This is a default policy that allows connections between all the resources" +) + +// PolicyUpdateOperation operation object with type and values to be applied +type PolicyUpdateOperation struct { + Type PolicyUpdateOperationType + Values []string +} + +// Policy of the Rego query +type Policy struct { + // ID of the policy' + ID string `gorm:"primaryKey"` + + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + + // Name of the Policy + Name string + + // Description of the policy visible in the UI + Description string + + // Enabled status of the policy + Enabled bool + + // Rules of the policy + Rules []*PolicyRule `gorm:"foreignKey:PolicyID;references:id;constraint:OnDelete:CASCADE;"` + + // SourcePostureChecks are ID references to Posture checks for policy source groups + SourcePostureChecks []string `gorm:"serializer:json"` +} + +// Copy returns a copy of the policy. +func (p *Policy) Copy() *Policy { + c := &Policy{ + ID: p.ID, + AccountID: p.AccountID, + Name: p.Name, + Description: p.Description, + Enabled: p.Enabled, + Rules: make([]*PolicyRule, len(p.Rules)), + SourcePostureChecks: make([]string, len(p.SourcePostureChecks)), + } + for i, r := range p.Rules { + c.Rules[i] = r.Copy() + } + copy(c.SourcePostureChecks, p.SourcePostureChecks) + return c +} + +// EventMeta returns activity event meta related to this policy +func (p *Policy) EventMeta() map[string]any { + return map[string]any{"name": p.Name} +} + +// UpgradeAndFix different version of policies to latest version +func (p *Policy) UpgradeAndFix() { + for _, r := range p.Rules { + // start migrate from version v0.20.3 + if r.Protocol == "" { + r.Protocol = PolicyRuleProtocolALL + } + if r.Protocol == PolicyRuleProtocolALL && !r.Bidirectional { + r.Bidirectional = true + } + // -- v0.20.4 + } +} + +// RuleGroups returns a list of all groups referenced in the policy's rules, +// including sources and destinations. +func (p *Policy) RuleGroups() []string { + groups := make([]string, 0) + for _, rule := range p.Rules { + groups = append(groups, rule.Sources...) + groups = append(groups, rule.Destinations...) + } + + return groups +} + +// SourceGroups returns a slice of all unique source groups referenced in the policy's rules. +func (p *Policy) SourceGroups() []string { + groups := make([]string, 0) + for _, rule := range p.Rules { + groups = append(groups, rule.Sources...) + } + return groups +} diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go new file mode 100644 index 000000000..bd9a99292 --- /dev/null +++ b/management/server/types/policyrule.go @@ -0,0 +1,87 @@ +package types + +// PolicyUpdateOperationType operation type +type PolicyUpdateOperationType int + +// PolicyTrafficActionType action type for the firewall +type PolicyTrafficActionType string + +// PolicyRuleProtocolType type of traffic +type PolicyRuleProtocolType string + +// PolicyRuleDirection direction of traffic +type PolicyRuleDirection string + +// RulePortRange represents a range of ports for a firewall rule. +type RulePortRange struct { + Start uint16 + End uint16 +} + +// PolicyRule is the metadata of the policy +type PolicyRule struct { + // ID of the policy rule + ID string `gorm:"primaryKey"` + + // PolicyID is a reference to Policy that this object belongs + PolicyID string `json:"-" gorm:"index"` + + // Name of the rule visible in the UI + Name string + + // Description of the rule visible in the UI + Description string + + // Enabled status of rule in the system + Enabled bool + + // Action policy accept or drops packets + Action PolicyTrafficActionType + + // Destinations policy destination groups + Destinations []string `gorm:"serializer:json"` + + // DestinationResource policy destination resource that the rule is applied to + DestinationResource Resource `gorm:"serializer:json"` + + // Sources policy source groups + Sources []string `gorm:"serializer:json"` + + // SourceResource policy source resource that the rule is applied to + SourceResource Resource `gorm:"serializer:json"` + + // Bidirectional define if the rule is applicable in both directions, sources, and destinations + Bidirectional bool + + // Protocol type of the traffic + Protocol PolicyRuleProtocolType + + // Ports or it ranges list + Ports []string `gorm:"serializer:json"` + + // PortRanges a list of port ranges. + PortRanges []RulePortRange `gorm:"serializer:json"` +} + +// Copy returns a copy of a policy rule +func (pm *PolicyRule) Copy() *PolicyRule { + rule := &PolicyRule{ + ID: pm.ID, + PolicyID: pm.PolicyID, + Name: pm.Name, + Description: pm.Description, + Enabled: pm.Enabled, + Action: pm.Action, + Destinations: make([]string, len(pm.Destinations)), + Sources: make([]string, len(pm.Sources)), + Bidirectional: pm.Bidirectional, + Protocol: pm.Protocol, + Ports: make([]string, len(pm.Ports)), + PortRanges: make([]RulePortRange, len(pm.PortRanges)), + } + copy(rule.Destinations, pm.Destinations) + copy(rule.Sources, pm.Sources) + copy(rule.Ports, pm.Ports) + copy(rule.PortRanges, pm.PortRanges) + return rule +} diff --git a/management/server/types/resource.go b/management/server/types/resource.go new file mode 100644 index 000000000..820872f20 --- /dev/null +++ b/management/server/types/resource.go @@ -0,0 +1,30 @@ +package types + +import ( + "github.com/netbirdio/netbird/management/server/http/api" +) + +type Resource struct { + ID string + Type string +} + +func (r *Resource) ToAPIResponse() *api.Resource { + if r.ID == "" && r.Type == "" { + return nil + } + + return &api.Resource{ + Id: r.ID, + Type: api.ResourceType(r.Type), + } +} + +func (r *Resource) FromAPIRequest(req *api.Resource) { + if req == nil { + return + } + + r.ID = req.Id + r.Type = string(req.Type) +} diff --git a/management/server/types/route_firewall_rule.go b/management/server/types/route_firewall_rule.go new file mode 100644 index 000000000..64708d68a --- /dev/null +++ b/management/server/types/route_firewall_rule.go @@ -0,0 +1,32 @@ +package types + +import ( + "github.com/netbirdio/netbird/management/domain" +) + +// RouteFirewallRule a firewall rule applicable for a routed network. +type RouteFirewallRule struct { + // SourceRanges IP ranges of the routing peers. + SourceRanges []string + + // Action of the traffic when the rule is applicable + Action string + + // Destination a network prefix for the routed traffic + Destination string + + // Protocol of the traffic + Protocol string + + // Port of the traffic + Port uint16 + + // PortRange represents the range of ports for a firewall rule + PortRange RulePortRange + + // Domains list of network domains for the routed traffic + Domains domain.List + + // isDynamic indicates whether the rule is for DNS routing + IsDynamic bool +} diff --git a/management/server/types/settings.go b/management/server/types/settings.go new file mode 100644 index 000000000..0ce5a6133 --- /dev/null +++ b/management/server/types/settings.go @@ -0,0 +1,68 @@ +package types + +import ( + "time" + + "github.com/netbirdio/netbird/management/server/account" +) + +// Settings represents Account settings structure that can be modified via API and Dashboard +type Settings struct { + // PeerLoginExpirationEnabled globally enables or disables peer login expiration + PeerLoginExpirationEnabled bool + + // PeerLoginExpiration is a setting that indicates when peer login expires. + // Applies to all peers that have Peer.LoginExpirationEnabled set to true. + PeerLoginExpiration time.Duration + + // PeerInactivityExpirationEnabled globally enables or disables peer inactivity expiration + PeerInactivityExpirationEnabled bool + + // PeerInactivityExpiration is a setting that indicates when peer inactivity expires. + // Applies to all peers that have Peer.PeerInactivityExpirationEnabled set to true. + PeerInactivityExpiration time.Duration + + // RegularUsersViewBlocked allows to block regular users from viewing even their own peers and some UI elements + RegularUsersViewBlocked bool + + // GroupsPropagationEnabled allows to propagate auto groups from the user to the peer + GroupsPropagationEnabled bool + + // JWTGroupsEnabled allows extract groups from JWT claim, which name defined in the JWTGroupsClaimName + // and add it to account groups. + JWTGroupsEnabled bool + + // JWTGroupsClaimName from which we extract groups name to add it to account groups + JWTGroupsClaimName string + + // JWTAllowGroups list of groups to which users are allowed access + JWTAllowGroups []string `gorm:"serializer:json"` + + // RoutingPeerDNSResolutionEnabled enabled the DNS resolution on the routing peers + RoutingPeerDNSResolutionEnabled bool + + // Extra is a dictionary of Account settings + Extra *account.ExtraSettings `gorm:"embedded;embeddedPrefix:extra_"` +} + +// Copy copies the Settings struct +func (s *Settings) Copy() *Settings { + settings := &Settings{ + PeerLoginExpirationEnabled: s.PeerLoginExpirationEnabled, + PeerLoginExpiration: s.PeerLoginExpiration, + JWTGroupsEnabled: s.JWTGroupsEnabled, + JWTGroupsClaimName: s.JWTGroupsClaimName, + GroupsPropagationEnabled: s.GroupsPropagationEnabled, + JWTAllowGroups: s.JWTAllowGroups, + RegularUsersViewBlocked: s.RegularUsersViewBlocked, + + PeerInactivityExpirationEnabled: s.PeerInactivityExpirationEnabled, + PeerInactivityExpiration: s.PeerInactivityExpiration, + + RoutingPeerDNSResolutionEnabled: s.RoutingPeerDNSResolutionEnabled, + } + if s.Extra != nil { + settings.Extra = s.Extra.Copy() + } + return settings +} diff --git a/management/server/types/setupkey.go b/management/server/types/setupkey.go new file mode 100644 index 000000000..a5cf346a0 --- /dev/null +++ b/management/server/types/setupkey.go @@ -0,0 +1,181 @@ +package types + +import ( + "crypto/sha256" + b64 "encoding/base64" + "hash/fnv" + "strconv" + "strings" + "time" + "unicode/utf8" + + "github.com/google/uuid" +) + +const ( + // SetupKeyReusable is a multi-use key (can be used for multiple machines) + SetupKeyReusable SetupKeyType = "reusable" + // SetupKeyOneOff is a single use key (can be used only once) + SetupKeyOneOff SetupKeyType = "one-off" + // DefaultSetupKeyDuration = 1 month + DefaultSetupKeyDuration = 24 * 30 * time.Hour + // DefaultSetupKeyName is a default name of the default setup key + DefaultSetupKeyName = "Default key" + // SetupKeyUnlimitedUsage indicates an unlimited usage of a setup key + SetupKeyUnlimitedUsage = 0 +) + +// SetupKeyType is the type of setup key +type SetupKeyType string + +// SetupKey represents a pre-authorized key used to register machines (peers) +type SetupKey struct { + Id string + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + Key string + KeySecret string + Name string + Type SetupKeyType + CreatedAt time.Time + ExpiresAt time.Time + UpdatedAt time.Time `gorm:"autoUpdateTime:false"` + // Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes) + Revoked bool + // UsedTimes indicates how many times the key was used + UsedTimes int + // LastUsed last time the key was used for peer registration + LastUsed time.Time + // AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register + AutoGroups []string `gorm:"serializer:json"` + // UsageLimit indicates the number of times this key can be used to enroll a machine. + // The value of 0 indicates the unlimited usage. + UsageLimit int + // Ephemeral indicate if the peers will be ephemeral or not + Ephemeral bool +} + +// Copy copies SetupKey to a new object +func (key *SetupKey) Copy() *SetupKey { + autoGroups := make([]string, len(key.AutoGroups)) + copy(autoGroups, key.AutoGroups) + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + return &SetupKey{ + Id: key.Id, + AccountID: key.AccountID, + Key: key.Key, + KeySecret: key.KeySecret, + Name: key.Name, + Type: key.Type, + CreatedAt: key.CreatedAt, + ExpiresAt: key.ExpiresAt, + UpdatedAt: key.UpdatedAt, + Revoked: key.Revoked, + UsedTimes: key.UsedTimes, + LastUsed: key.LastUsed, + AutoGroups: autoGroups, + UsageLimit: key.UsageLimit, + Ephemeral: key.Ephemeral, + } +} + +// EventMeta returns activity event meta related to the setup key +func (key *SetupKey) EventMeta() map[string]any { + return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret} +} + +// HiddenKey returns the Key value hidden with "*" and a 5 character prefix. +// E.g., "831F6*******************************" +func HiddenKey(key string, length int) string { + prefix := key[0:5] + if length > utf8.RuneCountInString(key) { + length = utf8.RuneCountInString(key) - len(prefix) + } + return prefix + strings.Repeat("*", length) +} + +// IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now +func (key *SetupKey) IncrementUsage() *SetupKey { + c := key.Copy() + c.UsedTimes++ + c.LastUsed = time.Now().UTC() + return c +} + +// IsValid is true if the key was not revoked, is not expired and used not more than it was supposed to +func (key *SetupKey) IsValid() bool { + return !key.IsRevoked() && !key.IsExpired() && !key.IsOverUsed() +} + +// IsRevoked if key was revoked +func (key *SetupKey) IsRevoked() bool { + return key.Revoked +} + +// IsExpired if key was expired +func (key *SetupKey) IsExpired() bool { + if key.ExpiresAt.IsZero() { + return false + } + return time.Now().After(key.ExpiresAt) +} + +// IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage. +func (key *SetupKey) IsOverUsed() bool { + limit := key.UsageLimit + if key.Type == SetupKeyOneOff { + limit = 1 + } + return limit > 0 && key.UsedTimes >= limit +} + +// GenerateSetupKey generates a new setup key +func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoGroups []string, + usageLimit int, ephemeral bool) (*SetupKey, string) { + key := strings.ToUpper(uuid.New().String()) + limit := usageLimit + if t == SetupKeyOneOff { + limit = 1 + } + + expiresAt := time.Time{} + if validFor != 0 { + expiresAt = time.Now().UTC().Add(validFor) + } + + hashedKey := sha256.Sum256([]byte(key)) + encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) + + return &SetupKey{ + Id: strconv.Itoa(int(Hash(key))), + Key: encodedHashedKey, + KeySecret: HiddenKey(key, 4), + Name: name, + Type: t, + CreatedAt: time.Now().UTC(), + ExpiresAt: expiresAt, + UpdatedAt: time.Now().UTC(), + Revoked: false, + UsedTimes: 0, + AutoGroups: autoGroups, + UsageLimit: limit, + Ephemeral: ephemeral, + }, key +} + +// GenerateDefaultSetupKey generates a default reusable setup key with an unlimited usage and 30 days expiration +func GenerateDefaultSetupKey() (*SetupKey, string) { + return GenerateSetupKey(DefaultSetupKeyName, SetupKeyReusable, DefaultSetupKeyDuration, []string{}, + SetupKeyUnlimitedUsage, false) +} + +func Hash(s string) uint32 { + h := fnv.New32a() + _, err := h.Write([]byte(s)) + if err != nil { + panic(err) + } + return h.Sum32() +} diff --git a/management/server/types/user.go b/management/server/types/user.go new file mode 100644 index 000000000..5f1b71792 --- /dev/null +++ b/management/server/types/user.go @@ -0,0 +1,231 @@ +package types + +import ( + "fmt" + "strings" + "time" + + "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/netbird/management/server/integration_reference" +) + +const ( + UserRoleOwner UserRole = "owner" + UserRoleAdmin UserRole = "admin" + UserRoleUser UserRole = "user" + UserRoleUnknown UserRole = "unknown" + UserRoleBillingAdmin UserRole = "billing_admin" + + UserStatusActive UserStatus = "active" + UserStatusDisabled UserStatus = "disabled" + UserStatusInvited UserStatus = "invited" + + UserIssuedAPI = "api" + UserIssuedIntegration = "integration" +) + +// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown +func StrRoleToUserRole(strRole string) UserRole { + switch strings.ToLower(strRole) { + case "owner": + return UserRoleOwner + case "admin": + return UserRoleAdmin + case "user": + return UserRoleUser + case "billing_admin": + return UserRoleBillingAdmin + default: + return UserRoleUnknown + } +} + +// UserStatus is the status of a User +type UserStatus string + +// UserRole is the role of a User +type UserRole string + +type UserInfo struct { + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Role string `json:"role"` + AutoGroups []string `json:"auto_groups"` + Status string `json:"-"` + IsServiceUser bool `json:"is_service_user"` + IsBlocked bool `json:"is_blocked"` + NonDeletable bool `json:"non_deletable"` + LastLogin time.Time `json:"last_login"` + Issued string `json:"issued"` + IntegrationReference integration_reference.IntegrationReference `json:"-"` + Permissions UserPermissions `json:"permissions"` +} + +type UserPermissions struct { + DashboardView string `json:"dashboard_view"` +} + +// User represents a user of the system +type User struct { + Id string `gorm:"primaryKey"` + // AccountID is a reference to Account that this object belongs + AccountID string `json:"-" gorm:"index"` + Role UserRole + IsServiceUser bool + // NonDeletable indicates whether the service user can be deleted + NonDeletable bool + // ServiceUserName is only set if IsServiceUser is true + ServiceUserName string + // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user + AutoGroups []string `gorm:"serializer:json"` + PATs map[string]*PersonalAccessToken `gorm:"-"` + PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"` + // Blocked indicates whether the user is blocked. Blocked users can't use the system. + Blocked bool + // LastLogin is the last time the user logged in to IdP + LastLogin time.Time + // CreatedAt records the time the user was created + CreatedAt time.Time + + // Issued of the user + Issued string `gorm:"default:api"` + + IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` +} + +// IsBlocked returns true if the user is blocked, false otherwise +func (u *User) IsBlocked() bool { + return u.Blocked +} + +func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool { + return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() +} + +// HasAdminPower returns true if the user has admin or owner roles, false otherwise +func (u *User) HasAdminPower() bool { + return u.Role == UserRoleAdmin || u.Role == UserRoleOwner +} + +// IsAdminOrServiceUser checks if the user has admin power or is a service user. +func (u *User) IsAdminOrServiceUser() bool { + return u.HasAdminPower() || u.IsServiceUser +} + +// IsRegularUser checks if the user is a regular user. +func (u *User) IsRegularUser() bool { + return !u.HasAdminPower() && !u.IsServiceUser +} + +// ToUserInfo converts a User object to a UserInfo object. +func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { + autoGroups := u.AutoGroups + if autoGroups == nil { + autoGroups = []string{} + } + + dashboardViewPermissions := "full" + if !u.HasAdminPower() { + dashboardViewPermissions = "limited" + if settings.RegularUsersViewBlocked { + dashboardViewPermissions = "blocked" + } + } + + if userData == nil { + return &UserInfo{ + ID: u.Id, + Email: "", + Name: u.ServiceUserName, + Role: string(u.Role), + AutoGroups: u.AutoGroups, + Status: string(UserStatusActive), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.LastLogin, + Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, + }, nil + } + if userData.ID != u.Id { + return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id) + } + + userStatus := UserStatusActive + if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite { + userStatus = UserStatusInvited + } + + return &UserInfo{ + ID: u.Id, + Email: userData.Email, + Name: userData.Name, + Role: string(u.Role), + AutoGroups: autoGroups, + Status: string(userStatus), + IsServiceUser: u.IsServiceUser, + IsBlocked: u.Blocked, + LastLogin: u.LastLogin, + Issued: u.Issued, + Permissions: UserPermissions{ + DashboardView: dashboardViewPermissions, + }, + }, nil +} + +// Copy the user +func (u *User) Copy() *User { + autoGroups := make([]string, len(u.AutoGroups)) + copy(autoGroups, u.AutoGroups) + pats := make(map[string]*PersonalAccessToken, len(u.PATs)) + for k, v := range u.PATs { + pats[k] = v.Copy() + } + return &User{ + Id: u.Id, + AccountID: u.AccountID, + Role: u.Role, + AutoGroups: autoGroups, + IsServiceUser: u.IsServiceUser, + NonDeletable: u.NonDeletable, + ServiceUserName: u.ServiceUserName, + PATs: pats, + Blocked: u.Blocked, + LastLogin: u.LastLogin, + CreatedAt: u.CreatedAt, + Issued: u.Issued, + IntegrationReference: u.IntegrationReference, + } +} + +// NewUser creates a new user +func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User { + return &User{ + Id: id, + Role: role, + IsServiceUser: isServiceUser, + NonDeletable: nonDeletable, + ServiceUserName: serviceUserName, + AutoGroups: autoGroups, + Issued: issued, + CreatedAt: time.Now().UTC(), + } +} + +// NewRegularUser creates a new user with role UserRoleUser +func NewRegularUser(id string) *User { + return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI) +} + +// NewAdminUser creates a new user with role UserRoleAdmin +func NewAdminUser(id string) *User { + return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI) +} + +// NewOwnerUser creates a new user with role UserRoleOwner +func NewOwnerUser(id string) *User { + return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI) +} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index d338b84b1..de7dd57df 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -9,13 +9,14 @@ import ( "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" ) const channelBufferSize = 100 type UpdateMessage struct { Update *proto.SyncResponse - NetworkMap *NetworkMap + NetworkMap *types.NetworkMap } type PeersUpdateManager struct { diff --git a/management/server/user.go b/management/server/user.go index edb5e6fd3..457721917 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -13,217 +13,17 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbContext "github.com/netbirdio/netbird/management/server/context" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/idp" - "github.com/netbirdio/netbird/management/server/integration_reference" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/management/server/util" ) -const ( - UserRoleOwner UserRole = "owner" - UserRoleAdmin UserRole = "admin" - UserRoleUser UserRole = "user" - UserRoleUnknown UserRole = "unknown" - UserRoleBillingAdmin UserRole = "billing_admin" - - UserStatusActive UserStatus = "active" - UserStatusDisabled UserStatus = "disabled" - UserStatusInvited UserStatus = "invited" - - UserIssuedAPI = "api" - UserIssuedIntegration = "integration" -) - -// StrRoleToUserRole returns UserRole for a given strRole or UserRoleUnknown if the specified role is unknown -func StrRoleToUserRole(strRole string) UserRole { - switch strings.ToLower(strRole) { - case "owner": - return UserRoleOwner - case "admin": - return UserRoleAdmin - case "user": - return UserRoleUser - case "billing_admin": - return UserRoleBillingAdmin - default: - return UserRoleUnknown - } -} - -// UserStatus is the status of a User -type UserStatus string - -// UserRole is the role of a User -type UserRole string - -// User represents a user of the system -type User struct { - Id string `gorm:"primaryKey"` - // AccountID is a reference to Account that this object belongs - AccountID string `json:"-" gorm:"index"` - Role UserRole - IsServiceUser bool - // NonDeletable indicates whether the service user can be deleted - NonDeletable bool - // ServiceUserName is only set if IsServiceUser is true - ServiceUserName string - // AutoGroups is a list of Group IDs to auto-assign to peers registered by this user - AutoGroups []string `gorm:"serializer:json"` - PATs map[string]*PersonalAccessToken `gorm:"-"` - PATsG []PersonalAccessToken `json:"-" gorm:"foreignKey:UserID;references:id"` - // Blocked indicates whether the user is blocked. Blocked users can't use the system. - Blocked bool - // LastLogin is the last time the user logged in to IdP - LastLogin time.Time - // CreatedAt records the time the user was created - CreatedAt time.Time - - // Issued of the user - Issued string `gorm:"default:api"` - - IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"` -} - -// IsBlocked returns true if the user is blocked, false otherwise -func (u *User) IsBlocked() bool { - return u.Blocked -} - -func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool { - return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() -} - -// HasAdminPower returns true if the user has admin or owner roles, false otherwise -func (u *User) HasAdminPower() bool { - return u.Role == UserRoleAdmin || u.Role == UserRoleOwner -} - -// IsAdminOrServiceUser checks if the user has admin power or is a service user. -func (u *User) IsAdminOrServiceUser() bool { - return u.HasAdminPower() || u.IsServiceUser -} - -// IsRegularUser checks if the user is a regular user. -func (u *User) IsRegularUser() bool { - return !u.HasAdminPower() && !u.IsServiceUser -} - -// ToUserInfo converts a User object to a UserInfo object. -func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo, error) { - autoGroups := u.AutoGroups - if autoGroups == nil { - autoGroups = []string{} - } - - dashboardViewPermissions := "full" - if !u.HasAdminPower() { - dashboardViewPermissions = "limited" - if settings.RegularUsersViewBlocked { - dashboardViewPermissions = "blocked" - } - } - - if userData == nil { - return &UserInfo{ - ID: u.Id, - Email: "", - Name: u.ServiceUserName, - Role: string(u.Role), - AutoGroups: u.AutoGroups, - Status: string(UserStatusActive), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.LastLogin, - Issued: u.Issued, - Permissions: UserPermissions{ - DashboardView: dashboardViewPermissions, - }, - }, nil - } - if userData.ID != u.Id { - return nil, fmt.Errorf("wrong UserData provided for user %s", u.Id) - } - - userStatus := UserStatusActive - if userData.AppMetadata.WTPendingInvite != nil && *userData.AppMetadata.WTPendingInvite { - userStatus = UserStatusInvited - } - - return &UserInfo{ - ID: u.Id, - Email: userData.Email, - Name: userData.Name, - Role: string(u.Role), - AutoGroups: autoGroups, - Status: string(userStatus), - IsServiceUser: u.IsServiceUser, - IsBlocked: u.Blocked, - LastLogin: u.LastLogin, - Issued: u.Issued, - Permissions: UserPermissions{ - DashboardView: dashboardViewPermissions, - }, - }, nil -} - -// Copy the user -func (u *User) Copy() *User { - autoGroups := make([]string, len(u.AutoGroups)) - copy(autoGroups, u.AutoGroups) - pats := make(map[string]*PersonalAccessToken, len(u.PATs)) - for k, v := range u.PATs { - pats[k] = v.Copy() - } - return &User{ - Id: u.Id, - AccountID: u.AccountID, - Role: u.Role, - AutoGroups: autoGroups, - IsServiceUser: u.IsServiceUser, - NonDeletable: u.NonDeletable, - ServiceUserName: u.ServiceUserName, - PATs: pats, - Blocked: u.Blocked, - LastLogin: u.LastLogin, - CreatedAt: u.CreatedAt, - Issued: u.Issued, - IntegrationReference: u.IntegrationReference, - } -} - -// NewUser creates a new user -func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User { - return &User{ - Id: id, - Role: role, - IsServiceUser: isServiceUser, - NonDeletable: nonDeletable, - ServiceUserName: serviceUserName, - AutoGroups: autoGroups, - Issued: issued, - CreatedAt: time.Now().UTC(), - } -} - -// NewRegularUser creates a new user with role UserRoleUser -func NewRegularUser(id string) *User { - return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI) -} - -// NewAdminUser creates a new user with role UserRoleAdmin -func NewAdminUser(id string) *User { - return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI) -} - -// NewOwnerUser creates a new user with role UserRoleOwner -func NewOwnerUser(id string) *User { - return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI) -} - // createServiceUser creates a new service user under the given account. -func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*UserInfo, error) { +func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountID string, initiatorUserID string, role types.UserRole, serviceUserName string, nonDeletable bool, autoGroups []string) (*types.UserInfo, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -240,12 +40,12 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI return nil, status.Errorf(status.PermissionDenied, "only users with admin power can create service users") } - if role == UserRoleOwner { + if role == types.UserRoleOwner { return nil, status.Errorf(status.InvalidArgument, "can't create a service user with owner role") } newUserID := uuid.New().String() - newUser := NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, UserIssuedAPI) + newUser := types.NewUser(newUserID, role, true, nonDeletable, serviceUserName, autoGroups, types.UserIssuedAPI) log.WithContext(ctx).Debugf("New User: %v", newUser) account.Users[newUserID] = newUser @@ -257,29 +57,29 @@ func (am *DefaultAccountManager) createServiceUser(ctx context.Context, accountI meta := map[string]any{"name": newUser.ServiceUserName} am.StoreEvent(ctx, initiatorUserID, newUser.Id, accountID, activity.ServiceUserCreated, meta) - return &UserInfo{ + return &types.UserInfo{ ID: newUser.Id, Email: "", Name: newUser.ServiceUserName, Role: string(newUser.Role), AutoGroups: newUser.AutoGroups, - Status: string(UserStatusActive), + Status: string(types.UserStatusActive), IsServiceUser: true, LastLogin: time.Time{}, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, }, nil } // CreateUser creates a new user under the given account. Effectively this is a user invite. -func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *UserInfo) (*UserInfo, error) { +func (am *DefaultAccountManager) CreateUser(ctx context.Context, accountID, userID string, user *types.UserInfo) (*types.UserInfo, error) { if user.IsServiceUser { - return am.createServiceUser(ctx, accountID, userID, StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups) + return am.createServiceUser(ctx, accountID, userID, types.StrRoleToUserRole(user.Role), user.Name, user.NonDeletable, user.AutoGroups) } return am.inviteNewUser(ctx, accountID, userID, user) } // inviteNewUser Invites a USer to a given account and creates reference in datastore -func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *UserInfo) (*UserInfo, error) { +func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, userID string, invite *types.UserInfo) (*types.UserInfo, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -291,14 +91,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, fmt.Errorf("provided user update is nil") } - invitedRole := StrRoleToUserRole(invite.Role) + invitedRole := types.StrRoleToUserRole(invite.Role) switch { case invite.Name == "": return nil, status.Errorf(status.InvalidArgument, "name can't be empty") case invite.Email == "": return nil, status.Errorf(status.InvalidArgument, "email can't be empty") - case invitedRole == UserRoleOwner: + case invitedRole == types.UserRoleOwner: return nil, status.Errorf(status.InvalidArgument, "can't invite a user with owner role") default: } @@ -348,7 +148,7 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return nil, err } - newUser := &User{ + newUser := &types.User{ Id: idpUser.ID, Role: invitedRole, AutoGroups: invite.AutoGroups, @@ -373,19 +173,19 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u return newUser.ToUserInfo(idpUser, account.Settings) } -func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*User, error) { - return am.Store.GetUserByUserID(ctx, LockingStrengthShare, id) +func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { + return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) } // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. -func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { +func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*types.User, error) { accountID, userID, err := am.GetAccountIDFromToken(ctx, claims) if err != nil { return nil, fmt.Errorf("failed to get account with token claims %v", err) } - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { return nil, err } @@ -409,7 +209,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A // ListUsers returns lists of all users under the account. // It doesn't populate user information such as email or name. -func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*User, error) { +func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -418,7 +218,7 @@ func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string return nil, err } - users := make([]*User, 0, len(account.Users)) + users := make([]*types.User, 0, len(account.Users)) for _, item := range account.Users { users = append(users, item) } @@ -426,7 +226,7 @@ func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string return users, nil } -func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *Account, initiatorUserID string, targetUser *User) { +func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, account *types.Account, initiatorUserID string, targetUser *types.User) { meta := map[string]any{"name": targetUser.ServiceUserName, "created_at": targetUser.CreatedAt} am.StoreEvent(ctx, initiatorUserID, targetUser.Id, account.Id, activity.ServiceUserDeleted, meta) delete(account.Users, targetUser.Id) @@ -458,12 +258,12 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return status.Errorf(status.NotFound, "target user not found") } - if targetUser.Role == UserRoleOwner { + if targetUser.Role == types.UserRoleOwner { return status.Errorf(status.PermissionDenied, "unable to delete a user with owner role") } // disable deleting integration user if the initiator is not admin service user - if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser { + if targetUser.Issued == types.UserIssuedIntegration && !executingUser.IsServiceUser { return status.Errorf(status.PermissionDenied, "only integration service user can delete this user") } @@ -480,7 +280,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init return am.deleteRegularUser(ctx, account, initiatorUserID, targetUserID) } -func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error { +func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *types.Account, initiatorUserID, targetUserID string) error { meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) if err != nil { return err @@ -494,13 +294,13 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) if updateAccountPeers { - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } return nil } -func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) { +func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *types.Account) (bool, error) { peers, err := account.FindUserPeers(targetUserID) if err != nil { return false, status.Errorf(status.Internal, "failed to find user peers") @@ -560,7 +360,7 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin } // CreatePAT creates a new PAT for the given user -func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { +func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -591,7 +391,7 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user") } - pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id) + pat, err := types.CreateNewPAT(tokenName, expiresIn, executingUser.Id) if err != nil { return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) } @@ -660,13 +460,13 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string } // GetPAT returns a specific PAT from a user -func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*PersonalAccessToken, error) { - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) +func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) (*types.PersonalAccessToken, error) { + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) if err != nil { return nil, err } @@ -685,13 +485,13 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i } // GetAllPATs returns all PATs for a user -func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*PersonalAccessToken, error) { - initiatorUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) +func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) ([]*types.PersonalAccessToken, error) { + initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) if err != nil { return nil, err } @@ -700,7 +500,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") } - pats := make([]*PersonalAccessToken, 0, len(targetUser.PATsG)) + pats := make([]*types.PersonalAccessToken, 0, len(targetUser.PATsG)) for _, pat := range targetUser.PATsG { pats = append(pats, pat.Copy()) } @@ -709,13 +509,13 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin } // SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error. -func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error) { +func (am *DefaultAccountManager) SaveUser(ctx context.Context, accountID, initiatorUserID string, update *types.User) (*types.UserInfo, error) { return am.SaveOrAddUser(ctx, accountID, initiatorUserID, update, false) // false means do not create user and throw status.NotFound } // SaveOrAddUser updates the given user. If addIfNotExists is set to true it will add user when no exist // Only User.AutoGroups, User.Role, and User.Blocked fields are allowed to be updated for now. -func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *User, addIfNotExists bool) (*UserInfo, error) { +func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, initiatorUserID string, update *types.User, addIfNotExists bool) (*types.UserInfo, error) { if update == nil { return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } @@ -723,7 +523,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*User{update}, addIfNotExists) + updatedUsers, err := am.SaveOrAddUsers(ctx, accountID, initiatorUserID, []*types.User{update}, addIfNotExists) if err != nil { return nil, err } @@ -738,7 +538,7 @@ func (am *DefaultAccountManager) SaveOrAddUser(ctx context.Context, accountID, i // SaveOrAddUsers updates existing users or adds new users to the account. // Note: This function does not acquire the global lock. // It is the caller's responsibility to ensure proper locking is in place before invoking this method. -func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) { +func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) { if len(updates) == 0 { return nil, nil //nolint:nilnil } @@ -757,7 +557,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, status.Errorf(status.PermissionDenied, "only users with admin power are authorized to perform user update operations") } - updatedUsers := make([]*UserInfo, 0, len(updates)) + updatedUsers := make([]*types.UserInfo, 0, len(updates)) var ( expiredPeers []*nbpeer.Peer userIDs []string @@ -808,7 +608,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, peerGroupsAdded := make(map[string][]string) peerGroupsRemoved := make(map[string][]string) if update.AutoGroups != nil && account.Settings.GroupsPropagationEnabled { - removedGroups := difference(oldUser.AutoGroups, update.AutoGroups) + removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups) // need force update all auto groups in any case they will not be duplicated peerGroupsAdded = account.UserGroupsAddToPeers(oldUser.Id, update.AutoGroups...) peerGroupsRemoved = account.UserGroupsRemoveFromPeers(oldUser.Id, removedGroups...) @@ -840,7 +640,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } for _, storeEvent := range eventsToStore { @@ -851,7 +651,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, transferredOwnerRole bool) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, transferredOwnerRole bool) []func() { var eventsToStore []func() if oldUser.IsBlocked() != newUser.IsBlocked() { @@ -880,11 +680,11 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, in return eventsToStore } -func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() { +func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, peerGroupsAdded, peerGroupsRemoved map[string][]string) []func() { var eventsToStore []func() if newUser.AutoGroups != nil { - removedGroups := difference(oldUser.AutoGroups, newUser.AutoGroups) - addedGroups := difference(newUser.AutoGroups, oldUser.AutoGroups) + removedGroups := util.Difference(oldUser.AutoGroups, newUser.AutoGroups) + addedGroups := util.Difference(newUser.AutoGroups, oldUser.AutoGroups) removedEvents := am.handleGroupRemovedFromUser(ctx, initiatorUserID, oldUser, newUser, account, removedGroups, peerGroupsRemoved) eventsToStore = append(eventsToStore, removedEvents...) @@ -895,7 +695,7 @@ func (am *DefaultAccountManager) prepareUserGroupsEvents(ctx context.Context, in return eventsToStore } -func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() { +func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, addedGroups []string, peerGroupsAdded map[string][]string) []func() { var eventsToStore []func() for _, g := range addedGroups { group := account.GetGroup(g) @@ -922,7 +722,7 @@ func (am *DefaultAccountManager) handleGroupAddedToUser(ctx context.Context, ini return eventsToStore } -func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *User, account *Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() { +func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, initiatorUserID string, oldUser, newUser *types.User, account *types.Account, removedGroups []string, peerGroupsRemoved map[string][]string) []func() { var eventsToStore []func() for _, g := range removedGroups { group := account.GetGroup(g) @@ -952,10 +752,10 @@ func (am *DefaultAccountManager) handleGroupRemovedFromUser(ctx context.Context, return eventsToStore } -func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool { - if initiatorUser.Role == UserRoleOwner && initiatorUser.Id != update.Id && update.Role == UserRoleOwner { +func handleOwnerRoleTransfer(account *types.Account, initiatorUser, update *types.User) bool { + if initiatorUser.Role == types.UserRoleOwner && initiatorUser.Id != update.Id && update.Role == types.UserRoleOwner { newInitiatorUser := initiatorUser.Copy() - newInitiatorUser.Role = UserRoleAdmin + newInitiatorUser.Role = types.UserRoleAdmin account.Users[initiatorUser.Id] = newInitiatorUser return true } @@ -965,7 +765,7 @@ func handleOwnerRoleTransfer(account *Account, initiatorUser, update *User) bool // getUserInfo retrieves the UserInfo for a given User and Account. // If the AccountManager has a non-nil idpManager and the User is not a service user, // it will attempt to look up the UserData from the cache. -func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, account *Account) (*UserInfo, error) { +func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *types.User, account *types.Account) (*types.UserInfo, error) { if !isNil(am.idpManager) && !user.IsServiceUser { userData, err := am.lookupUserInCache(ctx, user.Id, account) if err != nil { @@ -977,23 +777,23 @@ func getUserInfo(ctx context.Context, am *DefaultAccountManager, user *User, acc } // validateUserUpdate validates the update operation for a user. -func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) error { +func validateUserUpdate(account *types.Account, initiatorUser, oldUser, update *types.User) error { if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && oldUser.Blocked != update.Blocked { return status.Errorf(status.PermissionDenied, "admins can't block or unblock themselves") } if initiatorUser.HasAdminPower() && initiatorUser.Id == update.Id && update.Role != initiatorUser.Role { return status.Errorf(status.PermissionDenied, "admins can't change their role") } - if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.Role != oldUser.Role { + if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.Role != oldUser.Role { return status.Errorf(status.PermissionDenied, "only owners can remove owner role from their user") } - if initiatorUser.Role == UserRoleAdmin && oldUser.Role == UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { + if initiatorUser.Role == types.UserRoleAdmin && oldUser.Role == types.UserRoleOwner && update.IsBlocked() && !oldUser.IsBlocked() { return status.Errorf(status.PermissionDenied, "unable to block owner user") } - if initiatorUser.Role == UserRoleAdmin && update.Role == UserRoleOwner && update.Role != oldUser.Role { + if initiatorUser.Role == types.UserRoleAdmin && update.Role == types.UserRoleOwner && update.Role != oldUser.Role { return status.Errorf(status.PermissionDenied, "only owners can add owner role to other users") } - if oldUser.IsServiceUser && update.Role == UserRoleOwner { + if oldUser.IsServiceUser && update.Role == types.UserRoleOwner { return status.Errorf(status.PermissionDenied, "can't update a service user with owner role") } @@ -1012,7 +812,7 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User) } // GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist -func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*Account, error) { +func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*types.Account, error) { start := time.Now() unlock := am.Store.AcquireGlobalLock(ctx) defer unlock() @@ -1039,7 +839,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u userObj := account.Users[userID] - if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == UserRoleOwner { + if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == types.UserRoleOwner { account.Domain = lowerDomain err = am.Store.SaveAccount(ctx, account) if err != nil { @@ -1052,7 +852,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return // based on provided user role. -func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*UserInfo, error) { +func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accountID, userID string) ([]*types.UserInfo, error) { account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err @@ -1068,7 +868,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun users := make(map[string]userLoggedInOnce, len(account.Users)) usersFromIntegration := make([]*idp.UserData, 0) for _, user := range account.Users { - if user.Issued == UserIssuedIntegration { + if user.Issued == types.UserIssuedIntegration { key := user.IntegrationReference.CacheKey(accountID, user.Id) info, err := am.externalCacheManager.Get(am.ctx, key) if err != nil { @@ -1092,7 +892,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun queriedUsers = append(queriedUsers, usersFromIntegration...) } - userInfos := make([]*UserInfo, 0) + userInfos := make([]*types.UserInfo, 0) // in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo if len(queriedUsers) == 0 { @@ -1116,7 +916,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun continue } - var info *UserInfo + var info *types.UserInfo if queriedUser, contains := findUserInIDPUserdata(localUser.Id, queriedUsers); contains { info, err = localUser.ToUserInfo(queriedUser, account.Settings) if err != nil { @@ -1136,16 +936,16 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun } } - info = &UserInfo{ + info = &types.UserInfo{ ID: localUser.Id, Email: "", Name: name, Role: string(localUser.Role), AutoGroups: localUser.AutoGroups, - Status: string(UserStatusActive), + Status: string(types.UserStatusActive), IsServiceUser: localUser.IsServiceUser, NonDeletable: localUser.NonDeletable, - Permissions: UserPermissions{DashboardView: dashboardViewPermissions}, + Permissions: types.UserPermissions{DashboardView: dashboardViewPermissions}, } } userInfos = append(userInfos, info) @@ -1155,7 +955,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun } // expireAndUpdatePeers expires all peers of the given user and updates them in the account -func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error { +func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *types.Account, peers []*nbpeer.Peer) error { var peerIDs []string for _, peer := range peers { // nolint:staticcheck @@ -1183,7 +983,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - am.updateAccountPeers(ctx, account.Id) + am.UpdateAccountPeers(ctx, account.Id) } return nil } @@ -1260,13 +1060,13 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account continue } - if targetUser.Role == UserRoleOwner { + if targetUser.Role == types.UserRoleOwner { allErrors = errors.Join(allErrors, fmt.Errorf("unable to delete a user: %s with owner role", targetUserID)) continue } // disable deleting integration user if the initiator is not admin service user - if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser { + if targetUser.Issued == types.UserIssuedIntegration && !executingUser.IsServiceUser { allErrors = errors.Join(allErrors, errors.New("only integration service user can delete this user")) continue } @@ -1291,7 +1091,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } if updateAccountPeers { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } for targetUserID, meta := range deletedUsersMeta { @@ -1301,7 +1101,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return allErrors } -func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) { +func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *types.Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) { tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) if err != nil { log.WithContext(ctx).Errorf("failed to resolve email address: %s", err) @@ -1342,8 +1142,8 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun } // updateUserPeersInGroups updates the user's peers in the specified groups by adding or removing them. -func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*nbgroup.Group, peers []*nbpeer.Peer, groupsToAdd, - groupsToRemove []string) (groupsToUpdate []*nbgroup.Group, err error) { +func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbpeer.Peer, groupsToAdd, + groupsToRemove []string) (groupsToUpdate []*types.Group, err error) { if len(groupsToAdd) == 0 && len(groupsToRemove) == 0 { return @@ -1376,7 +1176,7 @@ func (am *DefaultAccountManager) updateUserPeersInGroups(accountGroups map[strin } // addUserPeersToGroup adds the user's peers to the group. -func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { +func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) { groupPeers := make(map[string]struct{}, len(group.Peers)) for _, pid := range group.Peers { groupPeers[pid] = struct{}{} @@ -1393,7 +1193,7 @@ func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) } // removeUserPeersFromGroup removes user's peers from the group. -func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *nbgroup.Group) { +func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) { // skip removing peers from group All if group.Name == "All" { return @@ -1419,7 +1219,7 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa } // areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account. -func areUsersLinkedToPeers(account *Account, userIDs []string) bool { +func areUsersLinkedToPeers(account *types.Account, userIDs []string) bool { for _, peer := range account.Peers { if slices.Contains(userIDs, peer.UserID) { return true diff --git a/management/server/user_test.go b/management/server/user_test.go index 498017afa..75d88f9c8 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -10,8 +10,11 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" - nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + gocache "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -41,11 +44,15 @@ const ( ) func TestUser_CreatePAT_ForSameUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -82,14 +89,18 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { } func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + account.Users[mockTargetUserId] = &types.User{ Id: mockTargetUserId, IsServiceUser: false, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -104,14 +115,18 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { } func TestUser_CreatePAT_ForServiceUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + account.Users[mockTargetUserId] = &types.User{ Id: mockTargetUserId, IsServiceUser: true, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -130,11 +145,15 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) { } func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -149,11 +168,15 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { } func TestUser_CreatePAT_WithEmptyName(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -168,19 +191,23 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) { } func TestUser_DeletePAT(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ + account.Users[mockUserID] = &types.User{ Id: mockUserID, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, HashedToken: mockToken1, }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -204,20 +231,24 @@ func TestUser_DeletePAT(t *testing.T) { } func TestUser_GetPAT(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ + account.Users[mockUserID] = &types.User{ Id: mockUserID, AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, HashedToken: mockToken1, }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -237,13 +268,17 @@ func TestUser_GetPAT(t *testing.T) { } func TestUser_GetAllPATs(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ + account.Users[mockUserID] = &types.User{ Id: mockUserID, AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ mockTokenID1: { ID: mockTokenID1, HashedToken: mockToken1, @@ -254,7 +289,7 @@ func TestUser_GetAllPATs(t *testing.T) { }, }, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -274,14 +309,14 @@ func TestUser_GetAllPATs(t *testing.T) { func TestUser_Copy(t *testing.T) { // this is an imaginary case which will never be in DB this way - user := User{ + user := types.User{ Id: "userId", AccountID: "accountId", Role: "role", IsServiceUser: true, ServiceUserName: "servicename", AutoGroups: []string{"group1", "group2"}, - PATs: map[string]*PersonalAccessToken{ + PATs: map[string]*types.PersonalAccessToken{ "pat1": { ID: "pat1", Name: "First PAT", @@ -340,11 +375,15 @@ func validateStruct(s interface{}) (err error) { } func TestUser_CreateServiceUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -366,26 +405,30 @@ func TestUser_CreateServiceUser(t *testing.T) { assert.NotNil(t, account.Users[user.ID]) assert.True(t, account.Users[user.ID].IsServiceUser) assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName) - assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role) + assert.Equal(t, types.UserRole(mockRole), account.Users[user.ID].Role) assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups) - assert.Equal(t, map[string]*PersonalAccessToken{}, account.Users[user.ID].PATs) + assert.Equal(t, map[string]*types.PersonalAccessToken{}, account.Users[user.ID].PATs) assert.Zero(t, user.Email) assert.True(t, user.IsServiceUser) assert.Equal(t, "active", user.Status) - _, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, UserRoleOwner, mockServiceUserName, false, nil) + _, err = am.createServiceUser(context.Background(), mockAccountID, mockUserID, types.UserRoleOwner, mockServiceUserName, false, nil) if err == nil { t.Fatal("should return error when creating service user with owner role") } } func TestUser_CreateUser_ServiceUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -395,7 +438,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + user, err := am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, Role: mockRole, IsServiceUser: true, @@ -413,7 +456,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { assert.Equal(t, 2, len(account.Users)) assert.True(t, account.Users[user.ID].IsServiceUser) assert.Equal(t, mockServiceUserName, account.Users[user.ID].ServiceUserName) - assert.Equal(t, UserRole(mockRole), account.Users[user.ID].Role) + assert.Equal(t, types.UserRole(mockRole), account.Users[user.ID].Role) assert.Equal(t, []string{"group1", "group2"}, account.Users[user.ID].AutoGroups) assert.Equal(t, mockServiceUserName, user.Name) @@ -423,11 +466,15 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { } func TestUser_CreateUser_RegularUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -437,7 +484,7 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { eventStore: &activity.InMemoryEventStore{}, } - _, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + _, err = am.CreateUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, Role: mockRole, IsServiceUser: false, @@ -448,11 +495,15 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { } func TestUser_InviteNewUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -495,7 +546,7 @@ func TestUser_InviteNewUser(t *testing.T) { am.idpManager = &idpMock // test if new invite with regular role works - _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, Role: mockRole, Email: "test@teste.com", @@ -506,9 +557,9 @@ func TestUser_InviteNewUser(t *testing.T) { assert.NoErrorf(t, err, "Invite user should not throw error") // test if new invite with owner role fails - _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &UserInfo{ + _, err = am.inviteNewUser(context.Background(), mockAccountID, mockUserID, &types.UserInfo{ Name: mockServiceUserName, - Role: string(UserRoleOwner), + Role: string(types.UserRoleOwner), Email: "test2@teste.com", IsServiceUser: false, AutoGroups: []string{"group1", "group2"}, @@ -520,13 +571,13 @@ func TestUser_InviteNewUser(t *testing.T) { func TestUser_DeleteUser_ServiceUser(t *testing.T) { tests := []struct { name string - serviceUser *User + serviceUser *types.User assertErrFunc assert.ErrorAssertionFunc assertErrMessage string }{ { name: "Can delete service user", - serviceUser: &User{ + serviceUser: &types.User{ Id: mockServiceUserID, IsServiceUser: true, ServiceUserName: mockServiceUserName, @@ -535,7 +586,7 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { }, { name: "Cannot delete non-deletable service user", - serviceUser: &User{ + serviceUser: &types.User{ Id: mockServiceUserID, IsServiceUser: true, ServiceUserName: mockServiceUserName, @@ -548,11 +599,16 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") account.Users[mockServiceUserID] = tt.serviceUser - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -580,11 +636,15 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { } func TestUser_DeleteUser_SelfDelete(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -601,38 +661,42 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { } func TestUser_DeleteUser_regularUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") targetId := "user2" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: true, ServiceUserName: "user2username", } targetId = "user3" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } targetId = "user4" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedIntegration, + Issued: types.UserIssuedIntegration, } targetId = "user5" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, + Issued: types.UserIssuedAPI, + Role: types.UserRoleOwner, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -683,60 +747,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { } func TestUser_DeleteUser_RegularUsers(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") targetId := "user2" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: true, ServiceUserName: "user2username", } targetId = "user3" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } targetId = "user4" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedIntegration, + Issued: types.UserIssuedIntegration, } targetId = "user5" - account.Users[targetId] = &User{ + account.Users[targetId] = &types.User{ Id: targetId, IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, + Issued: types.UserIssuedAPI, + Role: types.UserRoleOwner, } - account.Users["user6"] = &User{ + account.Users["user6"] = &types.User{ Id: "user6", IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } - account.Users["user7"] = &User{ + account.Users["user7"] = &types.User{ Id: "user7", IsServiceUser: false, - Issued: UserIssuedAPI, + Issued: types.UserIssuedAPI, } - account.Users["user8"] = &User{ + account.Users["user8"] = &types.User{ Id: "user8", IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, + Issued: types.UserIssuedAPI, + Role: types.UserRoleAdmin, } - account.Users["user9"] = &User{ + account.Users["user9"] = &types.User{ Id: "user9", IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, + Issued: types.UserIssuedAPI, + Role: types.UserRoleAdmin, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -834,11 +902,15 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { } func TestDefaultAccountManager_GetUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -863,13 +935,17 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { } func TestDefaultAccountManager_ListUsers(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewRegularUser("normal_user1") - account.Users["normal_user2"] = NewRegularUser("normal_user2") + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) - err := store.SaveAccount(context.Background(), account) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") + account.Users["normal_user1"] = types.NewRegularUser("normal_user1") + account.Users["normal_user2"] = types.NewRegularUser("normal_user2") + + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -901,43 +977,43 @@ func TestDefaultAccountManager_ListUsers(t *testing.T) { func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { testCases := []struct { name string - role UserRole + role types.UserRole limitedViewSettings bool expectedDashboardPermissions string }{ { name: "Regular user, no limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: false, expectedDashboardPermissions: "limited", }, { name: "Admin user, no limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: false, expectedDashboardPermissions: "full", }, { name: "Owner, no limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: false, expectedDashboardPermissions: "full", }, { name: "Regular user, limited view settings", - role: UserRoleUser, + role: types.UserRoleUser, limitedViewSettings: true, expectedDashboardPermissions: "blocked", }, { name: "Admin user, limited view settings", - role: UserRoleAdmin, + role: types.UserRoleAdmin, limitedViewSettings: true, expectedDashboardPermissions: "full", }, { name: "Owner, limited view settings", - role: UserRoleOwner, + role: types.UserRoleOwner, limitedViewSettings: true, expectedDashboardPermissions: "full", }, @@ -945,13 +1021,18 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - store := newStore(t) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) + account.Users["normal_user1"] = types.NewUser("normal_user1", testCase.role, false, false, "", []string{}, types.UserIssuedAPI) account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings delete(account.Users, mockUserID) - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -976,13 +1057,17 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { } func TestDefaultAccountManager_ExternalCache(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - externalUser := &User{ + externalUser := &types.User{ Id: "externalUser", - Role: UserRoleUser, - Issued: UserIssuedIntegration, + Role: types.UserRoleUser, + Issued: types.UserIssuedIntegration, IntegrationReference: integration_reference.IntegrationReference{ ID: 1, IntegrationType: "external", @@ -990,7 +1075,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { } account.Users[externalUser.Id] = externalUser - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -1020,7 +1105,7 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { infos, err := am.GetUsersFromAccount(context.Background(), mockAccountID, mockUserID) assert.NoError(t, err) assert.Equal(t, 2, len(infos)) - var user *UserInfo + var user *types.UserInfo for _, info := range infos { if info.ID == externalUser.Id { user = info @@ -1032,24 +1117,28 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { func TestUser_IsAdmin(t *testing.T) { - user := NewAdminUser(mockUserID) + user := types.NewAdminUser(mockUserID) assert.True(t, user.HasAdminPower()) - user = NewRegularUser(mockUserID) + user = types.NewRegularUser(mockUserID) assert.False(t, user.HasAdminPower()) } func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) + account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + account.Users[mockServiceUserID] = &types.User{ Id: mockServiceUserID, Role: "user", IsServiceUser: true, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -1068,17 +1157,20 @@ func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { } func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { - store := newStore(t) - defer store.Close(context.Background()) + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), "", t.TempDir()) + if err != nil { + t.Fatalf("Error when creating store: %s", err) + } + t.Cleanup(cleanup) account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + account.Users[mockServiceUserID] = &types.User{ Id: mockServiceUserID, Role: "user", IsServiceUser: true, } - err := store.SaveAccount(context.Background(), account) + err = store.SaveAccount(context.Background(), account) if err != nil { t.Fatalf("Error when saving account: %s", err) } @@ -1112,25 +1204,25 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { tt := []struct { name string initiatorID string - update *User + update *types.User expectedErr bool }{ { name: "Should_Fail_To_Update_Admin_Role", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleUser, + Role: types.UserRoleUser, Blocked: false, }, }, { name: "Should_Fail_When_Admin_Blocks_Themselves", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1138,9 +1230,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Non_Existing_User", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: userID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1148,9 +1240,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_When_Initiator_Is_Not_An_Admin", expectedErr: true, initiatorID: regularUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1158,9 +1250,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Update_User", expectedErr: false, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: regularUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: true, }, }, @@ -1168,9 +1260,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Transfer_Owner_Role_To_User", expectedErr: false, initiatorID: ownerUserID, - update: &User{ + update: &types.User{ Id: adminUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1178,9 +1270,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Transfer_Owner_Role_To_Service_User", expectedErr: true, initiatorID: ownerUserID, - update: &User{ + update: &types.User{ Id: serviceUserID, - Role: UserRoleOwner, + Role: types.UserRoleOwner, Blocked: false, }, }, @@ -1188,9 +1280,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_User_Role_By_Admin", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1198,9 +1290,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_User_Role_By_User", expectedErr: true, initiatorID: regularUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1208,9 +1300,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_User_Role_By_Service_User", expectedErr: true, initiatorID: serviceUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleAdmin, + Role: types.UserRoleAdmin, Blocked: false, }, }, @@ -1218,9 +1310,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Update_Owner_Role_By_Admin", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: regularUserID, - Role: UserRoleOwner, + Role: types.UserRoleOwner, Blocked: false, }, }, @@ -1228,9 +1320,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { name: "Should_Fail_To_Block_Owner_Role_By_Admin", expectedErr: true, initiatorID: adminUserID, - update: &User{ + update: &types.User{ Id: ownerUserID, - Role: UserRoleOwner, + Role: types.UserRoleOwner, Blocked: true, }, }, @@ -1246,9 +1338,9 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { } // create other users - account.Users[regularUserID] = NewRegularUser(regularUserID) - account.Users[adminUserID] = NewAdminUser(adminUserID) - account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"} + account.Users[regularUserID] = types.NewRegularUser(regularUserID) + account.Users[adminUserID] = types.NewAdminUser(adminUserID) + account.Users[serviceUserID] = &types.User{IsServiceUser: true, Id: serviceUserID, Role: types.UserRoleAdmin, ServiceUserName: "service"} err = manager.Store.SaveAccount(context.Background(), account) if err != nil { t.Fatal(err) @@ -1272,22 +1364,22 @@ func TestUserAccountPeersUpdate(t *testing.T) { // account groups propagation is enabled manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) - err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err := manager.SaveGroup(context.Background(), account.Id, userID, &types.Group{ ID: "groupA", Name: "GroupA", Peers: []string{peer1.ID, peer2.ID, peer3.ID}, }) require.NoError(t, err) - policy := &Policy{ + policy := &types.Policy{ Enabled: true, - Rules: []*PolicyRule{ + Rules: []*types.PolicyRule{ { Enabled: true, Sources: []string{"groupA"}, Destinations: []string{"groupA"}, Bidirectional: true, - Action: PolicyTrafficActionAccept, + Action: types.PolicyTrafficActionAccept, }, }, } @@ -1307,11 +1399,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser1", AccountID: account.Id, - Role: UserRoleUser, - Issued: UserIssuedAPI, + Role: types.UserRoleUser, + Issued: types.UserIssuedAPI, }, true) require.NoError(t, err) @@ -1330,11 +1422,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser1", AccountID: account.Id, - Role: UserRoleUser, - Issued: UserIssuedAPI, + Role: types.UserRoleUser, + Issued: types.UserIssuedAPI, }, false) require.NoError(t, err) @@ -1364,11 +1456,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { }) // create a user and add new peer with the user - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser2", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, }, true) require.NoError(t, err) @@ -1390,11 +1482,11 @@ func TestUserAccountPeersUpdate(t *testing.T) { close(done) }() - _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &types.User{ Id: "regularUser2", AccountID: account.Id, - Role: UserRoleAdmin, - Issued: UserIssuedAPI, + Role: types.UserRoleAdmin, + Issued: types.UserIssuedAPI, }, false) require.NoError(t, err) diff --git a/management/server/users/manager.go b/management/server/users/manager.go new file mode 100644 index 000000000..718eb6190 --- /dev/null +++ b/management/server/users/manager.go @@ -0,0 +1,49 @@ +package users + +import ( + "context" + "errors" + + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" +) + +type Manager interface { + GetUser(ctx context.Context, userID string) (*types.User, error) +} + +type managerImpl struct { + store store.Store +} + +type managerMock struct { +} + +func NewManager(store store.Store) Manager { + return &managerImpl{ + store: store, + } +} + +func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) { + return m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) +} + +func NewManagerMock() Manager { + return &managerMock{} +} + +func (m *managerMock) GetUser(ctx context.Context, userID string) (*types.User, error) { + switch userID { + case "adminUser": + return &types.User{Id: userID, Role: types.UserRoleAdmin}, nil + case "regularUser": + return &types.User{Id: userID, Role: types.UserRoleUser}, nil + case "ownerUser": + return &types.User{Id: userID, Role: types.UserRoleOwner}, nil + case "billingUser": + return &types.User{Id: userID, Role: types.UserRoleBillingAdmin}, nil + default: + return nil, errors.New("user not found") + } +} diff --git a/management/server/util/util.go b/management/server/util/util.go new file mode 100644 index 000000000..ff738781f --- /dev/null +++ b/management/server/util/util.go @@ -0,0 +1,16 @@ +package util + +// Difference returns the elements in `a` that aren't in `b`. +func Difference(a, b []string) []string { + mb := make(map[string]struct{}, len(b)) + for _, x := range b { + mb[x] = struct{}{} + } + var diff []string + for _, x := range a { + if _, found := mb[x]; !found { + diff = append(diff, x) + } + } + return diff +} diff --git a/route/route.go b/route/route.go index e23801e6e..8f3c99b4c 100644 --- a/route/route.go +++ b/route/route.go @@ -4,6 +4,7 @@ import ( "fmt" "net/netip" "slices" + "strings" log "github.com/sirupsen/logrus" @@ -88,18 +89,18 @@ type Route struct { // AccountID is a reference to Account that this object belongs AccountID string `gorm:"index"` // Network and Domains are mutually exclusive - Network netip.Prefix `gorm:"serializer:json"` - Domains domain.List `gorm:"serializer:json"` - KeepRoute bool - NetID NetID - Description string - Peer string - PeerGroups []string `gorm:"serializer:json"` - NetworkType NetworkType - Masquerade bool - Metric int - Enabled bool - Groups []string `gorm:"serializer:json"` + Network netip.Prefix `gorm:"serializer:json"` + Domains domain.List `gorm:"serializer:json"` + KeepRoute bool + NetID NetID + Description string + Peer string + PeerGroups []string `gorm:"serializer:json"` + NetworkType NetworkType + Masquerade bool + Metric int + Enabled bool + Groups []string `gorm:"serializer:json"` AccessControlGroups []string `gorm:"serializer:json"` } @@ -111,19 +112,19 @@ func (r *Route) EventMeta() map[string]any { // Copy copies a route object func (r *Route) Copy() *Route { route := &Route{ - ID: r.ID, - Description: r.Description, - NetID: r.NetID, - Network: r.Network, - Domains: slices.Clone(r.Domains), - KeepRoute: r.KeepRoute, - NetworkType: r.NetworkType, - Peer: r.Peer, - PeerGroups: slices.Clone(r.PeerGroups), - Metric: r.Metric, - Masquerade: r.Masquerade, - Enabled: r.Enabled, - Groups: slices.Clone(r.Groups), + ID: r.ID, + Description: r.Description, + NetID: r.NetID, + Network: r.Network, + Domains: slices.Clone(r.Domains), + KeepRoute: r.KeepRoute, + NetworkType: r.NetworkType, + Peer: r.Peer, + PeerGroups: slices.Clone(r.PeerGroups), + Metric: r.Metric, + Masquerade: r.Masquerade, + Enabled: r.Enabled, + Groups: slices.Clone(r.Groups), AccessControlGroups: slices.Clone(r.AccessControlGroups), } return route @@ -149,7 +150,7 @@ func (r *Route) IsEqual(other *Route) bool { other.Masquerade == r.Masquerade && other.Enabled == r.Enabled && slices.Equal(r.Groups, other.Groups) && - slices.Equal(r.PeerGroups, other.PeerGroups)&& + slices.Equal(r.PeerGroups, other.PeerGroups) && slices.Equal(r.AccessControlGroups, other.AccessControlGroups) } @@ -170,6 +171,11 @@ func (r *Route) GetHAUniqueID() HAUniqueID { return HAUniqueID(fmt.Sprintf("%s%s%s", r.NetID, haSeparator, r.Network.String())) } +// GetResourceID returns the Networks Resource ID from a route ID +func (r *Route) GetResourceID() string { + return strings.Split(string(r.ID), ":")[0] +} + // ParseNetwork Parses a network prefix string and returns a netip.Prefix object and if is invalid, IPv4 or IPv6 func ParseNetwork(networkString string) (NetworkType, netip.Prefix, error) { prefix, err := netip.ParsePrefix(networkString) From 82b4e58ad0b2c0c3efb310b32da46eff2fc92281 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 20 Dec 2024 16:20:50 +0100 Subject: [PATCH 088/159] Do not start DNS forwarder on client side (#3094) --- client/internal/engine.go | 44 ++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 9724e2a22..042d384dc 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -802,14 +802,12 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.acl.ApplyFiltering(networkMap) } - var dnsRouteFeatureFlag bool - if networkMap.PeerConfig != nil { - dnsRouteFeatureFlag = networkMap.PeerConfig.RoutingPeerDnsResolutionEnabled - } - routedDomains, routes := toRoutes(networkMap.GetRoutes()) - - e.updateDNSForwarder(dnsRouteFeatureFlag, routedDomains) + // DNS forwarder + dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) + dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes()) + e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains) + routes := toRoutes(networkMap.GetRoutes()) if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil { log.Errorf("failed to update clientRoutes, err: %v", err) } @@ -874,12 +872,18 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { return nil } -func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) { +func toDNSFeatureFlag(networkMap *mgmProto.NetworkMap) bool { + if networkMap.PeerConfig != nil { + return networkMap.PeerConfig.RoutingPeerDnsResolutionEnabled + } + return false +} + +func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route { if protoRoutes == nil { protoRoutes = []*mgmProto.Route{} } - var dnsRoutes []string routes := make([]*route.Route, 0) for _, protoRoute := range protoRoutes { var prefix netip.Prefix @@ -890,7 +894,6 @@ func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) { continue } } - dnsRoutes = append(dnsRoutes, protoRoute.Domains...) convertedRoute := &route.Route{ ID: route.ID(protoRoute.ID), @@ -905,7 +908,24 @@ func toRoutes(protoRoutes []*mgmProto.Route) ([]string, []*route.Route) { } routes = append(routes, convertedRoute) } - return dnsRoutes, routes + return routes +} + +func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string { + if protoRoutes == nil { + protoRoutes = []*mgmProto.Route{} + } + + var dnsRoutes []string + for _, protoRoute := range protoRoutes { + if len(protoRoute.Domains) == 0 { + continue + } + if protoRoute.Peer == myPubKey { + dnsRoutes = append(dnsRoutes, protoRoute.Domains...) + } + } + return dnsRoutes } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig) nbdns.Config { @@ -1243,7 +1263,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, error) { if err != nil { return nil, nil, err } - _, routes := toRoutes(netMap.GetRoutes()) + routes := toRoutes(netMap.GetRoutes()) dnsCfg := toDNSConfig(netMap.GetDNSConfig()) return routes, &dnsCfg, nil } From 7ee7ada27347bec20d0650a9c0feeb991779b2d4 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 20 Dec 2024 21:10:53 +0300 Subject: [PATCH 089/159] [management] Fix duplicate resource routes when routing peer is part of the source group (#3095) * Remove duplicate resource routes when routing peer is part of the source group Signed-off-by: bcmmbaga * Add tests Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/route_test.go | 67 ++++++++++++++++++++++++++++++ management/server/types/account.go | 6 +-- 2 files changed, 69 insertions(+), 4 deletions(-) diff --git a/management/server/route_test.go b/management/server/route_test.go index 5390cb66b..2ef2b0173 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -2175,6 +2175,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { peerJIp = "100.65.29.65" peerKIp = "100.65.29.66" peerMIp = "100.65.29.67" + peerOIp = "100.65.29.68" ) account := &types.Account{ @@ -2256,6 +2257,20 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { IP: net.ParseIP(peerMIp), Status: &nbpeer.PeerStatus{}, }, + "peerN": { + ID: "peerN", + IP: net.ParseIP("100.65.20.18"), + Key: "peerN", + Status: &nbpeer.PeerStatus{}, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + }, + }, + "peerO": { + ID: "peerO", + IP: net.ParseIP(peerOIp), + Status: &nbpeer.PeerStatus{}, + }, }, Groups: map[string]*types.Group{ "router1": { @@ -2330,6 +2345,14 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Name: "Pipeline", Peers: []string{"peerM"}, }, + "metrics": { + ID: "metrics", + Name: "Metrics", + Peers: []string{"peerN", "peerO"}, + Resources: []types.Resource{ + {ID: "resource6"}, + }, + }, }, Networks: []*networkTypes.Network{ { @@ -2352,6 +2375,10 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { ID: "network5", Name: "Pipeline Network", }, + { + ID: "network6", + Name: "Metrics Network", + }, }, NetworkRouters: []*routerTypes.NetworkRouter{ { @@ -2389,6 +2416,13 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Masquerade: false, Metric: 9999, }, + { + ID: "router6", + NetworkID: "network6", + Peer: "peerN", + Masquerade: false, + Metric: 9999, + }, }, NetworkResources: []*resourceTypes.NetworkResource{ { @@ -2426,6 +2460,13 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Type: "host", Prefix: netip.MustParsePrefix("10.12.12.1/32"), }, + { + ID: "resource6", + NetworkID: "network6", + Name: "Resource 6", + Type: "domain", + Domain: "*.google.com", + }, }, Policies: []*types.Policy{ { @@ -2527,6 +2568,24 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { }, }, }, + { + ID: "policyResource6", + Name: "policyResource6", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "ruleResource6", + Name: "ruleResource6", + Bidirectional: true, + Enabled: true, + Protocol: types.PolicyRuleProtocolTCP, + Action: types.PolicyTrafficActionAccept, + Ports: []string{"9090"}, + Sources: []string{"metrics"}, + Destinations: []string{"metrics"}, + }, + }, + }, }, } @@ -2553,6 +2612,10 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { // which is part of the destination in the policies (policyResource3 and policyResource4) policies = account.GetPoliciesForNetworkResource("resource4") assert.Len(t, policies, 2, "resource4 should have exactly 2 policy applied via access control groups") + + // Test case: Resource6 is applied to the access control groups (metrics), + policies = account.GetPoliciesForNetworkResource("resource6") + assert.Len(t, policies, 1, "resource6 should have exactly 1 policy applied via access control groups") }) t.Run("validate routing peer firewall rules for network resources", func(t *testing.T) { @@ -2663,5 +2726,9 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerM", resourcePoliciesMap, resourceRoutersMap) assert.Len(t, routes, 1) assert.Len(t, sourcePeers, 0) + + _, routes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), "peerN", resourcePoliciesMap, resourceRoutersMap) + assert.Len(t, routes, 1) + assert.Len(t, sourcePeers, 2) }) } diff --git a/management/server/types/account.go b/management/server/types/account.go index 5d0f12019..b36b719e4 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1330,10 +1330,8 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st // routing peer should be able to connect with all source peers if addSourcePeers { allSourcePeers = append(allSourcePeers, group.Peers...) - } - - // add routes for the resource if the peer is in the distribution group - if slices.Contains(group.Peers, peerID) { + } else if slices.Contains(group.Peers, peerID) { + // add routes for the resource if the peer is in the distribution group for peerId, router := range networkRoutingPeers { routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) } From b48cf1bf65782c68600b00541113f90bb1aa88db Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sat, 21 Dec 2024 15:56:52 +0100 Subject: [PATCH 090/159] [client] Reduce DNS handler chain lock contention (#3099) --- client/internal/dns/handler_chain.go | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index c6ac3ebd6..9302d50b1 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -1,6 +1,7 @@ package dns import ( + "slices" "strings" "sync" @@ -161,16 +162,19 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { log.Tracef("handling DNS request for domain=%s", qname) c.mu.RLock() - defer c.mu.RUnlock() + handlers := slices.Clone(c.handlers) + c.mu.RUnlock() - log.Tracef("current handlers (%d):", len(c.handlers)) - for _, h := range c.handlers { - log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d", - h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) + if log.IsLevelEnabled(log.TraceLevel) { + log.Tracef("current handlers (%d):", len(handlers)) + for _, h := range handlers { + log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v priority=%d", + h.Pattern, h.OrigPattern, h.IsWildcard, h.Priority) + } } // Try handlers in priority order - for _, entry := range c.handlers { + for _, entry := range handlers { var matched bool switch { case entry.Pattern == ".": From e670068cabfa4288455d34d3b917eaa5a5ae2f7a Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 23 Dec 2024 14:37:09 +0100 Subject: [PATCH 091/159] [management] Run test sequential (#3101) --- .github/workflows/golang-test-linux.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 50c6cba84..155190398 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -183,7 +183,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) benchmark: needs: [ build-cache ] From 05930ee6b192fa9928e709a8fc74ef6e2cb779a1 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 23 Dec 2024 15:57:15 +0100 Subject: [PATCH 092/159] [client] Add firewall rules to the debug bundle (#3089) Adds the following to the debug bundle: - iptables: `iptables-save`, `iptables -v -n -L` - nftables: `nft list ruleset` or if not available formatted output from netlink (WIP) --- client/server/debug.go | 24 ++ client/server/debug_linux.go | 693 ++++++++++++++++++++++++++++++++ client/server/debug_nonlinux.go | 15 + client/server/debug_test.go | 113 ++++++ 4 files changed, 845 insertions(+) create mode 100644 client/server/debug_linux.go create mode 100644 client/server/debug_nonlinux.go diff --git a/client/server/debug.go b/client/server/debug.go index c12fd99db..9dfde0367 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -40,6 +40,8 @@ netbird.err: Most recent, anonymized stderr log file of the NetBird client. netbird.out: Most recent, anonymized stdout log file of the NetBird client. routes.txt: Anonymized system routes, if --system-info flag was provided. interfaces.txt: Anonymized network interface information, if --system-info flag was provided. +iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided. +nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized network map containing peer configurations, routes, DNS settings, and firewall rules. state.json: Anonymized client state dump containing netbird states. @@ -106,6 +108,24 @@ The config.txt file contains anonymized configuration information of the NetBird - CustomDNSAddress Other non-sensitive configuration options are included without anonymization. + +Firewall Rules (Linux only) +The bundle includes two separate firewall rule files: + +iptables.txt: +- Complete iptables ruleset with packet counters using 'iptables -v -n -L' +- Includes all tables (filter, nat, mangle, raw, security) +- Shows packet and byte counters for each rule +- All IP addresses are anonymized +- Chain names, table names, and other non-sensitive information remain unchanged + +nftables.txt: +- Complete nftables ruleset obtained via 'nft -a list ruleset' +- Includes rule handle numbers and packet counters +- All tables, chains, and rules are included +- Shows packet and byte counters for each rule +- All IP addresses are anonymized +- Chain names, table names, and other non-sensitive information remain unchanged ` const ( @@ -172,6 +192,10 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques if err := s.addInterfaces(req, anonymizer, archive); err != nil { log.Errorf("Failed to add interfaces to debug bundle: %v", err) } + + if err := s.addFirewallRules(req, anonymizer, archive); err != nil { + log.Errorf("Failed to add firewall rules to debug bundle: %v", err) + } } if err := s.addNetworkMap(req, anonymizer, archive); err != nil { diff --git a/client/server/debug_linux.go b/client/server/debug_linux.go new file mode 100644 index 000000000..60bc40561 --- /dev/null +++ b/client/server/debug_linux.go @@ -0,0 +1,693 @@ +//go:build linux && !android + +package server + +import ( + "archive/zip" + "bytes" + "encoding/binary" + "fmt" + "os/exec" + "sort" + "strings" + + "github.com/google/nftables" + "github.com/google/nftables/expr" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/proto" +) + +// addFirewallRules collects and adds firewall rules to the archive +func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + log.Info("Collecting firewall rules") + // Collect and add iptables rules + iptablesRules, err := collectIPTablesRules() + if err != nil { + log.Warnf("Failed to collect iptables rules: %v", err) + } else { + if req.GetAnonymize() { + iptablesRules = anonymizer.AnonymizeString(iptablesRules) + } + if err := addFileToZip(archive, strings.NewReader(iptablesRules), "iptables.txt"); err != nil { + log.Warnf("Failed to add iptables rules to bundle: %v", err) + } + } + + // Collect and add nftables rules + nftablesRules, err := collectNFTablesRules() + if err != nil { + log.Warnf("Failed to collect nftables rules: %v", err) + } else { + if req.GetAnonymize() { + nftablesRules = anonymizer.AnonymizeString(nftablesRules) + } + if err := addFileToZip(archive, strings.NewReader(nftablesRules), "nftables.txt"); err != nil { + log.Warnf("Failed to add nftables rules to bundle: %v", err) + } + } + + return nil +} + +// collectIPTablesRules collects rules using both iptables-save and verbose listing +func collectIPTablesRules() (string, error) { + var builder strings.Builder + + // First try using iptables-save + saveOutput, err := collectIPTablesSave() + if err != nil { + log.Warnf("Failed to collect iptables rules using iptables-save: %v", err) + } else { + builder.WriteString("=== iptables-save output ===\n") + builder.WriteString(saveOutput) + builder.WriteString("\n") + } + + // Then get verbose statistics for each table + builder.WriteString("=== iptables -v -n -L output ===\n") + + // Get list of tables + tables := []string{"filter", "nat", "mangle", "raw", "security"} + + for _, table := range tables { + builder.WriteString(fmt.Sprintf("*%s\n", table)) + + // Get verbose statistics for the entire table + stats, err := getTableStatistics(table) + if err != nil { + log.Warnf("Failed to get statistics for table %s: %v", table, err) + continue + } + builder.WriteString(stats) + builder.WriteString("\n") + } + + return builder.String(), nil +} + +// collectIPTablesSave uses iptables-save to get rule definitions +func collectIPTablesSave() (string, error) { + cmd := exec.Command("iptables-save") + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("execute iptables-save: %w (stderr: %s)", err, stderr.String()) + } + + rules := stdout.String() + if strings.TrimSpace(rules) == "" { + return "", fmt.Errorf("no iptables rules found") + } + + return rules, nil +} + +// getTableStatistics gets verbose statistics for an entire table using iptables command +func getTableStatistics(table string) (string, error) { + cmd := exec.Command("iptables", "-v", "-n", "-L", "-t", table) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("execute iptables -v -n -L: %w (stderr: %s)", err, stderr.String()) + } + + return stdout.String(), nil +} + +// collectNFTablesRules attempts to collect nftables rules using either nft command or netlink +func collectNFTablesRules() (string, error) { + // First try using nft command + rules, err := collectNFTablesFromCommand() + if err != nil { + log.Debugf("Failed to collect nftables rules using nft command: %v, falling back to netlink", err) + // Fall back to netlink + rules, err = collectNFTablesFromNetlink() + if err != nil { + return "", fmt.Errorf("collect nftables rules using both nft and netlink failed: %w", err) + } + } + return rules, nil +} + +// collectNFTablesFromCommand attempts to collect rules using nft command +func collectNFTablesFromCommand() (string, error) { + cmd := exec.Command("nft", "-a", "list", "ruleset") + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("execute nft list ruleset: %w (stderr: %s)", err, stderr.String()) + } + + rules := stdout.String() + if strings.TrimSpace(rules) == "" { + return "", fmt.Errorf("no nftables rules found") + } + + return rules, nil +} + +// collectNFTablesFromNetlink collects rules using netlink library +func collectNFTablesFromNetlink() (string, error) { + conn, err := nftables.New() + if err != nil { + return "", fmt.Errorf("create nftables connection: %w", err) + } + + tables, err := conn.ListTables() + if err != nil { + return "", fmt.Errorf("list tables: %w", err) + } + + sortTables(tables) + return formatTables(conn, tables), nil +} + +func formatTables(conn *nftables.Conn, tables []*nftables.Table) string { + var builder strings.Builder + + for _, table := range tables { + builder.WriteString(fmt.Sprintf("table %s %s {\n", formatFamily(table.Family), table.Name)) + + chains, err := getAndSortTableChains(conn, table) + if err != nil { + log.Warnf("Failed to list chains for table %s: %v", table.Name, err) + continue + } + + // Format chains + for _, chain := range chains { + formatChain(conn, table, chain, &builder) + } + + // Format sets + if sets, err := conn.GetSets(table); err != nil { + log.Warnf("Failed to get sets for table %s: %v", table.Name, err) + } else if len(sets) > 0 { + builder.WriteString("\n") + for _, set := range sets { + builder.WriteString(formatSet(conn, set)) + } + } + + builder.WriteString("}\n") + } + + return builder.String() +} + +func getAndSortTableChains(conn *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) { + chains, err := conn.ListChains() + if err != nil { + return nil, err + } + + var tableChains []*nftables.Chain + for _, chain := range chains { + if chain.Table.Name == table.Name && chain.Table.Family == table.Family { + tableChains = append(tableChains, chain) + } + } + + sort.Slice(tableChains, func(i, j int) bool { + return tableChains[i].Name < tableChains[j].Name + }) + + return tableChains, nil +} + +func formatChain(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, builder *strings.Builder) { + builder.WriteString(fmt.Sprintf("\tchain %s {\n", chain.Name)) + + if chain.Type != "" { + var policy string + if chain.Policy != nil { + policy = fmt.Sprintf("; policy %s", formatPolicy(*chain.Policy)) + } + builder.WriteString(fmt.Sprintf("\t\ttype %s hook %s priority %d%s\n", + formatChainType(chain.Type), + formatChainHook(chain.Hooknum), + chain.Priority, + policy)) + } + + rules, err := conn.GetRules(table, chain) + if err != nil { + log.Warnf("Failed to get rules for chain %s: %v", chain.Name, err) + } else { + sort.Slice(rules, func(i, j int) bool { + return rules[i].Position < rules[j].Position + }) + for _, rule := range rules { + builder.WriteString(formatRule(rule)) + } + } + + builder.WriteString("\t}\n") +} + +func sortTables(tables []*nftables.Table) { + sort.Slice(tables, func(i, j int) bool { + if tables[i].Family != tables[j].Family { + return tables[i].Family < tables[j].Family + } + return tables[i].Name < tables[j].Name + }) +} + +func formatFamily(family nftables.TableFamily) string { + switch family { + case nftables.TableFamilyIPv4: + return "ip" + case nftables.TableFamilyIPv6: + return "ip6" + case nftables.TableFamilyINet: + return "inet" + case nftables.TableFamilyARP: + return "arp" + case nftables.TableFamilyBridge: + return "bridge" + case nftables.TableFamilyNetdev: + return "netdev" + default: + return fmt.Sprintf("family-%d", family) + } +} + +func formatChainType(typ nftables.ChainType) string { + switch typ { + case nftables.ChainTypeFilter: + return "filter" + case nftables.ChainTypeNAT: + return "nat" + case nftables.ChainTypeRoute: + return "route" + default: + return fmt.Sprintf("type-%s", typ) + } +} + +func formatChainHook(hook *nftables.ChainHook) string { + if hook == nil { + return "none" + } + switch *hook { + case *nftables.ChainHookPrerouting: + return "prerouting" + case *nftables.ChainHookInput: + return "input" + case *nftables.ChainHookForward: + return "forward" + case *nftables.ChainHookOutput: + return "output" + case *nftables.ChainHookPostrouting: + return "postrouting" + default: + return fmt.Sprintf("hook-%d", *hook) + } +} + +func formatPolicy(policy nftables.ChainPolicy) string { + switch policy { + case nftables.ChainPolicyDrop: + return "drop" + case nftables.ChainPolicyAccept: + return "accept" + default: + return fmt.Sprintf("policy-%d", policy) + } +} + +func formatRule(rule *nftables.Rule) string { + var builder strings.Builder + builder.WriteString("\t\t") + + for i := 0; i < len(rule.Exprs); i++ { + if i > 0 { + builder.WriteString(" ") + } + i = formatExprSequence(&builder, rule.Exprs, i) + } + + builder.WriteString("\n") + return builder.String() +} + +func formatExprSequence(builder *strings.Builder, exprs []expr.Any, i int) int { + curr := exprs[i] + + // Handle Meta + Cmp sequence + if meta, ok := curr.(*expr.Meta); ok && i+1 < len(exprs) { + if cmp, ok := exprs[i+1].(*expr.Cmp); ok { + if formatted := formatMetaWithCmp(meta, cmp); formatted != "" { + builder.WriteString(formatted) + return i + 1 + } + } + } + + // Handle Payload + Cmp sequence + if payload, ok := curr.(*expr.Payload); ok && i+1 < len(exprs) { + if cmp, ok := exprs[i+1].(*expr.Cmp); ok { + builder.WriteString(formatPayloadWithCmp(payload, cmp)) + return i + 1 + } + } + + builder.WriteString(formatExpr(curr)) + return i +} + +func formatMetaWithCmp(meta *expr.Meta, cmp *expr.Cmp) string { + switch meta.Key { + case expr.MetaKeyIIFNAME: + name := strings.TrimRight(string(cmp.Data), "\x00") + return fmt.Sprintf("iifname %s %q", formatCmpOp(cmp.Op), name) + case expr.MetaKeyOIFNAME: + name := strings.TrimRight(string(cmp.Data), "\x00") + return fmt.Sprintf("oifname %s %q", formatCmpOp(cmp.Op), name) + case expr.MetaKeyMARK: + if len(cmp.Data) == 4 { + val := binary.BigEndian.Uint32(cmp.Data) + return fmt.Sprintf("meta mark %s 0x%x", formatCmpOp(cmp.Op), val) + } + } + return "" +} + +func formatPayloadWithCmp(p *expr.Payload, cmp *expr.Cmp) string { + if p.Base == expr.PayloadBaseNetworkHeader { + switch p.Offset { + case 12: // Source IP + if p.Len == 4 { + return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) + } else if p.Len == 2 { + return fmt.Sprintf("ip saddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) + } + case 16: // Destination IP + if p.Len == 4 { + return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) + } else if p.Len == 2 { + return fmt.Sprintf("ip daddr %s %s", formatCmpOp(cmp.Op), formatIPBytes(cmp.Data)) + } + } + } + return fmt.Sprintf("%d reg%d [%d:%d] %s %v", + p.Base, p.DestRegister, p.Offset, p.Len, + formatCmpOp(cmp.Op), cmp.Data) +} + +func formatIPBytes(data []byte) string { + if len(data) == 4 { + return fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3]) + } else if len(data) == 2 { + return fmt.Sprintf("%d.%d.0.0/16", data[0], data[1]) + } + return fmt.Sprintf("%v", data) +} + +func formatCmpOp(op expr.CmpOp) string { + switch op { + case expr.CmpOpEq: + return "==" + case expr.CmpOpNeq: + return "!=" + case expr.CmpOpLt: + return "<" + case expr.CmpOpLte: + return "<=" + case expr.CmpOpGt: + return ">" + case expr.CmpOpGte: + return ">=" + default: + return fmt.Sprintf("op-%d", op) + } +} + +// formatExpr formats an expression in nft-like syntax +func formatExpr(exp expr.Any) string { + switch e := exp.(type) { + case *expr.Meta: + return formatMeta(e) + case *expr.Cmp: + return formatCmp(e) + case *expr.Payload: + return formatPayload(e) + case *expr.Verdict: + return formatVerdict(e) + case *expr.Counter: + return fmt.Sprintf("counter packets %d bytes %d", e.Packets, e.Bytes) + case *expr.Masq: + return "masquerade" + case *expr.NAT: + return formatNat(e) + case *expr.Match: + return formatMatch(e) + case *expr.Queue: + return fmt.Sprintf("queue num %d", e.Num) + case *expr.Lookup: + return fmt.Sprintf("@%s", e.SetName) + case *expr.Bitwise: + return formatBitwise(e) + case *expr.Fib: + return formatFib(e) + case *expr.Target: + return fmt.Sprintf("jump %s", e.Name) // Properly format jump targets + case *expr.Immediate: + if e.Register == 1 { + return formatImmediateData(e.Data) + } + return fmt.Sprintf("immediate %v", e.Data) + default: + return fmt.Sprintf("<%T>", exp) + } +} + +func formatImmediateData(data []byte) string { + // For IP addresses (4 bytes) + if len(data) == 4 { + return fmt.Sprintf("%d.%d.%d.%d", data[0], data[1], data[2], data[3]) + } + return fmt.Sprintf("%v", data) +} + +func formatMeta(e *expr.Meta) string { + // Handle source register case first (meta mark set) + if e.SourceRegister { + return fmt.Sprintf("meta %s set reg %d", formatMetaKey(e.Key), e.Register) + } + + // For interface names, handle register load operation + switch e.Key { + case expr.MetaKeyIIFNAME, + expr.MetaKeyOIFNAME, + expr.MetaKeyBRIIIFNAME, + expr.MetaKeyBRIOIFNAME: + // Simply the key name with no register reference + return formatMetaKey(e.Key) + + case expr.MetaKeyMARK: + // For mark operations, we want just "mark" + return "mark" + } + + // For other meta keys, show as loading into register + return fmt.Sprintf("meta %s => reg %d", formatMetaKey(e.Key), e.Register) +} + +func formatMetaKey(key expr.MetaKey) string { + switch key { + case expr.MetaKeyLEN: + return "length" + case expr.MetaKeyPROTOCOL: + return "protocol" + case expr.MetaKeyPRIORITY: + return "priority" + case expr.MetaKeyMARK: + return "mark" + case expr.MetaKeyIIF: + return "iif" + case expr.MetaKeyOIF: + return "oif" + case expr.MetaKeyIIFNAME: + return "iifname" + case expr.MetaKeyOIFNAME: + return "oifname" + case expr.MetaKeyIIFTYPE: + return "iiftype" + case expr.MetaKeyOIFTYPE: + return "oiftype" + case expr.MetaKeySKUID: + return "skuid" + case expr.MetaKeySKGID: + return "skgid" + case expr.MetaKeyNFTRACE: + return "nftrace" + case expr.MetaKeyRTCLASSID: + return "rtclassid" + case expr.MetaKeySECMARK: + return "secmark" + case expr.MetaKeyNFPROTO: + return "nfproto" + case expr.MetaKeyL4PROTO: + return "l4proto" + case expr.MetaKeyBRIIIFNAME: + return "briifname" + case expr.MetaKeyBRIOIFNAME: + return "broifname" + case expr.MetaKeyPKTTYPE: + return "pkttype" + case expr.MetaKeyCPU: + return "cpu" + case expr.MetaKeyIIFGROUP: + return "iifgroup" + case expr.MetaKeyOIFGROUP: + return "oifgroup" + case expr.MetaKeyCGROUP: + return "cgroup" + case expr.MetaKeyPRANDOM: + return "prandom" + default: + return fmt.Sprintf("meta-%d", key) + } +} + +func formatCmp(e *expr.Cmp) string { + ops := map[expr.CmpOp]string{ + expr.CmpOpEq: "==", + expr.CmpOpNeq: "!=", + expr.CmpOpLt: "<", + expr.CmpOpLte: "<=", + expr.CmpOpGt: ">", + expr.CmpOpGte: ">=", + } + return fmt.Sprintf("%s %v", ops[e.Op], e.Data) +} + +func formatPayload(e *expr.Payload) string { + var proto string + switch e.Base { + case expr.PayloadBaseNetworkHeader: + proto = "ip" + case expr.PayloadBaseTransportHeader: + proto = "tcp" + default: + proto = fmt.Sprintf("payload-%d", e.Base) + } + return fmt.Sprintf("%s reg%d [%d:%d]", proto, e.DestRegister, e.Offset, e.Len) +} + +func formatVerdict(e *expr.Verdict) string { + switch e.Kind { + case expr.VerdictAccept: + return "accept" + case expr.VerdictDrop: + return "drop" + case expr.VerdictJump: + return fmt.Sprintf("jump %s", e.Chain) + case expr.VerdictGoto: + return fmt.Sprintf("goto %s", e.Chain) + case expr.VerdictReturn: + return "return" + default: + return fmt.Sprintf("verdict-%d", e.Kind) + } +} + +func formatNat(e *expr.NAT) string { + switch e.Type { + case expr.NATTypeSourceNAT: + return "snat" + case expr.NATTypeDestNAT: + return "dnat" + default: + return fmt.Sprintf("nat-%d", e.Type) + } +} + +func formatMatch(e *expr.Match) string { + return fmt.Sprintf("match %s rev %d", e.Name, e.Rev) +} + +func formatBitwise(e *expr.Bitwise) string { + return fmt.Sprintf("bitwise reg%d = reg%d & %v ^ %v", + e.DestRegister, e.SourceRegister, e.Mask, e.Xor) +} + +func formatFib(e *expr.Fib) string { + var flags []string + if e.FlagSADDR { + flags = append(flags, "saddr") + } + if e.FlagDADDR { + flags = append(flags, "daddr") + } + if e.FlagMARK { + flags = append(flags, "mark") + } + if e.FlagIIF { + flags = append(flags, "iif") + } + if e.FlagOIF { + flags = append(flags, "oif") + } + if e.ResultADDRTYPE { + flags = append(flags, "type") + } + return fmt.Sprintf("fib reg%d %s", e.Register, strings.Join(flags, ",")) +} + +func formatSet(conn *nftables.Conn, set *nftables.Set) string { + var builder strings.Builder + builder.WriteString(fmt.Sprintf("\tset %s {\n", set.Name)) + builder.WriteString(fmt.Sprintf("\t\ttype %s\n", formatSetKeyType(set.KeyType))) + if set.ID > 0 { + builder.WriteString(fmt.Sprintf("\t\t# handle %d\n", set.ID)) + } + + elements, err := conn.GetSetElements(set) + if err != nil { + log.Warnf("Failed to get elements for set %s: %v", set.Name, err) + } else if len(elements) > 0 { + builder.WriteString("\t\telements = {") + for i, elem := range elements { + if i > 0 { + builder.WriteString(", ") + } + builder.WriteString(fmt.Sprintf("%v", elem.Key)) + } + builder.WriteString("}\n") + } + + builder.WriteString("\t}\n") + return builder.String() +} + +func formatSetKeyType(keyType nftables.SetDatatype) string { + switch keyType { + case nftables.TypeInvalid: + return "invalid" + case nftables.TypeIPAddr: + return "ipv4_addr" + case nftables.TypeIP6Addr: + return "ipv6_addr" + case nftables.TypeEtherAddr: + return "ether_addr" + case nftables.TypeInetProto: + return "inet_proto" + case nftables.TypeInetService: + return "inet_service" + case nftables.TypeMark: + return "mark" + default: + return fmt.Sprintf("type-%v", keyType) + } +} diff --git a/client/server/debug_nonlinux.go b/client/server/debug_nonlinux.go new file mode 100644 index 000000000..c54ac9b6e --- /dev/null +++ b/client/server/debug_nonlinux.go @@ -0,0 +1,15 @@ +//go:build !linux || android + +package server + +import ( + "archive/zip" + + "github.com/netbirdio/netbird/client/anonymize" + "github.com/netbirdio/netbird/client/proto" +) + +// collectFirewallRules returns nothing on non-linux systems +func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error { + return nil +} diff --git a/client/server/debug_test.go b/client/server/debug_test.go index c8f7bae5d..ebd0bffbc 100644 --- a/client/server/debug_test.go +++ b/client/server/debug_test.go @@ -428,3 +428,116 @@ func isInCGNATRange(ip net.IP) bool { } return cgnat.Contains(ip) } + +func TestAnonymizeFirewallRules(t *testing.T) { + // TODO: Add ipv6 + + // Example iptables-save output + iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024 +*filter +:INPUT ACCEPT [0:0] +:FORWARD ACCEPT [0:0] +:OUTPUT ACCEPT [0:0] +-A INPUT -s 192.168.1.0/24 -j ACCEPT +-A INPUT -s 44.192.140.1/32 -j DROP +-A FORWARD -s 10.0.0.0/8 -j DROP +-A FORWARD -s 44.192.140.0/24 -d 52.84.12.34/24 -j ACCEPT +COMMIT + +*nat +:PREROUTING ACCEPT [0:0] +:INPUT ACCEPT [0:0] +:OUTPUT ACCEPT [0:0] +:POSTROUTING ACCEPT [0:0] +-A POSTROUTING -s 192.168.100.0/24 -j MASQUERADE +-A PREROUTING -d 44.192.140.10/32 -p tcp -m tcp --dport 80 -j DNAT --to-destination 192.168.1.10:80 +COMMIT` + + // Example iptables -v -n -L output + iptablesVerbose := `Chain INPUT (policy ACCEPT 0 packets, 0 bytes) + pkts bytes target prot opt in out source destination + 0 0 ACCEPT all -- * * 192.168.1.0/24 0.0.0.0/0 + 100 1024 DROP all -- * * 44.192.140.1 0.0.0.0/0 + +Chain FORWARD (policy ACCEPT 0 packets, 0 bytes) + pkts bytes target prot opt in out source destination + 0 0 DROP all -- * * 10.0.0.0/8 0.0.0.0/0 + 25 256 ACCEPT all -- * * 44.192.140.0/24 52.84.12.34/24 + +Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes) + pkts bytes target prot opt in out source destination` + + // Example nftables output + nftablesRules := `table inet filter { + chain input { + type filter hook input priority filter; policy accept; + ip saddr 192.168.1.1 accept + ip saddr 44.192.140.1 drop + } + chain forward { + type filter hook forward priority filter; policy accept; + ip saddr 10.0.0.0/8 drop + ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept + } + }` + + anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses()) + + // Test iptables-save anonymization + anonIptablesSave := anonymizer.AnonymizeString(iptablesSave) + + // Private IP addresses should remain unchanged + assert.Contains(t, anonIptablesSave, "192.168.1.0/24") + assert.Contains(t, anonIptablesSave, "10.0.0.0/8") + assert.Contains(t, anonIptablesSave, "192.168.100.0/24") + assert.Contains(t, anonIptablesSave, "192.168.1.10") + + // Public IP addresses should be anonymized to the default range + assert.NotContains(t, anonIptablesSave, "44.192.140.1") + assert.NotContains(t, anonIptablesSave, "44.192.140.0/24") + assert.NotContains(t, anonIptablesSave, "52.84.12.34") + assert.Contains(t, anonIptablesSave, "198.51.100.") // Default anonymous range + + // Structure should be preserved + assert.Contains(t, anonIptablesSave, "*filter") + assert.Contains(t, anonIptablesSave, ":INPUT ACCEPT [0:0]") + assert.Contains(t, anonIptablesSave, "COMMIT") + assert.Contains(t, anonIptablesSave, "-j MASQUERADE") + assert.Contains(t, anonIptablesSave, "--dport 80") + + // Test iptables verbose output anonymization + anonIptablesVerbose := anonymizer.AnonymizeString(iptablesVerbose) + + // Private IP addresses should remain unchanged + assert.Contains(t, anonIptablesVerbose, "192.168.1.0/24") + assert.Contains(t, anonIptablesVerbose, "10.0.0.0/8") + + // Public IP addresses should be anonymized to the default range + assert.NotContains(t, anonIptablesVerbose, "44.192.140.1") + assert.NotContains(t, anonIptablesVerbose, "44.192.140.0/24") + assert.NotContains(t, anonIptablesVerbose, "52.84.12.34") + assert.Contains(t, anonIptablesVerbose, "198.51.100.") // Default anonymous range + + // Structure and counters should be preserved + assert.Contains(t, anonIptablesVerbose, "Chain INPUT (policy ACCEPT 0 packets, 0 bytes)") + assert.Contains(t, anonIptablesVerbose, "100 1024 DROP") + assert.Contains(t, anonIptablesVerbose, "pkts bytes target") + + // Test nftables anonymization + anonNftables := anonymizer.AnonymizeString(nftablesRules) + + // Private IP addresses should remain unchanged + assert.Contains(t, anonNftables, "192.168.1.1") + assert.Contains(t, anonNftables, "10.0.0.0/8") + + // Public IP addresses should be anonymized to the default range + assert.NotContains(t, anonNftables, "44.192.140.1") + assert.NotContains(t, anonNftables, "44.192.140.0/24") + assert.NotContains(t, anonNftables, "52.84.12.34") + assert.Contains(t, anonNftables, "198.51.100.") // Default anonymous range + + // Structure should be preserved + assert.Contains(t, anonNftables, "table inet filter {") + assert.Contains(t, anonNftables, "chain input {") + assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") +} From ad9f044aadadf6a77facd418373e9aad32095443 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 23 Dec 2024 18:22:17 +0100 Subject: [PATCH 093/159] [client] Add stateful userspace firewall and remove egress filters (#3093) - Add stateful firewall functionality for UDP/TCP/ICMP in userspace firewalll - Removes all egress drop rules/filters, still needs refactoring so we don't add output rules to any chains/filters. - on Linux, if the OUTPUT policy is DROP then we don't do anything about it (no extra allow rules). This is up to the user, if they don't want anything leaving their machine they'll have to manage these rules explicitly. --- client/firewall/iptables/acl_linux.go | 6 - client/firewall/iptables/manager_linux.go | 14 +- client/firewall/nftables/acl_linux.go | 53 - client/firewall/uspfilter/allow_netbird.go | 20 +- .../uspfilter/allow_netbird_windows.go | 16 + client/firewall/uspfilter/conntrack/common.go | 138 +++ .../uspfilter/conntrack/common_test.go | 115 ++ client/firewall/uspfilter/conntrack/icmp.go | 170 +++ .../firewall/uspfilter/conntrack/icmp_test.go | 39 + client/firewall/uspfilter/conntrack/tcp.go | 376 +++++++ .../firewall/uspfilter/conntrack/tcp_test.go | 311 ++++++ client/firewall/uspfilter/conntrack/udp.go | 158 +++ .../firewall/uspfilter/conntrack/udp_test.go | 243 +++++ client/firewall/uspfilter/uspfilter.go | 258 ++++- .../uspfilter/uspfilter_bench_test.go | 998 ++++++++++++++++++ client/firewall/uspfilter/uspfilter_test.go | 309 +++++- 16 files changed, 3104 insertions(+), 120 deletions(-) create mode 100644 client/firewall/uspfilter/conntrack/common.go create mode 100644 client/firewall/uspfilter/conntrack/common_test.go create mode 100644 client/firewall/uspfilter/conntrack/icmp.go create mode 100644 client/firewall/uspfilter/conntrack/icmp_test.go create mode 100644 client/firewall/uspfilter/conntrack/tcp.go create mode 100644 client/firewall/uspfilter/conntrack/tcp_test.go create mode 100644 client/firewall/uspfilter/conntrack/udp.go create mode 100644 client/firewall/uspfilter/conntrack/udp_test.go create mode 100644 client/firewall/uspfilter/uspfilter_bench_test.go diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index 1c0527ebc..d774f4538 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -332,18 +332,12 @@ func (m *aclManager) createDefaultChains() error { // The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule. func (m *aclManager) seedInitialEntries() { - established := getConntrackEstablished() m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"}) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules}) - m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"}) - m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...)) - m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName}) m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index adb8f20ef..0e1e5836f 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -207,19 +207,9 @@ func (m *Manager) AllowNetbird() error { "", ) if err != nil { - return fmt.Errorf("failed to allow netbird interface traffic: %w", err) + return fmt.Errorf("allow netbird interface traffic: %w", err) } - _, err = m.AddPeerFiltering( - net.ParseIP("0.0.0.0"), - "all", - nil, - nil, - firewall.RuleDirectionOUT, - firewall.ActionAccept, - "", - "", - ) - return err + return nil } // Flush doesn't need to be implemented for this manager diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index abe890fb9..852cfec8d 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "fmt" "net" - "net/netip" "strconv" "strings" "time" @@ -28,7 +27,6 @@ const ( // filter chains contains the rules that jump to the rules chains chainNameInputFilter = "netbird-acl-input-filter" - chainNameOutputFilter = "netbird-acl-output-filter" chainNameForwardFilter = "netbird-acl-forward-filter" chainNamePrerouting = "netbird-rt-prerouting" @@ -441,18 +439,6 @@ func (m *AclManager) createDefaultChains() (err error) { return err } - // netbird-acl-output-filter - // type filter hook output priority filter; policy accept; - chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput) - m.addFwdAllow(chain, expr.MetaKeyOIFNAME) - m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules - m.addDropExpressions(chain, expr.MetaKeyOIFNAME) - err = m.rConn.Flush() - if err != nil { - log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err) - return err - } - // netbird-acl-forward-filter chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward) m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd @@ -619,45 +605,6 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met return nil } -func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) { - ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4()) - dstOp := expr.CmpOpNeq - expressions := []expr.Any{ - &expr.Meta{Key: iifname, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Payload{ - DestRegister: 2, - Base: expr.PayloadBaseNetworkHeader, - Offset: 16, - Len: 4, - }, - &expr.Bitwise{ - SourceRegister: 2, - DestRegister: 2, - Len: 4, - Xor: []byte{0x0, 0x0, 0x0, 0x0}, - Mask: m.wgIface.Address().Network.Mask, - }, - &expr.Cmp{ - Op: dstOp, - Register: 2, - Data: ip.Unmap().AsSlice(), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - } - _ = m.rConn.AddRule(&nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: expressions, - }) -} - func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) { expressions := []expr.Any{ &expr.Meta{Key: ifaceKey, Register: 1}, diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index cefc81a3c..cc0792255 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -2,7 +2,10 @@ package uspfilter -import "github.com/netbirdio/netbird/client/internal/statemanager" +import ( + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/internal/statemanager" +) // Reset firewall to the default state func (m *Manager) Reset(stateManager *statemanager.Manager) error { @@ -12,6 +15,21 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if m.udpTracker != nil { + m.udpTracker.Close() + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + } + + if m.icmpTracker != nil { + m.icmpTracker.Close() + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + } + + if m.tcpTracker != nil { + m.tcpTracker.Close() + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + } + if m.nativeFirewall != nil { return m.nativeFirewall.Reset(stateManager) } diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index d3732301e..0d55d6268 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -26,6 +27,21 @@ func (m *Manager) Reset(*statemanager.Manager) error { m.outgoingRules = make(map[string]RuleSet) m.incomingRules = make(map[string]RuleSet) + if m.udpTracker != nil { + m.udpTracker.Close() + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + } + + if m.icmpTracker != nil { + m.icmpTracker.Close() + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + } + + if m.tcpTracker != nil { + m.tcpTracker.Close() + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + } + if !isWindowsFirewallReachable() { return nil } diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go new file mode 100644 index 000000000..a4b1971bf --- /dev/null +++ b/client/firewall/uspfilter/conntrack/common.go @@ -0,0 +1,138 @@ +// common.go +package conntrack + +import ( + "net" + "sync" + "sync/atomic" + "time" +) + +// BaseConnTrack provides common fields and locking for all connection types +type BaseConnTrack struct { + sync.RWMutex + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 + lastSeen atomic.Int64 // Unix nano for atomic access + established atomic.Bool +} + +// these small methods will be inlined by the compiler + +// UpdateLastSeen safely updates the last seen timestamp +func (b *BaseConnTrack) UpdateLastSeen() { + b.lastSeen.Store(time.Now().UnixNano()) +} + +// IsEstablished safely checks if connection is established +func (b *BaseConnTrack) IsEstablished() bool { + return b.established.Load() +} + +// SetEstablished safely sets the established state +func (b *BaseConnTrack) SetEstablished(state bool) { + b.established.Store(state) +} + +// GetLastSeen safely gets the last seen timestamp +func (b *BaseConnTrack) GetLastSeen() time.Time { + return time.Unix(0, b.lastSeen.Load()) +} + +// timeoutExceeded checks if the connection has exceeded the given timeout +func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool { + lastSeen := time.Unix(0, b.lastSeen.Load()) + return time.Since(lastSeen) > timeout +} + +// IPAddr is a fixed-size IP address to avoid allocations +type IPAddr [16]byte + +// MakeIPAddr creates an IPAddr from net.IP +func MakeIPAddr(ip net.IP) (addr IPAddr) { + // Optimization: check for v4 first as it's more common + if ip4 := ip.To4(); ip4 != nil { + copy(addr[12:], ip4) + } else { + copy(addr[:], ip.To16()) + } + return addr +} + +// ConnKey uniquely identifies a connection +type ConnKey struct { + SrcIP IPAddr + DstIP IPAddr + SrcPort uint16 + DstPort uint16 +} + +// makeConnKey creates a connection key +func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { + return ConnKey{ + SrcIP: MakeIPAddr(srcIP), + DstIP: MakeIPAddr(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + } +} + +// ValidateIPs checks if IPs match without allocation +func ValidateIPs(connIP IPAddr, pktIP net.IP) bool { + if ip4 := pktIP.To4(); ip4 != nil { + // Compare IPv4 addresses (last 4 bytes) + for i := 0; i < 4; i++ { + if connIP[12+i] != ip4[i] { + return false + } + } + return true + } + // Compare full IPv6 addresses + ip6 := pktIP.To16() + for i := 0; i < 16; i++ { + if connIP[i] != ip6[i] { + return false + } + } + return true +} + +// PreallocatedIPs is a pool of IP byte slices to reduce allocations +type PreallocatedIPs struct { + sync.Pool +} + +// NewPreallocatedIPs creates a new IP pool +func NewPreallocatedIPs() *PreallocatedIPs { + return &PreallocatedIPs{ + Pool: sync.Pool{ + New: func() interface{} { + ip := make(net.IP, 16) + return &ip + }, + }, + } +} + +// Get retrieves an IP from the pool +func (p *PreallocatedIPs) Get() net.IP { + return *p.Pool.Get().(*net.IP) +} + +// Put returns an IP to the pool +func (p *PreallocatedIPs) Put(ip net.IP) { + p.Pool.Put(&ip) +} + +// copyIP copies an IP address efficiently +func copyIP(dst, src net.IP) { + if len(src) == 16 { + copy(dst, src) + } else { + // Handle IPv4 + copy(dst[12:], src.To4()) + } +} diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go new file mode 100644 index 000000000..72d006def --- /dev/null +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -0,0 +1,115 @@ +package conntrack + +import ( + "net" + "testing" +) + +func BenchmarkIPOperations(b *testing.B) { + b.Run("MakeIPAddr", func(b *testing.B) { + ip := net.ParseIP("192.168.1.1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = MakeIPAddr(ip) + } + }) + + b.Run("ValidateIPs", func(b *testing.B) { + ip1 := net.ParseIP("192.168.1.1") + ip2 := net.ParseIP("192.168.1.1") + addr := MakeIPAddr(ip1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ValidateIPs(addr, ip2) + } + }) + + b.Run("IPPool", func(b *testing.B) { + pool := NewPreallocatedIPs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ip := pool.Get() + pool.Put(ip) + } + }) + +} +func BenchmarkAtomicOperations(b *testing.B) { + conn := &BaseConnTrack{} + b.Run("UpdateLastSeen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + conn.UpdateLastSeen() + } + }) + + b.Run("IsEstablished", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = conn.IsEstablished() + } + }) + + b.Run("SetEstablished", func(b *testing.B) { + for i := 0; i < b.N; i++ { + conn.SetEstablished(i%2 == 0) + } + }) + + b.Run("GetLastSeen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = conn.GetLastSeen() + } + }) +} + +// Memory pressure tests +func BenchmarkMemoryPressure(b *testing.B) { + b.Run("TCPHighLoad", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + // Generate different IPs + srcIPs := make([]net.IP, 100) + dstIPs := make([]net.IP, 100) + for i := 0; i < 100; i++ { + srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) + dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + srcIdx := i % len(srcIPs) + dstIdx := (i + 1) % len(dstIPs) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn) + + // Simulate some valid inbound packets + if i%3 == 0 { + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck) + } + } + }) + + b.Run("UDPHighLoad", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + // Generate different IPs + srcIPs := make([]net.IP, 100) + dstIPs := make([]net.IP, 100) + for i := 0; i < 100; i++ { + srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) + dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + srcIdx := i % len(srcIPs) + dstIdx := (i + 1) % len(dstIPs) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80) + + // Simulate some valid inbound packets + if i%3 == 0 { + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535)) + } + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go new file mode 100644 index 000000000..e0a971678 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -0,0 +1,170 @@ +package conntrack + +import ( + "net" + "sync" + "time" + + "github.com/google/gopacket/layers" +) + +const ( + // DefaultICMPTimeout is the default timeout for ICMP connections + DefaultICMPTimeout = 30 * time.Second + // ICMPCleanupInterval is how often we check for stale ICMP connections + ICMPCleanupInterval = 15 * time.Second +) + +// ICMPConnKey uniquely identifies an ICMP connection +type ICMPConnKey struct { + // Supports both IPv4 and IPv6 + SrcIP [16]byte + DstIP [16]byte + Sequence uint16 // ICMP sequence number + ID uint16 // ICMP identifier +} + +// ICMPConnTrack represents an ICMP connection state +type ICMPConnTrack struct { + BaseConnTrack + Sequence uint16 + ID uint16 +} + +// ICMPTracker manages ICMP connection states +type ICMPTracker struct { + connections map[ICMPConnKey]*ICMPConnTrack + timeout time.Duration + cleanupTicker *time.Ticker + mutex sync.RWMutex + done chan struct{} + ipPool *PreallocatedIPs +} + +// NewICMPTracker creates a new ICMP connection tracker +func NewICMPTracker(timeout time.Duration) *ICMPTracker { + if timeout == 0 { + timeout = DefaultICMPTimeout + } + + tracker := &ICMPTracker{ + connections: make(map[ICMPConnKey]*ICMPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(ICMPCleanupInterval), + done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound records an outbound ICMP Echo Request +func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { + key := makeICMPKey(srcIP, dstIP, id, seq) + now := time.Now().UnixNano() + + t.mutex.Lock() + conn, exists := t.connections[key] + if !exists { + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &ICMPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + }, + ID: id, + Sequence: seq, + } + conn.lastSeen.Store(now) + conn.established.Store(true) + t.connections[key] = conn + } + t.mutex.Unlock() + + conn.lastSeen.Store(now) +} + +// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request +func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { + switch icmpType { + case uint8(layers.ICMPv4TypeDestinationUnreachable), + uint8(layers.ICMPv4TypeTimeExceeded): + return true + case uint8(layers.ICMPv4TypeEchoReply): + // continue processing + default: + return false + } + + key := makeICMPKey(dstIP, srcIP, id, seq) + + t.mutex.RLock() + conn, exists := t.connections[key] + t.mutex.RUnlock() + + if !exists { + return false + } + + if conn.timeoutExceeded(t.timeout) { + return false + } + + return conn.IsEstablished() && + ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && + conn.ID == id && + conn.Sequence == seq +} + +func (t *ICMPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} +func (t *ICMPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + for key, conn := range t.connections { + if conn.timeoutExceeded(t.timeout) { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *ICMPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) + + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() +} + +// makeICMPKey creates an ICMP connection key +func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { + return ICMPConnKey{ + SrcIP: MakeIPAddr(srcIP), + DstIP: MakeIPAddr(dstIP), + ID: id, + Sequence: seq, + } +} diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go new file mode 100644 index 000000000..21176e719 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -0,0 +1,39 @@ +package conntrack + +import ( + "net" + "testing" +) + +func BenchmarkICMPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewICMPTracker(DefaultICMPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535)) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewICMPTracker(DefaultICMPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0) + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go new file mode 100644 index 000000000..e8d20f41c --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -0,0 +1,376 @@ +package conntrack + +// TODO: Send RST packets for invalid/timed-out connections + +import ( + "net" + "sync" + "time" +) + +const ( + // MSL (Maximum Segment Lifetime) is typically 2 minutes + MSL = 2 * time.Minute + // TimeWaitTimeout (TIME-WAIT) should last 2*MSL + TimeWaitTimeout = 2 * MSL +) + +const ( + TCPSyn uint8 = 0x02 + TCPAck uint8 = 0x10 + TCPFin uint8 = 0x01 + TCPRst uint8 = 0x04 + TCPPush uint8 = 0x08 + TCPUrg uint8 = 0x20 +) + +const ( + // DefaultTCPTimeout is the default timeout for established TCP connections + DefaultTCPTimeout = 3 * time.Hour + // TCPHandshakeTimeout is timeout for TCP handshake completion + TCPHandshakeTimeout = 60 * time.Second + // TCPCleanupInterval is how often we check for stale connections + TCPCleanupInterval = 5 * time.Minute +) + +// TCPState represents the state of a TCP connection +type TCPState int + +const ( + TCPStateNew TCPState = iota + TCPStateSynSent + TCPStateSynReceived + TCPStateEstablished + TCPStateFinWait1 + TCPStateFinWait2 + TCPStateClosing + TCPStateTimeWait + TCPStateCloseWait + TCPStateLastAck + TCPStateClosed +) + +// TCPConnKey uniquely identifies a TCP connection +type TCPConnKey struct { + SrcIP [16]byte + DstIP [16]byte + SrcPort uint16 + DstPort uint16 +} + +// TCPConnTrack represents a TCP connection state +type TCPConnTrack struct { + BaseConnTrack + State TCPState +} + +// TCPTracker manages TCP connection states +type TCPTracker struct { + connections map[ConnKey]*TCPConnTrack + mutex sync.RWMutex + cleanupTicker *time.Ticker + done chan struct{} + timeout time.Duration + ipPool *PreallocatedIPs +} + +// NewTCPTracker creates a new TCP connection tracker +func NewTCPTracker(timeout time.Duration) *TCPTracker { + tracker := &TCPTracker{ + connections: make(map[ConnKey]*TCPConnTrack), + cleanupTicker: time.NewTicker(TCPCleanupInterval), + done: make(chan struct{}), + timeout: timeout, + ipPool: NewPreallocatedIPs(), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound processes an outbound TCP packet and updates connection state +func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { + // Create key before lock + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + now := time.Now().UnixNano() + + t.mutex.Lock() + conn, exists := t.connections[key] + if !exists { + // Use preallocated IPs + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &TCPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: srcPort, + DestPort: dstPort, + }, + State: TCPStateNew, + } + conn.lastSeen.Store(now) + conn.established.Store(false) + t.connections[key] = conn + } + t.mutex.Unlock() + + // Lock individual connection for state update + conn.Lock() + t.updateState(conn, flags, true) + conn.Unlock() + conn.lastSeen.Store(now) +} + +// IsValidInbound checks if an inbound TCP packet matches a tracked connection +func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { + if !isValidFlagCombination(flags) { + return false + } + + // Handle new SYN packets + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + t.mutex.Lock() + if _, exists := t.connections[key]; !exists { + // Use preallocated IPs + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, dstIP) + copyIP(dstIPCopy, srcIP) + + conn := &TCPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: dstPort, + DestPort: srcPort, + }, + State: TCPStateSynReceived, + } + conn.lastSeen.Store(time.Now().UnixNano()) + conn.established.Store(false) + t.connections[key] = conn + } + t.mutex.Unlock() + return true + } + + // Look up existing connection + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + t.mutex.RLock() + conn, exists := t.connections[key] + t.mutex.RUnlock() + + if !exists { + return false + } + + // Handle RST packets + if flags&TCPRst != 0 { + conn.Lock() + isEstablished := conn.IsEstablished() + if isEstablished || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { + conn.State = TCPStateClosed + conn.SetEstablished(false) + conn.Unlock() + return true + } + conn.Unlock() + return false + } + + // Update state + conn.Lock() + t.updateState(conn, flags, false) + conn.UpdateLastSeen() + isEstablished := conn.IsEstablished() + isValidState := t.isValidStateForFlags(conn.State, flags) + conn.Unlock() + + return isEstablished || isValidState +} + +// updateState updates the TCP connection state based on flags +func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) { + // Handle RST flag specially - it always causes transition to closed + if flags&TCPRst != 0 { + conn.State = TCPStateClosed + conn.SetEstablished(false) + return + } + + switch conn.State { + case TCPStateNew: + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + conn.State = TCPStateSynSent + } + + case TCPStateSynSent: + if flags&TCPSyn != 0 && flags&TCPAck != 0 { + if isOutbound { + conn.State = TCPStateSynReceived + } else { + // Simultaneous open + conn.State = TCPStateEstablished + conn.SetEstablished(true) + } + } + + case TCPStateSynReceived: + if flags&TCPAck != 0 && flags&TCPSyn == 0 { + conn.State = TCPStateEstablished + conn.SetEstablished(true) + } + + case TCPStateEstablished: + if flags&TCPFin != 0 { + if isOutbound { + conn.State = TCPStateFinWait1 + } else { + conn.State = TCPStateCloseWait + } + conn.SetEstablished(false) + } + + case TCPStateFinWait1: + switch { + case flags&TCPFin != 0 && flags&TCPAck != 0: + // Simultaneous close - both sides sent FIN + conn.State = TCPStateClosing + case flags&TCPFin != 0: + conn.State = TCPStateFinWait2 + case flags&TCPAck != 0: + conn.State = TCPStateFinWait2 + } + + case TCPStateFinWait2: + if flags&TCPFin != 0 { + conn.State = TCPStateTimeWait + } + + case TCPStateClosing: + if flags&TCPAck != 0 { + conn.State = TCPStateTimeWait + // Keep established = false from previous state + } + + case TCPStateCloseWait: + if flags&TCPFin != 0 { + conn.State = TCPStateLastAck + } + + case TCPStateLastAck: + if flags&TCPAck != 0 { + conn.State = TCPStateClosed + } + + case TCPStateTimeWait: + // Stay in TIME-WAIT for 2MSL before transitioning to closed + // This is handled by the cleanup routine + } +} + +// isValidStateForFlags checks if the TCP flags are valid for the current connection state +func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { + if !isValidFlagCombination(flags) { + return false + } + + switch state { + case TCPStateNew: + return flags&TCPSyn != 0 && flags&TCPAck == 0 + case TCPStateSynSent: + return flags&TCPSyn != 0 && flags&TCPAck != 0 + case TCPStateSynReceived: + return flags&TCPAck != 0 + case TCPStateEstablished: + if flags&TCPRst != 0 { + return true + } + return flags&TCPAck != 0 + case TCPStateFinWait1: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateFinWait2: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateClosing: + // In CLOSING state, we should accept the final ACK + return flags&TCPAck != 0 + case TCPStateTimeWait: + // In TIME_WAIT, we might see retransmissions + return flags&TCPAck != 0 + case TCPStateCloseWait: + return flags&TCPFin != 0 || flags&TCPAck != 0 + case TCPStateLastAck: + return flags&TCPAck != 0 + } + return false +} + +func (t *TCPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} + +func (t *TCPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + for key, conn := range t.connections { + var timeout time.Duration + switch { + case conn.State == TCPStateTimeWait: + timeout = TimeWaitTimeout + case conn.IsEstablished(): + timeout = t.timeout + default: + timeout = TCPHandshakeTimeout + } + + lastSeen := conn.GetLastSeen() + if time.Since(lastSeen) > timeout { + // Return IPs to pool + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *TCPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) + + // Clean up all remaining IPs + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() +} + +func isValidFlagCombination(flags uint8) bool { + // Invalid: SYN+FIN + if flags&TCPSyn != 0 && flags&TCPFin != 0 { + return false + } + + // Invalid: RST with SYN or FIN + if flags&TCPRst != 0 && (flags&TCPSyn != 0 || flags&TCPFin != 0) { + return false + } + + return true +} diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go new file mode 100644 index 000000000..3933c8889 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -0,0 +1,311 @@ +package conntrack + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestTCPStateMachine(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("100.64.0.1") + dstIP := net.ParseIP("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + t.Run("Security Tests", func(t *testing.T) { + tests := []struct { + name string + flags uint8 + wantDrop bool + desc string + }{ + { + name: "Block unsolicited SYN-ACK", + flags: TCPSyn | TCPAck, + wantDrop: true, + desc: "Should block SYN-ACK without prior SYN", + }, + { + name: "Block invalid SYN-FIN", + flags: TCPSyn | TCPFin, + wantDrop: true, + desc: "Should block invalid SYN-FIN combination", + }, + { + name: "Block unsolicited RST", + flags: TCPRst, + wantDrop: true, + desc: "Should block RST without connection", + }, + { + name: "Block unsolicited ACK", + flags: TCPAck, + wantDrop: true, + desc: "Should block ACK without connection", + }, + { + name: "Block data without connection", + flags: TCPAck | TCPPush, + wantDrop: true, + desc: "Should block data without established connection", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags) + require.Equal(t, !tt.wantDrop, isValid, tt.desc) + }) + } + }) + + t.Run("Connection Flow Tests", func(t *testing.T) { + tests := []struct { + name string + test func(*testing.T) + desc string + }{ + { + name: "Normal Handshake", + test: func(t *testing.T) { + t.Helper() + + // Send initial SYN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + + // Receive SYN-ACK + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + require.True(t, valid, "SYN-ACK should be allowed") + + // Send ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + + // Test data transfer + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) + require.True(t, valid, "Data should be allowed after handshake") + }, + }, + { + name: "Normal Close", + test: func(t *testing.T) { + t.Helper() + + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Send FIN + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) + + // Receive ACK for FIN + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + require.True(t, valid, "ACK for FIN should be allowed") + + // Receive FIN from other side + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + require.True(t, valid, "FIN should be allowed") + + // Send final ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + }, + }, + { + name: "RST During Connection", + test: func(t *testing.T) { + t.Helper() + + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Receive RST + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + require.True(t, valid, "RST should be allowed for established connection") + + // Verify connection is closed + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) + t.Helper() + + require.False(t, valid, "Data should be blocked after RST") + }, + }, + { + name: "Simultaneous Close", + test: func(t *testing.T) { + t.Helper() + + // First establish connection + establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) + + // Both sides send FIN+ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) + require.True(t, valid, "Simultaneous FIN should be allowed") + + // Both sides send final ACK + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) + require.True(t, valid, "Final ACKs should be allowed") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Helper() + + tracker = NewTCPTracker(DefaultTCPTimeout) + tt.test(t) + }) + } + }) +} + +func TestRSTHandling(t *testing.T) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("100.64.0.1") + dstIP := net.ParseIP("100.64.0.2") + srcPort := uint16(12345) + dstPort := uint16(80) + + tests := []struct { + name string + setupState func() + sendRST func() + wantValid bool + desc string + }{ + { + name: "RST in established", + setupState: func() { + // Establish connection first + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) + }, + sendRST: func() { + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + }, + wantValid: true, + desc: "Should accept RST for established connection", + }, + { + name: "RST without connection", + setupState: func() {}, + sendRST: func() { + tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) + }, + wantValid: false, + desc: "Should reject RST without connection", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupState() + tt.sendRST() + + // Verify connection state is as expected + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + conn := tracker.connections[key] + if tt.wantValid { + require.NotNil(t, conn) + require.Equal(t, TCPStateClosed, conn.State) + require.False(t, conn.IsEstablished()) + } + }) + } +} + +// Helper to establish a TCP connection +func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { + t.Helper() + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) + + valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) + require.True(t, valid, "SYN-ACK should be allowed") + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) +} + +func BenchmarkTCPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) + } + }) + + b.Run("ConcurrentAccess", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + if i%2 == 0 { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } else { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck) + } + i++ + } + }) + }) +} + +// Benchmark connection cleanup +func BenchmarkCleanup(b *testing.B) { + b.Run("TCPCleanup", func(b *testing.B) { + tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing + defer tracker.Close() + + // Pre-populate with expired connections + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + for i := 0; i < 10000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + // Wait for connections to expire + time.Sleep(200 * time.Millisecond) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.cleanup() + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go new file mode 100644 index 000000000..a969a4e84 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -0,0 +1,158 @@ +package conntrack + +import ( + "net" + "sync" + "time" +) + +const ( + // DefaultUDPTimeout is the default timeout for UDP connections + DefaultUDPTimeout = 30 * time.Second + // UDPCleanupInterval is how often we check for stale connections + UDPCleanupInterval = 15 * time.Second +) + +// UDPConnTrack represents a UDP connection state +type UDPConnTrack struct { + BaseConnTrack +} + +// UDPTracker manages UDP connection states +type UDPTracker struct { + connections map[ConnKey]*UDPConnTrack + timeout time.Duration + cleanupTicker *time.Ticker + mutex sync.RWMutex + done chan struct{} + ipPool *PreallocatedIPs +} + +// NewUDPTracker creates a new UDP connection tracker +func NewUDPTracker(timeout time.Duration) *UDPTracker { + if timeout == 0 { + timeout = DefaultUDPTimeout + } + + tracker := &UDPTracker{ + connections: make(map[ConnKey]*UDPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(UDPCleanupInterval), + done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), + } + + go tracker.cleanupRoutine() + return tracker +} + +// TrackOutbound records an outbound UDP connection +func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + now := time.Now().UnixNano() + + t.mutex.Lock() + conn, exists := t.connections[key] + if !exists { + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &UDPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: srcPort, + DestPort: dstPort, + }, + } + conn.lastSeen.Store(now) + conn.established.Store(true) + t.connections[key] = conn + } + t.mutex.Unlock() + + conn.lastSeen.Store(now) +} + +// IsValidInbound checks if an inbound packet matches a tracked connection +func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + + t.mutex.RLock() + conn, exists := t.connections[key] + t.mutex.RUnlock() + + if !exists { + return false + } + + if conn.timeoutExceeded(t.timeout) { + return false + } + + return conn.IsEstablished() && + ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && + conn.DestPort == srcPort && + conn.SourcePort == dstPort +} + +// cleanupRoutine periodically removes stale connections +func (t *UDPTracker) cleanupRoutine() { + for { + select { + case <-t.cleanupTicker.C: + t.cleanup() + case <-t.done: + return + } + } +} + +func (t *UDPTracker) cleanup() { + t.mutex.Lock() + defer t.mutex.Unlock() + + for key, conn := range t.connections { + if conn.timeoutExceeded(t.timeout) { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + delete(t.connections, key) + } + } +} + +// Close stops the cleanup routine and releases resources +func (t *UDPTracker) Close() { + t.cleanupTicker.Stop() + close(t.done) + + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() +} + +// GetConnection safely retrieves a connection state +func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + conn, exists := t.connections[key] + if !exists { + return nil, false + } + + return conn, true +} + +// Timeout returns the configured timeout duration for the tracker +func (t *UDPTracker) Timeout() time.Duration { + return t.timeout +} diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go new file mode 100644 index 000000000..671721890 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -0,0 +1,243 @@ +package conntrack + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewUDPTracker(t *testing.T) { + tests := []struct { + name string + timeout time.Duration + wantTimeout time.Duration + }{ + { + name: "with custom timeout", + timeout: 1 * time.Minute, + wantTimeout: 1 * time.Minute, + }, + { + name: "with zero timeout uses default", + timeout: 0, + wantTimeout: DefaultUDPTimeout, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tracker := NewUDPTracker(tt.timeout) + assert.NotNil(t, tracker) + assert.Equal(t, tt.wantTimeout, tracker.timeout) + assert.NotNil(t, tracker.connections) + assert.NotNil(t, tracker.cleanupTicker) + assert.NotNil(t, tracker.done) + }) + } +} + +func TestUDPTracker_TrackOutbound(t *testing.T) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("192.168.1.3") + srcPort := uint16(12345) + dstPort := uint16(53) + + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + + // Verify connection was tracked + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + conn, exists := tracker.connections[key] + require.True(t, exists) + assert.True(t, conn.SourceIP.Equal(srcIP)) + assert.True(t, conn.DestIP.Equal(dstIP)) + assert.Equal(t, srcPort, conn.SourcePort) + assert.Equal(t, dstPort, conn.DestPort) + assert.True(t, conn.IsEstablished()) + assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) +} + +func TestUDPTracker_IsValidInbound(t *testing.T) { + tracker := NewUDPTracker(1 * time.Second) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("192.168.1.3") + srcPort := uint16(12345) + dstPort := uint16(53) + + // Track outbound connection + tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) + + tests := []struct { + name string + srcIP net.IP + dstIP net.IP + srcPort uint16 + dstPort uint16 + sleep time.Duration + want bool + }{ + { + name: "valid inbound response", + srcIP: dstIP, // Original destination is now source + dstIP: srcIP, // Original source is now destination + srcPort: dstPort, // Original destination port is now source + dstPort: srcPort, // Original source port is now destination + sleep: 0, + want: true, + }, + { + name: "invalid source IP", + srcIP: net.ParseIP("192.168.1.4"), + dstIP: srcIP, + srcPort: dstPort, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid destination IP", + srcIP: dstIP, + dstIP: net.ParseIP("192.168.1.4"), + srcPort: dstPort, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid source port", + srcIP: dstIP, + dstIP: srcIP, + srcPort: 54321, + dstPort: srcPort, + sleep: 0, + want: false, + }, + { + name: "invalid destination port", + srcIP: dstIP, + dstIP: srcIP, + srcPort: dstPort, + dstPort: 54321, + sleep: 0, + want: false, + }, + { + name: "expired connection", + srcIP: dstIP, + dstIP: srcIP, + srcPort: dstPort, + dstPort: srcPort, + sleep: 2 * time.Second, // Longer than tracker timeout + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.sleep > 0 { + time.Sleep(tt.sleep) + } + got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestUDPTracker_Cleanup(t *testing.T) { + // Use shorter intervals for testing + timeout := 50 * time.Millisecond + cleanupInterval := 25 * time.Millisecond + + // Create tracker with custom cleanup interval + tracker := &UDPTracker{ + connections: make(map[ConnKey]*UDPConnTrack), + timeout: timeout, + cleanupTicker: time.NewTicker(cleanupInterval), + done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), + } + + // Start cleanup routine + go tracker.cleanupRoutine() + + // Add some connections + connections := []struct { + srcIP net.IP + dstIP net.IP + srcPort uint16 + dstPort uint16 + }{ + { + srcIP: net.ParseIP("192.168.1.2"), + dstIP: net.ParseIP("192.168.1.3"), + srcPort: 12345, + dstPort: 53, + }, + { + srcIP: net.ParseIP("192.168.1.4"), + dstIP: net.ParseIP("192.168.1.5"), + srcPort: 12346, + dstPort: 53, + }, + } + + for _, conn := range connections { + tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort) + } + + // Verify initial connections + assert.Len(t, tracker.connections, 2) + + // Wait for connection timeout and cleanup interval + time.Sleep(timeout + 2*cleanupInterval) + + tracker.mutex.RLock() + connCount := len(tracker.connections) + tracker.mutex.RUnlock() + + // Verify connections were cleaned up + assert.Equal(t, 0, connCount, "Expected all connections to be cleaned up") + + // Properly close the tracker + tracker.Close() +} + +func BenchmarkUDPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000)) + } + }) +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index fb726395b..24cfd6e96 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -4,6 +4,8 @@ import ( "fmt" "net" "net/netip" + "os" + "strconv" "sync" "github.com/google/gopacket" @@ -12,6 +14,7 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -19,6 +22,8 @@ import ( const layerTypeAll = 0 +const EnvDisableConntrack = "NB_DISABLE_CONNTRACK" + var ( errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") ) @@ -42,6 +47,11 @@ type Manager struct { nativeFirewall firewall.Manager mutex sync.RWMutex + + stateful bool + udpTracker *conntrack.UDPTracker + icmpTracker *conntrack.ICMPTracker + tcpTracker *conntrack.TCPTracker } // decoder for packages @@ -73,6 +83,8 @@ func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager } func create(iface IFaceMapper) (*Manager, error) { + disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) + m := &Manager{ decoders: sync.Pool{ New: func() any { @@ -90,6 +102,16 @@ func create(iface IFaceMapper) (*Manager, error) { outgoingRules: make(map[string]RuleSet), incomingRules: make(map[string]RuleSet), wgIface: iface, + stateful: !disableConntrack, + } + + // Only initialize trackers if stateful mode is enabled + if disableConntrack { + log.Info("conntrack is disabled") + } else { + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) } if err := iface.SetFilter(m); err != nil { @@ -249,16 +271,16 @@ func (m *Manager) Flush() error { return nil } // DropOutgoing filter outgoing packets func (m *Manager) DropOutgoing(packetData []byte) bool { - return m.dropFilter(packetData, m.outgoingRules, false) + return m.processOutgoingHooks(packetData) } // DropIncoming filter incoming packets func (m *Manager) DropIncoming(packetData []byte) bool { - return m.dropFilter(packetData, m.incomingRules, true) + return m.dropFilter(packetData, m.incomingRules) } -// dropFilter implements same logic for booth direction of the traffic -func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isIncomingPacket bool) bool { +// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP +func (m *Manager) processOutgoingHooks(packetData []byte) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -266,61 +288,213 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet, isInco defer m.decoders.Put(d) if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - log.Tracef("couldn't decode layer, err: %s", err) - return true + return false } if len(d.decoded) < 2 { - log.Tracef("not enough levels in network packet") + return false + } + + srcIP, dstIP := m.extractIPs(d) + if srcIP == nil { + return false + } + + // Always process UDP hooks + if d.decoded[1] == layers.LayerTypeUDP { + // Track UDP state only if enabled + if m.stateful { + m.trackUDPOutbound(d, srcIP, dstIP) + } + return m.checkUDPHooks(d, dstIP, packetData) + } + + // Track other protocols only if stateful mode is enabled + if m.stateful { + switch d.decoded[1] { + case layers.LayerTypeTCP: + m.trackTCPOutbound(d, srcIP, dstIP) + case layers.LayerTypeICMPv4: + m.trackICMPOutbound(d, srcIP, dstIP) + } + } + + return false +} + +func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) { + switch d.decoded[0] { + case layers.LayerTypeIPv4: + return d.ip4.SrcIP, d.ip4.DstIP + case layers.LayerTypeIPv6: + return d.ip6.SrcIP, d.ip6.DstIP + default: + return nil, nil + } +} + +func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) { + flags := getTCPFlags(&d.tcp) + m.tcpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.tcp.SrcPort), + uint16(d.tcp.DstPort), + flags, + ) +} + +func getTCPFlags(tcp *layers.TCP) uint8 { + var flags uint8 + if tcp.SYN { + flags |= conntrack.TCPSyn + } + if tcp.ACK { + flags |= conntrack.TCPAck + } + if tcp.FIN { + flags |= conntrack.TCPFin + } + if tcp.RST { + flags |= conntrack.TCPRst + } + if tcp.PSH { + flags |= conntrack.TCPPush + } + if tcp.URG { + flags |= conntrack.TCPUrg + } + return flags +} + +func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) { + m.udpTracker.TrackOutbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) +} + +func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool { + for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { + if rules, exists := m.outgoingRules[ipKey]; exists { + for _, rule := range rules { + if rule.udpHook != nil && (rule.dPort == 0 || rule.dPort == uint16(d.udp.DstPort)) { + return rule.udpHook(packetData) + } + } + } + } + return false +} + +func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) { + if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { + m.icmpTracker.TrackOutbound( + srcIP, + dstIP, + d.icmp4.Id, + d.icmp4.Seq, + ) + } +} + +// dropFilter implements filtering logic for incoming packets +func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + if !m.isValidPacket(d, packetData) { return true } - ipLayer := d.decoded[0] - - switch ipLayer { - case layers.LayerTypeIPv4: - if !m.wgNetwork.Contains(d.ip4.SrcIP) || !m.wgNetwork.Contains(d.ip4.DstIP) { - return false - } - case layers.LayerTypeIPv6: - if !m.wgNetwork.Contains(d.ip6.SrcIP) || !m.wgNetwork.Contains(d.ip6.DstIP) { - return false - } - default: + srcIP, dstIP := m.extractIPs(d) + if srcIP == nil { log.Errorf("unknown layer: %v", d.decoded[0]) return true } - var ip net.IP - switch ipLayer { - case layers.LayerTypeIPv4: - if isIncomingPacket { - ip = d.ip4.SrcIP - } else { - ip = d.ip4.DstIP - } - case layers.LayerTypeIPv6: - if isIncomingPacket { - ip = d.ip6.SrcIP - } else { - ip = d.ip6.DstIP - } + if !m.isWireguardTraffic(srcIP, dstIP) { + return false } - filter, ok := validateRule(ip, packetData, rules[ip.String()], d) - if ok { - return filter + // Check connection state only if enabled + if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { + return false } - filter, ok = validateRule(ip, packetData, rules["0.0.0.0"], d) - if ok { - return filter + + return m.applyRules(srcIP, packetData, rules, d) +} + +func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + log.Tracef("couldn't decode layer, err: %s", err) + return false } - filter, ok = validateRule(ip, packetData, rules["::"], d) - if ok { + + if len(d.decoded) < 2 { + log.Tracef("not enough levels in network packet") + return false + } + return true +} + +func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool { + return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP) +} + +func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { + switch d.decoded[1] { + case layers.LayerTypeTCP: + return m.tcpTracker.IsValidInbound( + srcIP, + dstIP, + uint16(d.tcp.SrcPort), + uint16(d.tcp.DstPort), + getTCPFlags(&d.tcp), + ) + + case layers.LayerTypeUDP: + return m.udpTracker.IsValidInbound( + srcIP, + dstIP, + uint16(d.udp.SrcPort), + uint16(d.udp.DstPort), + ) + + case layers.LayerTypeICMPv4: + return m.icmpTracker.IsValidInbound( + srcIP, + dstIP, + d.icmp4.Id, + d.icmp4.Seq, + d.icmp4.TypeCode.Type(), + ) + + // TODO: ICMPv6 + } + + return false +} + +func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { + if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { return filter } - // default policy is DROP ALL + if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok { + return filter + } + + if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok { + return filter + } + + // Default policy: DROP ALL return true } diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go new file mode 100644 index 000000000..3c661e71c --- /dev/null +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -0,0 +1,998 @@ +package uspfilter + +import ( + "fmt" + "math/rand" + "net" + "os" + "strings" + "testing" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" + + fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/iface/device" +) + +// generateRandomIPs generates n different random IPs in the 100.64.0.0/10 range +func generateRandomIPs(n int) []net.IP { + ips := make([]net.IP, n) + seen := make(map[string]bool) + + for i := 0; i < n; { + ip := make(net.IP, 4) + ip[0] = 100 + ip[1] = byte(64 + rand.Intn(63)) // 64-126 + ip[2] = byte(rand.Intn(256)) + ip[3] = byte(1 + rand.Intn(254)) // avoid .0 and .255 + + key := ip.String() + if !seen[key] { + ips[i] = ip + seen[key] = true + i++ + } + } + return ips +} + +func generatePacket(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16, protocol layers.IPProtocol) []byte { + b.Helper() + + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: protocol, + } + + var transportLayer gopacket.SerializableLayer + switch protocol { + case layers.IPProtocolTCP: + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + } + require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) + transportLayer = tcp + case layers.IPProtocolUDP: + udp := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + require.NoError(b, udp.SetNetworkLayerForChecksum(ipv4)) + transportLayer = udp + } + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test")) + require.NoError(b, err) + return buf.Bytes() +} + +// BenchmarkCoreFiltering focuses on the essential performance comparisons between +// stateful and stateless filtering approaches +func BenchmarkCoreFiltering(b *testing.B) { + scenarios := []struct { + name string + stateful bool + setupFunc func(*Manager) + desc string + }{ + { + name: "stateless_single_allow_all", + stateful: false, + setupFunc: func(m *Manager) { + // Single rule allowing all traffic + _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "allow all") + require.NoError(b, err) + }, + desc: "Baseline: Single 'allow all' rule without connection tracking", + }, + { + name: "stateful_no_rules", + stateful: true, + setupFunc: func(m *Manager) { + // No explicit rules - rely purely on connection tracking + }, + desc: "Pure connection tracking without any rules", + }, + { + name: "stateless_explicit_return", + stateful: false, + setupFunc: func(m *Manager) { + // Add explicit rules matching return traffic pattern + for i := 0; i < 1000; i++ { // Simulate realistic ruleset size + ip := generateRandomIPs(1)[0] + _, err := m.AddPeerFiltering(ip, fw.ProtocolTCP, + &fw.Port{Values: []int{1024 + i}}, + &fw.Port{Values: []int{80}}, + fw.RuleDirectionIN, fw.ActionAccept, "", "explicit return") + require.NoError(b, err) + } + }, + desc: "Explicit rules matching return traffic patterns without state", + }, + { + name: "stateful_with_established", + stateful: true, + setupFunc: func(m *Manager) { + // Add some basic rules but rely on state for established connections + _, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, + fw.RuleDirectionIN, fw.ActionDrop, "", "default drop") + require.NoError(b, err) + }, + desc: "Connection tracking with established connections", + }, + } + + // Test both TCP and UDP + protocols := []struct { + name string + proto layers.IPProtocol + }{ + {"TCP", layers.IPProtocolTCP}, + {"UDP", layers.IPProtocolUDP}, + } + + for _, sc := range scenarios { + for _, proto := range protocols { + b.Run(fmt.Sprintf("%s_%s", sc.name, proto.name), func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + require.NoError(b, os.Setenv("NB_DISABLE_CONNTRACK", "1")) + } else { + require.NoError(b, os.Setenv("NB_CONNTRACK_TIMEOUT", "1m")) + } + + // Create manager and basic setup + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + // Apply scenario-specific setup + sc.setupFunc(manager) + + // Generate test packets + srcIP := generateRandomIPs(1)[0] + dstIP := generateRandomIPs(1)[0] + srcPort := uint16(1024 + b.N%60000) + dstPort := uint16(80) + + outbound := generatePacket(b, srcIP, dstIP, srcPort, dstPort, proto.proto) + inbound := generatePacket(b, dstIP, srcIP, dstPort, srcPort, proto.proto) + + // For stateful scenarios, establish the connection + if sc.stateful { + manager.processOutgoingHooks(outbound) + } + + // Measure inbound packet processing + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } + } +} + +// BenchmarkStateScaling measures how performance scales with connection table size +func BenchmarkStateScaling(b *testing.B) { + connCounts := []int{100, 1000, 10000, 100000} + + for _, count := range connCounts { + b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + // Pre-populate connection table + srcIPs := generateRandomIPs(count) + dstIPs := generateRandomIPs(count) + for i := 0; i < count; i++ { + outbound := generatePacket(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, layers.IPProtocolTCP) + manager.processOutgoingHooks(outbound) + } + + // Test packet + testOut := generatePacket(b, srcIPs[0], dstIPs[0], 1024, 80, layers.IPProtocolTCP) + testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) + + // First establish our test connection + manager.processOutgoingHooks(testOut) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(testIn, manager.incomingRules) + } + }) + } +} + +// BenchmarkEstablishmentOverhead measures the overhead of connection establishment +func BenchmarkEstablishmentOverhead(b *testing.B) { + scenarios := []struct { + name string + established bool + }{ + {"established", true}, + {"new", false}, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + + srcIP := generateRandomIPs(1)[0] + dstIP := generateRandomIPs(1)[0] + outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP) + inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + + if sc.established { + manager.processOutgoingHooks(outbound) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } +} + +// BenchmarkRoutedNetworkReturn compares approaches for handling routed network return traffic +func BenchmarkRoutedNetworkReturn(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + state string // "new", "established", "post_handshake" (TCP only) + setupFunc func(*Manager) + genPackets func(net.IP, net.IP) ([]byte, []byte) // generates appropriate packets for the scenario + desc string + }{ + { + name: "allow_non_wg_tcp_new", + proto: layers.IPProtocolTCP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + }, + desc: "Allow non-WG: TCP new connection", + }, + { + name: "allow_non_wg_tcp_established", + proto: layers.IPProtocolTCP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate packets with ACK flag for established connection + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) + }, + desc: "Allow non-WG: TCP established connection", + }, + { + name: "allow_non_wg_udp_new", + proto: layers.IPProtocolUDP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Allow non-WG: UDP new connection", + }, + { + name: "allow_non_wg_udp_established", + proto: layers.IPProtocolUDP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + } + b.Setenv("NB_DISABLE_CONNTRACK", "1") + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Allow non-WG: UDP established connection", + }, + { + name: "stateful_tcp_new", + proto: layers.IPProtocolTCP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) + }, + desc: "Stateful: TCP new connection", + }, + { + name: "stateful_tcp_established", + proto: layers.IPProtocolTCP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate established TCP packets (ACK flag) + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPAck)) + }, + desc: "Stateful: TCP established connection", + }, + { + name: "stateful_tcp_post_handshake", + proto: layers.IPProtocolTCP, + state: "post_handshake", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + // Generate packets with PSH+ACK flags for data transfer + return generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPPush|conntrack.TCPAck)) + }, + desc: "Stateful: TCP post-handshake data transfer", + }, + { + name: "stateful_udp_new", + proto: layers.IPProtocolUDP, + state: "new", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Stateful: UDP new connection", + }, + { + name: "stateful_udp_established", + proto: layers.IPProtocolUDP, + state: "established", + setupFunc: func(m *Manager) { + m.wgNetwork = &net.IPNet{ + IP: net.ParseIP("0.0.0.0"), + Mask: net.CIDRMask(0, 32), + } + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + }, + genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) { + return generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolUDP), + generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolUDP) + }, + desc: "Stateful: UDP established connection", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + // Setup scenario + sc.setupFunc(manager) + + // Use IPs outside WG range for routed network simulation + srcIP := net.ParseIP("192.168.1.2") + dstIP := net.ParseIP("8.8.8.8") + outbound, inbound := sc.genPackets(srcIP, dstIP) + + // For stateful cases and established connections + if !strings.Contains(sc.name, "allow_non_wg") || + (strings.Contains(sc.state, "established") || sc.state == "post_handshake") { + manager.processOutgoingHooks(outbound) + + // For TCP post-handshake, simulate full handshake + if sc.state == "post_handshake" { + // SYN + syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + // SYN-ACK + synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + // ACK + ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.dropFilter(inbound, manager.incomingRules) + } + }) + } +} + +var scenarios = []struct { + name string + stateful bool // Whether conntrack is enabled + rules bool // Whether to add return traffic rules + routed bool // Whether to test routed network traffic + connCount int // Number of concurrent connections + desc string +}{ + { + name: "stateless_with_rules_100conns", + stateful: false, + rules: true, + routed: false, + connCount: 100, + desc: "Pure stateless with return traffic rules, 100 conns", + }, + { + name: "stateless_with_rules_1000conns", + stateful: false, + rules: true, + routed: false, + connCount: 1000, + desc: "Pure stateless with return traffic rules, 1000 conns", + }, + { + name: "stateful_no_rules_100conns", + stateful: true, + rules: false, + routed: false, + connCount: 100, + desc: "Pure stateful tracking without rules, 100 conns", + }, + { + name: "stateful_no_rules_1000conns", + stateful: true, + rules: false, + routed: false, + connCount: 1000, + desc: "Pure stateful tracking without rules, 1000 conns", + }, + { + name: "stateful_with_rules_100conns", + stateful: true, + rules: true, + routed: false, + connCount: 100, + desc: "Combined stateful + rules (current implementation), 100 conns", + }, + { + name: "stateful_with_rules_1000conns", + stateful: true, + rules: true, + routed: false, + connCount: 1000, + desc: "Combined stateful + rules (current implementation), 1000 conns", + }, + { + name: "routed_network_100conns", + stateful: true, + rules: false, + routed: true, + connCount: 100, + desc: "Routed network traffic (non-WG), 100 conns", + }, + { + name: "routed_network_1000conns", + stateful: true, + rules: false, + routed: true, + connCount: 1000, + desc: "Routed network traffic (non-WG), 1000 conns", + }, +} + +// BenchmarkLongLivedConnections tests performance with realistic TCP traffic patterns +func BenchmarkLongLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + // Single rule to allow all return traffic from port 80 + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create established connections + for i := 0; i < sc.connCount; i++ { + // Initial SYN + syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + + // SYN-ACK + synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + + // ACK + ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + + // Prepare test packets simulating bidirectional traffic + inPackets := make([][]byte, sc.connCount) + outPackets := make([][]byte, sc.connCount) + for i := 0; i < sc.connCount; i++ { + // Server -> Client (inbound) + inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)) + // Client -> Server (outbound) + outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + connIdx := i % sc.connCount + + // Simulate bidirectional traffic + // First outbound data + manager.processOutgoingHooks(outPackets[connIdx]) + // Then inbound response - this is what we're actually measuring + manager.dropFilter(inPackets[connIdx], manager.incomingRules) + } + }) + } +} + +// BenchmarkShortLivedConnections tests performance with many short-lived connections +func BenchmarkShortLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + // Single rule to allow all return traffic from port 80 + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create packet patterns for a complete HTTP-like short connection: + // 1. Initial handshake (SYN, SYN-ACK, ACK) + // 2. HTTP Request (PSH+ACK from client) + // 3. HTTP Response (PSH+ACK from server) + // 4. Connection teardown (FIN+ACK, ACK, FIN+ACK, ACK) + type connPackets struct { + syn []byte + synAck []byte + ack []byte + request []byte + response []byte + finClient []byte + ackServer []byte + finServer []byte + ackClient []byte + } + + // Generate all possible connection patterns + patterns := make([]connPackets, sc.connCount) + for i := 0; i < sc.connCount; i++ { + patterns[i] = connPackets{ + // Handshake + syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)), + synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)), + ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + + // Data transfer + request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)), + + // Connection teardown + finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPAck)), + finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Each iteration creates a new short-lived connection + connIdx := i % sc.connCount + p := patterns[connIdx] + + // Connection establishment + manager.processOutgoingHooks(p.syn) + manager.dropFilter(p.synAck, manager.incomingRules) + manager.processOutgoingHooks(p.ack) + + // Data transfer + manager.processOutgoingHooks(p.request) + manager.dropFilter(p.response, manager.incomingRules) + + // Connection teardown + manager.processOutgoingHooks(p.finClient) + manager.dropFilter(p.ackServer, manager.incomingRules) + manager.dropFilter(p.finServer, manager.incomingRules) + manager.processOutgoingHooks(p.ackClient) + } + }) + } +} + +// BenchmarkParallelLongLivedConnections tests performance with realistic TCP traffic patterns in parallel +func BenchmarkParallelLongLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + // Setup initial state based on scenario + if sc.rules { + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs for connections + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + // Create established connections + for i := 0; i < sc.connCount; i++ { + syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)) + manager.processOutgoingHooks(syn) + + synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) + manager.dropFilter(synack, manager.incomingRules) + + ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)) + manager.processOutgoingHooks(ack) + } + + // Pre-generate test packets + inPackets := make([][]byte, sc.connCount) + outPackets := make([][]byte, sc.connCount) + for i := 0; i < sc.connCount; i++ { + inPackets[i] = generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)) + outPackets[i] = generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + // Each goroutine gets its own counter to distribute load + counter := 0 + for pb.Next() { + connIdx := counter % sc.connCount + counter++ + + // Simulate bidirectional traffic + manager.processOutgoingHooks(outPackets[connIdx]) + manager.dropFilter(inPackets[connIdx], manager.incomingRules) + } + }) + }) + } +} + +// BenchmarkParallelShortLivedConnections tests performance with many short-lived connections in parallel +func BenchmarkParallelShortLivedConnections(b *testing.B) { + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + // Configure stateful/stateless mode + if !sc.stateful { + b.Setenv("NB_DISABLE_CONNTRACK", "1") + } else { + require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK")) + } + + manager, _ := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + defer b.Cleanup(func() { + require.NoError(b, manager.Reset(nil)) + }) + + manager.SetNetwork(&net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }) + + if sc.rules { + _, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, + &fw.Port{Values: []int{80}}, + nil, + fw.RuleDirectionIN, fw.ActionAccept, "", "return traffic") + require.NoError(b, err) + } + + // Generate IPs and pre-generate all packet patterns + srcIPs := make([]net.IP, sc.connCount) + dstIPs := make([]net.IP, sc.connCount) + for i := 0; i < sc.connCount; i++ { + if sc.routed { + srcIPs[i] = net.IPv4(192, 168, 1, byte(2+(i%250))).To4() + dstIPs[i] = net.IPv4(8, 8, byte((i/250)%255), byte(2+(i%250))).To4() + } else { + srcIPs[i] = generateRandomIPs(1)[0] + dstIPs[i] = generateRandomIPs(1)[0] + } + } + + type connPackets struct { + syn []byte + synAck []byte + ack []byte + request []byte + response []byte + finClient []byte + ackServer []byte + finServer []byte + ackClient []byte + } + + patterns := make([]connPackets, sc.connCount) + for i := 0; i < sc.connCount; i++ { + patterns[i] = connPackets{ + syn: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPSyn)), + synAck: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)), + ack: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + request: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPPush|conntrack.TCPAck)), + response: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPPush|conntrack.TCPAck)), + finClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPAck)), + finServer: generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], + 80, uint16(1024+i), uint16(conntrack.TCPFin|conntrack.TCPAck)), + ackClient: generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], + uint16(1024+i), 80, uint16(conntrack.TCPAck)), + } + } + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + counter := 0 + for pb.Next() { + connIdx := counter % sc.connCount + counter++ + p := patterns[connIdx] + + // Full connection lifecycle + manager.processOutgoingHooks(p.syn) + manager.dropFilter(p.synAck, manager.incomingRules) + manager.processOutgoingHooks(p.ack) + + manager.processOutgoingHooks(p.request) + manager.dropFilter(p.response, manager.incomingRules) + + manager.processOutgoingHooks(p.finClient) + manager.dropFilter(p.ackServer, manager.incomingRules) + manager.dropFilter(p.finServer, manager.incomingRules) + manager.processOutgoingHooks(p.ackClient) + } + }) + }) + } +} + +// generateTCPPacketWithFlags creates a TCP packet with specific flags +func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte { + b.Helper() + + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: layers.IPProtocolTCP, + } + + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + } + + // Set TCP flags + tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0 + tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0 + tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0 + tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0 + tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0 + + require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) + return buf.Bytes() +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index d7c93cb7f..d3563e6f2 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -3,6 +3,7 @@ package uspfilter import ( "fmt" "net" + "sync" "testing" "time" @@ -11,6 +12,7 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -185,10 +187,10 @@ func TestAddUDPPacketHook(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - manager := &Manager{ - incomingRules: map[string]RuleSet{}, - outgoingRules: map[string]RuleSet{}, - } + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) @@ -313,7 +315,7 @@ func TestNotMatchByIP(t *testing.T) { t.Errorf("failed to set network layer for checksum: %v", err) return } - payload := gopacket.Payload([]byte("test")) + payload := gopacket.Payload("test") buf := gopacket.NewSerializeBuffer() opts := gopacket.SerializeOptions{ @@ -325,7 +327,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if m.dropFilter(buf.Bytes(), m.outgoingRules, false) { + if m.dropFilter(buf.Bytes(), m.outgoingRules) { t.Errorf("expected packet to be accepted") return } @@ -348,6 +350,9 @@ func TestRemovePacketHook(t *testing.T) { if err != nil { t.Fatalf("Failed to create Manager: %s", err) } + defer func() { + require.NoError(t, manager.Reset(nil)) + }() // Add a UDP packet hook hookFunc := func(data []byte) bool { return true } @@ -384,6 +389,88 @@ func TestRemovePacketHook(t *testing.T) { } } +func TestProcessOutgoingHooks(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + manager.udpTracker.Close() + manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond) + defer func() { + require.NoError(t, manager.Reset(nil)) + }() + + manager.decoders = sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d + }, + } + + hookCalled := false + hookID := manager.AddUDPPacketHook( + false, + net.ParseIP("100.10.0.100"), + 53, + func([]byte) bool { + hookCalled = true + return true + }, + ) + require.NotEmpty(t, hookID) + + // Create test UDP packet + ipv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: net.ParseIP("100.10.0.1"), + DstIP: net.ParseIP("100.10.0.100"), + Protocol: layers.IPProtocolUDP, + } + udp := &layers.UDP{ + SrcPort: 51334, + DstPort: 53, + } + + err = udp.SetNetworkLayerForChecksum(ipv4) + require.NoError(t, err) + payload := gopacket.Payload("test") + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload) + require.NoError(t, err) + + // Test hook gets called + result := manager.processOutgoingHooks(buf.Bytes()) + require.True(t, result) + require.True(t, hookCalled) + + // Test non-UDP packet is ignored + ipv4.Protocol = layers.IPProtocolTCP + buf = gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(buf, opts, ipv4) + require.NoError(t, err) + + result = manager.processOutgoingHooks(buf.Bytes()) + require.False(t, result) +} + func TestUSPFilterCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { @@ -418,3 +505,213 @@ func TestUSPFilterCreatePerformance(t *testing.T) { }) } } + +func TestStatefulFirewall_UDPTracking(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }) + require.NoError(t, err) + + manager.wgNetwork = &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + + manager.udpTracker.Close() // Close the existing tracker + manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond) + manager.decoders = sync.Pool{ + New: func() any { + d := &decoder{ + decoded: []gopacket.LayerType{}, + } + d.parser = gopacket.NewDecodingLayerParser( + layers.LayerTypeIPv4, + &d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp, + ) + d.parser.IgnoreUnsupported = true + return d + }, + } + defer func() { + require.NoError(t, manager.Reset(nil)) + }() + + // Set up packet parameters + srcIP := net.ParseIP("100.10.0.1") + dstIP := net.ParseIP("100.10.0.100") + srcPort := uint16(51334) + dstPort := uint16(53) + + // Create outbound packet + outboundIPv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: srcIP, + DstIP: dstIP, + Protocol: layers.IPProtocolUDP, + } + outboundUDP := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + + err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4) + require.NoError(t, err) + + outboundBuf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + err = gopacket.SerializeLayers(outboundBuf, opts, + outboundIPv4, + outboundUDP, + gopacket.Payload("test"), + ) + require.NoError(t, err) + + // Process outbound packet and verify connection tracking + drop := manager.DropOutgoing(outboundBuf.Bytes()) + require.False(t, drop, "Initial outbound packet should not be dropped") + + // Verify connection was tracked + conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + + require.True(t, exists, "Connection should be tracked after outbound packet") + require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match") + require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match") + require.Equal(t, srcPort, conn.SourcePort, "Source port should match") + require.Equal(t, dstPort, conn.DestPort, "Destination port should match") + + // Create valid inbound response packet + inboundIPv4 := &layers.IPv4{ + TTL: 64, + Version: 4, + SrcIP: dstIP, // Original destination is now source + DstIP: srcIP, // Original source is now destination + Protocol: layers.IPProtocolUDP, + } + inboundUDP := &layers.UDP{ + SrcPort: layers.UDPPort(dstPort), // Original destination port is now source + DstPort: layers.UDPPort(srcPort), // Original source port is now destination + } + + err = inboundUDP.SetNetworkLayerForChecksum(inboundIPv4) + require.NoError(t, err) + + inboundBuf := gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(inboundBuf, opts, + inboundIPv4, + inboundUDP, + gopacket.Payload("response"), + ) + require.NoError(t, err) + // Test roundtrip response handling over time + checkPoints := []struct { + sleep time.Duration + shouldAllow bool + description string + }{ + { + sleep: 0, + shouldAllow: true, + description: "Immediate response should be allowed", + }, + { + sleep: 50 * time.Millisecond, + shouldAllow: true, + description: "Response within timeout should be allowed", + }, + { + sleep: 100 * time.Millisecond, + shouldAllow: true, + description: "Response at half timeout should be allowed", + }, + { + // tracker hasn't updated conn for 250ms -> greater than 200ms timeout + sleep: 250 * time.Millisecond, + shouldAllow: false, + description: "Response after timeout should be dropped", + }, + } + + for _, cp := range checkPoints { + time.Sleep(cp.sleep) + + drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules) + require.Equal(t, cp.shouldAllow, !drop, cp.description) + + // If the connection should still be valid, verify it exists + if cp.shouldAllow { + conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) + require.True(t, exists, "Connection should still exist during valid window") + require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(), + "LastSeen should be updated for valid responses") + } + } + + // Test invalid response packets (while connection is expired) + invalidCases := []struct { + name string + modifyFunc func(*layers.IPv4, *layers.UDP) + description string + }{ + { + name: "wrong source IP", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + ip.SrcIP = net.ParseIP("100.10.0.101") + }, + description: "Response from wrong IP should be dropped", + }, + { + name: "wrong destination IP", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + ip.DstIP = net.ParseIP("100.10.0.2") + }, + description: "Response to wrong IP should be dropped", + }, + { + name: "wrong source port", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + udp.SrcPort = 54 + }, + description: "Response from wrong port should be dropped", + }, + { + name: "wrong destination port", + modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) { + udp.DstPort = 51335 + }, + description: "Response to wrong port should be dropped", + }, + } + + // Create a new outbound connection for invalid tests + drop = manager.processOutgoingHooks(outboundBuf.Bytes()) + require.False(t, drop, "Second outbound packet should not be dropped") + + for _, tc := range invalidCases { + t.Run(tc.name, func(t *testing.T) { + testIPv4 := *inboundIPv4 + testUDP := *inboundUDP + + tc.modifyFunc(&testIPv4, &testUDP) + + err = testUDP.SetNetworkLayerForChecksum(&testIPv4) + require.NoError(t, err) + + testBuf := gopacket.NewSerializeBuffer() + err = gopacket.SerializeLayers(testBuf, opts, + &testIPv4, + &testUDP, + gopacket.Payload("response"), + ) + require.NoError(t, err) + + // Verify the invalid packet is dropped + drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules) + require.True(t, drop, tc.description) + }) + } +} From 0dbaddc7bec1365d480aad7cf0505067a13ef32a Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 24 Dec 2024 15:05:23 +0100 Subject: [PATCH 094/159] [client] Don't fail debug if log file is console (#3103) --- client/server/debug.go | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/client/server/debug.go b/client/server/debug.go index 9dfde0367..3c4967b4e 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -139,10 +139,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) ( s.mutex.Lock() defer s.mutex.Unlock() - if s.logFile == "console" { - return nil, fmt.Errorf("log file is set to console, cannot create debug bundle") - } - bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip") if err != nil { return nil, fmt.Errorf("create zip file: %w", err) @@ -185,17 +181,7 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques } if req.GetSystemInfo() { - if err := s.addRoutes(req, anonymizer, archive); err != nil { - log.Errorf("Failed to add routes to debug bundle: %v", err) - } - - if err := s.addInterfaces(req, anonymizer, archive); err != nil { - log.Errorf("Failed to add interfaces to debug bundle: %v", err) - } - - if err := s.addFirewallRules(req, anonymizer, archive); err != nil { - log.Errorf("Failed to add firewall rules to debug bundle: %v", err) - } + s.addSystemInfo(req, anonymizer, archive) } if err := s.addNetworkMap(req, anonymizer, archive); err != nil { @@ -206,8 +192,10 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques log.Errorf("Failed to add state file to debug bundle: %v", err) } - if err := s.addLogfile(req, anonymizer, archive); err != nil { - return fmt.Errorf("add log file: %w", err) + if s.logFile != "console" { + if err := s.addLogfile(req, anonymizer, archive); err != nil { + return fmt.Errorf("add log file: %w", err) + } } if err := archive.Close(); err != nil { @@ -216,6 +204,20 @@ func (s *Server) createArchive(bundlePath *os.File, req *proto.DebugBundleReques return nil } +func (s *Server) addSystemInfo(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) { + if err := s.addRoutes(req, anonymizer, archive); err != nil { + log.Errorf("Failed to add routes to debug bundle: %v", err) + } + + if err := s.addInterfaces(req, anonymizer, archive); err != nil { + log.Errorf("Failed to add interfaces to debug bundle: %v", err) + } + + if err := s.addFirewallRules(req, anonymizer, archive); err != nil { + log.Errorf("Failed to add firewall rules to debug bundle: %v", err) + } +} + func (s *Server) addReadme(req *proto.DebugBundleRequest, archive *zip.Writer) error { if req.GetAnonymize() { readmeReader := strings.NewReader(readmeContent) From b3c87cb5d1c62c2158548a8a7d0d1d25bc57c0de Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 26 Dec 2024 00:51:27 +0100 Subject: [PATCH 095/159] [client] Fix inbound tracking in userspace firewall (#3111) * Don't create state for inbound SYN * Allow final ack in some cases * Relax state machine test a little --- client/firewall/uspfilter/conntrack/common.go | 1 - client/firewall/uspfilter/conntrack/tcp.go | 40 ++++--------------- .../firewall/uspfilter/conntrack/tcp_test.go | 7 +--- 3 files changed, 10 insertions(+), 38 deletions(-) diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index a4b1971bf..e459bc75a 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -10,7 +10,6 @@ import ( // BaseConnTrack provides common fields and locking for all connection types type BaseConnTrack struct { - sync.RWMutex SourceIP net.IP DestIP net.IP SourcePort uint16 diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index e8d20f41c..a7968dc73 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -62,6 +62,7 @@ type TCPConnKey struct { type TCPConnTrack struct { BaseConnTrack State TCPState + sync.RWMutex } // TCPTracker manages TCP connection states @@ -131,36 +132,8 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, return false } - // Handle new SYN packets - if flags&TCPSyn != 0 && flags&TCPAck == 0 { - key := makeConnKey(dstIP, srcIP, dstPort, srcPort) - t.mutex.Lock() - if _, exists := t.connections[key]; !exists { - // Use preallocated IPs - srcIPCopy := t.ipPool.Get() - dstIPCopy := t.ipPool.Get() - copyIP(srcIPCopy, dstIP) - copyIP(dstIPCopy, srcIP) - - conn := &TCPConnTrack{ - BaseConnTrack: BaseConnTrack{ - SourceIP: srcIPCopy, - DestIP: dstIPCopy, - SourcePort: dstPort, - DestPort: srcPort, - }, - State: TCPStateSynReceived, - } - conn.lastSeen.Store(time.Now().UnixNano()) - conn.established.Store(false) - t.connections[key] = conn - } - t.mutex.Unlock() - return true - } - - // Look up existing connection key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + t.mutex.RLock() conn, exists := t.connections[key] t.mutex.RUnlock() @@ -172,8 +145,7 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, // Handle RST packets if flags&TCPRst != 0 { conn.Lock() - isEstablished := conn.IsEstablished() - if isEstablished || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { + if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { conn.State = TCPStateClosed conn.SetEstablished(false) conn.Unlock() @@ -183,7 +155,6 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, return false } - // Update state conn.Lock() t.updateState(conn, flags, false) conn.UpdateLastSeen() @@ -306,6 +277,11 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool { return flags&TCPFin != 0 || flags&TCPAck != 0 case TCPStateLastAck: return flags&TCPAck != 0 + case TCPStateClosed: + // Accept retransmitted ACKs in closed state + // This is important because the final ACK might be lost + // and the peer will retransmit their FIN-ACK + return flags&TCPAck != 0 } return false } diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 3933c8889..6c8f82423 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -125,11 +125,8 @@ func TestTCPStateMachine(t *testing.T) { valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) require.True(t, valid, "RST should be allowed for established connection") - // Verify connection is closed - valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) - t.Helper() - - require.False(t, valid, "Data should be blocked after RST") + // Connection is logically dead but we don't enforce blocking subsequent packets + // The connection will be cleaned up by timeout }, }, { From 445b626dc8cf1f8489291bc5da59bcfdf89391ad Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Fri, 27 Dec 2024 14:39:34 +0300 Subject: [PATCH 096/159] [management] Add missing group usage checks for network resources and routes access control (#3117) * Prevent deletion of groups linked to routes access control groups Signed-off-by: bcmmbaga * Prevent deletion of groups linked to network resource Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/group.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/management/server/group.go b/management/server/group.go index d433a3485..f1057dda6 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -474,6 +474,10 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") } + if len(group.Resources) > 0 { + return &GroupLinkError{"network resource", group.Resources[0].ID} + } + if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { return &GroupLinkError{"route", string(linkedRoute.NetID)} } @@ -529,7 +533,10 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI } for _, r := range routes { - if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) { + isLinked := slices.Contains(r.Groups, groupID) || + slices.Contains(r.PeerGroups, groupID) || + slices.Contains(r.AccessControlGroups, groupID) + if isLinked { return true, r } } From fbce8bb51197c2e29a168386bf3ba2da7fa16c1a Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 27 Dec 2024 14:13:36 +0100 Subject: [PATCH 097/159] [management] remove ids from policy creation api (#2997) --- management/server/http/api/openapi.yml | 62 ++++++++++++++++--- management/server/http/api/types.gen.go | 37 ++++++----- .../handlers/policies/policies_handler.go | 11 +++- .../policies/policies_handler_test.go | 11 ++-- 4 files changed, 90 insertions(+), 31 deletions(-) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 351976baf..6c1d6b424 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -725,10 +725,6 @@ components: PolicyRuleMinimum: type: object properties: - id: - description: Policy rule ID - type: string - example: ch8i4ug6lnn4g9hqv7mg name: description: Policy rule name identifier type: string @@ -790,6 +786,31 @@ components: - end PolicyRuleUpdate: + allOf: + - $ref: '#/components/schemas/PolicyRuleMinimum' + - type: object + properties: + id: + description: Policy rule ID + type: string + example: ch8i4ug6lnn4g9hqv7mg + sources: + description: Policy rule source group IDs + type: array + items: + type: string + example: "ch8i4ug6lnn4g9hqv797" + destinations: + description: Policy rule destination group IDs + type: array + items: + type: string + example: "ch8i4ug6lnn4g9h7v7m0" + required: + - sources + - destinations + + PolicyRuleCreate: allOf: - $ref: '#/components/schemas/PolicyRuleMinimum' - type: object @@ -817,6 +838,10 @@ components: - $ref: '#/components/schemas/PolicyRuleMinimum' - type: object properties: + id: + description: Policy rule ID + type: string + example: ch8i4ug6lnn4g9hqv7mg sources: description: Policy rule source group IDs type: array @@ -836,10 +861,6 @@ components: PolicyMinimum: type: object properties: - id: - description: Policy ID - type: string - example: ch8i4ug6lnn4g9hqv7mg name: description: Policy name identifier type: string @@ -854,7 +875,6 @@ components: example: true required: - name - - description - enabled PolicyUpdate: allOf: @@ -874,11 +894,33 @@ components: $ref: '#/components/schemas/PolicyRuleUpdate' required: - rules + PolicyCreate: + allOf: + - $ref: '#/components/schemas/PolicyMinimum' + - type: object + properties: + source_posture_checks: + description: Posture checks ID's applied to policy source groups + type: array + items: + type: string + example: "chacdk86lnnboviihd70" + rules: + description: Policy rule object for policy UI editor + type: array + items: + $ref: '#/components/schemas/PolicyRuleUpdate' + required: + - rules Policy: allOf: - $ref: '#/components/schemas/PolicyMinimum' - type: object properties: + id: + description: Policy ID + type: string + example: ch8i4ug6lnn4g9hqv7mg source_posture_checks: description: Posture checks ID's applied to policy source groups type: array @@ -2463,7 +2505,7 @@ paths: content: 'application/json': schema: - $ref: '#/components/schemas/PolicyUpdate' + $ref: '#/components/schemas/PolicyCreate' responses: '200': description: A Policy object diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 40574d6f1..83226587f 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -879,7 +879,7 @@ type PersonalAccessTokenRequest struct { // Policy defines model for Policy. type Policy struct { // Description Policy friendly description - Description string `json:"description"` + Description *string `json:"description,omitempty"` // Enabled Policy status Enabled bool `json:"enabled"` @@ -897,16 +897,31 @@ type Policy struct { SourcePostureChecks []string `json:"source_posture_checks"` } -// PolicyMinimum defines model for PolicyMinimum. -type PolicyMinimum struct { +// PolicyCreate defines model for PolicyCreate. +type PolicyCreate struct { // Description Policy friendly description - Description string `json:"description"` + Description *string `json:"description,omitempty"` // Enabled Policy status Enabled bool `json:"enabled"` - // Id Policy ID - Id *string `json:"id,omitempty"` + // Name Policy name identifier + Name string `json:"name"` + + // Rules Policy rule object for policy UI editor + Rules []PolicyRuleUpdate `json:"rules"` + + // SourcePostureChecks Posture checks ID's applied to policy source groups + SourcePostureChecks *[]string `json:"source_posture_checks,omitempty"` +} + +// PolicyMinimum defines model for PolicyMinimum. +type PolicyMinimum struct { + // Description Policy friendly description + Description *string `json:"description,omitempty"` + + // Enabled Policy status + Enabled bool `json:"enabled"` // Name Policy name identifier Name string `json:"name"` @@ -970,9 +985,6 @@ type PolicyRuleMinimum struct { // Enabled Policy rule status Enabled bool `json:"enabled"` - // Id Policy rule ID - Id *string `json:"id,omitempty"` - // Name Policy rule name identifier Name string `json:"name"` @@ -1039,14 +1051,11 @@ type PolicyRuleUpdateProtocol string // PolicyUpdate defines model for PolicyUpdate. type PolicyUpdate struct { // Description Policy friendly description - Description string `json:"description"` + Description *string `json:"description,omitempty"` // Enabled Policy status Enabled bool `json:"enabled"` - // Id Policy ID - Id *string `json:"id,omitempty"` - // Name Policy name identifier Name string `json:"name"` @@ -1473,7 +1482,7 @@ type PutApiPeersPeerIdJSONRequestBody = PeerRequest type PostApiPoliciesJSONRequestBody = PolicyUpdate // PutApiPoliciesPolicyIdJSONRequestBody defines body for PutApiPoliciesPolicyId for application/json ContentType. -type PutApiPoliciesPolicyIdJSONRequestBody = PolicyUpdate +type PutApiPoliciesPolicyIdJSONRequestBody = PolicyCreate // PostApiPostureChecksJSONRequestBody defines body for PostApiPostureChecks for application/json ContentType. type PostApiPostureChecksJSONRequestBody = PostureCheckUpdate diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index d538d07db..b1035c570 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -133,16 +133,21 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s return } + description := "" + if req.Description != nil { + description = *req.Description + } + policy := &types.Policy{ ID: policyID, AccountID: accountID, Name: req.Name, Enabled: req.Enabled, - Description: req.Description, + Description: description, } for _, rule := range req.Rules { var ruleID string - if rule.Id != nil { + if rule.Id != nil && policyID != "" { ruleID = *rule.Id } @@ -370,7 +375,7 @@ func toPolicyResponse(groups []*types.Group, policy *types.Policy) *api.Policy { ap := &api.Policy{ Id: &policy.ID, Name: policy.Name, - Description: policy.Description, + Description: &policy.Description, Enabled: policy.Enabled, SourcePostureChecks: policy.SourcePostureChecks, } diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index 956d0b7cd..3e1be187c 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -154,6 +154,7 @@ func TestPoliciesGetPolicy(t *testing.T) { func TestPoliciesWritePolicy(t *testing.T) { str := func(s string) *string { return &s } + emptyString := "" tt := []struct { name string expectedStatus int @@ -184,8 +185,9 @@ func TestPoliciesWritePolicy(t *testing.T) { expectedStatus: http.StatusOK, expectedBody: true, expectedPolicy: &api.Policy{ - Id: str("id-was-set"), - Name: "Default POSTed Policy", + Id: str("id-was-set"), + Name: "Default POSTed Policy", + Description: &emptyString, Rules: []api.PolicyRule{ { Id: str("id-was-set"), @@ -232,8 +234,9 @@ func TestPoliciesWritePolicy(t *testing.T) { expectedStatus: http.StatusOK, expectedBody: true, expectedPolicy: &api.Policy{ - Id: str("id-existed"), - Name: "Default POSTed Policy", + Id: str("id-existed"), + Name: "Default POSTed Policy", + Description: &emptyString, Rules: []api.PolicyRule{ { Id: str("id-existed"), From 1a623943c88e872e746d3248a6a2ce963b997bcc Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 30 Dec 2024 12:40:24 +0100 Subject: [PATCH 098/159] [management] Fix networks net map generation with posture checks (#3124) --- go.mod | 1 + go.sum | 2 + management/server/account_test.go | 6 +- management/server/peer_test.go | 6 +- management/server/policy_test.go | 32 +- management/server/route.go | 2 +- management/server/types/account.go | 140 +++++---- management/server/types/account_test.go | 382 ++++++++++++++++++++++++ management/server/types/policy.go | 12 +- 9 files changed, 508 insertions(+), 75 deletions(-) diff --git a/go.mod b/go.mod index d48280df0..330d0763f 100644 --- a/go.mod +++ b/go.mod @@ -79,6 +79,7 @@ require ( github.com/testcontainers/testcontainers-go v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 github.com/things-go/go-socks5 v0.0.4 + github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907 github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 diff --git a/go.sum b/go.sum index 540cbf20b..ea4597836 100644 --- a/go.sum +++ b/go.sum @@ -698,6 +698,8 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907 h1:S5h7yNKStqF8CqFtgtMNMzk/lUI3p82LrX6h2BhlsTM= +github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907/go.mod h1:/7Fy/4/OyrkguTf2i2pO4erUD/8QAlrlmXSdSJPu678= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/management/server/account_test.go b/management/server/account_test.go index d83eab6d1..280d998fd 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3037,9 +3037,9 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 1, 3, 3, 10}, + {"Small", 50, 5, 1, 3, 3, 11}, {"Medium", 500, 100, 7, 13, 10, 70}, - {"Large", 5000, 200, 65, 80, 60, 200}, + {"Large", 5000, 200, 65, 80, 60, 220}, {"Small single", 50, 10, 1, 3, 3, 70}, {"Medium single", 500, 10, 7, 13, 10, 26}, {"Large 5", 5000, 15, 65, 80, 60, 200}, @@ -3179,7 +3179,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { maxMsPerOpCICD float64 }{ {"Small", 50, 5, 107, 120, 107, 160}, - {"Medium", 500, 100, 105, 140, 105, 190}, + {"Medium", 500, 100, 105, 140, 105, 220}, {"Large", 5000, 200, 180, 220, 180, 350}, {"Small single", 50, 10, 107, 120, 105, 160}, {"Medium single", 500, 10, 105, 140, 105, 170}, diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 2ab262ff0..9ad67d2bf 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -932,11 +932,11 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { }{ {"Small", 50, 5, 90, 120, 90, 120}, {"Medium", 500, 100, 110, 150, 120, 260}, - {"Large", 5000, 200, 800, 1390, 2500, 4600}, + {"Large", 5000, 200, 800, 1700, 2500, 5000}, {"Small single", 50, 10, 90, 120, 90, 120}, {"Medium single", 500, 10, 110, 170, 120, 200}, - {"Large 5", 5000, 15, 1300, 2100, 5000, 7000}, - {"Extra Large", 2000, 2000, 1300, 2100, 4000, 6000}, + {"Large 5", 5000, 15, 1300, 2100, 4900, 7000}, + {"Extra Large", 2000, 2000, 1300, 2400, 4000, 6400}, } log.SetOutput(io.Discard) diff --git a/management/server/policy_test.go b/management/server/policy_test.go index fab738abe..0d17da23a 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -74,6 +74,19 @@ func TestAccount_getPeersByPolicy(t *testing.T) { "peerH", }, }, + "GroupWorkstations": { + ID: "GroupWorkstations", + Name: "All", + Peers: []string{ + "peerB", + "peerA", + "peerD", + "peerE", + "peerF", + "peerG", + "peerH", + }, + }, "GroupSwarm": { ID: "GroupSwarm", Name: "swarm", @@ -127,7 +140,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Action: types.PolicyTrafficActionAccept, Sources: []string{ "GroupSwarm", - "GroupAll", + "GroupWorkstations", }, Destinations: []string{ "GroupSwarm", @@ -159,6 +172,8 @@ func TestAccount_getPeersByPolicy(t *testing.T) { assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerE"]) assert.Contains(t, peers, account.Peers["peerF"]) + assert.Contains(t, peers, account.Peers["peerG"]) + assert.Contains(t, peers, account.Peers["peerH"]) epectedFirewallRules := []*types.FirewallRule{ { @@ -189,21 +204,6 @@ func TestAccount_getPeersByPolicy(t *testing.T) { Protocol: "all", Port: "", }, - { - PeerIP: "100.65.254.139", - Direction: types.FirewallRuleDirectionOUT, - Action: "accept", - Protocol: "all", - Port: "", - }, - { - PeerIP: "100.65.254.139", - Direction: types.FirewallRuleDirectionIN, - Action: "accept", - Protocol: "all", - Port: "", - }, - { PeerIP: "100.65.62.5", Direction: types.FirewallRuleDirectionOUT, diff --git a/management/server/route.go b/management/server/route.go index 1eb51aea7..b6b44fbbd 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -364,7 +364,7 @@ func toProtocolRoute(route *route.Route) *proto.Route { } func toProtocolRoutes(routes []*route.Route) []*proto.Route { - protoRoutes := make([]*proto.Route, 0) + protoRoutes := make([]*proto.Route, 0, len(routes)) for _, r := range routes { protoRoutes = append(protoRoutes, toProtocolRoute(r)) } diff --git a/management/server/types/account.go b/management/server/types/account.go index b36b719e4..3ef862fa6 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "github.com/yourbasic/radix" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -1045,37 +1046,32 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, // for destination group peers, call this method with an empty list of sourcePostureChecksIDs func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) { peerInGroups := false - filteredPeers := make([]*nbpeer.Peer, 0, len(groups)) - for _, g := range groups { - group, ok := a.Groups[g] - if !ok { + uniquePeerIDs := a.getUniquePeerIDsFromGroupsIDs(ctx, groups) + filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs)) + for _, p := range uniquePeerIDs { + peer, ok := a.Peers[p] + if !ok || peer == nil { continue } - for _, p := range group.Peers { - peer, ok := a.Peers[p] - if !ok || peer == nil { - continue - } - - // validate the peer based on policy posture checks applied - isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) - if !isValid { - continue - } - - if _, ok := validatedPeersMap[peer.ID]; !ok { - continue - } - - if peer.ID == peerID { - peerInGroups = true - continue - } - - filteredPeers = append(filteredPeers, peer) + // validate the peer based on policy posture checks applied + isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID) + if !isValid { + continue } + + if _, ok := validatedPeersMap[peer.ID]; !ok { + continue + } + + if peer.ID == peerID { + peerInGroups = true + continue + } + + filteredPeers = append(filteredPeers, peer) } + return filteredPeers, peerInGroups } @@ -1151,7 +1147,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli continue } - rulePeers := a.getRulePeers(rule, peerID, distributionPeers, validatedPeersMap) + rulePeers := a.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap) rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) fwRules = append(fwRules, rules...) } @@ -1159,8 +1155,8 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli return fwRules } -func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { - distPeersWithPolicy := make(map[string]struct{}) +func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { + distPeersWithPolicy := make([]string, 0) for _, id := range rule.Sources { group := a.Groups[id] if group == nil { @@ -1173,14 +1169,17 @@ func (a *Account) getRulePeers(rule *PolicyRule, peerID string, distributionPeer } _, distPeer := distributionPeers[pID] _, valid := validatedPeersMap[pID] - if distPeer && valid { - distPeersWithPolicy[pID] = struct{}{} + if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { + distPeersWithPolicy = append(distPeersWithPolicy, pID) } } } - distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) - for pID := range distPeersWithPolicy { + radix.Sort(distPeersWithPolicy) + uniqueDistributionPeers := slices.Compact(distPeersWithPolicy) + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(uniqueDistributionPeers)) + for _, pID := range uniqueDistributionPeers { peer := a.Peers[pID] if peer == nil { continue @@ -1271,7 +1270,11 @@ func (a *Account) GetPeerNetworkResourceFirewallRules(ctx context.Context, peer distributionPeers := getPoliciesSourcePeers(resourceAppliedPolicies, a.Groups) rules := a.getRouteFirewallRules(ctx, peer.ID, resourceAppliedPolicies, route, validatedPeersMap, distributionPeers) - routesFirewallRules = append(routesFirewallRules, rules...) + for _, rule := range rules { + if len(rule.SourceRanges) > 0 { + routesFirewallRules = append(routesFirewallRules, rule) + } + } } return routesFirewallRules @@ -1306,7 +1309,7 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) { var isRoutingPeer bool var routes []*route.Route - var allSourcePeers []string + allSourcePeers := make([]string, 0) for _, resource := range a.NetworkResources { var addSourcePeers bool @@ -1319,28 +1322,63 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st } } + addedResourceRoute := false for _, policy := range resourcePolicies[resource.ID] { - for _, sourceGroup := range policy.SourceGroups() { - group := a.GetGroup(sourceGroup) - if group == nil { - log.WithContext(ctx).Warnf("policy %s has source group %s that doesn't exist under account %s, will continue map generation without it", policy.ID, sourceGroup, a.Id) - continue - } - - // routing peer should be able to connect with all source peers - if addSourcePeers { - allSourcePeers = append(allSourcePeers, group.Peers...) - } else if slices.Contains(group.Peers, peerID) { - // add routes for the resource if the peer is in the distribution group - for peerId, router := range networkRoutingPeers { - routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) - } + peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) + if addSourcePeers { + allSourcePeers = append(allSourcePeers, a.getPostureValidPeers(peers, policy.SourcePostureChecks)...) + } else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + // add routes for the resource if the peer is in the distribution group + for peerId, router := range networkRoutingPeers { + routes = append(routes, a.getNetworkResourcesRoutes(resource, peerId, router, resourcePolicies)...) } + addedResourceRoute = true + } + if addedResourceRoute { + break } } } - return isRoutingPeer, routes, allSourcePeers + radix.Sort(allSourcePeers) + return isRoutingPeer, routes, slices.Compact(allSourcePeers) +} + +func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { + var dest []string + for _, peerID := range inputPeers { + if a.validatePostureChecksOnPeer(context.Background(), postureChecksIDs, peerID) { + dest = append(dest, peerID) + } + } + return dest +} + +func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { + gObjs := make([]*Group, 0, len(groups)) + tp := 0 + for _, groupID := range groups { + group := a.GetGroup(groupID) + if group == nil { + log.WithContext(ctx).Warnf("group %s doesn't exist under account %s, will continue map generation without it", groupID, a.Id) + continue + } + + if group.IsGroupAll() || len(groups) == 1 { + return group.Peers + } + + gObjs = append(gObjs, group) + tp += len(group.Peers) + } + + ids := make([]string, 0, tp) + for _, group := range gObjs { + ids = append(ids, group.Peers...) + } + + radix.Sort(ids) + return slices.Compact(ids) } // getNetworkResources filters and returns a list of network resources associated with the given network ID. diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index c73421d16..efe930108 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -1,14 +1,20 @@ package types import ( + "context" + "net" + "net/netip" + "slices" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/route" ) @@ -373,3 +379,379 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) { result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) require.Len(t, result, 0) } + +const ( + accID = "accountID" + network1ID = "network1ID" + group1ID = "group1" + accNetResourcePeer1ID = "peer1" + accNetResourcePeer2ID = "peer2" + accNetResourceRouter1ID = "router1" + accNetResource1ID = "resource1ID" + accNetResourceRestrictPostureCheckID = "restrictPostureCheck" + accNetResourceRelaxedPostureCheckID = "relaxedPostureCheck" + accNetResourceLockedPostureCheckID = "lockedPostureCheck" + accNetResourceLinuxPostureCheckID = "linuxPostureCheck" +) + +var ( + accNetResourcePeer1IP = net.IP{192, 168, 1, 1} + accNetResourcePeer2IP = net.IP{192, 168, 1, 2} + accNetResourceRouter1IP = net.IP{192, 168, 1, 3} + accNetResourceValidPeers = map[string]struct{}{accNetResourcePeer1ID: {}, accNetResourcePeer2ID: {}} +) + +func getBasicAccountsWithResource() *Account { + return &Account{ + Id: accID, + Peers: map[string]*nbpeer.Peer{ + accNetResourcePeer1ID: { + ID: accNetResourcePeer1ID, + AccountID: accID, + Key: "peer1Key", + IP: accNetResourcePeer1IP, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + WtVersion: "0.35.1", + KernelVersion: "4.4.0", + }, + }, + accNetResourcePeer2ID: { + ID: accNetResourcePeer2ID, + AccountID: accID, + Key: "peer2Key", + IP: accNetResourcePeer2IP, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "windows", + WtVersion: "0.34.1", + KernelVersion: "4.4.0", + }, + }, + accNetResourceRouter1ID: { + ID: accNetResourceRouter1ID, + AccountID: accID, + Key: "router1Key", + IP: accNetResourceRouter1IP, + Meta: nbpeer.PeerSystemMeta{ + GoOS: "linux", + WtVersion: "0.35.1", + KernelVersion: "4.4.0", + }, + }, + }, + Groups: map[string]*Group{ + group1ID: { + ID: group1ID, + Peers: []string{accNetResourcePeer1ID, accNetResourcePeer2ID}, + }, + }, + Networks: []*networkTypes.Network{ + { + ID: network1ID, + AccountID: accID, + Name: "network1", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: accNetResourceRouter1ID, + NetworkID: network1ID, + AccountID: accID, + Peer: accNetResourceRouter1ID, + PeerGroups: []string{}, + Masquerade: false, + Metric: 100, + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: accNetResource1ID, + AccountID: accID, + NetworkID: network1ID, + Address: "10.10.10.0/24", + Prefix: netip.MustParsePrefix("10.10.10.0/24"), + Type: resourceTypes.NetworkResourceType("subnet"), + }, + }, + Policies: []*Policy{ + { + ID: "policy1ID", + AccountID: accID, + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule1ID", + Enabled: true, + Sources: []string{group1ID}, + DestinationResource: Resource{ + ID: accNetResource1ID, + Type: "Host", + }, + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"80"}, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: nil, + }, + }, + PostureChecks: []*posture.Checks{ + { + ID: accNetResourceRestrictPostureCheckID, + Name: accNetResourceRestrictPostureCheckID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.35.0", + }, + }, + }, + { + ID: accNetResourceRelaxedPostureCheckID, + Name: accNetResourceRelaxedPostureCheckID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.0.1", + }, + }, + }, + { + ID: accNetResourceLockedPostureCheckID, + Name: accNetResourceLockedPostureCheckID, + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "7.7.7", + }, + }, + }, + { + ID: accNetResourceLinuxPostureCheckID, + Name: accNetResourceLinuxPostureCheckID, + Checks: posture.ChecksDefinition{ + OSVersionCheck: &posture.OSVersionCheck{ + Linux: &posture.MinKernelVersionCheck{ + MinKernelVersion: "0.0.0"}, + }, + }, + }, + }, + } +} + +func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // all peers should match the policy + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 2, "expected source peers don't match") + assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + assert.Equal(t, accNetResourcePeer2ID, sourcePeers[1], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 1, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } +} + +func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // should allow peer1 to match the policy + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 1, "expected source peers don't match") + assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 1, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } +} + +func Test_NetworksNetMapGenWithNoMatchedPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // should not match any peer + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceLockedPostureCheckID} + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 0, "expected rules count don't match") +} + +func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // should allow peer1 to match the policy + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceRestrictPostureCheckID} + + // should allow peer1 and peer2 to match the policy + newPolicy := &Policy{ + ID: "policy2ID", + AccountID: accID, + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "policy2ID", + Enabled: true, + Sources: []string{group1ID}, + DestinationResource: Resource{ + ID: accNetResource1ID, + Type: "Host", + }, + Protocol: PolicyRuleProtocolTCP, + Ports: []string{"22"}, + Action: PolicyTrafficActionAccept, + }, + }, + SourcePostureChecks: []string{accNetResourceRelaxedPostureCheckID}, + } + + account.Policies = append(account.Policies, newPolicy) + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 2, "expected source peers don't match") + assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + assert.Equal(t, accNetResourcePeer2ID, sourcePeers[1], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 2, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } + + assert.Equal(t, uint16(22), rules[1].Port, "should have port 22") + assert.Equal(t, "tcp", rules[1].Protocol, "should have protocol tcp") + if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[1].SourceRanges, accNetResourcePeer1IP.String()) + } + if !slices.Contains(rules[1].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should have source range of peer2 %s", rules[1].SourceRanges, accNetResourcePeer2IP.String()) + } +} + +func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) { + account := getBasicAccountsWithResource() + + // two posture checks should match only the peers that match both checks + policy := account.Policies[0] + policy.SourcePostureChecks = []string{accNetResourceRelaxedPostureCheckID, accNetResourceLinuxPostureCheckID} + + // validate for peer1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate for peer2 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourcePeer2ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.False(t, isRouter, "expected router status") + assert.Len(t, networkResourcesRoutes, 0, "expected network resource route don't match") + assert.Len(t, sourcePeers, 0, "expected source peers don't match") + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers = account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 1, "expected source peers don't match") + assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + + // validate rules for router1 + rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) + assert.Len(t, rules, 1, "expected rules count don't match") + assert.Equal(t, uint16(80), rules[0].Port, "should have port 80") + assert.Equal(t, "tcp", rules[0].Protocol, "should have protocol tcp") + if !slices.Contains(rules[0].SourceRanges, accNetResourcePeer1IP.String()+"/32") { + t.Errorf("%s should have source range of peer1 %s", rules[0].SourceRanges, accNetResourcePeer1IP.String()) + } + if slices.Contains(rules[0].SourceRanges, accNetResourcePeer2IP.String()+"/32") { + t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) + } +} diff --git a/management/server/types/policy.go b/management/server/types/policy.go index c2b82d68a..5b2cf06a0 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -1,5 +1,11 @@ package types +import ( + "slices" + + "github.com/yourbasic/radix" +) + const ( // PolicyTrafficActionAccept indicates that the traffic is accepted PolicyTrafficActionAccept = PolicyTrafficActionType("accept") @@ -117,9 +123,13 @@ func (p *Policy) RuleGroups() []string { // SourceGroups returns a slice of all unique source groups referenced in the policy's rules. func (p *Policy) SourceGroups() []string { + if len(p.Rules) == 1 { + return p.Rules[0].Sources + } groups := make([]string, 0) for _, rule := range p.Rules { groups = append(groups, rule.Sources...) } - return groups + radix.Sort(groups) + return slices.Compact(groups) } From 18316be09a1b8d1dfbaa942155e37c46f7375a24 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Mon, 30 Dec 2024 12:53:51 +0100 Subject: [PATCH 099/159] [management] add selfhosted metrics for networks (#3118) --- .github/workflows/golangci-lint.yml | 2 +- management/server/metrics/selfhosted.go | 18 ++++++++++ management/server/metrics/selfhosted_test.go | 37 ++++++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 89defce32..6705a34ec 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -19,7 +19,7 @@ jobs: - name: codespell uses: codespell-project/actions-codespell@v2 with: - ignore_words_list: erro,clienta,hastable,iif,groupd + ignore_words_list: erro,clienta,hastable,iif,groupd,testin skip: go.mod,go.sum only_warn: 1 golangci: diff --git a/management/server/metrics/selfhosted.go b/management/server/metrics/selfhosted.go index 82b34393f..03cb21af1 100644 --- a/management/server/metrics/selfhosted.go +++ b/management/server/metrics/selfhosted.go @@ -195,6 +195,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties { groups int routes int routesWithRGGroups int + networks int + networkResources int + networkRouters int + networkRoutersWithPG int nameservers int uiClient int version string @@ -219,6 +223,16 @@ func (w *Worker) generateProperties(ctx context.Context) properties { } groups += len(account.Groups) + networks += len(account.Networks) + networkResources += len(account.NetworkResources) + + networkRouters += len(account.NetworkRouters) + for _, router := range account.NetworkRouters { + if len(router.PeerGroups) > 0 { + networkRoutersWithPG++ + } + } + routes += len(account.Routes) for _, route := range account.Routes { if len(route.PeerGroups) > 0 { @@ -312,6 +326,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties { metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks metricsProperties["posture_checks"] = postureChecks metricsProperties["groups"] = groups + metricsProperties["networks"] = networks + metricsProperties["network_resources"] = networkResources + metricsProperties["network_routers"] = networkRouters + metricsProperties["network_routers_with_groups"] = networkRoutersWithPG metricsProperties["routes"] = routes metricsProperties["routes_with_routing_groups"] = routesWithRGGroups metricsProperties["nameservers"] = nameservers diff --git a/management/server/metrics/selfhosted_test.go b/management/server/metrics/selfhosted_test.go index 1d356387f..4894c1ac4 100644 --- a/management/server/metrics/selfhosted_test.go +++ b/management/server/metrics/selfhosted_test.go @@ -5,6 +5,9 @@ import ( "testing" nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/store" @@ -172,6 +175,31 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account { }, }, }, + Networks: []*networkTypes.Network{ + { + ID: "1", + AccountID: "1", + }, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + { + ID: "1", + AccountID: "1", + NetworkID: "1", + }, + { + ID: "2", + AccountID: "1", + NetworkID: "1", + }, + }, + NetworkRouters: []*routerTypes.NetworkRouter{ + { + ID: "1", + AccountID: "1", + NetworkID: "1", + }, + }, }, } } @@ -200,6 +228,15 @@ func TestGenerateProperties(t *testing.T) { if properties["routes"] != 2 { t.Errorf("expected 2 routes, got %d", properties["routes"]) } + if properties["networks"] != 1 { + t.Errorf("expected 1 networks, got %d", properties["networks"]) + } + if properties["network_resources"] != 2 { + t.Errorf("expected 2 network_resources, got %d", properties["network_resources"]) + } + if properties["network_routers"] != 1 { + t.Errorf("expected 1 network_routers, got %d", properties["network_routers"]) + } if properties["rules"] != 4 { t.Errorf("expected 4 rules, got %d", properties["rules"]) } From 43ef64cf673fc785a2005f0c7b3f5616211f4065 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 31 Dec 2024 14:07:21 +0100 Subject: [PATCH 100/159] [client] Ignore case when matching domains in handler chain (#3133) --- client/internal/dns/handler_chain.go | 21 ++- client/internal/dns/handler_chain_test.go | 168 ++++++++++++++++++++++ 2 files changed, 178 insertions(+), 11 deletions(-) diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 9302d50b1..5f63d1ab3 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -68,17 +68,16 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority c.mu.Lock() defer c.mu.Unlock() + pattern = strings.ToLower(dns.Fqdn(pattern)) origPattern := pattern isWildcard := strings.HasPrefix(pattern, "*.") if isWildcard { pattern = pattern[2:] } - pattern = dns.Fqdn(pattern) - origPattern = dns.Fqdn(origPattern) - // First remove any existing handler with same original pattern and priority + // First remove any existing handler with same pattern (case-insensitive) and priority for i := len(c.handlers) - 1; i >= 0; i-- { - if c.handlers[i].OrigPattern == origPattern && c.handlers[i].Priority == priority { + if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority { if c.handlers[i].StopHandler != nil { c.handlers[i].StopHandler.stop() } @@ -126,10 +125,10 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) { pattern = dns.Fqdn(pattern) - // Find and remove handlers matching both original pattern and priority + // Find and remove handlers matching both original pattern (case-insensitive) and priority for i := len(c.handlers) - 1; i >= 0; i-- { entry := c.handlers[i] - if entry.OrigPattern == pattern && entry.Priority == priority { + if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority { if entry.StopHandler != nil { entry.StopHandler.stop() } @@ -144,9 +143,9 @@ func (c *HandlerChain) HasHandlers(pattern string) bool { c.mu.RLock() defer c.mu.RUnlock() - pattern = dns.Fqdn(pattern) + pattern = strings.ToLower(dns.Fqdn(pattern)) for _, entry := range c.handlers { - if entry.Pattern == pattern { + if strings.EqualFold(entry.Pattern, pattern) { return true } } @@ -158,7 +157,7 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - qname := r.Question[0].Name + qname := strings.ToLower(r.Question[0].Name) log.Tracef("handling DNS request for domain=%s", qname) c.mu.RLock() @@ -187,9 +186,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // If handler wants subdomain matching, allow suffix match // Otherwise require exact match if entry.MatchSubdomains { - matched = qname == entry.Pattern || strings.HasSuffix(qname, "."+entry.Pattern) + matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern) } else { - matched = qname == entry.Pattern + matched = strings.EqualFold(qname, entry.Pattern) } } diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 727b6e908..eb40c907f 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -507,5 +507,173 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) { // Test 4: Remove last handler chain.RemoveHandler(testDomain, nbdns.PriorityDefault) + assert.False(t, chain.HasHandlers(testDomain)) } + +func TestHandlerChain_CaseSensitivity(t *testing.T) { + tests := []struct { + name string + scenario string + addHandlers []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + } + query string + expectedCalls int + }{ + { + name: "case insensitive exact match", + scenario: "handler registered lowercase, query uppercase", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"example.com.", nbdns.PriorityDefault, false, true}, + }, + query: "EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "case insensitive wildcard match", + scenario: "handler registered mixed case wildcard, query different case", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"*.Example.Com.", nbdns.PriorityDefault, false, true}, + }, + query: "sub.EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "multiple handlers different case same domain", + scenario: "second handler should replace first despite case difference", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, + {"example.com.", nbdns.PriorityDefault, false, true}, + }, + query: "ExAmPlE.cOm.", + expectedCalls: 1, + }, + { + name: "subdomain matching case insensitive", + scenario: "handler with MatchSubdomains true should match regardless of case", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"example.com.", nbdns.PriorityDefault, true, true}, + }, + query: "SUB.EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "root zone case insensitive", + scenario: "root zone handler should match regardless of case", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {".", nbdns.PriorityDefault, false, true}, + }, + query: "EXAMPLE.COM.", + expectedCalls: 1, + }, + { + name: "multiple handlers different priority", + scenario: "should call higher priority handler despite case differences", + addHandlers: []struct { + pattern string + priority int + subdomains bool + shouldMatch bool + }{ + {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, + {"example.com.", nbdns.PriorityMatchDomain, false, false}, + {"Example.Com.", nbdns.PriorityDNSRoute, false, true}, + }, + query: "example.com.", + expectedCalls: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chain := nbdns.NewHandlerChain() + handlerCalls := make(map[string]bool) // track which patterns were called + + // Add handlers according to test case + for _, h := range tt.addHandlers { + var handler dns.Handler + pattern := h.pattern // capture pattern for closure + + if h.subdomains { + subHandler := &nbdns.MockSubdomainHandler{ + Subdomains: true, + } + if h.shouldMatch { + subHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + handlerCalls[pattern] = true + w := args.Get(0).(dns.ResponseWriter) + r := args.Get(1).(*dns.Msg) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeSuccess) + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + } + handler = subHandler + } else { + mockHandler := &nbdns.MockHandler{} + if h.shouldMatch { + mockHandler.On("ServeDNS", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + handlerCalls[pattern] = true + w := args.Get(0).(dns.ResponseWriter) + r := args.Get(1).(*dns.Msg) + resp := new(dns.Msg) + resp.SetRcode(r, dns.RcodeSuccess) + assert.NoError(t, w.WriteMsg(resp)) + }).Once() + } + handler = mockHandler + } + + chain.AddHandler(pattern, handler, h.priority, nil) + } + + // Execute request + r := new(dns.Msg) + r.SetQuestion(tt.query, dns.TypeA) + chain.ServeDNS(&mockResponseWriter{}, r) + + // Verify each handler was called exactly as expected + for _, h := range tt.addHandlers { + wasCalled := handlerCalls[h.pattern] + assert.Equal(t, h.shouldMatch, wasCalled, + "Handler for pattern %q was %s when it should%s have been", + h.pattern, + map[bool]string{true: "called", false: "not called"}[wasCalled], + map[bool]string{true: "", false: " not"}[wasCalled == h.shouldMatch]) + } + + // Verify total number of calls + assert.Equal(t, tt.expectedCalls, len(handlerCalls), + "Wrong number of total handler calls") + }) + } +} From abbdf20f65f031fd6dc0a95b55a38044b158046b Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 31 Dec 2024 14:08:48 +0100 Subject: [PATCH 101/159] [client] Allow inbound rosenpass port (#3109) --- client/firewall/iptables/manager_linux.go | 2 +- client/internal/dnsfwd/manager.go | 2 +- client/internal/engine.go | 45 +++++++++++++++++++---- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 0e1e5836f..da8e2c08f 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -197,7 +197,7 @@ func (m *Manager) AllowNetbird() error { } _, err := m.AddPeerFiltering( - net.ParseIP("0.0.0.0"), + net.IP{0, 0, 0, 0}, "all", nil, nil, diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index 7cff6d517..f876bda30 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -83,7 +83,7 @@ func (h *Manager) allowDNSFirewall() error { IsRange: false, Values: []int{ListenPort}, } - dnsRules, err := h.firewall.AddPeerFiltering(net.ParseIP("0.0.0.0"), firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") + dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") if err != nil { log.Errorf("failed to add allow DNS router rules, err: %v", err) return err diff --git a/client/internal/engine.go b/client/internal/engine.go index 042d384dc..896104df8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -406,13 +406,9 @@ func (e *Engine) Start() error { e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) if err != nil { log.Errorf("failed creating firewall manager: %s", err) - } - - if e.firewall != nil && e.firewall.IsServerRouteSupported() { - err = e.routeManager.EnableServerRouter(e.firewall) - if err != nil { - e.close() - return fmt.Errorf("enable server router: %w", err) + } else if e.firewall != nil { + if err := e.initFirewall(err); err != nil { + return err } } @@ -455,6 +451,41 @@ func (e *Engine) Start() error { return nil } +func (e *Engine) initFirewall(error) error { + if e.firewall.IsServerRouteSupported() { + if err := e.routeManager.EnableServerRouter(e.firewall); err != nil { + e.close() + return fmt.Errorf("enable server router: %w", err) + } + } + + if e.rpManager == nil || !e.config.RosenpassEnabled { + return nil + } + + rosenpassPort := e.rpManager.GetAddress().Port + port := manager.Port{Values: []int{rosenpassPort}} + + // this rule is static and will be torn down on engine down by the firewall manager + if _, err := e.firewall.AddPeerFiltering( + net.IP{0, 0, 0, 0}, + manager.ProtocolUDP, + nil, + &port, + manager.RuleDirectionIN, + manager.ActionAccept, + "", + "", + ); err != nil { + log.Errorf("failed to allow rosenpass interface traffic: %v", err) + return nil + } + + log.Infof("rosenpass interface traffic allowed on port %d", rosenpassPort) + + return nil +} + // modifyPeers updates peers that have been modified (e.g. IP address has been changed). // It closes the existing connection, removes it from the peerConns map, and creates a new one. func (e *Engine) modifyPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { From 2bdb4cb44a8c128214eac0933d040f2e8447a92a Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 31 Dec 2024 18:59:37 +0300 Subject: [PATCH 102/159] [management] Preserve jwt groups when accessing API with PAT (#3128) * Skip JWT group sync for token-based authentication Signed-off-by: bcmmbaga * Add tests Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/account.go | 6 ++++ management/server/account_test.go | 30 +++++++++++++++++-- .../server/http/middleware/auth_middleware.go | 1 + management/server/jwtclaims/extractor.go | 2 ++ 4 files changed, 37 insertions(+), 2 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index e60b41b4e..83a8759f9 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1252,6 +1252,12 @@ func (am *DefaultAccountManager) GetAccountIDFromToken(ctx context.Context, clai // syncJWTGroups processes the JWT groups for a user, updates the account based on the groups, // and propagates changes to peers if group propagation is enabled. func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID string, claims jwtclaims.AuthorizationClaims) error { + if claim, exists := claims.Raw[jwtclaims.IsToken]; exists { + if isToken, ok := claim.(bool); ok && isToken { + return nil + } + } + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { return err diff --git a/management/server/account_test.go b/management/server/account_test.go index 280d998fd..2289c96f9 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2729,6 +2729,19 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, manager.Store.SaveAccount(context.Background(), account), "unable to save account") + t.Run("skip sync for token auth type", func(t *testing.T) { + claims := jwtclaims.AuthorizationClaims{ + UserId: "user1", + Raw: jwt.MapClaims{"groups": []interface{}{"group3"}, "is_token": true}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced") + }) + t.Run("empty jwt groups", func(t *testing.T) { claims := jwtclaims.AuthorizationClaims{ UserId: "user1", @@ -2822,7 +2835,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.Len(t, user.AutoGroups, 1, "new group should be added") }) - t.Run("remove all JWT groups", func(t *testing.T) { + t.Run("remove all JWT groups when list is empty", func(t *testing.T) { claims := jwtclaims.AuthorizationClaims{ UserId: "user1", Raw: jwt.MapClaims{"groups": []interface{}{}}, @@ -2833,7 +2846,20 @@ func TestAccount_SetJWTGroups(t *testing.T) { user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") - assert.Contains(t, user.AutoGroups, "group1", " group1 should still be present") + assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present") + }) + + t.Run("remove all JWT groups when claim does not exist", func(t *testing.T) { + claims := jwtclaims.AuthorizationClaims{ + UserId: "user2", + Raw: jwt.MapClaims{}, + } + err = manager.syncJWTGroups(context.Background(), "accountID", claims) + assert.NoError(t, err, "unable to sync jwt groups") + + user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") + assert.NoError(t, err, "unable to get user") + assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed") }) } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 0d3459712..0a54cbaed 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -175,6 +175,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory + claimMaps[jwtclaims.IsToken] = true jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint // Update the current request with the new context information. diff --git a/management/server/jwtclaims/extractor.go b/management/server/jwtclaims/extractor.go index c441650e9..18214b434 100644 --- a/management/server/jwtclaims/extractor.go +++ b/management/server/jwtclaims/extractor.go @@ -22,6 +22,8 @@ const ( LastLoginSuffix = "nb_last_login" // Invited claim indicates that an incoming JWT is from a user that just accepted an invitation Invited = "nb_invited" + // IsToken claim indicates that auth type from the user is a token + IsToken = "is_token" ) // ExtractClaims Extract function type From 18b049cd2439de3c7769e7666e86ed16254df477 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 31 Dec 2024 18:10:40 +0100 Subject: [PATCH 103/159] [management] remove sorting from network map generation (#3126) --- go.mod | 1 - go.sum | 2 - .../networks/resources/types/resource.go | 1 + management/server/types/account.go | 110 ++++++++---------- management/server/types/account_test.go | 30 ++--- management/server/types/policy.go | 21 ++-- route/route.go | 3 + 7 files changed, 79 insertions(+), 89 deletions(-) diff --git a/go.mod b/go.mod index 330d0763f..d48280df0 100644 --- a/go.mod +++ b/go.mod @@ -79,7 +79,6 @@ require ( github.com/testcontainers/testcontainers-go v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 github.com/things-go/go-socks5 v0.0.4 - github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907 github.com/yusufpapurcu/wmi v1.2.4 github.com/zcalusic/sysinfo v1.1.3 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.49.0 diff --git a/go.sum b/go.sum index ea4597836..540cbf20b 100644 --- a/go.sum +++ b/go.sum @@ -698,8 +698,6 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= -github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907 h1:S5h7yNKStqF8CqFtgtMNMzk/lUI3p82LrX6h2BhlsTM= -github.com/yourbasic/radix v0.0.0-20180308122924-cbe1cc82e907/go.mod h1:/7Fy/4/OyrkguTf2i2pO4erUD/8QAlrlmXSdSJPu678= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 7eecdce0f..162f90378 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -111,6 +111,7 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network NetID: route.NetID(n.Name), Description: n.Description, Peer: peer.Key, + PeerID: peer.ID, PeerGroups: nil, Masquerade: router.Masquerade, Metric: router.Metric, diff --git a/management/server/types/account.go b/management/server/types/account.go index 3ef862fa6..f9e1cc9b4 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -13,7 +13,6 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" - "github.com/yourbasic/radix" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" @@ -304,55 +303,47 @@ func (a *Account) GetPeerNetworkMap( return nm } -func (a *Account) addNetworksRoutingPeers(networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers []string) []*nbpeer.Peer { - missingPeers := map[string]struct{}{} - for _, r := range networkResourcesRoutes { - if r.Peer == peer.Key { - continue - } +func (a *Account) addNetworksRoutingPeers( + networkResourcesRoutes []*route.Route, + peer *nbpeer.Peer, + peersToConnect []*nbpeer.Peer, + expiredPeers []*nbpeer.Peer, + isRouter bool, + sourcePeers map[string]struct{}, +) []*nbpeer.Peer { - missing := true - for _, p := range slices.Concat(peersToConnect, expiredPeers) { - if r.Peer == p.Key { - missing = false - break - } - } - if missing { - missingPeers[r.Peer] = struct{}{} - } + networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) + for _, r := range networkResourcesRoutes { + networkRoutesPeers[r.PeerID] = struct{}{} } - if isRouter { - for _, s := range sourcePeers { - if s == peer.ID { - continue - } + delete(sourcePeers, peer.ID) - missing := true - for _, p := range slices.Concat(peersToConnect, expiredPeers) { - if s == p.ID { - missing = false - break - } - } - if missing { - p, ok := a.Peers[s] - if ok { - missingPeers[p.Key] = struct{}{} - } - } + for _, existingPeer := range peersToConnect { + delete(sourcePeers, existingPeer.ID) + delete(networkRoutesPeers, existingPeer.ID) + } + for _, expPeer := range expiredPeers { + delete(sourcePeers, expPeer.ID) + delete(networkRoutesPeers, expPeer.ID) + } + + missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) + if isRouter { + for p := range sourcePeers { + missingPeers[p] = struct{}{} } } + for p := range networkRoutesPeers { + missingPeers[p] = struct{}{} + } for p := range missingPeers { - for _, p2 := range a.Peers { - if p2.Key == p { - peersToConnect = append(peersToConnect, p2) - break - } + if missingPeer := a.Peers[p]; missingPeer != nil { + peersToConnect = append(peersToConnect, missingPeer) } } + return peersToConnect } @@ -1156,7 +1147,7 @@ func (a *Account) getRouteFirewallRules(ctx context.Context, peerID string, poli } func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, validatedPeersMap map[string]struct{}) []*nbpeer.Peer { - distPeersWithPolicy := make([]string, 0) + distPeersWithPolicy := make(map[string]struct{}) for _, id := range rule.Sources { group := a.Groups[id] if group == nil { @@ -1170,16 +1161,13 @@ func (a *Account) getRulePeers(rule *PolicyRule, postureChecks []string, peerID _, distPeer := distributionPeers[pID] _, valid := validatedPeersMap[pID] if distPeer && valid && a.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { - distPeersWithPolicy = append(distPeersWithPolicy, pID) + distPeersWithPolicy[pID] = struct{}{} } } } - radix.Sort(distPeersWithPolicy) - uniqueDistributionPeers := slices.Compact(distPeersWithPolicy) - - distributionGroupPeers := make([]*nbpeer.Peer, 0, len(uniqueDistributionPeers)) - for _, pID := range uniqueDistributionPeers { + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { peer := a.Peers[pID] if peer == nil { continue @@ -1306,10 +1294,10 @@ func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { } // GetNetworkResourcesRoutesToSync returns network routes for syncing with a specific peer and its ACL peers. -func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, []string) { +func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID string, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter) (bool, []*route.Route, map[string]struct{}) { var isRoutingPeer bool var routes []*route.Route - allSourcePeers := make([]string, 0) + allSourcePeers := make(map[string]struct{}, len(a.Peers)) for _, resource := range a.NetworkResources { var addSourcePeers bool @@ -1326,7 +1314,9 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st for _, policy := range resourcePolicies[resource.ID] { peers := a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups()) if addSourcePeers { - allSourcePeers = append(allSourcePeers, a.getPostureValidPeers(peers, policy.SourcePostureChecks)...) + for _, pID := range a.getPostureValidPeers(peers, policy.SourcePostureChecks) { + allSourcePeers[pID] = struct{}{} + } } else if slices.Contains(peers, peerID) && a.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { // add routes for the resource if the peer is in the distribution group for peerId, router := range networkRoutingPeers { @@ -1340,8 +1330,7 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st } } - radix.Sort(allSourcePeers) - return isRoutingPeer, routes, slices.Compact(allSourcePeers) + return isRoutingPeer, routes, allSourcePeers } func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string { @@ -1355,8 +1344,7 @@ func (a *Account) getPostureValidPeers(inputPeers []string, postureChecksIDs []s } func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []string) []string { - gObjs := make([]*Group, 0, len(groups)) - tp := 0 + peerIDs := make(map[string]struct{}, len(groups)) // we expect at least one peer per group as initial capacity for _, groupID := range groups { group := a.GetGroup(groupID) if group == nil { @@ -1368,17 +1356,17 @@ func (a *Account) getUniquePeerIDsFromGroupsIDs(ctx context.Context, groups []st return group.Peers } - gObjs = append(gObjs, group) - tp += len(group.Peers) + for _, peerID := range group.Peers { + peerIDs[peerID] = struct{}{} + } } - ids := make([]string, 0, tp) - for _, group := range gObjs { - ids = append(ids, group.Peers...) + ids := make([]string, 0, len(peerIDs)) + for peerID := range peerIDs { + ids = append(ids, peerID) } - radix.Sort(ids) - return slices.Compact(ids) + return ids } // getNetworkResources filters and returns a list of network resources associated with the given network ID. diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index efe930108..367baef4f 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -316,19 +316,19 @@ func Test_GetResourcePoliciesMap(t *testing.T) { func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) { account := setupTestAccount() - peer := &nbpeer.Peer{Key: "peer1"} + peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"} networkResourcesRoutes := []*route.Route{ - {Peer: "peer2Key"}, - {Peer: "peer3Key"}, + {Peer: "peer2Key", PeerID: "peer2"}, + {Peer: "peer3Key", PeerID: "peer3"}, } peersToConnect := []*nbpeer.Peer{ - {Key: "peer2Key"}, + {Key: "peer2Key", ID: "peer2"}, } expiredPeers := []*nbpeer.Peer{ - {Key: "peer4Key"}, + {Key: "peer4Key", ID: "peer4"}, } - result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) require.Len(t, result, 2) require.Equal(t, "peer2Key", result[0].Key) require.Equal(t, "peer3Key", result[1].Key) @@ -345,7 +345,7 @@ func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) { } expiredPeers := []*nbpeer.Peer{} - result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) require.Len(t, result, 1) require.Equal(t, "peer2Key", result[0].Key) } @@ -364,7 +364,7 @@ func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) { {Key: "peer3Key"}, } - result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) require.Len(t, result, 1) require.Equal(t, "peer2Key", result[0].Key) } @@ -376,7 +376,7 @@ func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) { peersToConnect := []*nbpeer.Peer{} expiredPeers := []*nbpeer.Peer{} - result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, []string{}) + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) require.Len(t, result, 0) } @@ -559,8 +559,8 @@ func Test_NetworksNetMapGenWithNoPostureChecks(t *testing.T) { assert.True(t, isRouter, "should be router") assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 2, "expected source peers don't match") - assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") - assert.Equal(t, accNetResourcePeer2ID, sourcePeers[1], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") // validate rules for router1 rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) @@ -599,7 +599,7 @@ func Test_NetworksNetMapGenWithPostureChecks(t *testing.T) { assert.True(t, isRouter, "should be router") assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 1, "expected source peers don't match") - assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") // validate rules for router1 rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) @@ -692,8 +692,8 @@ func Test_NetworksNetMapGenWithTwoPoliciesAndPostureChecks(t *testing.T) { assert.True(t, isRouter, "should be router") assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 2, "expected source peers don't match") - assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") - assert.Equal(t, accNetResourcePeer2ID, sourcePeers[1], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer2ID], "expected source peers don't match") // validate rules for router1 rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) @@ -741,7 +741,7 @@ func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) { assert.True(t, isRouter, "should be router") assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") assert.Len(t, sourcePeers, 1, "expected source peers don't match") - assert.Equal(t, accNetResourcePeer1ID, sourcePeers[0], "expected source peers don't match") + assert.NotNil(t, sourcePeers[accNetResourcePeer1ID], "expected source peers don't match") // validate rules for router1 rules := account.GetPeerNetworkResourceFirewallRules(context.Background(), account.Peers[accNetResourceRouter1ID], accNetResourceValidPeers, networkResourcesRoutes, account.GetResourcePoliciesMap()) diff --git a/management/server/types/policy.go b/management/server/types/policy.go index 5b2cf06a0..17964ed1f 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -1,11 +1,5 @@ package types -import ( - "slices" - - "github.com/yourbasic/radix" -) - const ( // PolicyTrafficActionAccept indicates that the traffic is accepted PolicyTrafficActionAccept = PolicyTrafficActionType("accept") @@ -126,10 +120,17 @@ func (p *Policy) SourceGroups() []string { if len(p.Rules) == 1 { return p.Rules[0].Sources } - groups := make([]string, 0) + groups := make(map[string]struct{}, len(p.Rules)) for _, rule := range p.Rules { - groups = append(groups, rule.Sources...) + for _, source := range rule.Sources { + groups[source] = struct{}{} + } } - radix.Sort(groups) - return slices.Compact(groups) + + groupIDs := make([]string, 0, len(groups)) + for groupID := range groups { + groupIDs = append(groupIDs, groupID) + } + + return groupIDs } diff --git a/route/route.go b/route/route.go index 8f3c99b4c..ad2aaba89 100644 --- a/route/route.go +++ b/route/route.go @@ -95,6 +95,7 @@ type Route struct { NetID NetID Description string Peer string + PeerID string `gorm:"-"` PeerGroups []string `gorm:"serializer:json"` NetworkType NetworkType Masquerade bool @@ -120,6 +121,7 @@ func (r *Route) Copy() *Route { KeepRoute: r.KeepRoute, NetworkType: r.NetworkType, Peer: r.Peer, + PeerID: r.PeerID, PeerGroups: slices.Clone(r.PeerGroups), Metric: r.Metric, Masquerade: r.Masquerade, @@ -146,6 +148,7 @@ func (r *Route) IsEqual(other *Route) bool { other.KeepRoute == r.KeepRoute && other.NetworkType == r.NetworkType && other.Peer == r.Peer && + other.PeerID == r.PeerID && other.Metric == r.Metric && other.Masquerade == r.Masquerade && other.Enabled == r.Enabled && From 03fd656344a3e65031bbfe8f5332dbcb522a1708 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 31 Dec 2024 18:45:40 +0100 Subject: [PATCH 104/159] [management] Fix policy tests (#3135) - Add firewall rule isEqual method - Fix tests --- management/server/policy_test.go | 16 +++++++++++----- management/server/types/firewall_rule.go | 9 +++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 0d17da23a..73fc6edba 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -76,7 +76,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, "GroupWorkstations": { ID: "GroupWorkstations", - Name: "All", + Name: "GroupWorkstations", Peers: []string{ "peerB", "peerA", @@ -280,10 +280,16 @@ func TestAccount_getPeersByPolicy(t *testing.T) { }, } assert.Len(t, firewallRules, len(epectedFirewallRules)) - slices.SortFunc(epectedFirewallRules, sortFunc()) - slices.SortFunc(firewallRules, sortFunc()) - for i := range firewallRules { - assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) + + for _, rule := range firewallRules { + contains := false + for _, expectedRule := range epectedFirewallRules { + if rule.IsEqual(expectedRule) { + contains = true + break + } + } + assert.True(t, contains, "rule not found in expected rules %#v", rule) } }) } diff --git a/management/server/types/firewall_rule.go b/management/server/types/firewall_rule.go index 3d1b7e225..4e405152c 100644 --- a/management/server/types/firewall_rule.go +++ b/management/server/types/firewall_rule.go @@ -35,6 +35,15 @@ type FirewallRule struct { Port string } +// IsEqual checks if two firewall rules are equal. +func (r *FirewallRule) IsEqual(other *FirewallRule) bool { + return r.PeerIP == other.PeerIP && + r.Direction == other.Direction && + r.Action == other.Action && + r.Protocol == other.Protocol && + r.Port == other.Port +} + // generateRouteFirewallRules generates a list of firewall rules for a given route. func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) []*RouteFirewallRule { rulesExists := make(map[string]struct{}) From 782e3f8853f1ac455d600fbad16eddbbeebefbac Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 2 Jan 2025 13:51:01 +0100 Subject: [PATCH 105/159] [management] Add integration test for the setup-keys API endpoints (#2936) --- management/cmd/management.go | 4 +- management/server/account.go | 4 +- management/server/account_test.go | 42 - management/server/geolocation/geolocation.go | 39 +- .../server/geolocation/geolocation_test.go | 2 +- management/server/grpcserver.go | 4 +- management/server/http/handler.go | 40 +- .../handlers/policies/geolocations_handler.go | 6 +- .../handlers/policies/policies_handler.go | 2 +- .../policies/posture_checks_handler.go | 6 +- .../policies/posture_checks_handler_test.go | 2 +- .../handlers/setup_keys/setupkeys_handler.go | 8 +- .../setup_keys/setupkeys_handler_test.go | 13 +- .../setupkeys_handler_benchmark_test.go | 226 ++++ .../setupkeys_handler_integration_test.go | 1146 +++++++++++++++++ .../http/testing/testdata/setup_keys.sql | 24 + .../http/testing/testing_tools/tools.go | 307 +++++ management/server/integrated_validator.go | 43 + management/server/jwtclaims/jwtValidator.go | 39 +- management/server/management_test.go | 42 +- management/server/networks/manager.go | 27 + .../server/networks/resources/manager.go | 39 + management/server/setupkey.go | 4 +- 23 files changed, 1919 insertions(+), 150 deletions(-) create mode 100644 management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go create mode 100644 management/server/http/testing/integration/setupkeys_handler_integration_test.go create mode 100644 management/server/http/testing/testdata/setup_keys.sql create mode 100644 management/server/http/testing/testing_tools/tools.go diff --git a/management/cmd/management.go b/management/cmd/management.go index 4f34009b7..1c8fca8dc 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -42,7 +42,7 @@ import ( nbContext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" - httpapi "github.com/netbirdio/netbird/management/server/http" + nbhttp "github.com/netbirdio/netbird/management/server/http" "github.com/netbirdio/netbird/management/server/http/configs" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -281,7 +281,7 @@ var ( routersManager := routers.NewManager(store, permissionsManager, accountManager) networksManager := networks.NewManager(store, permissionsManager, resourcesManager, routersManager, accountManager) - httpAPIHandler, err := httpapi.APIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, *jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) + httpAPIHandler, err := nbhttp.NewAPIHandler(ctx, accountManager, networksManager, resourcesManager, routersManager, groupsManager, geo, jwtValidator, appMetrics, httpAPIAuthCfg, integratedPeerValidator) if err != nil { return fmt.Errorf("failed creating HTTP API handler: %v", err) } diff --git a/management/server/account.go b/management/server/account.go index 83a8759f9..6c8205f26 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -161,7 +161,7 @@ type DefaultAccountManager struct { externalCacheManager ExternalCacheManager ctx context.Context eventStore activity.Store - geo *geolocation.Geolocation + geo geolocation.Geolocation requestBuffer *AccountRequestBuffer @@ -244,7 +244,7 @@ func BuildManager( singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, - geo *geolocation.Geolocation, + geo geolocation.Geolocation, userDeleteFromIDPEnabled bool, integratedPeerValidator integrated_validator.IntegratedValidator, metrics telemetry.AppMetrics, diff --git a/management/server/account_test.go b/management/server/account_test.go index 2289c96f9..4f6cdf78d 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -27,7 +27,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/jwtclaims" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -38,47 +37,6 @@ import ( "github.com/netbirdio/netbird/route" ) -type MocIntegratedValidator struct { - ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) -} - -func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { - return nil -} - -func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { - if a.ValidatePeerFunc != nil { - return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) - } - return update, false, nil -} -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { - validatedPeers := make(map[string]struct{}) - for _, peer := range peers { - validatedPeers[peer.ID] = struct{}{} - } - return validatedPeers, nil -} - -func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { - return peer -} - -func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { - return false, false, nil -} - -func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { - return nil -} - -func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { - -} - -func (MocIntegratedValidator) Stop(_ context.Context) { -} - func verifyCanAddPeerToAccount(t *testing.T, manager AccountManager, account *types.Account, userID string) { t.Helper() peer := &nbpeer.Peer{ diff --git a/management/server/geolocation/geolocation.go b/management/server/geolocation/geolocation.go index 553a31581..c0179a1c4 100644 --- a/management/server/geolocation/geolocation.go +++ b/management/server/geolocation/geolocation.go @@ -14,7 +14,14 @@ import ( log "github.com/sirupsen/logrus" ) -type Geolocation struct { +type Geolocation interface { + Lookup(ip net.IP) (*Record, error) + GetAllCountries() ([]Country, error) + GetCitiesByCountry(countryISOCode string) ([]City, error) + Stop() error +} + +type geolocationImpl struct { mmdbPath string mux sync.RWMutex db *maxminddb.Reader @@ -54,7 +61,7 @@ const ( geonamesdbPattern = "geonames_*.db" ) -func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geolocation, error) { +func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (Geolocation, error) { mmdbGlobPattern := filepath.Join(dataDir, mmdbPattern) mmdbFile, err := getDatabaseFilename(ctx, geoLiteCityTarGZURL, mmdbGlobPattern, autoUpdate) if err != nil { @@ -86,7 +93,7 @@ func NewGeolocation(ctx context.Context, dataDir string, autoUpdate bool) (*Geol return nil, err } - geo := &Geolocation{ + geo := &geolocationImpl{ mmdbPath: mmdbPath, mux: sync.RWMutex{}, db: db, @@ -113,7 +120,7 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) { return db, nil } -func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) { +func (gl *geolocationImpl) Lookup(ip net.IP) (*Record, error) { gl.mux.RLock() defer gl.mux.RUnlock() @@ -127,7 +134,7 @@ func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) { } // GetAllCountries retrieves a list of all countries. -func (gl *Geolocation) GetAllCountries() ([]Country, error) { +func (gl *geolocationImpl) GetAllCountries() ([]Country, error) { allCountries, err := gl.locationDB.GetAllCountries() if err != nil { return nil, err @@ -143,7 +150,7 @@ func (gl *Geolocation) GetAllCountries() ([]Country, error) { } // GetCitiesByCountry retrieves a list of cities in a specific country based on the country's ISO code. -func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) { +func (gl *geolocationImpl) GetCitiesByCountry(countryISOCode string) ([]City, error) { allCities, err := gl.locationDB.GetCitiesByCountry(countryISOCode) if err != nil { return nil, err @@ -158,7 +165,7 @@ func (gl *Geolocation) GetCitiesByCountry(countryISOCode string) ([]City, error) return cities, nil } -func (gl *Geolocation) Stop() error { +func (gl *geolocationImpl) Stop() error { close(gl.stopCh) if gl.db != nil { if err := gl.db.Close(); err != nil { @@ -259,3 +266,21 @@ func cleanupMaxMindDatabases(ctx context.Context, dataDir string, mmdbFile strin } return nil } + +type Mock struct{} + +func (g *Mock) Lookup(ip net.IP) (*Record, error) { + return &Record{}, nil +} + +func (g *Mock) GetAllCountries() ([]Country, error) { + return []Country{}, nil +} + +func (g *Mock) GetCitiesByCountry(countryISOCode string) ([]City, error) { + return []City{}, nil +} + +func (g *Mock) Stop() error { + return nil +} diff --git a/management/server/geolocation/geolocation_test.go b/management/server/geolocation/geolocation_test.go index 9bdefd268..fecd715be 100644 --- a/management/server/geolocation/geolocation_test.go +++ b/management/server/geolocation/geolocation_test.go @@ -24,7 +24,7 @@ func TestGeoLite_Lookup(t *testing.T) { db, err := openDB(filename) assert.NoError(t, err) - geo := &Geolocation{ + geo := &geolocationImpl{ mux: sync.RWMutex{}, db: db, stopCh: make(chan struct{}), diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 2635ac11b..daa23d2ab 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -38,7 +38,7 @@ type GRPCServer struct { peersUpdateManager *PeersUpdateManager config *Config secretsManager SecretsManager - jwtValidator *jwtclaims.JWTValidator + jwtValidator jwtclaims.JWTValidator jwtClaimsExtractor *jwtclaims.ClaimsExtractor appMetrics telemetry.AppMetrics ephemeralManager *EphemeralManager @@ -61,7 +61,7 @@ func NewServer( return nil, err } - var jwtValidator *jwtclaims.JWTValidator + var jwtValidator jwtclaims.JWTValidator if config.HttpConfig != nil && config.HttpConfig.AuthIssuer != "" && config.HttpConfig.AuthAudience != "" && validateURL(config.HttpConfig.AuthKeysLocation) { jwtValidator, err = jwtclaims.NewJWTValidator( diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 7db7ab5b8..cc2ad00b7 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -35,15 +35,8 @@ import ( const apiPrefix = "/api" -type apiHandler struct { - Router *mux.Router - AccountManager s.AccountManager - geolocationManager *geolocation.Geolocation - AuthCfg configs.AuthCfg -} - -// APIHandler creates the Management service HTTP API handler registering all the available endpoints. -func APIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager *geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { +// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. +func NewAPIHandler(ctx context.Context, accountManager s.AccountManager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, jwtValidator jwtclaims.JWTValidator, appMetrics telemetry.AppMetrics, authCfg configs.AuthCfg, integratedValidator integrated_validator.IntegratedValidator) (http.Handler, error) { claimsExtractor := jwtclaims.NewClaimsExtractor( jwtclaims.WithAudience(authCfg.Audience), jwtclaims.WithUserIDClaim(authCfg.UserIDClaim), @@ -78,27 +71,20 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, networksMa router := rootRouter.PathPrefix(prefix).Subrouter() router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler, acMiddleware.Handler) - api := apiHandler{ - Router: router, - AccountManager: accountManager, - geolocationManager: LocationManager, - AuthCfg: authCfg, - } - - if _, err := integrations.RegisterHandlers(ctx, prefix, api.Router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil { + if _, err := integrations.RegisterHandlers(ctx, prefix, router, accountManager, claimsExtractor, integratedValidator, appMetrics.GetMeter()); err != nil { return nil, fmt.Errorf("register integrations endpoints: %w", err) } - accounts.AddEndpoints(api.AccountManager, authCfg, router) - peers.AddEndpoints(api.AccountManager, authCfg, router) - users.AddEndpoints(api.AccountManager, authCfg, router) - setup_keys.AddEndpoints(api.AccountManager, authCfg, router) - policies.AddEndpoints(api.AccountManager, api.geolocationManager, authCfg, router) - groups.AddEndpoints(api.AccountManager, authCfg, router) - routes.AddEndpoints(api.AccountManager, authCfg, router) - dns.AddEndpoints(api.AccountManager, authCfg, router) - events.AddEndpoints(api.AccountManager, authCfg, router) - networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, api.AccountManager, api.AccountManager.GetAccountIDFromToken, authCfg, router) + accounts.AddEndpoints(accountManager, authCfg, router) + peers.AddEndpoints(accountManager, authCfg, router) + users.AddEndpoints(accountManager, authCfg, router) + setup_keys.AddEndpoints(accountManager, authCfg, router) + policies.AddEndpoints(accountManager, LocationManager, authCfg, router) + groups.AddEndpoints(accountManager, authCfg, router) + routes.AddEndpoints(accountManager, authCfg, router) + dns.AddEndpoints(accountManager, authCfg, router) + events.AddEndpoints(accountManager, authCfg, router) + networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, accountManager.GetAccountIDFromToken, authCfg, router) return rootRouter, nil } diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go index e5bf3e695..161d97402 100644 --- a/management/server/http/handlers/policies/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -22,18 +22,18 @@ var ( // geolocationsHandler is a handler that returns locations. type geolocationsHandler struct { accountManager server.AccountManager - geolocationManager *geolocation.Geolocation + geolocationManager geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -func addLocationsEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { +func addLocationsEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, authCfg) router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") } // newGeolocationsHandlerHandler creates a new Geolocations handler -func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { +func newGeolocationsHandlerHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *geolocationsHandler { return &geolocationsHandler{ accountManager: accountManager, geolocationManager: geolocationManager, diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index b1035c570..a748e73b8 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -23,7 +23,7 @@ type handler struct { claimsExtractor *jwtclaims.ClaimsExtractor } -func AddEndpoints(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { +func AddEndpoints(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { policiesHandler := newHandler(accountManager, authCfg) router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go index 44917605b..ce0d4878c 100644 --- a/management/server/http/handlers/policies/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -19,11 +19,11 @@ import ( // postureChecksHandler is a handler that returns posture checks of the account. type postureChecksHandler struct { accountManager server.AccountManager - geolocationManager *geolocation.Geolocation + geolocationManager geolocation.Geolocation claimsExtractor *jwtclaims.ClaimsExtractor } -func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager *geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { +func addPostureCheckEndpoint(accountManager server.AccountManager, locationManager geolocation.Geolocation, authCfg configs.AuthCfg, router *mux.Router) { postureCheckHandler := newPostureChecksHandler(accountManager, locationManager, authCfg) router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") @@ -34,7 +34,7 @@ func addPostureCheckEndpoint(accountManager server.AccountManager, locationManag } // newPostureChecksHandler creates a new PostureChecks handler -func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager *geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { +func newPostureChecksHandler(accountManager server.AccountManager, geolocationManager geolocation.Geolocation, authCfg configs.AuthCfg) *postureChecksHandler { return &postureChecksHandler{ accountManager: accountManager, geolocationManager: geolocationManager, diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index e9a539e45..237687fd4 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -70,7 +70,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksH return claims.AccountId, claims.UserId, nil }, }, - geolocationManager: &geolocation.Geolocation{}, + geolocationManager: &geolocation.Mock{}, claimsExtractor: jwtclaims.NewClaimsExtractor( jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index 89696a165..a627d7203 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -93,7 +93,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { return } - apiSetupKeys := toResponseBody(setupKey) + apiSetupKeys := ToResponseBody(setupKey) // for the creation we need to send the plain key apiSetupKeys.Key = setupKey.Key @@ -183,7 +183,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { apiSetupKeys := make([]*api.SetupKey, 0) for _, key := range setupKeys { - apiSetupKeys = append(apiSetupKeys, toResponseBody(key)) + apiSetupKeys = append(apiSetupKeys, ToResponseBody(key)) } util.WriteJSONObject(r.Context(), w, apiSetupKeys) @@ -216,14 +216,14 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { func writeSuccess(ctx context.Context, w http.ResponseWriter, key *types.SetupKey) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(200) - err := json.NewEncoder(w).Encode(toResponseBody(key)) + err := json.NewEncoder(w).Encode(ToResponseBody(key)) if err != nil { util.WriteError(ctx, err, w) return } } -func toResponseBody(key *types.SetupKey) *api.SetupKey { +func ToResponseBody(key *types.SetupKey) *api.SetupKey { var state string switch { case key.IsExpired(): diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go index 4ecb1e9ed..f56227c10 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -26,7 +26,6 @@ const ( newSetupKeyName = "New Setup Key" updatedSetupKeyName = "KKKey" notFoundSetupKeyID = "notFoundSetupKeyID" - testAccountID = "test_id" ) func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKey, updatedSetupKey *types.SetupKey, @@ -81,7 +80,7 @@ func initSetupKeysTestMetaData(defaultKey *types.SetupKey, newKey *types.SetupKe return jwtclaims.AuthorizationClaims{ UserId: user.Id, Domain: "hotmail.com", - AccountId: testAccountID, + AccountId: "testAccountId", } }), ), @@ -102,7 +101,7 @@ func TestSetupKeysHandlers(t *testing.T) { updatedDefaultSetupKey.Name = updatedSetupKeyName updatedDefaultSetupKey.Revoked = true - expectedNewKey := toResponseBody(newSetupKey) + expectedNewKey := ToResponseBody(newSetupKey) expectedNewKey.Key = plainKey tt := []struct { name string @@ -120,7 +119,7 @@ func TestSetupKeysHandlers(t *testing.T) { requestPath: "/api/setup-keys", expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKeys: []*api.SetupKey{toResponseBody(defaultSetupKey)}, + expectedSetupKeys: []*api.SetupKey{ToResponseBody(defaultSetupKey)}, }, { name: "Get Existing Setup Key", @@ -128,7 +127,7 @@ func TestSetupKeysHandlers(t *testing.T) { requestPath: "/api/setup-keys/" + existingSetupKeyID, expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKey: toResponseBody(defaultSetupKey), + expectedSetupKey: ToResponseBody(defaultSetupKey), }, { name: "Get Not Existing Setup Key", @@ -159,7 +158,7 @@ func TestSetupKeysHandlers(t *testing.T) { ))), expectedStatus: http.StatusOK, expectedBody: true, - expectedSetupKey: toResponseBody(updatedDefaultSetupKey), + expectedSetupKey: ToResponseBody(updatedDefaultSetupKey), }, { name: "Delete Setup Key", @@ -228,7 +227,7 @@ func TestSetupKeysHandlers(t *testing.T) { func assertKeys(t *testing.T, got *api.SetupKey, expected *api.SetupKey) { t.Helper() // this comparison is done manually because when converting to JSON dates formatted differently - // assert.Equal(t, got.UpdatedAt, tc.expectedSetupKey.UpdatedAt) //doesn't work + // assert.Equal(t, got.UpdatedAt, tc.expectedResponse.UpdatedAt) //doesn't work assert.WithinDurationf(t, got.UpdatedAt, expected.UpdatedAt, 0, "") assert.WithinDurationf(t, got.Expires, expected.Expires, 0, "") assert.Equal(t, got.Name, expected.Name) diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go new file mode 100644 index 000000000..5e2895bcc --- /dev/null +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -0,0 +1,226 @@ +package benchmarks + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" +) + +// Map to store peers, groups, users, and setupKeys by name +var benchCasesSetupKeys = map[string]testing_tools.BenchmarkCase{ + "Setup Keys - XS": {Peers: 10000, Groups: 10000, Users: 10000, SetupKeys: 5}, + "Setup Keys - S": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 100}, + "Setup Keys - M": {Peers: 100, Groups: 20, Users: 20, SetupKeys: 1000}, + "Setup Keys - L": {Peers: 5, Groups: 5, Users: 5, SetupKeys: 5000}, + "Peers - L": {Peers: 10000, Groups: 5, Users: 5, SetupKeys: 5000}, + "Groups - L": {Peers: 5, Groups: 10000, Users: 5, SetupKeys: 5000}, + "Users - L": {Peers: 5, Groups: 5, Users: 10000, SetupKeys: 5000}, + "Setup Keys - XL": {Peers: 500, Groups: 50, Users: 100, SetupKeys: 25000}, +} + +func BenchmarkCreateSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + requestBody := api.CreateSetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName + strconv.Itoa(i), + Type: "reusable", + UsageLimit: 0, + } + + // the time marshal will be recorded as well but for our use case that is ok + body, err := json.Marshal(requestBody) + assert.NoError(b, err) + + req := testing_tools.BuildRequest(b, body, http.MethodPost, "/api/setup-keys", testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkUpdateSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + groupId := testing_tools.TestGroupId + if i%2 == 0 { + groupId = testing_tools.NewGroupId + } + requestBody := api.SetupKeyRequest{ + AutoGroups: []string{groupId}, + Revoked: false, + } + + // the time marshal will be recorded as well but for our use case that is ok + body, err := json.Marshal(requestBody) + assert.NoError(b, err) + + req := testing_tools.BuildRequest(b, body, http.MethodPut, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetOneSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys/"+testing_tools.TestKeyId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetAllSetupKeys(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15}, + "Setup Keys - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 40}, + "Setup Keys - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Peers - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Groups - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Users - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Setup Keys - XL": {MinMsPerOpLocal: 140, MaxMsPerOpLocal: 220, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 500}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkDeleteSetupKey(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesSetupKeys { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodDelete, "/api/setup-keys/"+"oldkey-"+strconv.Itoa(i), testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go new file mode 100644 index 000000000..193c0fb02 --- /dev/null +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -0,0 +1,1146 @@ +package integration + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "sort" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/handlers/setup_keys" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" +) + +func Test_SetupKeys_Create(t *testing.T) { + truePointer := true + + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestBody *api.CreateSetupKeyRequest + requestType string + requestPath string + userId string + }{ + { + name: "Create Setup Key", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 0, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 0, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key with already existing name", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.ExistingKeyName, + Type: "one-off", + UsageLimit: 0, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key as on-off with more than one usage", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "one-off", + UsageLimit: 3, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key with expiration in the past", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: -testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "one-off", + UsageLimit: 0, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Create Setup Key with AutoGroups that do exist", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 1, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key for ephemeral peers", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: []string{}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + Ephemeral: &truePointer, + UsageLimit: 1, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Create Setup Key with AutoGroups that do not exist", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: []string{"someGroupID"}, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 0, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Create Setup Key", + requestType: http.MethodPost, + requestPath: "/api/setup-keys", + requestBody: &api.CreateSetupKeyRequest{ + AutoGroups: nil, + ExpiresIn: testing_tools.ExpiresIn, + Name: testing_tools.NewKeyName, + Type: "reusable", + UsageLimit: 0, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.NewKeyName, + Revoked: false, + State: "valid", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 0, + UsedTimes: 0, + Valid: true, + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(user.name+" - "+tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + req := testing_tools.BuildRequest(t, body, tc.requestType, tc.requestPath, user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + validateCreatedKey(t, tc.expectedResponse, got) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_Update(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestBody *api.SetupKeyRequest + requestType string + requestPath string + requestId string + }{ + { + name: "Add existing Group to existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId, testing_tools.NewGroupId}, + Revoked: false, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId, testing_tools.NewGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Add non-existing Group to existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId, "someGroupId"}, + Revoked: false, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Add existing Group to non-existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: "someId", + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId, testing_tools.NewGroupId}, + Revoked: false, + }, + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + { + name: "Remove existing Group from existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{}, + Revoked: false, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Remove existing Group to non-existing Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: "someID", + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{}, + Revoked: false, + }, + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + { + name: "Revoke existing valid Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: true, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "one-off", + UpdatedAt: time.Now(), + UsageLimit: 1, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Revoke existing revoked Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: true, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Un-Revoke existing revoked Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: false, + }, + expectedStatus: http.StatusUnprocessableEntity, + expectedResponse: nil, + }, + { + name: "Revoke existing expired Setup Key", + requestType: http.MethodPut, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.ExpiredKeyId, + requestBody: &api.SetupKeyRequest{ + AutoGroups: []string{testing_tools.TestGroupId}, + Revoked: true, + }, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "expired", + Type: "reusable", + UpdatedAt: time.Now(), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + body, err := json.Marshal(tc.requestBody) + if err != nil { + t.Fatalf("Failed to marshal request body: %v", err) + } + + req := testing_tools.BuildRequest(t, body, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + validateCreatedKey(t, tc.expectedResponse, got) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_Get(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestType string + requestPath string + requestId string + }{ + { + name: "Get existing valid Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Get existing expired Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.ExpiredKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "expired", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + { + name: "Get existing revoked Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Get non-existing Setup Key", + requestType: http.MethodGet, + requestPath: "/api/setup-keys/{id}", + requestId: "someId", + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectRespnose := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectRespnose { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + validateCreatedKey(t, tc.expectedResponse, got) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse, setup_keys.ToResponseBody(key)) + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_GetAll(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse []*api.SetupKey + requestType string + requestPath string + }{ + { + name: "Get all Setup Keys", + requestType: http.MethodGet, + requestPath: "/api/setup-keys", + expectedStatus: http.StatusOK, + expectedResponse: []*api.SetupKey{ + { + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + { + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + { + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "expired", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := []api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + sort.Slice(got, func(i, j int) bool { + return got[i].UsageLimit < got[j].UsageLimit + }) + + sort.Slice(tc.expectedResponse, func(i, j int) bool { + return tc.expectedResponse[i].UsageLimit < tc.expectedResponse[j].UsageLimit + }) + + for i, _ := range tc.expectedResponse { + validateCreatedKey(t, tc.expectedResponse[i], &got[i]) + + key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got[i].Id) + if err != nil { + return + } + + validateCreatedKey(t, tc.expectedResponse[i], setup_keys.ToResponseBody(key)) + } + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func Test_SetupKeys_Delete(t *testing.T) { + users := []struct { + name string + userId string + expectResponse bool + }{ + { + name: "Regular user", + userId: testing_tools.TestUserId, + expectResponse: false, + }, + { + name: "Admin user", + userId: testing_tools.TestAdminId, + expectResponse: true, + }, + { + name: "Owner user", + userId: testing_tools.TestOwnerId, + expectResponse: true, + }, + { + name: "Regular service user", + userId: testing_tools.TestServiceUserId, + expectResponse: false, + }, + { + name: "Admin service user", + userId: testing_tools.TestServiceAdminId, + expectResponse: true, + }, + { + name: "Blocked user", + userId: testing_tools.BlockedUserId, + expectResponse: false, + }, + { + name: "Other user", + userId: testing_tools.OtherUserId, + expectResponse: false, + }, + { + name: "Invalid token", + userId: testing_tools.InvalidToken, + expectResponse: false, + }, + } + + tt := []struct { + name string + expectedStatus int + expectedResponse *api.SetupKey + requestType string + requestPath string + requestId string + }{ + { + name: "Delete existing valid Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.TestKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "valid", + Type: "one-off", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 1, + UsedTimes: 0, + Valid: true, + }, + }, + { + name: "Delete existing expired Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.ExpiredKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: true, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: false, + State: "expired", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 5, + UsedTimes: 1, + Valid: false, + }, + }, + { + name: "Delete existing revoked Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: testing_tools.RevokedKeyId, + expectedStatus: http.StatusOK, + expectedResponse: &api.SetupKey{ + AutoGroups: []string{testing_tools.TestGroupId}, + Ephemeral: false, + Expires: time.Time{}, + Id: "", + Key: "", + LastUsed: time.Time{}, + Name: testing_tools.ExistingKeyName, + Revoked: true, + State: "revoked", + Type: "reusable", + UpdatedAt: time.Date(2021, time.August, 19, 20, 46, 20, 0, time.UTC), + UsageLimit: 3, + UsedTimes: 0, + Valid: false, + }, + }, + { + name: "Delete non-existing Setup Key", + requestType: http.MethodDelete, + requestPath: "/api/setup-keys/{id}", + requestId: "someId", + expectedStatus: http.StatusNotFound, + expectedResponse: nil, + }, + } + + for _, tc := range tt { + for _, user := range users { + t.Run(tc.name, func(t *testing.T) { + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + + req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) + + recorder := httptest.NewRecorder() + + apiHandler.ServeHTTP(recorder, req) + + content, expectResponse := testing_tools.ReadResponse(t, recorder, tc.expectedStatus, user.expectResponse) + if !expectResponse { + return + } + got := &api.SetupKey{} + if err := json.Unmarshal(content, &got); err != nil { + t.Fatalf("Sent content is not in correct json format; %v", err) + } + + _, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got.Id) + assert.Errorf(t, err, "Expected error when trying to get deleted key") + + select { + case <-done: + case <-time.After(time.Second): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + } + } +} + +func validateCreatedKey(t *testing.T, expectedKey *api.SetupKey, got *api.SetupKey) { + t.Helper() + + if got.Expires.After(time.Now().Add(-1*time.Minute)) && got.Expires.Before(time.Now().Add(testing_tools.ExpiresIn*time.Second)) || + got.Expires.After(time.Date(2300, 01, 01, 0, 0, 0, 0, time.Local)) || + got.Expires.Before(time.Date(1950, 01, 01, 0, 0, 0, 0, time.Local)) { + got.Expires = time.Time{} + expectedKey.Expires = time.Time{} + } + + if got.Id == "" { + t.Fatalf("Expected key to have an ID") + } + got.Id = "" + + if got.Key == "" { + t.Fatalf("Expected key to have a key") + } + got.Key = "" + + if got.UpdatedAt.After(time.Now().Add(-1*time.Minute)) && got.UpdatedAt.Before(time.Now().Add(+1*time.Minute)) { + got.UpdatedAt = time.Time{} + expectedKey.UpdatedAt = time.Time{} + } + + expectedKey.UpdatedAt = expectedKey.UpdatedAt.In(time.UTC) + got.UpdatedAt = got.UpdatedAt.In(time.UTC) + + assert.Equal(t, expectedKey, got) +} diff --git a/management/server/http/testing/testdata/setup_keys.sql b/management/server/http/testing/testdata/setup_keys.sql new file mode 100644 index 000000000..a315ea0f7 --- /dev/null +++ b/management/server/http/testing/testdata/setup_keys.sql @@ -0,0 +1,24 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); + + +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',3,0); +INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,'0001-01-01 00:00:00+00:00','["testGroupId"]',5,1); + diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go new file mode 100644 index 000000000..da910c5c3 --- /dev/null +++ b/management/server/http/testing/testing_tools/tools.go @@ -0,0 +1,307 @@ +package testing_tools + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/groups" + nbhttp "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/configs" + "github.com/netbirdio/netbird/management/server/jwtclaims" + "github.com/netbirdio/netbird/management/server/networks" + "github.com/netbirdio/netbird/management/server/networks/resources" + "github.com/netbirdio/netbird/management/server/networks/routers" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" +) + +const ( + TestAccountId = "testAccountId" + TestPeerId = "testPeerId" + TestGroupId = "testGroupId" + TestKeyId = "testKeyId" + + TestUserId = "testUserId" + TestAdminId = "testAdminId" + TestOwnerId = "testOwnerId" + TestServiceUserId = "testServiceUserId" + TestServiceAdminId = "testServiceAdminId" + BlockedUserId = "blockedUserId" + OtherUserId = "otherUserId" + InvalidToken = "invalidToken" + + NewKeyName = "newKey" + NewGroupId = "newGroupId" + ExpiresIn = 3600 + RevokedKeyId = "revokedKeyId" + ExpiredKeyId = "expiredKeyId" + + ExistingKeyName = "existingKey" +) + +type TB interface { + Cleanup(func()) + Helper() + Errorf(format string, args ...any) + Fatalf(format string, args ...any) + TempDir() string +} + +// BenchmarkCase defines a single benchmark test case +type BenchmarkCase struct { + Peers int + Groups int + Users int + SetupKeys int +} + +// PerformanceMetrics holds the performance expectations +type PerformanceMetrics struct { + MinMsPerOpLocal float64 + MaxMsPerOpLocal float64 + MinMsPerOpCICD float64 + MaxMsPerOpCICD float64 +} + +func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage) (http.Handler, server.AccountManager, chan struct{}) { + store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) + if err != nil { + t.Fatalf("Failed to create test store: %v", err) + } + t.Cleanup(cleanup) + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + if err != nil { + t.Fatalf("Failed to create metrics: %v", err) + } + + peersUpdateManager := server.NewPeersUpdateManager(nil) + updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId) + done := make(chan struct{}) + go func() { + if expectedPeerUpdate != nil { + peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate) + } else { + peerShouldNotReceiveUpdate(t, updMsg) + } + close(done) + }() + + geoMock := &geolocation.Mock{} + validatorMock := server.MocIntegratedValidator{} + am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + networksManagerMock := networks.NewManagerMock() + resourcesManagerMock := resources.NewManagerMock() + routersManagerMock := routers.NewManagerMock() + groupsManagerMock := groups.NewManagerMock() + apiHandler, err := nbhttp.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, &jwtclaims.JwtValidatorMock{}, metrics, configs.AuthCfg{}, validatorMock) + if err != nil { + t.Fatalf("Failed to create API handler: %v", err) + } + + return apiHandler, am, done +} + +func peerShouldNotReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage) { + t.Helper() + select { + case msg := <-updateMessage: + t.Errorf("Unexpected message received: %+v", msg) + case <-time.After(500 * time.Millisecond): + return + } +} + +func peerShouldReceiveUpdate(t TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) { + t.Helper() + + select { + case msg := <-updateMessage: + if msg == nil { + t.Errorf("Received nil update message, expected valid message") + } + assert.Equal(t, expected, msg) + case <-time.After(500 * time.Millisecond): + t.Errorf("Timed out waiting for update message") + } +} + +func BuildRequest(t TB, requestBody []byte, requestType, requestPath, user string) *http.Request { + t.Helper() + + req := httptest.NewRequest(requestType, requestPath, bytes.NewBuffer(requestBody)) + req.Header.Set("Authorization", "Bearer "+user) + + return req +} + +func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedStatus int, expectResponse bool) ([]byte, bool) { + t.Helper() + + res := recorder.Result() + defer res.Body.Close() + + content, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("Failed to read response body: %v", err) + } + + if !expectResponse { + return nil, false + } + + if status := recorder.Code; status != expectedStatus { + t.Fatalf("handler returned wrong status code: got %v want %v, content: %s", + status, expectedStatus, string(content)) + } + + return content, expectedStatus == http.StatusOK +} + +func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) { + b.Helper() + + ctx := context.Background() + account, err := am.GetAccount(ctx, TestAccountId) + if err != nil { + b.Fatalf("Failed to get account: %v", err) + } + + // Create peers + for i := 0; i < peers; i++ { + peerKey, _ := wgtypes.GeneratePrivateKey() + peer := &nbpeer.Peer{ + ID: fmt.Sprintf("oldpeer-%d", i), + DNSLabel: fmt.Sprintf("oldpeer-%d", i), + Key: peerKey.PublicKey().String(), + IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), + Status: &nbpeer.PeerStatus{}, + UserID: TestUserId, + } + account.Peers[peer.ID] = peer + } + + // Create users + for i := 0; i < users; i++ { + user := &types.User{ + Id: fmt.Sprintf("olduser-%d", i), + AccountID: account.Id, + Role: types.UserRoleUser, + } + account.Users[user.Id] = user + } + + for i := 0; i < setupKeys; i++ { + key := &types.SetupKey{ + Id: fmt.Sprintf("oldkey-%d", i), + AccountID: account.Id, + AutoGroups: []string{"someGroupID"}, + ExpiresAt: time.Now().Add(ExpiresIn * time.Second), + Name: NewKeyName + strconv.Itoa(i), + Type: "reusable", + UsageLimit: 0, + } + account.SetupKeys[key.Id] = key + } + + // Create groups and policies + account.Policies = make([]*types.Policy, 0, groups) + for i := 0; i < groups; i++ { + groupID := fmt.Sprintf("group-%d", i) + group := &types.Group{ + ID: groupID, + Name: fmt.Sprintf("Group %d", i), + } + for j := 0; j < peers/groups; j++ { + peerIndex := i*(peers/groups) + j + group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) + } + account.Groups[groupID] = group + + // Create a policy for this group + policy := &types.Policy{ + ID: fmt.Sprintf("policy-%d", i), + Name: fmt.Sprintf("Policy for Group %d", i), + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: fmt.Sprintf("rule-%d", i), + Name: fmt.Sprintf("Rule for Group %d", i), + Enabled: true, + Sources: []string{groupID}, + Destinations: []string{groupID}, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Action: types.PolicyTrafficActionAccept, + }, + }, + } + account.Policies = append(account.Policies, policy) + } + + account.PostureChecks = []*posture.Checks{ + { + ID: "PostureChecksAll", + Name: "All", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.0.1", + }, + }, + }, + } + + err = am.Store.SaveAccount(context.Background(), account) + if err != nil { + b.Fatalf("Failed to save account: %v", err) + } + +} + +func EvaluateBenchmarkResults(b *testing.B, name string, duration time.Duration, perfMetrics PerformanceMetrics, recorder *httptest.ResponseRecorder) { + b.Helper() + + if recorder.Code != http.StatusOK { + b.Fatalf("Benchmark %s failed: unexpected status code %d", name, recorder.Code) + } + + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + minExpected := perfMetrics.MinMsPerOpLocal + maxExpected := perfMetrics.MaxMsPerOpLocal + if os.Getenv("CI") == "true" { + minExpected = perfMetrics.MinMsPerOpCICD + maxExpected = perfMetrics.MaxMsPerOpCICD + } + + if msPerOp < minExpected { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", name, msPerOp, minExpected) + } + + if msPerOp > maxExpected { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", name, msPerOp, maxExpected) + } +} diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 47c4ca6ae..62e9213f7 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" + nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" ) @@ -78,3 +79,45 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID func (am *DefaultAccountManager) GetValidatedPeers(account *types.Account) (map[string]struct{}, error) { return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) } + +type MocIntegratedValidator struct { + ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) +} + +func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { + return nil +} + +func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { + if a.ValidatePeerFunc != nil { + return a.ValidatePeerFunc(context.Background(), update, peer, userID, accountID, dnsDomain, peersGroup, extraSettings) + } + return update, false, nil +} +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { + validatedPeers := make(map[string]struct{}) + for _, peer := range peers { + validatedPeers[peer.ID] = struct{}{} + } + return validatedPeers, nil +} + +func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { + return peer +} + +func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { + return false, false, nil +} + +func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { + return nil +} + +func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { + // just a dummy +} + +func (MocIntegratedValidator) Stop(_ context.Context) { + // just a dummy +} diff --git a/management/server/jwtclaims/jwtValidator.go b/management/server/jwtclaims/jwtValidator.go index b91616fa5..79e59e76f 100644 --- a/management/server/jwtclaims/jwtValidator.go +++ b/management/server/jwtclaims/jwtValidator.go @@ -72,15 +72,19 @@ type JSONWebKey struct { X5c []string `json:"x5c"` } -// JWTValidator struct to handle token validation and parsing -type JWTValidator struct { +type JWTValidator interface { + ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) +} + +// jwtValidatorImpl struct to handle token validation and parsing +type jwtValidatorImpl struct { options Options } var keyNotFound = errors.New("unable to find appropriate key") // NewJWTValidator constructor -func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (*JWTValidator, error) { +func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, keysLocation string, idpSignkeyRefreshEnabled bool) (JWTValidator, error) { keys, err := getPemKeys(ctx, keysLocation) if err != nil { return nil, err @@ -146,13 +150,13 @@ func NewJWTValidator(ctx context.Context, issuer string, audienceList []string, options.UserProperty = "user" } - return &JWTValidator{ + return &jwtValidatorImpl{ options: options, }, nil } // ValidateAndParse validates the token and returns the parsed token -func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { +func (m *jwtValidatorImpl) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { // If the token is empty... if token == "" { // Check if it was required @@ -318,3 +322,28 @@ func getMaxAgeFromCacheHeader(ctx context.Context, cacheControl string) int { return 0 } + +type JwtValidatorMock struct{} + +func (j *JwtValidatorMock) ValidateAndParse(ctx context.Context, token string) (*jwt.Token, error) { + claimMaps := jwt.MapClaims{} + + switch token { + case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId": + claimMaps[UserIDClaim] = token + claimMaps[AccountIDSuffix] = "testAccountId" + claimMaps[DomainIDSuffix] = "test.com" + claimMaps[DomainCategorySuffix] = "private" + case "otherUserId": + claimMaps[UserIDClaim] = "otherUserId" + claimMaps[AccountIDSuffix] = "otherAccountId" + claimMaps[DomainIDSuffix] = "other.com" + claimMaps[DomainCategorySuffix] = "private" + case "invalidToken": + return nil, errors.New("invalid token") + } + + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps) + return jwtToken, nil +} + diff --git a/management/server/management_test.go b/management/server/management_test.go index 40514ae14..cfa2c138f 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -21,13 +21,10 @@ import ( "github.com/netbirdio/netbird/encryption" mgmtProto "github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/server" - "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/util" ) @@ -448,43 +445,6 @@ var _ = Describe("Management service", func() { }) }) -type MocIntegratedValidator struct { -} - -func (a MocIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *account.ExtraSettings, oldExtraSettings *account.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { - return nil -} - -func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) { - return update, false, nil -} - -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { - validatedPeers := make(map[string]struct{}) - for p := range peers { - validatedPeers[p] = struct{}{} - } - return validatedPeers, nil -} - -func (MocIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer { - return peer -} - -func (MocIntegratedValidator) IsNotValidPeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) { - return false, false, nil -} - -func (MocIntegratedValidator) PeerDeleted(_ context.Context, _, _ string) error { - return nil -} - -func (MocIntegratedValidator) SetPeerInvalidationListener(func(accountID string)) { - -} - -func (MocIntegratedValidator) Stop(_ context.Context) {} - func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { defer GinkgoRecover() @@ -547,7 +507,7 @@ func startServer(config *server.Config, dataDir string, testFile string) (*grpc. log.Fatalf("failed creating metrics: %v", err) } - accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics) + accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, server.MocIntegratedValidator{}, metrics) if err != nil { log.Fatalf("failed creating a manager: %v", err) } diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index 4a7b3db77..51205f1e9 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -32,6 +32,9 @@ type managerImpl struct { routersManager routers.Manager } +type mockManager struct { +} + func NewManager(store store.Store, permissionsManager permissions.Manager, resourceManager resources.Manager, routersManager routers.Manager, accountManager s.AccountManager) Manager { return &managerImpl{ store: store, @@ -185,3 +188,27 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw return nil } + +func NewManagerMock() Manager { + return &mockManager{} +} + +func (m *mockManager) GetAllNetworks(ctx context.Context, accountID, userID string) ([]*types.Network, error) { + return []*types.Network{}, nil +} + +func (m *mockManager) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + return network, nil +} + +func (m *mockManager) GetNetwork(ctx context.Context, accountID, userID, networkID string) (*types.Network, error) { + return &types.Network{}, nil +} + +func (m *mockManager) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { + return network, nil +} + +func (m *mockManager) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error { + return nil +} diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 0fff5bcf8..02b462947 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -34,6 +34,9 @@ type managerImpl struct { accountManager s.AccountManager } +type mockManager struct { +} + func NewManager(store store.Store, permissionsManager permissions.Manager, groupsManager groups.Manager, accountManager s.AccountManager) Manager { return &managerImpl{ store: store, @@ -381,3 +384,39 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti return eventsToStore, nil } + +func NewManagerMock() Manager { + return &mockManager{} +} + +func (m *mockManager) GetAllResourcesInNetwork(ctx context.Context, accountID, userID, networkID string) ([]*types.NetworkResource, error) { + return []*types.NetworkResource{}, nil +} + +func (m *mockManager) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { + return []*types.NetworkResource{}, nil +} + +func (m *mockManager) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { + return map[string][]string{}, nil +} + +func (m *mockManager) CreateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + return &types.NetworkResource{}, nil +} + +func (m *mockManager) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) { + return &types.NetworkResource{}, nil +} + +func (m *mockManager) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) { + return &types.NetworkResource{}, nil +} + +func (m *mockManager) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error { + return nil +} + +func (m *mockManager) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, resourceID string) ([]func(), error) { + return []func(){}, nil +} diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 9a4a1efb8..f2f1aad45 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -75,7 +75,7 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, autoGroups); err != nil { - return err + return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) } setupKey, plainKey = types.GenerateSetupKey(keyName, keyType, expiresIn, autoGroups, usageLimit, ephemeral) @@ -132,7 +132,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { if err = validateSetupKeyAutoGroups(ctx, transaction, accountID, keyToSave.AutoGroups); err != nil { - return err + return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) } oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id) From bc013e48886ab18c42b02885fba367879551dcf7 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 2 Jan 2025 18:46:28 +0100 Subject: [PATCH 106/159] [management] exclude self from network map if self is routing peer (#3142) --- .../setupkeys_handler_integration_test.go | 2 +- management/server/types/account.go | 1 + management/server/types/account_test.go | 49 ++++++++++++++++--- 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go index 193c0fb02..e22a1a5a8 100644 --- a/management/server/http/testing/integration/setupkeys_handler_integration_test.go +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -925,7 +925,7 @@ func Test_SetupKeys_GetAll(t *testing.T) { return tc.expectedResponse[i].UsageLimit < tc.expectedResponse[j].UsageLimit }) - for i, _ := range tc.expectedResponse { + for i := range tc.expectedResponse { validateCreatedKey(t, tc.expectedResponse[i], &got[i]) key, err := am.GetSetupKey(context.Background(), testing_tools.TestAccountId, testing_tools.TestUserId, got[i].Id) diff --git a/management/server/types/account.go b/management/server/types/account.go index f9e1cc9b4..8fa61d700 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -318,6 +318,7 @@ func (a *Account) addNetworksRoutingPeers( } delete(sourcePeers, peer.ID) + delete(networkRoutesPeers, peer.ID) for _, existingPeer := range peersToConnect { delete(sourcePeers, existingPeer.ID) diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index 367baef4f..389ab17f6 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -336,12 +336,12 @@ func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) { func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) { account := setupTestAccount() - peer := &nbpeer.Peer{Key: "peer1"} + peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"} networkResourcesRoutes := []*route.Route{ {Peer: "peer2Key"}, } peersToConnect := []*nbpeer.Peer{ - {Key: "peer2Key"}, + {Key: "peer2Key", ID: "peer2"}, } expiredPeers := []*nbpeer.Peer{} @@ -352,16 +352,16 @@ func Test_AddNetworksRoutingPeersIgnoresExistingPeers(t *testing.T) { func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) { account := setupTestAccount() - peer := &nbpeer.Peer{Key: "peer1Key"} + peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"} networkResourcesRoutes := []*route.Route{ - {Peer: "peer2Key"}, - {Peer: "peer3Key"}, + {Peer: "peer2Key", PeerID: "peer2"}, + {Peer: "peer3Key", PeerID: "peer3"}, } peersToConnect := []*nbpeer.Peer{ - {Key: "peer2Key"}, + {Key: "peer2Key", ID: "peer2"}, } expiredPeers := []*nbpeer.Peer{ - {Key: "peer3Key"}, + {Key: "peer3Key", ID: "peer3"}, } result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, false, map[string]struct{}{}) @@ -369,9 +369,24 @@ func Test_AddNetworksRoutingPeersAddsExpiredPeers(t *testing.T) { require.Equal(t, "peer2Key", result[0].Key) } +func Test_AddNetworksRoutingPeersExcludesSelf(t *testing.T) { + account := setupTestAccount() + peer := &nbpeer.Peer{Key: "peer1Key", ID: "peer1"} + networkResourcesRoutes := []*route.Route{ + {Peer: "peer1Key", PeerID: "peer1"}, + {Peer: "peer2Key", PeerID: "peer2"}, + } + peersToConnect := []*nbpeer.Peer{} + expiredPeers := []*nbpeer.Peer{} + + result := account.addNetworksRoutingPeers(networkResourcesRoutes, peer, peersToConnect, expiredPeers, true, map[string]struct{}{}) + require.Len(t, result, 1) + require.Equal(t, "peer2Key", result[0].Key) +} + func Test_AddNetworksRoutingPeersHandlesNoMissingPeers(t *testing.T) { account := setupTestAccount() - peer := &nbpeer.Peer{Key: "peer1"} + peer := &nbpeer.Peer{Key: "peer1key", ID: "peer1"} networkResourcesRoutes := []*route.Route{} peersToConnect := []*nbpeer.Peer{} expiredPeers := []*nbpeer.Peer{} @@ -755,3 +770,21 @@ func Test_NetworksNetMapGenWithTwoPostureChecks(t *testing.T) { t.Errorf("%s should not have source range of peer2 %s", rules[0].SourceRanges, accNetResourcePeer2IP.String()) } } + +func Test_NetworksNetMapGenShouldExcludeOtherRouters(t *testing.T) { + account := getBasicAccountsWithResource() + + account.Peers["router2Id"] = &nbpeer.Peer{Key: "router2Key", ID: "router2Id", AccountID: accID, IP: net.IP{192, 168, 1, 4}} + account.NetworkRouters = append(account.NetworkRouters, &routerTypes.NetworkRouter{ + ID: "router2Id", + NetworkID: network1ID, + AccountID: accID, + Peer: "router2Id", + }) + + // validate routes for router1 + isRouter, networkResourcesRoutes, sourcePeers := account.GetNetworkResourcesRoutesToSync(context.Background(), accNetResourceRouter1ID, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap()) + assert.True(t, isRouter, "should be router") + assert.Len(t, networkResourcesRoutes, 1, "expected network resource route don't match") + assert.Len(t, sourcePeers, 2, "expected source peers don't match") +} From a01253c3c837e984b7813277a1eaa644db056c4a Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 3 Jan 2025 15:24:30 +0100 Subject: [PATCH 107/159] [management] add users benchmark (#3141) --- .github/workflows/golang-test-linux.yml | 2 +- management/server/account_test.go | 2 +- .../setupkeys_handler_benchmark_test.go | 16 +- .../users_handler_benchmark_test.go | 182 ++++++++++++++++++ .../server/http/testing/testdata/users.sql | 23 +++ 5 files changed, 215 insertions(+), 10 deletions(-) create mode 100644 management/server/http/testing/benchmarks/users_handler_benchmark_test.go create mode 100644 management/server/http/testing/testdata/users.sql diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 155190398..2de7a518a 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -232,7 +232,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./... test_client_on_docker: needs: [ build-cache ] diff --git a/management/server/account_test.go b/management/server/account_test.go index 4f6cdf78d..39a748192 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3025,7 +3025,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { {"Medium", 500, 100, 7, 13, 10, 70}, {"Large", 5000, 200, 65, 80, 60, 220}, {"Small single", 50, 10, 1, 3, 3, 70}, - {"Medium single", 500, 10, 7, 13, 10, 26}, + {"Medium single", 500, 10, 7, 13, 10, 32}, {"Large 5", 5000, 15, 65, 80, 60, 200}, } diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go index 5e2895bcc..9431b8c86 100644 --- a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -32,14 +32,14 @@ var benchCasesSetupKeys = map[string]testing_tools.BenchmarkCase{ func BenchmarkCreateSetupKey(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, } log.SetOutput(io.Discard) diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go new file mode 100644 index 000000000..f453ba7c8 --- /dev/null +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -0,0 +1,182 @@ +package benchmarks + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" +) + +// Map to store peers, groups, users, and setupKeys by name +var benchCasesUsers = map[string]testing_tools.BenchmarkCase{ + "Users - XS": {Peers: 10000, Groups: 10000, Users: 5, SetupKeys: 10000}, + "Users - S": {Peers: 5, Groups: 5, Users: 10, SetupKeys: 5}, + "Users - M": {Peers: 100, Groups: 20, Users: 1000, SetupKeys: 1000}, + "Users - L": {Peers: 5, Groups: 5, Users: 5000, SetupKeys: 5}, + "Peers - L": {Peers: 10000, Groups: 5, Users: 5000, SetupKeys: 5}, + "Groups - L": {Peers: 5, Groups: 10000, Users: 5000, SetupKeys: 5}, + "Setup Keys - L": {Peers: 5, Groups: 5, Users: 5000, SetupKeys: 10000}, + "Users - XL": {Peers: 500, Groups: 50, Users: 25000, SetupKeys: 3000}, +} + +func BenchmarkUpdateUser(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Users - XS": {MinMsPerOpLocal: 700, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 7000}, + "Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 6, MaxMsPerOpCICD: 40}, + "Users - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200}, + "Users - L": {MinMsPerOpLocal: 60, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 700}, + "Peers - L": {MinMsPerOpLocal: 300, MaxMsPerOpLocal: 500, MinMsPerOpCICD: 550, MaxMsPerOpCICD: 2000}, + "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 750, MaxMsPerOpCICD: 5000}, + "Setup Keys - L": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 1000}, + "Users - XL": {MinMsPerOpLocal: 350, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 700, MaxMsPerOpCICD: 3500}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesUsers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + groupId := testing_tools.TestGroupId + if i%2 == 0 { + groupId = testing_tools.NewGroupId + } + requestBody := api.UserRequest{ + AutoGroups: []string{groupId}, + IsBlocked: false, + Role: "admin", + } + + // the time marshal will be recorded as well but for our use case that is ok + body, err := json.Marshal(requestBody) + assert.NoError(b, err) + + req := testing_tools.BuildRequest(b, body, http.MethodPut, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetOneUser(b *testing.B) { + b.Skip("Skipping benchmark as endpoint is missing") + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Users - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Users - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Users - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + "Users - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesUsers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetAllUsers(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Users - XS": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 180}, + "Users - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, + "Users - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 12, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, + "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, + "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, + "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, + "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 140, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 200}, + "Users - XL": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 90}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesUsers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/setup-keys", testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkDeleteUsers(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Users - XS": {MinMsPerOpLocal: 1000, MaxMsPerOpLocal: 1600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 10000}, + "Users - S": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200}, + "Users - M": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 230}, + "Users - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 45, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 190}, + "Peers - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 1800}, + "Groups - L": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 1200, MaxMsPerOpCICD: 7500}, + "Setup Keys - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 55, MaxMsPerOpCICD: 600}, + "Users - XL": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 400}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesUsers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodDelete, "/api/users/"+"olduser-"+strconv.Itoa(i), testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} diff --git a/management/server/http/testing/testdata/users.sql b/management/server/http/testing/testdata/users.sql new file mode 100644 index 000000000..da7cae565 --- /dev/null +++ b/management/server/http/testing/testdata/users.sql @@ -0,0 +1,23 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',3,0); +INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,'0001-01-01 00:00:00+00:00','["testGroupId"]',5,1); +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); + + +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); From cfa6d09c5e90c181b57982461c3d18d3dc844c73 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 3 Jan 2025 15:28:15 +0100 Subject: [PATCH 108/159] [management] add peers benchmark (#3143) --- .../peers_handler_benchmark_test.go | 175 ++++++++++++++++++ .../setupkeys_handler_benchmark_test.go | 10 +- .../setupkeys_handler_integration_test.go | 10 +- .../server/http/testing/testdata/peers.sql | 22 +++ .../http/testing/testing_tools/tools.go | 20 +- 5 files changed, 218 insertions(+), 19 deletions(-) create mode 100644 management/server/http/testing/benchmarks/peers_handler_benchmark_test.go create mode 100644 management/server/http/testing/testdata/peers.sql diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go new file mode 100644 index 000000000..829649798 --- /dev/null +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -0,0 +1,175 @@ +package benchmarks + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "strconv" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + "github.com/netbirdio/netbird/management/server" + "github.com/netbirdio/netbird/management/server/http/api" + "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" +) + +// Map to store peers, groups, users, and setupKeys by name +var benchCasesPeers = map[string]testing_tools.BenchmarkCase{ + "Peers - XS": {Peers: 5, Groups: 10000, Users: 10000, SetupKeys: 10000}, + "Peers - S": {Peers: 100, Groups: 5, Users: 5, SetupKeys: 5}, + "Peers - M": {Peers: 1000, Groups: 20, Users: 20, SetupKeys: 100}, + "Peers - L": {Peers: 5000, Groups: 5, Users: 5, SetupKeys: 5}, + "Groups - L": {Peers: 5000, Groups: 10000, Users: 5, SetupKeys: 5}, + "Users - L": {Peers: 5000, Groups: 5, Users: 10000, SetupKeys: 5}, + "Setup Keys - L": {Peers: 5000, Groups: 5, Users: 5, SetupKeys: 10000}, + "Peers - XL": {Peers: 25000, Groups: 50, Users: 100, SetupKeys: 500}, +} + +func BenchmarkUpdatePeer(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Peers - XS": {MinMsPerOpLocal: 1300, MaxMsPerOpLocal: 1700, MinMsPerOpCICD: 2200, MaxMsPerOpCICD: 13000}, + "Peers - S": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 200}, + "Peers - M": {MinMsPerOpLocal: 160, MaxMsPerOpLocal: 190, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 500}, + "Peers - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 430, MinMsPerOpCICD: 450, MaxMsPerOpCICD: 1400}, + "Groups - L": {MinMsPerOpLocal: 1200, MaxMsPerOpLocal: 1500, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 13000}, + "Users - L": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 800, MaxMsPerOpCICD: 2800}, + "Setup Keys - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 700, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 1300}, + "Peers - XL": {MinMsPerOpLocal: 1400, MaxMsPerOpLocal: 1900, MinMsPerOpCICD: 2200, MaxMsPerOpCICD: 5000}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesPeers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + requestBody := api.PeerRequest{ + Name: "peer" + strconv.Itoa(i), + } + + // the time marshal will be recorded as well but for our use case that is ok + body, err := json.Marshal(requestBody) + assert.NoError(b, err) + + req := testing_tools.BuildRequest(b, body, http.MethodPut, "/api/peers/"+testing_tools.TestPeerId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetOnePeer(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Peers - XS": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 900, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 7000}, + "Peers - S": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 7, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30}, + "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 35, MaxMsPerOpCICD: 80}, + "Peers - L": {MinMsPerOpLocal: 120, MaxMsPerOpLocal: 160, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, + "Groups - L": {MinMsPerOpLocal: 500, MaxMsPerOpLocal: 750, MinMsPerOpCICD: 900, MaxMsPerOpCICD: 6500}, + "Users - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 300, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 600}, + "Setup Keys - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 300, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 600}, + "Peers - XL": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 1500}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesPeers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/peers/"+testing_tools.TestPeerId, testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkGetAllPeers(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Peers - XS": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 900, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 6000}, + "Peers - S": {MinMsPerOpLocal: 4, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 7, MaxMsPerOpCICD: 30}, + "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 90}, + "Peers - L": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 350}, + "Groups - L": {MinMsPerOpLocal: 5000, MaxMsPerOpLocal: 5500, MinMsPerOpCICD: 7000, MaxMsPerOpCICD: 15000}, + "Users - L": {MinMsPerOpLocal: 250, MaxMsPerOpLocal: 300, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 700}, + "Setup Keys - L": {MinMsPerOpLocal: 250, MaxMsPerOpLocal: 350, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 700}, + "Peers - XL": {MinMsPerOpLocal: 900, MaxMsPerOpLocal: 1300, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 2200}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesPeers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/peers", testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} + +func BenchmarkDeletePeer(b *testing.B) { + var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ + "Peers - XS": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 7000}, + "Peers - S": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 210}, + "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 230}, + "Peers - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 210}, + "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 700, MaxMsPerOpCICD: 5500}, + "Users - L": {MinMsPerOpLocal: 170, MaxMsPerOpLocal: 210, MinMsPerOpCICD: 290, MaxMsPerOpCICD: 1700}, + "Setup Keys - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 125, MinMsPerOpCICD: 55, MaxMsPerOpCICD: 280}, + "Peers - XL": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 250}, + } + + log.SetOutput(io.Discard) + defer log.SetOutput(os.Stderr) + + recorder := httptest.NewRecorder() + + for name, bc := range benchCasesPeers { + b.Run(name, func(b *testing.B) { + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false) + testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), 1000, bc.Groups, bc.Users, bc.SetupKeys) + + b.ResetTimer() + start := time.Now() + for i := 0; i < b.N; i++ { + req := testing_tools.BuildRequest(b, nil, http.MethodDelete, "/api/peers/"+"oldpeer-"+strconv.Itoa(i), testing_tools.TestAdminId) + apiHandler.ServeHTTP(recorder, req) + } + + testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), expectedMetrics[name], recorder) + }) + } +} diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go index 9431b8c86..ef4815762 100644 --- a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -49,7 +49,7 @@ func BenchmarkCreateSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -95,7 +95,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -142,7 +142,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -176,7 +176,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -210,7 +210,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) { for name, bc := range benchCasesSetupKeys { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000) b.ResetTimer() diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go index e22a1a5a8..0d95c2635 100644 --- a/management/server/http/testing/integration/setupkeys_handler_integration_test.go +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -284,7 +284,7 @@ func Test_SetupKeys_Create(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(user.name+" - "+tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) body, err := json.Marshal(tc.requestBody) if err != nil { @@ -569,7 +569,7 @@ func Test_SetupKeys_Update(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) body, err := json.Marshal(tc.requestBody) if err != nil { @@ -748,7 +748,7 @@ func Test_SetupKeys_Get(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) @@ -900,7 +900,7 @@ func Test_SetupKeys_GetAll(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId) @@ -1084,7 +1084,7 @@ func Test_SetupKeys_Delete(t *testing.T) { for _, tc := range tt { for _, user := range users { t.Run(tc.name, func(t *testing.T) { - apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil) + apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true) req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId) diff --git a/management/server/http/testing/testdata/peers.sql b/management/server/http/testing/testdata/peers.sql new file mode 100644 index 000000000..03ff2d3d3 --- /dev/null +++ b/management/server/http/testing/testdata/peers.sql @@ -0,0 +1,22 @@ +CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); +INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',3,0); +INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,'0001-01-01 00:00:00+00:00','["testGroupId"]',5,1); + +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); + +INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,0,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go index da910c5c3..534836250 100644 --- a/management/server/http/testing/testing_tools/tools.go +++ b/management/server/http/testing/testing_tools/tools.go @@ -81,7 +81,7 @@ type PerformanceMetrics struct { MaxMsPerOpCICD float64 } -func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage) (http.Handler, server.AccountManager, chan struct{}) { +func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, server.AccountManager, chan struct{}) { store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir()) if err != nil { t.Fatalf("Failed to create test store: %v", err) @@ -96,14 +96,16 @@ func BuildApiBlackBoxWithDBState(t TB, sqlFile string, expectedPeerUpdate *serve peersUpdateManager := server.NewPeersUpdateManager(nil) updMsg := peersUpdateManager.CreateChannel(context.Background(), TestPeerId) done := make(chan struct{}) - go func() { - if expectedPeerUpdate != nil { - peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate) - } else { - peerShouldNotReceiveUpdate(t, updMsg) - } - close(done) - }() + if validateUpdate { + go func() { + if expectedPeerUpdate != nil { + peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate) + } else { + peerShouldNotReceiveUpdate(t, updMsg) + } + close(done) + }() + } geoMock := &geolocation.Mock{} validatorMock := server.MocIntegratedValidator{} From d9487a57497cf1d1ab0a1b871e6627870a300763 Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Fri, 3 Jan 2025 15:48:31 +0100 Subject: [PATCH 109/159] [misc] separate integration and benchmark test workflows (#3147) --- .github/workflows/golang-test-linux.yml | 98 +++++++++++++++++++ .../peers_handler_benchmark_test.go | 3 + .../setupkeys_handler_benchmark_test.go | 3 + .../users_handler_benchmark_test.go | 11 ++- .../setupkeys_handler_integration_test.go | 3 + 5 files changed, 114 insertions(+), 4 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 2de7a518a..742486234 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -234,6 +234,104 @@ jobs: - name: Test run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./... + api_benchmark: + needs: [ build-cache ] + strategy: + fail-fast: false + matrix: + arch: [ '386','amd64' ] + store: [ 'sqlite', 'postgres' ] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Install modules + run: go mod tidy + + - name: check git status + run: git --no-pager diff --exit-code + + - name: Test + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=benchmark -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) + + api_integration_test: + needs: [ build-cache ] + strategy: + fail-fast: false + matrix: + arch: [ '386','amd64' ] + store: [ 'sqlite', 'postgres'] + runs-on: ubuntu-22.04 + steps: + - name: Install Go + uses: actions/setup-go@v5 + with: + go-version: "1.23.x" + cache: false + + - name: Checkout code + uses: actions/checkout@v4 + + - name: Get Go environment + run: | + echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV + echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV + + - name: Cache Go modules + uses: actions/cache/restore@v4 + with: + path: | + ${{ env.cache }} + ${{ env.modcache }} + key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-gotest-cache- + + - name: Install dependencies + run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev + + - name: Install 32-bit libpcap + if: matrix.arch == '386' + run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386 + + - name: Install modules + run: go mod tidy + + - name: check git status + run: git --no-pager diff --exit-code + + - name: Test + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -tags=integration $(go list ./... | grep /management) + test_client_on_docker: needs: [ build-cache ] runs-on: ubuntu-20.04 diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index 829649798..e76370426 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -1,3 +1,6 @@ +//go:build benchmark +// +build benchmark + package benchmarks import ( diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go index ef4815762..bbdb4250b 100644 --- a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -1,3 +1,6 @@ +//go:build benchmark +// +build benchmark + package benchmarks import ( diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go index f453ba7c8..b62341995 100644 --- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -1,3 +1,6 @@ +//go:build benchmark +// +build benchmark + package benchmarks import ( @@ -49,7 +52,7 @@ func BenchmarkUpdateUser(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -98,7 +101,7 @@ func BenchmarkGetOneUser(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -132,7 +135,7 @@ func BenchmarkGetAllUsers(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys) b.ResetTimer() @@ -166,7 +169,7 @@ func BenchmarkDeleteUsers(b *testing.B) { for name, bc := range benchCasesUsers { b.Run(name, func(b *testing.B) { - apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil) + apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false) testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys) b.ResetTimer() diff --git a/management/server/http/testing/integration/setupkeys_handler_integration_test.go b/management/server/http/testing/integration/setupkeys_handler_integration_test.go index 0d95c2635..ed6e642a2 100644 --- a/management/server/http/testing/integration/setupkeys_handler_integration_test.go +++ b/management/server/http/testing/integration/setupkeys_handler_integration_test.go @@ -1,3 +1,6 @@ +//go:build integration +// +build integration + package integration import ( From 02a3feddb8713f450530f67465aa447c057d5afd Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Mon, 6 Jan 2025 15:38:30 +0300 Subject: [PATCH 110/159] [management] Add MySQL Support (#3108) * Add mysql store support * Add support to disable activity events recording --- .github/workflows/golang-test-linux.yml | 39 ++++++++- .../workflows/test-infrastructure-files.yml | 23 ++++- go.mod | 6 +- go.sum | 13 ++- infrastructure_files/configure.sh | 12 +++ infrastructure_files/docker-compose.yml.tmpl | 1 + .../docker-compose.yml.tmpl.traefik | 1 + management/server/account.go | 5 +- management/server/account_test.go | 44 +++++----- management/server/event.go | 38 +++++---- management/server/group_test.go | 1 + .../http/handlers/peers/peers_handler.go | 4 +- .../handlers/routes/routes_handler_test.go | 17 ++-- .../handlers/setup_keys/setupkeys_handler.go | 4 +- .../server/http/handlers/users/pat_handler.go | 9 +- .../http/handlers/users/pat_handler_test.go | 13 +-- .../server/http/middleware/auth_middleware.go | 2 +- .../http/middleware/auth_middleware_test.go | 5 +- .../server/http/testing/testdata/peers.sql | 24 +++--- .../http/testing/testdata/setup_keys.sql | 26 +++--- .../server/http/testing/testdata/users.sql | 24 +++--- .../http/testing/testing_tools/tools.go | 6 +- management/server/management_proto_test.go | 10 ++- management/server/migration/migration.go | 24 ++++-- management/server/peer.go | 7 +- management/server/peer/peer.go | 17 +++- management/server/peer_test.go | 15 ++-- management/server/route_test.go | 10 +-- management/server/setupkey_test.go | 8 +- management/server/store/file_store.go | 5 +- management/server/store/sql_store.go | 83 ++++++++++++++++--- management/server/store/store.go | 45 ++++++++-- management/server/testdata/extended-store.sql | 12 +-- management/server/testdata/store.sql | 10 +-- .../server/testdata/store_policy_migrate.sql | 10 +-- .../testdata/store_with_expired_peers.sql | 10 +-- management/server/testutil/store.go | 67 +++++++++++---- management/server/testutil/store_ios.go | 12 ++- .../server/types/personal_access_token.go | 24 +++++- management/server/types/setupkey.go | 31 +++++-- management/server/types/user.go | 18 ++-- management/server/user.go | 2 +- management/server/user_test.go | 7 +- management/server/util/util.go | 5 ++ 44 files changed, 525 insertions(+), 224 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 742486234..da1db5c03 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -142,7 +142,7 @@ jobs: fail-fast: false matrix: arch: [ '386','amd64' ] - store: [ 'sqlite', 'postgres'] + store: [ 'sqlite', 'postgres', 'mysql' ] runs-on: ubuntu-22.04 steps: - name: Install Go @@ -182,6 +182,17 @@ jobs: - name: check git status run: git --no-pager diff --exit-code + - name: Login to Docker hub + if: matrix.store == 'mysql' + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKER_USER }} + password: ${{ secrets.DOCKER_TOKEN }} + + - name: download mysql image + if: matrix.store == 'mysql' + run: docker pull mlsmaycon/warmed-mysql:8 + - name: Test run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) @@ -191,7 +202,7 @@ jobs: fail-fast: false matrix: arch: [ '386','amd64' ] - store: [ 'sqlite', 'postgres' ] + store: [ 'sqlite', 'postgres', 'mysql' ] runs-on: ubuntu-22.04 steps: - name: Install Go @@ -231,6 +242,17 @@ jobs: - name: check git status run: git --no-pager diff --exit-code + - name: Login to Docker hub + if: matrix.store == 'mysql' + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKER_USER }} + password: ${{ secrets.DOCKER_TOKEN }} + + - name: download mysql image + if: matrix.store == 'mysql' + run: docker pull mlsmaycon/warmed-mysql:8 + - name: Test run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m ./... @@ -240,7 +262,7 @@ jobs: fail-fast: false matrix: arch: [ '386','amd64' ] - store: [ 'sqlite', 'postgres' ] + store: [ 'sqlite', 'postgres', 'mysql' ] runs-on: ubuntu-22.04 steps: - name: Install Go @@ -280,6 +302,17 @@ jobs: - name: check git status run: git --no-pager diff --exit-code + - name: Login to Docker hub + if: matrix.store == 'mysql' + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKER_USER }} + password: ${{ secrets.DOCKER_TOKEN }} + + - name: download mysql image + if: matrix.store == 'mysql' + run: docker pull mlsmaycon/warmed-mysql:8 + - name: Test run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=benchmark -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index da3ec746a..5a3c6c22e 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -20,7 +20,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - store: [ 'sqlite', 'postgres' ] + store: [ 'sqlite', 'postgres', 'mysql' ] services: postgres: image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }} @@ -34,6 +34,19 @@ jobs: --health-timeout 5s ports: - 5432:5432 + mysql: + image: ${{ (matrix.store == 'mysql') && 'mysql' || '' }} + env: + MYSQL_USER: netbird + MYSQL_PASSWORD: mysql + MYSQL_ROOT_PASSWORD: mysqlroot + MYSQL_DATABASE: netbird + options: >- + --health-cmd "mysqladmin ping --silent" + --health-interval 10s + --health-timeout 5s + ports: + - 3306:3306 steps: - name: Set Database Connection String run: | @@ -42,6 +55,11 @@ jobs: else echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV fi + if [ "${{ matrix.store }}" == "mysql" ]; then + echo "NETBIRD_STORE_ENGINE_MYSQL_DSN=netbird:mysql@tcp($(hostname -I | awk '{print $1}'):3306)/netbird" >> $GITHUB_ENV + else + echo "NETBIRD_STORE_ENGINE_MYSQL_DSN==" >> $GITHUB_ENV + fi - name: Install jq run: sudo apt-get install -y jq @@ -84,6 +102,7 @@ jobs: CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified" CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }} + NETBIRD_STORE_ENGINE_MYSQL_DSN: ${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }} CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false - name: check values @@ -112,6 +131,7 @@ jobs: CI_NETBIRD_SIGNAL_PORT: 12345 CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }} NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$' + NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$' CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4" @@ -149,6 +169,7 @@ jobs: grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES" grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000" grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP + grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN" # check relay values grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml diff --git a/go.mod b/go.mod index d48280df0..9d7cb1fe4 100644 --- a/go.mod +++ b/go.mod @@ -77,6 +77,7 @@ require ( github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/stretchr/testify v1.9.0 github.com/testcontainers/testcontainers-go v0.31.0 + github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 github.com/things-go/go-socks5 v0.0.4 github.com/yusufpapurcu/wmi v1.2.4 @@ -96,9 +97,10 @@ require ( golang.org/x/term v0.27.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 + gorm.io/driver/mysql v1.5.7 gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.3 - gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde + gorm.io/gorm v1.25.7 nhooyr.io/websocket v1.8.11 ) @@ -107,6 +109,7 @@ require ( cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect dario.cat/mergo v1.0.0 // indirect + filippo.io/edwards25519 v1.1.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect github.com/BurntSushi/toml v1.4.0 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect @@ -151,6 +154,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/go-text/render v0.1.0 // indirect github.com/go-text/typesetting v0.1.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect diff --git a/go.sum b/go.sum index 540cbf20b..8383475a4 100644 --- a/go.sum +++ b/go.sum @@ -48,6 +48,8 @@ cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJpl dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= fyne.io/fyne/v2 v2.5.0 h1:lEjEIso0Vi4sJXYngIMoXOM6aUjqnPjK7pBpxRxG9aI= fyne.io/fyne/v2 v2.5.0/go.mod h1:9D4oT3NWeG+MLi/lP7ItZZyujHC/qqMJpoGTAYX5Uqc= fyne.io/systray v1.11.0 h1:D9HISlxSkx+jHSniMBR6fCFOUjk1x/OOOJLa9lJYAKg= @@ -238,6 +240,9 @@ github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7 github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= @@ -681,6 +686,8 @@ github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8 github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/testcontainers/testcontainers-go v0.31.0 h1:W0VwIhcEVhRflwL9as3dhY6jXjVCA27AkmbnZ+UTh3U= github.com/testcontainers/testcontainers-go v0.31.0/go.mod h1:D2lAoA0zUFiSY+eAflqK5mcUx/A5hrrORaEQrd0SefI= +github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0 h1:790+S8ewZYCbG+o8IiFlZ8ZZ33XbNO6zV9qhU6xhlRk= +github.com/testcontainers/testcontainers-go/modules/mysql v0.31.0/go.mod h1:REFmO+lSG9S6uSBEwIMZCxeI36uhScjTwChYADeO3JA= github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0 h1:isAwFS3KNKRbJMbWv+wolWqOFUECmjYZ+sIRZCIBc/E= github.com/testcontainers/testcontainers-go/modules/postgres v0.31.0/go.mod h1:ZNYY8vumNCEG9YI59A9d6/YaMY49uwRhmeU563EzFGw= github.com/things-go/go-socks5 v0.0.4 h1:jMQjIc+qhD4z9cITOMnBiwo9dDmpGuXmBlkRFrl/qD0= @@ -1225,12 +1232,14 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= +gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= gorm.io/driver/postgres v1.5.7 h1:8ptbNJTDbEmhdr62uReG5BGkdQyeasu/FZHxI0IMGnM= gorm.io/driver/postgres v1.5.7/go.mod h1:3e019WlBaYI5o5LIdNV+LyxCMNtLOQETBXL2h4chKpA= gorm.io/driver/sqlite v1.5.3 h1:7/0dUgX28KAcopdfbRWWl68Rflh6osa4rDh+m51KL2g= gorm.io/driver/sqlite v1.5.3/go.mod h1:qxAuCol+2r6PannQDpOP1FP6ag3mKi4esLnB/jHed+4= -gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg= -gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.7 h1:VsD6acwRjz2zFxGO50gPO6AkNs7KKnvfzUjHQhZDz/A= +gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs= diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index ff33004b2..d02e4f40c 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -53,6 +53,18 @@ if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "postgres" ]]; then export NETBIRD_STORE_ENGINE_POSTGRES_DSN fi +# Check if MySQL is set as the store engine +if [[ "$NETBIRD_STORE_CONFIG_ENGINE" == "mysql" ]]; then + # Exit if 'NETBIRD_STORE_ENGINE_MYSQL_DSN' is not set + if [[ -z "$NETBIRD_STORE_ENGINE_MYSQL_DSN" ]]; then + echo "Warning: NETBIRD_STORE_CONFIG_ENGINE=mysql but NETBIRD_STORE_ENGINE_MYSQL_DSN is not set." + echo "Please add the following line to your setup.env file:" + echo 'NETBIRD_STORE_ENGINE_MYSQL_DSN=":@tcp(127.0.0.1:3306)/"' + exit 1 + fi + export NETBIRD_STORE_ENGINE_MYSQL_DSN +fi + # local development or tests if [[ $NETBIRD_DOMAIN == "localhost" || $NETBIRD_DOMAIN == "127.0.0.1" ]]; then export NETBIRD_MGMT_SINGLE_ACCOUNT_MODE_DOMAIN="netbird.selfhosted" diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index ba68b3f8d..b7904fb5b 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -96,6 +96,7 @@ services: max-file: "2" environment: - NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN + - NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN # Coturn coturn: diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index c4415d848..7d51c4ffb 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -83,6 +83,7 @@ services: - traefik.http.services.netbird-management.loadbalancer.server.scheme=h2c environment: - NETBIRD_STORE_ENGINE_POSTGRES_DSN=$NETBIRD_STORE_ENGINE_POSTGRES_DSN + - NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN # Coturn coturn: diff --git a/management/server/account.go b/management/server/account.go index 6c8205f26..41da7f079 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -438,7 +438,6 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con } func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *types.Account, oldSettings, newSettings *types.Settings, userID, accountID string) error { - if newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration @@ -790,7 +789,7 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s if user.Issued == types.UserIssuedIntegration { continue } - users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) + users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero()) } log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id) userData, err := am.lookupCache(ctx, users, account.Id) @@ -1135,7 +1134,7 @@ func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string return fmt.Errorf("token not found") } - pat.LastUsed = time.Now().UTC() + pat.LastUsed = util.ToPtr(time.Now().UTC()) return am.Store.SaveAccount(ctx, account) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 39a748192..d8ceef0e7 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -16,6 +16,7 @@ import ( "time" "github.com/golang-jwt/jwt" + "github.com/netbirdio/netbird/management/server/util" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" @@ -145,7 +146,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { LoginExpired: true, }, UserID: userID, - LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30), + LastLogin: util.ToPtr(time.Now().UTC().Add(-time.Hour * 24 * 30 * 30)), }, "peer-2": { ID: peerID2, @@ -159,7 +160,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { LoginExpired: false, }, UserID: userID, - LastLogin: time.Now().UTC(), + LastLogin: util.ToPtr(time.Now().UTC()), LoginExpirationEnabled: true, }, }, @@ -183,7 +184,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { LoginExpired: true, }, UserID: userID, - LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30), + LastLogin: util.ToPtr(time.Now().UTC().Add(-time.Hour * 24 * 30 * 30)), LoginExpirationEnabled: true, }, "peer-2": { @@ -198,7 +199,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { LoginExpired: true, }, UserID: userID, - LastLogin: time.Now().UTC().Add(-time.Hour * 24 * 30 * 30), + LastLogin: util.ToPtr(time.Now().UTC().Add(-time.Hour * 24 * 30 * 30)), LoginExpirationEnabled: true, }, }, @@ -771,7 +772,6 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { "tokenId": { ID: "tokenId", HashedToken: encodedHashedToken, - LastUsed: time.Time{}, }, }, } @@ -793,7 +793,7 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { if err != nil { t.Fatalf("Error when getting account: %s", err) } - assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero()) + assert.True(t, !account.Users["someUser"].PATs["tokenId"].GetLastUsed().IsZero()) } func TestAccountManager_PrivateAccount(t *testing.T) { @@ -1054,7 +1054,7 @@ func genUsers(p string, n int) map[string]*types.User { users[fmt.Sprintf("%s-%d", p, i)] = &types.User{ Id: fmt.Sprintf("%s-%d", p, i), Role: types.UserRoleAdmin, - LastLogin: now, + LastLogin: util.ToPtr(now), CreatedAt: now, Issued: "api", AutoGroups: []string{"one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten"}, @@ -1706,10 +1706,10 @@ func TestAccount_Copy(t *testing.T) { ID: "pat1", Name: "First PAT", HashedToken: "SoMeHaShEdToKeN", - ExpirationDate: time.Now().UTC().AddDate(0, 0, 7), + ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)), CreatedBy: "user1", CreatedAt: time.Now().UTC(), - LastUsed: time.Now().UTC(), + LastUsed: util.ToPtr(time.Now().UTC()), }, }, }, @@ -2065,7 +2065,7 @@ func TestAccount_GetExpiredPeers(t *testing.T) { Connected: true, LoginExpired: false, }, - LastLogin: time.Now().UTC().Add(-30 * time.Minute), + LastLogin: util.ToPtr(time.Now().UTC().Add(-30 * time.Minute)), UserID: userID, }, "peer-2": { @@ -2076,7 +2076,7 @@ func TestAccount_GetExpiredPeers(t *testing.T) { Connected: true, LoginExpired: false, }, - LastLogin: time.Now().UTC().Add(-2 * time.Hour), + LastLogin: util.ToPtr(time.Now().UTC().Add(-2 * time.Hour)), UserID: userID, }, @@ -2088,7 +2088,7 @@ func TestAccount_GetExpiredPeers(t *testing.T) { Connected: true, LoginExpired: false, }, - LastLogin: time.Now().UTC().Add(-1 * time.Hour), + LastLogin: util.ToPtr(time.Now().UTC().Add(-1 * time.Hour)), UserID: userID, }, }, @@ -2150,7 +2150,7 @@ func TestAccount_GetInactivePeers(t *testing.T) { Connected: false, LoginExpired: false, }, - LastLogin: time.Now().UTC().Add(-30 * time.Minute), + LastLogin: util.ToPtr(time.Now().UTC().Add(-30 * time.Minute)), UserID: userID, }, "peer-2": { @@ -2161,7 +2161,7 @@ func TestAccount_GetInactivePeers(t *testing.T) { Connected: false, LoginExpired: false, }, - LastLogin: time.Now().UTC().Add(-2 * time.Hour), + LastLogin: util.ToPtr(time.Now().UTC().Add(-2 * time.Hour)), UserID: userID, }, "peer-3": { @@ -2172,7 +2172,7 @@ func TestAccount_GetInactivePeers(t *testing.T) { Connected: true, LoginExpired: false, }, - LastLogin: time.Now().UTC().Add(-1 * time.Hour), + LastLogin: util.ToPtr(time.Now().UTC().Add(-1 * time.Hour)), UserID: userID, }, }, @@ -2442,7 +2442,7 @@ func TestAccount_GetNextPeerExpiration(t *testing.T) { LoginExpired: false, }, LoginExpirationEnabled: true, - LastLogin: time.Now().UTC(), + LastLogin: util.ToPtr(time.Now().UTC()), UserID: userID, }, "peer-2": { @@ -2602,7 +2602,7 @@ func TestAccount_GetNextInactivePeerExpiration(t *testing.T) { LastSeen: time.Now().Add(-1 * time.Second), }, InactivityExpirationEnabled: true, - LastLogin: time.Now().UTC(), + LastLogin: util.ToPtr(time.Now().UTC()), UserID: userID, }, "peer-2": { @@ -2680,8 +2680,8 @@ func TestAccount_SetJWTGroups(t *testing.T) { }, Settings: &types.Settings{GroupsPropagationEnabled: true, JWTGroupsEnabled: true, JWTGroupsClaimName: "groups"}, Users: map[string]*types.User{ - "user1": {Id: "user1", AccountID: "accountID"}, - "user2": {Id: "user2", AccountID: "accountID"}, + "user1": {Id: "user1", AccountID: "accountID", CreatedAt: time.Now()}, + "user2": {Id: "user2", AccountID: "accountID", CreatedAt: time.Now()}, }, } @@ -3021,8 +3021,8 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 1, 3, 3, 11}, - {"Medium", 500, 100, 7, 13, 10, 70}, + {"Small", 50, 5, 1, 3, 3, 14}, + {"Medium", 500, 100, 7, 13, 10, 80}, {"Large", 5000, 200, 65, 80, 60, 220}, {"Small single", 50, 10, 1, 3, 3, 70}, {"Medium single", 500, 10, 7, 13, 10, 32}, @@ -3164,7 +3164,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { }{ {"Small", 50, 5, 107, 120, 107, 160}, {"Medium", 500, 100, 105, 140, 105, 220}, - {"Large", 5000, 200, 180, 220, 180, 350}, + {"Large", 5000, 200, 180, 220, 180, 395}, {"Small single", 50, 10, 107, 120, 105, 160}, {"Medium single", 500, 10, 105, 140, 105, 170}, {"Large 5", 5000, 15, 180, 220, 180, 340}, diff --git a/management/server/event.go b/management/server/event.go index 93b809226..788d1b51c 100644 --- a/management/server/event.go +++ b/management/server/event.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + "os" "time" log "github.com/sirupsen/logrus" @@ -11,6 +12,11 @@ import ( "github.com/netbirdio/netbird/management/server/status" ) +func isEnabled() bool { + response := os.Getenv("NB_EVENT_ACTIVITY_LOG_ENABLED") + return response == "" || response == "true" +} + // GetEvents returns a list of activity events of an account func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userID string) ([]*activity.Event, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -56,20 +62,20 @@ func (am *DefaultAccountManager) GetEvents(ctx context.Context, accountID, userI } func (am *DefaultAccountManager) StoreEvent(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { - - go func() { - _, err := am.eventStore.Save(ctx, &activity.Event{ - Timestamp: time.Now().UTC(), - Activity: activityID, - InitiatorID: initiatorID, - TargetID: targetID, - AccountID: accountID, - Meta: meta, - }) - if err != nil { - // todo add metric - log.WithContext(ctx).Errorf("received an error while storing an activity event, error: %s", err) - } - }() - + if isEnabled() { + go func() { + _, err := am.eventStore.Save(ctx, &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activityID, + InitiatorID: initiatorID, + TargetID: targetID, + AccountID: accountID, + Meta: meta, + }) + if err != nil { + // todo add metric + log.WithContext(ctx).Errorf("received an error while storing an activity event, error: %s", err) + } + }() + } } diff --git a/management/server/group_test.go b/management/server/group_test.go index 834388d1e..cc90f187b 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -355,6 +355,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *t setupKey := &types.SetupKey{ Id: "example setup key", AutoGroups: []string{groupForSetupKeys.ID}, + UpdatedAt: time.Now(), } user := &types.User{ diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 5bc616e47..7eb8e2153 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -349,7 +349,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD UiVersion: peer.Meta.UIVersion, DnsLabel: fqdn(peer, dnsDomain), LoginExpirationEnabled: peer.LoginExpirationEnabled, - LastLogin: peer.LastLogin, + LastLogin: peer.GetLastLogin(), LoginExpired: peer.Status.LoginExpired, ApprovalRequired: !approved, CountryCode: peer.Location.CountryCode, @@ -383,7 +383,7 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn UiVersion: peer.Meta.UIVersion, DnsLabel: fqdn(peer, dnsDomain), LoginExpirationEnabled: peer.LoginExpirationEnabled, - LastLogin: peer.LastLogin, + LastLogin: peer.GetLastLogin(), LoginExpired: peer.Status.LoginExpired, AccessiblePeersCount: accessiblePeersCount, CountryCode: peer.Location.CountryCode, diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index 4cee3ee30..45c465587 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -11,6 +11,7 @@ import ( "net/netip" "testing" + "github.com/netbirdio/netbird/management/server/util" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/server/http/api" @@ -239,7 +240,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "awesomeNet", - Network: toPtr("192.168.0.0/16"), + Network: util.ToPtr("192.168.0.0/16"), Peer: &existingPeerID, NetworkType: route.IPv4NetworkString, Masquerade: false, @@ -259,7 +260,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "domainNet", - Network: toPtr("invalid Prefix"), + Network: util.ToPtr("invalid Prefix"), KeepRoute: true, Domains: &[]string{existingDomain}, Peer: &existingPeerID, @@ -281,7 +282,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "awesomeNet", - Network: toPtr("192.168.0.0/16"), + Network: util.ToPtr("192.168.0.0/16"), Peer: &existingPeerID, NetworkType: route.IPv4NetworkString, Masquerade: false, @@ -385,7 +386,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "awesomeNet", - Network: toPtr("192.168.0.0/16"), + Network: util.ToPtr("192.168.0.0/16"), Peer: &existingPeerID, NetworkType: route.IPv4NetworkString, Masquerade: false, @@ -404,7 +405,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "awesomeNet", - Network: toPtr("invalid Prefix"), + Network: util.ToPtr("invalid Prefix"), Domains: &[]string{existingDomain}, Peer: &existingPeerID, NetworkType: route.DomainNetworkString, @@ -425,7 +426,7 @@ func TestRoutesHandlers(t *testing.T) { Id: existingRouteID, Description: "Post", NetworkId: "awesomeNet", - Network: toPtr("192.168.0.0/16"), + Network: util.ToPtr("192.168.0.0/16"), Peer: &emptyString, PeerGroups: &[]string{existingGroupID}, NetworkType: route.IPv4NetworkString, @@ -663,7 +664,3 @@ func toApiRoute(t *testing.T, r *route.Route) *api.Route { require.NoError(t, err, "Failed to convert route") return apiRoute } - -func toPtr[T any](v T) *T { - return &v -} diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index a627d7203..67e296901 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -240,12 +240,12 @@ func ToResponseBody(key *types.SetupKey) *api.SetupKey { Id: key.Id, Key: key.KeySecret, Name: key.Name, - Expires: key.ExpiresAt, + Expires: key.GetExpiresAt(), Type: string(key.Type), Valid: key.IsValid(), Revoked: key.Revoked, UsedTimes: key.UsedTimes, - LastUsed: key.LastUsed, + LastUsed: key.GetLastUsed(), State: state, AutoGroups: key.AutoGroups, UpdatedAt: key.UpdatedAt, diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go index 197785b34..7b93d2ae1 100644 --- a/management/server/http/handlers/users/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -3,7 +3,6 @@ package users import ( "encoding/json" "net/http" - "time" "github.com/gorilla/mux" @@ -166,17 +165,13 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { } func toPATResponse(pat *types.PersonalAccessToken) *api.PersonalAccessToken { - var lastUsed *time.Time - if !pat.LastUsed.IsZero() { - lastUsed = &pat.LastUsed - } return &api.PersonalAccessToken{ CreatedAt: pat.CreatedAt, CreatedBy: pat.CreatedBy, Name: pat.Name, - ExpirationDate: pat.ExpirationDate, + ExpirationDate: pat.GetExpirationDate(), Id: pat.ID, - LastUsed: lastUsed, + LastUsed: pat.LastUsed, } } diff --git a/management/server/http/handlers/users/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go index 21bdc461e..9388067a4 100644 --- a/management/server/http/handlers/users/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -12,6 +12,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server/util" "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/http/api" @@ -42,19 +43,19 @@ var testAccount = &types.Account{ ID: existingTokenID, Name: "My first token", HashedToken: "someHash", - ExpirationDate: time.Now().UTC().AddDate(0, 0, 7), + ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)), CreatedBy: existingUserID, CreatedAt: time.Now().UTC(), - LastUsed: time.Now().UTC(), + LastUsed: util.ToPtr(time.Now().UTC()), }, "token2": { ID: "token2", Name: "My second token", HashedToken: "someOtherHash", - ExpirationDate: time.Now().UTC().AddDate(0, 0, 7), + ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)), CreatedBy: existingUserID, CreatedAt: time.Now().UTC(), - LastUsed: time.Now().UTC(), + LastUsed: util.ToPtr(time.Now().UTC()), }, }, }, @@ -248,8 +249,8 @@ func toTokenResponse(serverToken types.PersonalAccessToken) api.PersonalAccessTo Id: serverToken.ID, Name: serverToken.Name, CreatedAt: serverToken.CreatedAt, - LastUsed: &serverToken.LastUsed, + LastUsed: serverToken.LastUsed, CreatedBy: serverToken.CreatedBy, - ExpirationDate: serverToken.ExpirationDate, + ExpirationDate: serverToken.GetExpirationDate(), } } diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 0a54cbaed..182c30cf6 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -161,7 +161,7 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ if err != nil { return fmt.Errorf("invalid Token: %w", err) } - if time.Now().After(pat.ExpirationDate) { + if time.Now().After(pat.GetExpirationDate()) { return fmt.Errorf("token expired") } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index b0d970c5d..41bdb7fc5 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/golang-jwt/jwt" + "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -39,10 +40,10 @@ var testAccount = &types.Account{ ID: tokenID, Name: "My first token", HashedToken: "someHash", - ExpirationDate: time.Now().UTC().AddDate(0, 0, 7), + ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)), CreatedBy: userID, CreatedAt: time.Now().UTC(), - LastUsed: time.Now().UTC(), + LastUsed: util.ToPtr(time.Now().UTC()), }, }, }, diff --git a/management/server/http/testing/testdata/peers.sql b/management/server/http/testing/testdata/peers.sql index 03ff2d3d3..863eda520 100644 --- a/management/server/http/testing/testdata/peers.sql +++ b/management/server/http/testing/testdata/peers.sql @@ -1,21 +1,21 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); -CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',1,0); -INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',3,0); -INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,'0001-01-01 00:00:00+00:00','["testGroupId"]',5,1); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,NULL,'["testGroupId"]',3,0); +INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,NULL,'["testGroupId"]',5,1); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); diff --git a/management/server/http/testing/testdata/setup_keys.sql b/management/server/http/testing/testdata/setup_keys.sql index a315ea0f7..6d30fb5fe 100644 --- a/management/server/http/testing/testdata/setup_keys.sql +++ b/management/server/http/testing/testdata/setup_keys.sql @@ -1,24 +1,24 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); -CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime DEFAULT NULL,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); -CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime DEFAULT NULL,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',1,0); -INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',3,0); -INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,'0001-01-01 00:00:00+00:00','["testGroupId"]',5,1); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,NULL,'["testGroupId"]',3,0); +INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,NULL,'["testGroupId"]',5,1); diff --git a/management/server/http/testing/testdata/users.sql b/management/server/http/testing/testdata/users.sql index da7cae565..346f7b7ac 100644 --- a/management/server/http/testing/testdata/users.sql +++ b/management/server/http/testing/testdata/users.sql @@ -1,23 +1,23 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`key_secret` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); INSERT INTO accounts VALUES('testAccountId','','2024-10-02 16:01:38.000000000+00:00','test.com','private',1,'testNetworkIdentifier','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO "groups" VALUES('testGroupId','testAccountId','testGroupName','api','[]',0,''); INSERT INTO "groups" VALUES('newGroupId','testAccountId','newGroupName','api','[]',0,''); -INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',1,0); -INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,'0001-01-01 00:00:00+00:00','["testGroupId"]',3,0); -INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,'0001-01-01 00:00:00+00:00','["testGroupId"]',5,1); +INSERT INTO setup_keys VALUES('testKeyId','testAccountId','testKey','testK****','existingKey','one-off','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:000',0,0,NULL,'["testGroupId"]',1,0); +INSERT INTO setup_keys VALUES('revokedKeyId','testAccountId','revokedKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','2321-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',1,0,NULL,'["testGroupId"]',3,0); +INSERT INTO setup_keys VALUES('expiredKeyId','testAccountId','expiredKey','testK****','existingKey','reusable','2021-08-19 20:46:20.000000000+00:00','1921-09-18 20:46:20.000000000+00:00','2021-08-19 20:46:20.000000000+00:00',0,1,NULL,'["testGroupId"]',5,1); INSERT INTO peers VALUES('testPeerId','testAccountId','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); -CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); -INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testUserId','testAccountId','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testAdminId','testAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testOwnerId','testAccountId','owner',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceUserId','testAccountId','user',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('testServiceAdminId','testAccountId','admin',1,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('blockedUserId','testAccountId','admin',0,0,'','[]',1,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); +INSERT INTO users VALUES('otherUserId','otherAccountId','admin',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.000000000+00:00','api',0,''); diff --git a/management/server/http/testing/testing_tools/tools.go b/management/server/http/testing/testing_tools/tools.go index 534836250..006d5679c 100644 --- a/management/server/http/testing/testing_tools/tools.go +++ b/management/server/http/testing/testing_tools/tools.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/netbirdio/netbird/management/server/util" "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -199,7 +200,7 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro DNSLabel: fmt.Sprintf("oldpeer-%d", i), Key: peerKey.PublicKey().String(), IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: TestUserId, } account.Peers[peer.ID] = peer @@ -220,7 +221,8 @@ func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, gro Id: fmt.Sprintf("oldkey-%d", i), AccountID: account.Id, AutoGroups: []string{"someGroupID"}, - ExpiresAt: time.Now().Add(ExpiresIn * time.Second), + UpdatedAt: time.Now().UTC(), + ExpiresAt: util.ToPtr(time.Now().Add(ExpiresIn * time.Second)), Name: NewKeyName + strconv.Itoa(i), Type: "reusable", UsageLimit: 0, diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index c66423736..8147afa44 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -475,8 +475,14 @@ func createRawClient(addr string) (mgmtProto.ManagementServiceClient, *grpc.Clie func Test_SyncStatusRace(t *testing.T) { t.Skip() - if os.Getenv("CI") == "true" && os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" { - t.Skip("Skipping on CI and Postgres store") + if os.Getenv("CI") == "true" { + if os.Getenv("NETBIRD_STORE_ENGINE") == "postgres" { + t.Skip("Skipping on CI and Postgres store") + } + + if os.Getenv("NETBIRD_STORE_ENGINE") == "mysql" { + t.Skip("Skipping on CI and MySQL store") + } } for i := 0; i < 500; i++ { t.Run(fmt.Sprintf("TestRun-%d", i), func(t *testing.T) { diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index 6f12d94b4..0a6951736 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -17,12 +17,19 @@ import ( "gorm.io/gorm" ) +func GetColumnName(db *gorm.DB, column string) string { + if db.Name() == "mysql" { + return fmt.Sprintf("`%s`", column) + } + return column +} + // MigrateFieldFromGobToJSON migrates a column from Gob encoding to JSON encoding. // T is the type of the model that contains the field to be migrated. // S is the type of the field to be migrated. func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, fieldName string) error { - - oldColumnName := fieldName + orgColumnName := fieldName + oldColumnName := GetColumnName(db, orgColumnName) newColumnName := fieldName + "_tmp" var model T @@ -72,7 +79,7 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f for _, row := range rows { var field S - str, ok := row[oldColumnName].(string) + str, ok := row[orgColumnName].(string) if !ok { return fmt.Errorf("type assertion failed") } @@ -111,7 +118,8 @@ func MigrateFieldFromGobToJSON[T any, S any](ctx context.Context, db *gorm.DB, f // MigrateNetIPFieldFromBlobToJSON migrates a Net IP column from Blob encoding to JSON encoding. // T is the type of the model that contains the field to be migrated. func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fieldName string, indexName string) error { - oldColumnName := fieldName + orgColumnName := fieldName + oldColumnName := GetColumnName(db, orgColumnName) newColumnName := fieldName + "_tmp" var model T @@ -163,7 +171,7 @@ func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fi for _, row := range rows { var blobValue string - if columnValue := row[oldColumnName]; columnValue != nil { + if columnValue := row[orgColumnName]; columnValue != nil { value, ok := columnValue.(string) if !ok { return fmt.Errorf("type assertion failed") @@ -210,7 +218,8 @@ func MigrateNetIPFieldFromBlobToJSON[T any](ctx context.Context, db *gorm.DB, fi } func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) error { - oldColumnName := "key" + orgColumnName := "key" + oldColumnName := GetColumnName(db, orgColumnName) newColumnName := "key_secret" var model T @@ -250,8 +259,9 @@ func MigrateSetupKeyToHashedSetupKey[T any](ctx context.Context, db *gorm.DB) er } for _, row := range rows { + var plainKey string - if columnValue := row[oldColumnName]; columnValue != nil { + if columnValue := row[orgColumnName]; columnValue != nil { value, ok := columnValue.(string) if !ok { return fmt.Errorf("type assertion failed") diff --git a/management/server/peer.go b/management/server/peer.go index ad20d279a..6c4c2d100 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -11,6 +11,7 @@ import ( "sync" "time" + "github.com/netbirdio/netbird/management/server/util" "github.com/rs/xid" log "github.com/sirupsen/logrus" @@ -511,7 +512,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, SSHEnabled: false, SSHKey: peer.SSHKey, - LastLogin: registrationTime, + LastLogin: util.ToPtr(registrationTime), CreatedAt: registrationTime, LoginExpirationEnabled: addedByUser, Ephemeral: ephemeral, @@ -566,7 +567,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } if addedByUser { - err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin) + err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.GetLastLogin()) if err != nil { return fmt.Errorf("failed to update user last login: %w", err) } @@ -911,7 +912,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *ty return err } - err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin) + err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.GetLastLogin()) if err != nil { return err } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 34d791844..355d78ce0 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -6,6 +6,8 @@ import ( "slices" "sort" "time" + + "github.com/netbirdio/netbird/management/server/util" ) // Peer represents a machine connected to the network. @@ -40,7 +42,7 @@ type Peer struct { InactivityExpirationEnabled bool // LastLogin the time when peer performed last login operation - LastLogin time.Time + LastLogin *time.Time // CreatedAt records the time the peer was created CreatedAt time.Time // Indicate ephemeral peer attribute @@ -222,6 +224,15 @@ func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) bool { return true } +// GetLastLogin returns the last login time of the peer. +func (p *Peer) GetLastLogin() time.Time { + if p.LastLogin != nil { + return *p.LastLogin + } + return time.Time{} + +} + // MarkLoginExpired marks peer's status expired or not func (p *Peer) MarkLoginExpired(expired bool) { newStatus := p.Status.Copy() @@ -258,7 +269,7 @@ func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) { if !p.AddedWithSSOLogin() || !p.LoginExpirationEnabled { return false, 0 } - expiresAt := p.LastLogin.Add(expiresIn) + expiresAt := p.GetLastLogin().Add(expiresIn) now := time.Now() timeLeft := expiresAt.Sub(now) return timeLeft <= 0, timeLeft @@ -291,7 +302,7 @@ func (p *PeerStatus) Copy() *PeerStatus { // UpdateLastLogin and set login expired false func (p *Peer) UpdateLastLogin() *Peer { - p.LastLogin = time.Now().UTC() + p.LastLogin = util.ToPtr(time.Now().UTC()) newStatus := p.Status.Copy() newStatus.LoginExpired = false p.Status = newStatus diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 9ad67d2bf..5f500c226 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/netbirdio/netbird/management/server/util" "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -80,7 +81,7 @@ func TestPeer_LoginExpired(t *testing.T) { t.Run(c.name, func(t *testing.T) { peer := &nbpeer.Peer{ LoginExpirationEnabled: c.expirationEnabled, - LastLogin: c.lastLogin, + LastLogin: util.ToPtr(c.lastLogin), UserID: userID, } @@ -141,7 +142,7 @@ func TestPeer_SessionExpired(t *testing.T) { } peer := &nbpeer.Peer{ InactivityExpirationEnabled: c.expirationEnabled, - LastLogin: c.lastLogin, + LastLogin: util.ToPtr(c.lastLogin), Status: peerStatus, UserID: userID, } @@ -744,7 +745,7 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou DNSLabel: fmt.Sprintf("peer-%d", i), Key: peerKey.PublicKey().String(), IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)), - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: regularUser, } account.Peers[peer.ID] = peer @@ -783,7 +784,7 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou DNSLabel: fmt.Sprintf("peer-nr-%d", len(account.Peers)+1), Key: peerKey.PublicKey().String(), IP: peerIP, - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, UserID: regularUser, Meta: nbpeer.PeerSystemMeta{ Hostname: fmt.Sprintf("peer-nr-%d", len(account.Peers)+1), @@ -1209,7 +1210,7 @@ func Test_RegisterPeerByUser(t *testing.T) { UserID: existingUserID, Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, SSHEnabled: false, - LastLogin: time.Now(), + LastLogin: util.ToPtr(time.Now()), } addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer) @@ -1231,7 +1232,7 @@ func Test_RegisterPeerByUser(t *testing.T) { lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") assert.NoError(t, err) - assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin) + assert.NotEqual(t, lastLogin, account.Users[existingUserID].GetLastLogin()) } func Test_RegisterPeerBySetupKey(t *testing.T) { @@ -1361,7 +1362,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { hashedKey := sha256.Sum256([]byte(faultyKey)) encodedHashedKey := b64.StdEncoding.EncodeToString(hashedKey[:]) - assert.Equal(t, lastUsed, account.SetupKeys[encodedHashedKey].LastUsed.UTC()) + assert.Equal(t, lastUsed, account.SetupKeys[encodedHashedKey].GetLastUsed().UTC()) assert.Equal(t, 0, account.SetupKeys[encodedHashedKey].UsedTimes) } diff --git a/management/server/route_test.go b/management/server/route_test.go index 2ef2b0173..48dbd11ed 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1307,7 +1307,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou WtVersion: "development", UIVersion: "development", }, - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, } account.Peers[peer1.ID] = peer1 @@ -1334,7 +1334,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou WtVersion: "development", UIVersion: "development", }, - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, } account.Peers[peer2.ID] = peer2 @@ -1361,7 +1361,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou WtVersion: "development", UIVersion: "development", }, - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, } account.Peers[peer3.ID] = peer3 @@ -1388,7 +1388,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou WtVersion: "development", UIVersion: "development", }, - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, } account.Peers[peer4.ID] = peer4 @@ -1415,7 +1415,7 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*types.Accou WtVersion: "development", UIVersion: "development", }, - Status: &nbpeer.PeerStatus{}, + Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true}, } account.Peers[peer5.ID] = peer5 diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index f728db5d4..e225ec54b 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -66,7 +66,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { t.Fatal(err) } - assertKey(t, newKey, keyName, revoked, "reusable", 0, key.CreatedAt, key.ExpiresAt, + assertKey(t, newKey, keyName, revoked, "reusable", 0, key.CreatedAt, key.GetExpiresAt(), key.Id, time.Now().UTC(), autoGroups, true) // check the corresponding events that should have been generated @@ -336,8 +336,8 @@ func assertKey(t *testing.T, key *types.SetupKey, expectedName string, expectedR t.Errorf("expected setup key to have UsedTimes = %v, got %v", expectedUsedTimes, key.UsedTimes) } - if key.ExpiresAt.Sub(expectedExpiresAt).Round(time.Hour) != 0 { - t.Errorf("expected setup key to have ExpiresAt ~ %v, got %v", expectedExpiresAt, key.ExpiresAt) + if key.GetExpiresAt().Sub(expectedExpiresAt).Round(time.Hour) != 0 { + t.Errorf("expected setup key to have ExpiresAt ~ %v, got %v", expectedExpiresAt, key.GetExpiresAt()) } if key.UpdatedAt.Sub(expectedUpdatedAt).Round(time.Hour) != 0 { @@ -391,7 +391,7 @@ func TestSetupKey_Copy(t *testing.T) { key, _ := types.GenerateSetupKey("key name", types.SetupKeyOneOff, time.Hour, []string{}, types.SetupKeyUnlimitedUsage, false) keyCopy := key.Copy() - assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.ExpiresAt, key.Id, + assertKey(t, keyCopy, key.Name, key.Revoked, string(key.Type), key.UsedTimes, key.CreatedAt, key.GetExpiresAt(), key.Id, key.UpdatedAt, key.AutoGroups, true) } diff --git a/management/server/store/file_store.go b/management/server/store/file_store.go index f40a0392b..4c9134e41 100644 --- a/management/server/store/file_store.go +++ b/management/server/store/file_store.go @@ -14,6 +14,7 @@ import ( nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/types" + nbutil "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/util" ) @@ -175,8 +176,8 @@ func restore(ctx context.Context, file string) (*FileStore, error) { migrationPeers := make(map[string]*nbpeer.Peer) // key to Peer for key, peer := range account.Peers { // set LastLogin for the peers that were onboarded before the peer login expiration feature - if peer.LastLogin.IsZero() { - peer.LastLogin = time.Now().UTC() + if peer.GetLastLogin().IsZero() { + peer.LastLogin = nbutil.ToPtr(time.Now().UTC()) } if peer.ID != "" { continue diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 62b004f9c..7b1a63411 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -16,6 +16,7 @@ import ( "time" log "github.com/sirupsen/logrus" + "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" @@ -39,6 +40,7 @@ const ( storeSqliteFileName = "store.db" idQueryCondition = "id = ?" keyQueryCondition = "key = ?" + mysqlKeyQueryCondition = "`key` = ?" accountAndIDQueryCondition = "account_id = ? and id = ?" accountAndIDsQueryCondition = "account_id = ? AND id IN ?" accountIDCondition = "account_id = ?" @@ -101,6 +103,13 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine Engine, metrics t return &SqlStore{db: db, storeEngine: storeEngine, metrics: metrics, installationPK: 1}, nil } +func GetKeyQueryCondition(s *SqlStore) string { + if s.storeEngine == MysqlStoreEngine { + return mysqlKeyQueryCondition + } + return keyQueryCondition +} + // AcquireGlobalLock acquires global lock across all the accounts and returns a function that releases the lock func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { log.WithContext(ctx).Tracef("acquiring global lock") @@ -397,7 +406,7 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P return result.Error } - if result.RowsAffected == 0 { + if result.RowsAffected == 0 && s.storeEngine != MysqlStoreEngine { return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID) } @@ -487,7 +496,7 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*types.Account, error) { var key types.SetupKey - result := s.db.Select("account_id").First(&key, keyQueryCondition, setupKey) + result := s.db.Select("account_id").First(&key, GetKeyQueryCondition(s), setupKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewSetupKeyNotFoundError(setupKey) @@ -734,7 +743,8 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*type func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*types.Account, error) { var peer nbpeer.Peer - result := s.db.Select("account_id").First(&peer, keyQueryCondition, peerKey) + result := s.db.Select("account_id").First(&peer, GetKeyQueryCondition(s), peerKey) + if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -752,7 +762,7 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { var peer nbpeer.Peer var accountID string - result := s.db.Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID) + result := s.db.Model(&peer).Select("account_id").Where(GetKeyQueryCondition(s), peerKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -778,7 +788,7 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { var accountID string - result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) + result := s.db.Model(&types.SetupKey{}).Select("account_id").Where(GetKeyQueryCondition(s), setupKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.NewSetupKeyNotFoundError(setupKey) @@ -851,7 +861,8 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { var peer nbpeer.Peer - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey) + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, GetKeyQueryCondition(s), peerKey) + if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") @@ -883,9 +894,13 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri } return status.NewGetUserFromStoreError() } - user.LastLogin = lastLogin - return s.db.Save(&user).Error + if !lastLogin.IsZero() { + user.LastLogin = &lastLogin + return s.db.Save(&user).Error + } + + return nil } func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { @@ -944,6 +959,16 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe return NewSqlStore(ctx, db, PostgresStoreEngine, metrics) } +// NewMysqlStore creates a new MySQL store. +func NewMysqlStore(ctx context.Context, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { + db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), getGormConfig()) + if err != nil { + return nil, err + } + + return NewSqlStore(ctx, db, MysqlStoreEngine, metrics) +} + func getGormConfig() *gorm.Config { return &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), @@ -961,6 +986,15 @@ func newPostgresStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, return NewPostgresqlStore(ctx, dsn, metrics) } +// newMysqlStore initializes a new MySQL store. +func newMysqlStore(ctx context.Context, metrics telemetry.AppMetrics) (Store, error) { + dsn, ok := os.LookupEnv(mysqlDsnEnv) + if !ok { + return nil, fmt.Errorf("%s is not set", mysqlDsnEnv) + } + return NewMysqlStore(ctx, dsn, metrics) +} + // NewSqliteStoreFromFileStore restores a store from FileStore and stores SQLite DB in the file located in datadir. func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, dataDir string, metrics telemetry.AppMetrics) (*SqlStore, error) { store, err := NewSqliteStore(ctx, dataDir, metrics) @@ -1005,10 +1039,33 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, return store, nil } +// NewMysqlStoreFromSqlStore restores a store from SqlStore and stores MySQL DB. +func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { + store, err := NewMysqlStore(ctx, dsn, metrics) + if err != nil { + return nil, err + } + + err = store.SaveInstallationID(ctx, sqliteStore.GetInstallationID()) + if err != nil { + return nil, err + } + + for _, account := range sqliteStore.GetAllAccounts(ctx) { + err := store.SaveAccount(ctx, account) + if err != nil { + return nil, err + } + } + + return store, nil +} + func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { var setupKey types.SetupKey result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&setupKey, keyQueryCondition, key) + First(&setupKey, GetKeyQueryCondition(s), key) + if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewSetupKeyNotFoundError(key) @@ -1300,9 +1357,13 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren // TODO: This fix is accepted for now, but if we need to handle this more frequently // we may need to reconsider changing the types. query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations) - if s.storeEngine == PostgresStoreEngine { + + switch s.storeEngine { + case PostgresStoreEngine: query = query.Order("json_array_length(peers::json) DESC") - } else { + case MysqlStoreEngine: + query = query.Order("JSON_LENGTH(JSON_EXTRACT(peers, \"$\")) DESC") + default: query = query.Order("json_array_length(peers) DESC") } diff --git a/management/server/store/store.go b/management/server/store/store.go index d9dc6b8f7..e1a6937e7 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -18,6 +18,7 @@ import ( "gorm.io/gorm" "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/testutil" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/telemetry" @@ -29,7 +30,6 @@ import ( networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/management/server/testutil" "github.com/netbirdio/netbird/route" ) @@ -171,8 +171,10 @@ const ( FileStoreEngine Engine = "jsonfile" SqliteStoreEngine Engine = "sqlite" PostgresStoreEngine Engine = "postgres" + MysqlStoreEngine Engine = "mysql" postgresDsnEnv = "NETBIRD_STORE_ENGINE_POSTGRES_DSN" + mysqlDsnEnv = "NETBIRD_STORE_ENGINE_MYSQL_DSN" ) func getStoreEngineFromEnv() Engine { @@ -183,7 +185,7 @@ func getStoreEngineFromEnv() Engine { } value := Engine(strings.ToLower(kind)) - if value == SqliteStoreEngine || value == PostgresStoreEngine { + if value == SqliteStoreEngine || value == PostgresStoreEngine || value == MysqlStoreEngine { return value } @@ -234,6 +236,9 @@ func NewStore(ctx context.Context, kind Engine, dataDir string, metrics telemetr case PostgresStoreEngine: log.WithContext(ctx).Info("using Postgres store engine") return newPostgresStore(ctx, metrics) + case MysqlStoreEngine: + log.WithContext(ctx).Info("using MySQL store engine") + return newMysqlStore(ctx, metrics) default: return nil, fmt.Errorf("unsupported kind of store: %s", kind) } @@ -317,12 +322,13 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) ( if err != nil { return nil, nil, fmt.Errorf("failed to create test store: %v", err) } - cleanUp := func() { - store.Close(ctx) - } + return getSqlStoreEngine(ctx, store, kind) +} + +func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) { if kind == PostgresStoreEngine { - cleanUp, err = testutil.CreatePGDB() + cleanUp, err := testutil.CreatePostgresTestContainer() if err != nil { return nil, nil, err } @@ -336,9 +342,34 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) ( if err != nil { return nil, nil, err } + + return store, cleanUp, nil } - return store, cleanUp, nil + if kind == MysqlStoreEngine { + cleanUp, err := testutil.CreateMysqlTestContainer() + if err != nil { + return nil, nil, err + } + + dsn, ok := os.LookupEnv(mysqlDsnEnv) + if !ok { + return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv) + } + + store, err = NewMysqlStoreFromSqlStore(ctx, store, dsn, nil) + if err != nil { + return nil, nil, err + } + + return store, cleanUp, nil + } + + closeConnection := func() { + store.Close(ctx) + } + + return store, closeConnection, nil } func loadSQL(db *gorm.DB, filepath string) error { diff --git a/management/server/testdata/extended-store.sql b/management/server/testdata/extended-store.sql index 455111439..2859e82c8 100644 --- a/management/server/testdata/extended-store.sql +++ b/management/server/testdata/extended-store.sql @@ -1,7 +1,7 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); -CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -26,10 +26,10 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:01:38.210014+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBB','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cfefqs706sqkneg59g2g"]',0,0); -INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBC','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBC','Faulty key with non existing group','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["abcd"]',0,0); -INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','["cfefqs706sqkneg59g3g"]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,''); -INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:01:38.210678+02:00','api',0,''); +INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBB','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cfefqs706sqkneg59g2g"]',0,0); +INSERT INTO setup_keys VALUES('A2C8E62B-38F5-4553-B31E-DD66C696CEBC','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBC','Faulty key with non existing group','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["abcd"]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','["cfefqs706sqkneg59g3g"]',0,NULL,'2024-10-02 16:01:38.210678+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 16:01:38.210678+02:00','api',0,''); INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00'); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','AwesomeGroup1','api','[]',0,''); diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 7f0c7b5a4..17f029713 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -1,8 +1,8 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); -CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); +CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime DEFAULT NULL,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `policy_rules` (`id` text,`policy_id` text,`name` text,`description` text,`enabled` numeric,`action` text,`destinations` text,`sources` text,`bidirectional` numeric,`protocol` text,`ports` text,`port_ranges` text,PRIMARY KEY (`id`),CONSTRAINT `fk_policies_rules` FOREIGN KEY (`policy_id`) REFERENCES `policies`(`id`) ON DELETE CASCADE); @@ -38,9 +38,9 @@ CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); -INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','["cs1tnh0hhcjnqoiuebeg"]',0,0); -INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); -INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:03:06.779156+02:00','api',0,''); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO personal_access_tokens VALUES('9dj38s35-63fb-11ec-90d6-0242ac120003','f4f6d672-63fb-11ec-90d6-0242ac120003','','SoMeHaShEdToKeN','2023-02-27 00:00:00+00:00','user','2023-01-01 00:00:00+00:00','2023-02-01 00:00:00+00:00'); INSERT INTO installations VALUES(1,''); INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]'); diff --git a/management/server/testdata/store_policy_migrate.sql b/management/server/testdata/store_policy_migrate.sql index a9360e9d6..9c961e389 100644 --- a/management/server/testdata/store_policy_migrate.sql +++ b/management/server/testdata/store_policy_migrate.sql @@ -1,7 +1,7 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); -CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -26,10 +26,10 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 16:04:23.538411+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'[]',0,0); INSERT INTO peers VALUES('cfefqs706sqkneg59g4g','bf1c8084-ba50-4ce7-9439-34653001fc3b','MI5mHfJhbggPfD3FqEIsXm8X5bSWeUI2LhO9MpEEtWA=','','"100.103.179.238"','Ubuntu-2204-jammy-amd64-base','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'crocodile','crocodile','2023-02-13 12:37:12.635454796+00:00',1,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAIJN1NM4bpB9K',0,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0); INSERT INTO peers VALUES('cfeg6sf06sqkneg59g50','bf1c8084-ba50-4ce7-9439-34653001fc3b','zMAOKUeIYIuun4n0xPR1b3IdYZPmsyjYmB2jWCuloC4=','','"100.103.26.180"','borg','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'dingo','dingo','2023-02-21 09:37:42.565899199+00:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','AAAAC3NzaC1lZDI1NTE5AAAAILHW',1,0,'2024-10-02 14:04:23.523293+00:00','2024-10-02 16:04:23.538926+02:00',0,'""','','',0); -INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); -INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 16:04:23.539152+02:00','api',0,''); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:04:23.539152+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 16:04:23.539152+02:00','api',0,''); INSERT INTO "groups" VALUES('cfefqs706sqkneg59g3g','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfefqs706sqkneg59g4g","cfeg6sf06sqkneg59g50"]',0,''); INSERT INTO installations VALUES(1,''); diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index 100a6470f..518c484d7 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -1,7 +1,7 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); -CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `policies` (`id` text,`account_id` text,`name` text,`description` text,`enabled` numeric,`source_posture_checks` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_policies` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -26,10 +26,10 @@ CREATE INDEX `idx_name_server_groups_account_id` ON `name_server_groups`(`accoun CREATE INDEX `idx_posture_checks_account_id` ON `posture_checks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-02 17:00:32.527528+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',1,3600000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); +INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'[]',0,0); INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-34653001fc3b','5rvhvriKJZ3S9oxYToVj5TzDM9u9y8cxg7htIMWlYAg=','72546A29-6BC8-4311-BCFC-9CDBF33F1A48','"100.64.114.31"','f2a34f6a4731','linux','Linux','11','unknown','Debian GNU/Linux','','0.12.0','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'f2a34f6a4731','f2a34f6a4731','2023-03-02 09:21:02.189035775+01:00',0,0,0,'','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAILzUUSYG/LGnV8zarb2SGN+tib/PZ+M7cL4WtTzUrTpk',0,1,'2023-03-01 19:48:19.817799698+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); -INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); -INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO installations VALUES(1,''); diff --git a/management/server/testutil/store.go b/management/server/testutil/store.go index 156a762fb..16438cab8 100644 --- a/management/server/testutil/store.go +++ b/management/server/testutil/store.go @@ -10,36 +10,75 @@ import ( log "github.com/sirupsen/logrus" "github.com/testcontainers/testcontainers-go" + "github.com/testcontainers/testcontainers-go/modules/mysql" "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" ) -func CreatePGDB() (func(), error) { +// CreateMysqlTestContainer creates a new MySQL container for testing. +func CreateMysqlTestContainer() (func(), error) { ctx := context.Background() - c, err := postgres.RunContainer(ctx, - testcontainers.WithImage("postgres:alpine"), - postgres.WithDatabase("test"), - postgres.WithUsername("postgres"), - postgres.WithPassword("postgres"), + + myContainer, err := mysql.RunContainer(ctx, + testcontainers.WithImage("mlsmaycon/warmed-mysql:8"), + mysql.WithDatabase("testing"), + mysql.WithUsername("testing"), + mysql.WithPassword("testing"), testcontainers.WithWaitStrategy( - wait.ForLog("database system is ready to accept connections"). - WithOccurrence(2).WithStartupTimeout(15*time.Second)), + wait.ForLog("/usr/sbin/mysqld: ready for connections"). + WithOccurrence(1).WithStartupTimeout(15*time.Second).WithPollInterval(100*time.Millisecond), + ), ) if err != nil { return nil, err } cleanup := func() { - timeout := 10 * time.Second - err = c.Stop(ctx, &timeout) - if err != nil { - log.WithContext(ctx).Warnf("failed to stop container: %s", err) + timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second) + defer cancelFunc() + if err = myContainer.Terminate(timeoutCtx); err != nil { + log.WithContext(ctx).Warnf("failed to stop mysql container %s: %s", myContainer.GetContainerID(), err) } } - talksConn, err := c.ConnectionString(ctx) + talksConn, err := myContainer.ConnectionString(ctx) if err != nil { - return cleanup, err + return nil, err } + + return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_MYSQL_DSN", talksConn) +} + +// CreatePostgresTestContainer creates a new PostgreSQL container for testing. +func CreatePostgresTestContainer() (func(), error) { + ctx := context.Background() + + pgContainer, err := postgres.RunContainer(ctx, + testcontainers.WithImage("postgres:16-alpine"), + postgres.WithDatabase("netbird"), + postgres.WithUsername("root"), + postgres.WithPassword("netbird"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(15*time.Second), + ), + ) + if err != nil { + return nil, err + } + + cleanup := func() { + timeoutCtx, cancelFunc := context.WithTimeout(ctx, 1*time.Second) + defer cancelFunc() + if err = pgContainer.Terminate(timeoutCtx); err != nil { + log.WithContext(ctx).Warnf("failed to stop postgres container %s: %s", pgContainer.GetContainerID(), err) + } + } + + talksConn, err := pgContainer.ConnectionString(ctx) + if err != nil { + return nil, err + } + return cleanup, os.Setenv("NETBIRD_STORE_ENGINE_POSTGRES_DSN", talksConn) } diff --git a/management/server/testutil/store_ios.go b/management/server/testutil/store_ios.go index af2cf7a3f..edde62f1e 100644 --- a/management/server/testutil/store_ios.go +++ b/management/server/testutil/store_ios.go @@ -3,4 +3,14 @@ package testutil -func CreatePGDB() (func(), error) { return func() {}, nil } +func CreatePostgresTestContainer() (func(), error) { + return func() { + // Empty function for Postgres + }, nil +} + +func CreateMysqlTestContainer() (func(), error) { + return func() { + // Empty function for MySQL + }, nil +} diff --git a/management/server/types/personal_access_token.go b/management/server/types/personal_access_token.go index 1bf225856..ff157fcc6 100644 --- a/management/server/types/personal_access_token.go +++ b/management/server/types/personal_access_token.go @@ -8,6 +8,7 @@ import ( "time" b "github.com/hashicorp/go-secure-stdlib/base62" + "github.com/netbirdio/netbird/management/server/util" "github.com/rs/xid" "github.com/netbirdio/netbird/base62" @@ -31,11 +32,11 @@ type PersonalAccessToken struct { UserID string `gorm:"index"` Name string HashedToken string - ExpirationDate time.Time + ExpirationDate *time.Time // scope could be added in future CreatedBy string CreatedAt time.Time - LastUsed time.Time + LastUsed *time.Time } func (t *PersonalAccessToken) Copy() *PersonalAccessToken { @@ -50,6 +51,22 @@ func (t *PersonalAccessToken) Copy() *PersonalAccessToken { } } +// GetExpirationDate returns the expiration time of the token. +func (t *PersonalAccessToken) GetExpirationDate() time.Time { + if t.ExpirationDate != nil { + return *t.ExpirationDate + } + return time.Time{} +} + +// GetLastUsed returns the last time the token was used. +func (t *PersonalAccessToken) GetLastUsed() time.Time { + if t.LastUsed != nil { + return *t.LastUsed + } + return time.Time{} +} + // PersonalAccessTokenGenerated holds the new PersonalAccessToken and the plain text version of it type PersonalAccessTokenGenerated struct { PlainToken string @@ -69,10 +86,9 @@ func CreateNewPAT(name string, expirationInDays int, createdBy string) (*Persona ID: xid.New().String(), Name: name, HashedToken: hashedToken, - ExpirationDate: currentTime.AddDate(0, 0, expirationInDays), + ExpirationDate: util.ToPtr(currentTime.AddDate(0, 0, expirationInDays)), CreatedBy: createdBy, CreatedAt: currentTime, - LastUsed: time.Time{}, }, PlainToken: plainToken, }, nil diff --git a/management/server/types/setupkey.go b/management/server/types/setupkey.go index a5cf346a0..2cd835289 100644 --- a/management/server/types/setupkey.go +++ b/management/server/types/setupkey.go @@ -10,6 +10,7 @@ import ( "unicode/utf8" "github.com/google/uuid" + "github.com/netbirdio/netbird/management/server/util" ) const ( @@ -38,14 +39,14 @@ type SetupKey struct { Name string Type SetupKeyType CreatedAt time.Time - ExpiresAt time.Time + ExpiresAt *time.Time UpdatedAt time.Time `gorm:"autoUpdateTime:false"` // Revoked indicates whether the key was revoked or not (we don't remove them for tracking purposes) Revoked bool // UsedTimes indicates how many times the key was used UsedTimes int // LastUsed last time the key was used for peer registration - LastUsed time.Time + LastUsed *time.Time // AutoGroups is a list of Group IDs that are auto assigned to a Peer when it uses this key to register AutoGroups []string `gorm:"serializer:json"` // UsageLimit indicates the number of times this key can be used to enroll a machine. @@ -86,6 +87,22 @@ func (key *SetupKey) EventMeta() map[string]any { return map[string]any{"name": key.Name, "type": key.Type, "key": key.KeySecret} } +// GetLastUsed returns the last used time of the setup key. +func (key *SetupKey) GetLastUsed() time.Time { + if key.LastUsed != nil { + return *key.LastUsed + } + return time.Time{} +} + +// GetExpiresAt returns the expiration time of the setup key. +func (key *SetupKey) GetExpiresAt() time.Time { + if key.ExpiresAt != nil { + return *key.ExpiresAt + } + return time.Time{} +} + // HiddenKey returns the Key value hidden with "*" and a 5 character prefix. // E.g., "831F6*******************************" func HiddenKey(key string, length int) string { @@ -100,7 +117,7 @@ func HiddenKey(key string, length int) string { func (key *SetupKey) IncrementUsage() *SetupKey { c := key.Copy() c.UsedTimes++ - c.LastUsed = time.Now().UTC() + c.LastUsed = util.ToPtr(time.Now().UTC()) return c } @@ -116,10 +133,10 @@ func (key *SetupKey) IsRevoked() bool { // IsExpired if key was expired func (key *SetupKey) IsExpired() bool { - if key.ExpiresAt.IsZero() { + if key.GetExpiresAt().IsZero() { return false } - return time.Now().After(key.ExpiresAt) + return time.Now().After(key.GetExpiresAt()) } // IsOverUsed if the key was used too many times. SetupKey.UsageLimit == 0 indicates the unlimited usage. @@ -140,9 +157,9 @@ func GenerateSetupKey(name string, t SetupKeyType, validFor time.Duration, autoG limit = 1 } - expiresAt := time.Time{} + var expiresAt *time.Time if validFor != 0 { - expiresAt = time.Now().UTC().Add(validFor) + expiresAt = util.ToPtr(time.Now().UTC().Add(validFor)) } hashedKey := sha256.Sum256([]byte(key)) diff --git a/management/server/types/user.go b/management/server/types/user.go index 5f1b71792..348fbfb22 100644 --- a/management/server/types/user.go +++ b/management/server/types/user.go @@ -84,7 +84,7 @@ type User struct { // Blocked indicates whether the user is blocked. Blocked users can't use the system. Blocked bool // LastLogin is the last time the user logged in to IdP - LastLogin time.Time + LastLogin *time.Time // CreatedAt records the time the user was created CreatedAt time.Time @@ -99,8 +99,16 @@ func (u *User) IsBlocked() bool { return u.Blocked } -func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool { - return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() +func (u *User) LastDashboardLoginChanged(lastLogin time.Time) bool { + return lastLogin.After(u.GetLastLogin()) && !u.GetLastLogin().IsZero() +} + +// GetLastLogin returns the last login time of the user. +func (u *User) GetLastLogin() time.Time { + if u.LastLogin != nil { + return *u.LastLogin + } + return time.Time{} } // HasAdminPower returns true if the user has admin or owner roles, false otherwise @@ -143,7 +151,7 @@ func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo Status: string(UserStatusActive), IsServiceUser: u.IsServiceUser, IsBlocked: u.Blocked, - LastLogin: u.LastLogin, + LastLogin: u.GetLastLogin(), Issued: u.Issued, Permissions: UserPermissions{ DashboardView: dashboardViewPermissions, @@ -168,7 +176,7 @@ func (u *User) ToUserInfo(userData *idp.UserData, settings *Settings) (*UserInfo Status: string(userStatus), IsServiceUser: u.IsServiceUser, IsBlocked: u.Blocked, - LastLogin: u.LastLogin, + LastLogin: u.GetLastLogin(), Issued: u.Issued, Permissions: UserPermissions{ DashboardView: dashboardViewPermissions, diff --git a/management/server/user.go b/management/server/user.go index 457721917..fcf3d34ff 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -880,7 +880,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun continue } if !user.IsServiceUser { - users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero()) + users[user.Id] = userLoggedInOnce(!user.GetLastLogin().IsZero()) } } queriedUsers, err = am.lookupCache(ctx, users, accountID) diff --git a/management/server/user_test.go b/management/server/user_test.go index 75d88f9c8..a028d164b 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -10,6 +10,7 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" + "github.com/netbirdio/netbird/management/server/util" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" @@ -321,14 +322,14 @@ func TestUser_Copy(t *testing.T) { ID: "pat1", Name: "First PAT", HashedToken: "SoMeHaShEdToKeN", - ExpirationDate: time.Now().AddDate(0, 0, 7), + ExpirationDate: util.ToPtr(time.Now().AddDate(0, 0, 7)), CreatedBy: "userId", CreatedAt: time.Now(), - LastUsed: time.Now(), + LastUsed: util.ToPtr(time.Now()), }, }, Blocked: false, - LastLogin: time.Now().UTC(), + LastLogin: util.ToPtr(time.Now().UTC()), CreatedAt: time.Now().UTC(), Issued: "test", IntegrationReference: integration_reference.IntegrationReference{ diff --git a/management/server/util/util.go b/management/server/util/util.go index ff738781f..d85b55f02 100644 --- a/management/server/util/util.go +++ b/management/server/util/util.go @@ -14,3 +14,8 @@ func Difference(a, b []string) []string { } return diff } + +// ToPtr returns a pointer to the given value. +func ToPtr[T any](value T) *T { + return &value +} From f08605a7f16a69d5d26aa7462b63868e40ec1541 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 6 Jan 2025 14:11:43 +0100 Subject: [PATCH 111/159] [client] Enable network map persistence by default (#3152) --- client/internal/connect.go | 3 +-- client/internal/engine.go | 2 +- client/server/server.go | 2 ++ 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/client/internal/connect.go b/client/internal/connect.go index 782984e27..5cbf54f75 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -382,8 +382,7 @@ func (c *ConnectClient) isContextCancelled() bool { // SetNetworkMapPersistence enables or disables network map persistence. // When enabled, the last received network map will be stored and can be retrieved // through the Engine's getLatestNetworkMap method. When disabled, any stored -// network map will be cleared. This functionality is primarily used for debugging -// and should not be enabled during normal operation. +// network map will be cleared. func (c *ConnectClient) SetNetworkMapPersistence(enabled bool) { c.engineMutex.Lock() c.persistNetworkMap = enabled diff --git a/client/internal/engine.go b/client/internal/engine.go index 896104df8..86f3332a8 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1597,7 +1597,7 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) { return nil, nil } - // Create a deep copy to avoid external modifications + log.Debugf("Retrieving latest network map with size %d bytes", proto.Size(e.latestNetworkMap)) nm, ok := proto.Clone(e.latestNetworkMap).(*mgmProto.NetworkMap) if !ok { diff --git a/client/server/server.go b/client/server/server.go index 5640ffa39..cba5a1148 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -91,6 +91,8 @@ func New(ctx context.Context, configPath, logFile string) *Server { signalProbe: internal.NewProbe(), relayProbe: internal.NewProbe(), wgProbe: internal.NewProbe(), + + persistNetworkMap: true, } } From 668aead4c87c460ec8b4cca9cd37d46118d04537 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 6 Jan 2025 14:12:28 +0100 Subject: [PATCH 112/159] [misc] remove outdated readme header (#3151) --- README.md | 7 ------- 1 file changed, 7 deletions(-) diff --git a/README.md b/README.md index e7925ae09..0537710e9 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,3 @@ -

- :hatching_chick: New Release! Device Posture Checks. - - Learn more - -

-

From 6848e1e128882b5c36f187c0ef83b1bed1d465e3 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Mon, 6 Jan 2025 14:16:31 +0100 Subject: [PATCH 113/159] [client] Add rootless container and fix client routes in netstack mode (#3150) --- .goreleaser.yaml | 57 +++++++++++++++++++ client/Dockerfile-rootless | 15 +++++ client/iface/netstack/env.go | 4 ++ client/internal/routemanager/manager.go | 40 +++++++++---- .../systemops/systemops_generic.go | 12 ++++ 5 files changed, 116 insertions(+), 12 deletions(-) create mode 100644 client/Dockerfile-rootless diff --git a/.goreleaser.yaml b/.goreleaser.yaml index e718b3fcd..d6479763e 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -179,6 +179,51 @@ dockers: - "--label=org.opencontainers.image.revision={{.FullCommit}}" - "--label=org.opencontainers.image.version={{.Version}}" - "--label=maintainer=dev@netbird.io" + + - image_templates: + - netbirdio/netbird:{{ .Version }}-rootless-amd64 + ids: + - netbird + goarch: amd64 + use: buildx + dockerfile: client/Dockerfile-rootless + build_flag_templates: + - "--platform=linux/amd64" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=maintainer=dev@netbird.io" + - image_templates: + - netbirdio/netbird:{{ .Version }}-rootless-arm64v8 + ids: + - netbird + goarch: arm64 + use: buildx + dockerfile: client/Dockerfile-rootless + build_flag_templates: + - "--platform=linux/arm64" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=maintainer=dev@netbird.io" + - image_templates: + - netbirdio/netbird:{{ .Version }}-rootless-arm + ids: + - netbird + goarch: arm + goarm: 6 + use: buildx + dockerfile: client/Dockerfile-rootless + build_flag_templates: + - "--platform=linux/arm" + - "--label=org.opencontainers.image.created={{.Date}}" + - "--label=org.opencontainers.image.title={{.ProjectName}}" + - "--label=org.opencontainers.image.version={{.Version}}" + - "--label=org.opencontainers.image.revision={{.FullCommit}}" + - "--label=maintainer=dev@netbird.io" + - image_templates: - netbirdio/relay:{{ .Version }}-amd64 ids: @@ -377,6 +422,18 @@ docker_manifests: - netbirdio/netbird:{{ .Version }}-arm - netbirdio/netbird:{{ .Version }}-amd64 + - name_template: netbirdio/netbird:{{ .Version }}-rootless + image_templates: + - netbirdio/netbird:{{ .Version }}-rootless-arm64v8 + - netbirdio/netbird:{{ .Version }}-rootless-arm + - netbirdio/netbird:{{ .Version }}-rootless-amd64 + + - name_template: netbirdio/netbird:rootless-latest + image_templates: + - netbirdio/netbird:{{ .Version }}-rootless-arm64v8 + - netbirdio/netbird:{{ .Version }}-rootless-arm + - netbirdio/netbird:{{ .Version }}-rootless-amd64 + - name_template: netbirdio/relay:{{ .Version }} image_templates: - netbirdio/relay:{{ .Version }}-arm64v8 diff --git a/client/Dockerfile-rootless b/client/Dockerfile-rootless new file mode 100644 index 000000000..f206debb5 --- /dev/null +++ b/client/Dockerfile-rootless @@ -0,0 +1,15 @@ +FROM alpine:3.21.0 + +COPY netbird /usr/local/bin/netbird + +RUN apk add --no-cache ca-certificates \ + && adduser -D -h /var/lib/netbird netbird +WORKDIR /var/lib/netbird +USER netbird:netbird + +ENV NB_FOREGROUND_MODE=true +ENV NB_USE_NETSTACK_MODE=true +ENV NB_CONFIG=config.json +ENV NB_DAEMON_ADDR=unix://netbird.sock + +ENTRYPOINT [ "/usr/local/bin/netbird", "up" ] diff --git a/client/iface/netstack/env.go b/client/iface/netstack/env.go index c77e39fe0..09889a57e 100644 --- a/client/iface/netstack/env.go +++ b/client/iface/netstack/env.go @@ -15,6 +15,10 @@ func IsEnabled() bool { func ListenAddr() string { sPort := os.Getenv("NB_SOCKS5_LISTENER_PORT") + if sPort == "" { + return listenAddr(DefaultSocks5Port) + } + port, err := strconv.Atoi(sPort) if err != nil { log.Warnf("invalid socks5 listener port, unable to convert it to int, falling back to default: %d", DefaultSocks5Port) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 389e97e2d..e25192398 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -17,6 +17,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" @@ -104,22 +105,43 @@ func NewManager( peerStore: peerStore, } - dm.routeRefCounter = refcounter.New( + dm.setupRefCounters() + + if runtime.GOOS == "android" { + cr := dm.initialClientRoutes(initialRoutes) + dm.notifier.SetInitialClientRoutes(cr) + } + return dm +} + +func (m *DefaultManager) setupRefCounters() { + m.routeRefCounter = refcounter.New( func(prefix netip.Prefix, _ struct{}) (struct{}, error) { - return struct{}{}, sysOps.AddVPNRoute(prefix, wgInterface.ToInterface()) + return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface()) }, func(prefix netip.Prefix, _ struct{}) error { - return sysOps.RemoveVPNRoute(prefix, wgInterface.ToInterface()) + return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface()) }, ) - dm.allowedIPsRefCounter = refcounter.New( + if netstack.IsEnabled() { + m.routeRefCounter = refcounter.New( + func(netip.Prefix, struct{}) (struct{}, error) { + return struct{}{}, refcounter.ErrIgnore + }, + func(netip.Prefix, struct{}) error { + return nil + }, + ) + } + + m.allowedIPsRefCounter = refcounter.New( func(prefix netip.Prefix, peerKey string) (string, error) { // save peerKey to use it in the remove function - return peerKey, wgInterface.AddAllowedIP(peerKey, prefix.String()) + return peerKey, m.wgInterface.AddAllowedIP(peerKey, prefix.String()) }, func(prefix netip.Prefix, peerKey string) error { - if err := wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil { + if err := m.wgInterface.RemoveAllowedIP(peerKey, prefix.String()); err != nil { if !errors.Is(err, configurer.ErrPeerNotFound) && !errors.Is(err, configurer.ErrAllowedIPNotFound) { return err } @@ -128,12 +150,6 @@ func NewManager( return nil }, ) - - if runtime.GOOS == "android" { - cr := dm.initialClientRoutes(initialRoutes) - dm.notifier.SetInitialClientRoutes(cr) - } - return dm } // Init sets up the routing diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index 3038c3ec5..31b7f3ac2 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -17,6 +17,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" @@ -62,6 +63,17 @@ func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemana r.removeFromRouteTable, ) + if netstack.IsEnabled() { + refCounter = refcounter.New( + func(netip.Prefix, struct{}) (Nexthop, error) { + return Nexthop{}, refcounter.ErrIgnore + }, + func(netip.Prefix, Nexthop) error { + return nil + }, + ) + } + r.refCounter = refCounter return r.setupHooks(initAddresses, stateManager) From 2bd68efc0815e2cc70e43cc2fabd8f1540803181 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20Nohlg=C3=A5rd?= Date: Mon, 6 Jan 2025 17:31:35 +0100 Subject: [PATCH 114/159] [relay] Handle IPv6 addresses in X-Real-IP header on relay service (#3085) --- relay/server/listener/ws/listener.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/server/listener/ws/listener.go b/relay/server/listener/ws/listener.go index 1ad57d27a..5c62c0826 100644 --- a/relay/server/listener/ws/listener.go +++ b/relay/server/listener/ws/listener.go @@ -96,5 +96,5 @@ func remoteAddr(r *http.Request) string { if r.Header.Get("X-Real-Ip") == "" || r.Header.Get("X-Real-Port") == "" { return r.RemoteAddr } - return fmt.Sprintf("%s:%s", r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port")) + return net.JoinHostPort(r.Header.Get("X-Real-Ip"), r.Header.Get("X-Real-Port")) } From d9905d1a574ad1111b4c7107aa37e02ce0e7fa1e Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Tue, 7 Jan 2025 20:38:18 +0100 Subject: [PATCH 115/159] [client] Add disable system flags (#3153) --- client/Dockerfile-rootless | 1 + client/cmd/system.go | 31 + client/cmd/up.go | 26 + client/firewall/uspfilter/uspfilter.go | 2 + client/internal/config.go | 50 ++ client/internal/connect.go | 5 + client/internal/dns/host.go | 14 + client/internal/dns/server.go | 104 ++- client/internal/dns/server_test.go | 12 +- client/internal/dnsfwd/manager.go | 5 + client/internal/engine.go | 70 +- client/internal/engine_test.go | 9 +- client/internal/routemanager/manager.go | 100 ++- client/internal/routemanager/manager_test.go | 10 +- client/internal/routemanager/server.go | 9 - .../internal/routemanager/server_android.go | 13 +- .../routemanager/server_nonandroid.go | 122 ++- client/proto/daemon.pb.go | 734 ++++++++++-------- client/proto/daemon.proto | 5 + client/server/server.go | 17 + 20 files changed, 817 insertions(+), 522 deletions(-) create mode 100644 client/cmd/system.go delete mode 100644 client/internal/routemanager/server.go diff --git a/client/Dockerfile-rootless b/client/Dockerfile-rootless index f206debb5..62bcaf964 100644 --- a/client/Dockerfile-rootless +++ b/client/Dockerfile-rootless @@ -11,5 +11,6 @@ ENV NB_FOREGROUND_MODE=true ENV NB_USE_NETSTACK_MODE=true ENV NB_CONFIG=config.json ENV NB_DAEMON_ADDR=unix://netbird.sock +ENV NB_DISABLE_DNS=true ENTRYPOINT [ "/usr/local/bin/netbird", "up" ] diff --git a/client/cmd/system.go b/client/cmd/system.go new file mode 100644 index 000000000..f628867a7 --- /dev/null +++ b/client/cmd/system.go @@ -0,0 +1,31 @@ +package cmd + +// Flag constants for system configuration +const ( + disableClientRoutesFlag = "disable-client-routes" + disableServerRoutesFlag = "disable-server-routes" + disableDNSFlag = "disable-dns" + disableFirewallFlag = "disable-firewall" +) + +var ( + disableClientRoutes bool + disableServerRoutes bool + disableDNS bool + disableFirewall bool +) + +func init() { + // Add system flags to upCmd + upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false, + "Disable client routes. If enabled, the client won't process client routes received from the management service.") + + upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false, + "Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.") + + upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false, + "Disable DNS. If enabled, the client won't configure DNS settings.") + + upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false, + "Disable firewall configuration. If enabled, the client won't modify firewall rules.") +} diff --git a/client/cmd/up.go b/client/cmd/up.go index 05ecce9e0..cd5521371 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -147,6 +147,19 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error { ic.DNSRouteInterval = &dnsRouteInterval } + if cmd.Flag(disableClientRoutesFlag).Changed { + ic.DisableClientRoutes = &disableClientRoutes + } + if cmd.Flag(disableServerRoutesFlag).Changed { + ic.DisableServerRoutes = &disableServerRoutes + } + if cmd.Flag(disableDNSFlag).Changed { + ic.DisableDNS = &disableDNS + } + if cmd.Flag(disableFirewallFlag).Changed { + ic.DisableFirewall = &disableFirewall + } + providedSetupKey, err := getSetupKey() if err != nil { return err @@ -264,6 +277,19 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error { loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval) } + if cmd.Flag(disableClientRoutesFlag).Changed { + loginRequest.DisableClientRoutes = &disableClientRoutes + } + if cmd.Flag(disableServerRoutesFlag).Changed { + loginRequest.DisableServerRoutes = &disableServerRoutes + } + if cmd.Flag(disableDNSFlag).Changed { + loginRequest.DisableDns = &disableDNS + } + if cmd.Flag(disableFirewallFlag).Changed { + loginRequest.DisableFirewall = &disableFirewall + } + var loginErr error var loginResp *proto.LoginResponse diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 24cfd6e96..ebe04caee 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -402,6 +402,8 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) { // dropFilter implements filtering logic for incoming packets func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { + // TODO: Disable router if --disable-server-router is set + m.mutex.RLock() defer m.mutex.RUnlock() diff --git a/client/internal/config.go b/client/internal/config.go index 998690ef1..594bdc570 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -61,6 +61,11 @@ type ConfigInput struct { DNSRouteInterval *time.Duration ClientCertPath string ClientCertKeyPath string + + DisableClientRoutes *bool + DisableServerRoutes *bool + DisableDNS *bool + DisableFirewall *bool } // Config Configuration type @@ -78,6 +83,12 @@ type Config struct { RosenpassEnabled bool RosenpassPermissive bool ServerSSHAllowed *bool + + DisableClientRoutes bool + DisableServerRoutes bool + DisableDNS bool + DisableFirewall bool + // SSHKey is a private SSH key in a PEM format SSHKey string @@ -402,7 +413,46 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { config.DNSRouteInterval = dynamic.DefaultInterval log.Infof("using default DNS route interval %s", config.DNSRouteInterval) updated = true + } + if input.DisableClientRoutes != nil && *input.DisableClientRoutes != config.DisableClientRoutes { + if *input.DisableClientRoutes { + log.Infof("disabling client routes") + } else { + log.Infof("enabling client routes") + } + config.DisableClientRoutes = *input.DisableClientRoutes + updated = true + } + + if input.DisableServerRoutes != nil && *input.DisableServerRoutes != config.DisableServerRoutes { + if *input.DisableServerRoutes { + log.Infof("disabling server routes") + } else { + log.Infof("enabling server routes") + } + config.DisableServerRoutes = *input.DisableServerRoutes + updated = true + } + + if input.DisableDNS != nil && *input.DisableDNS != config.DisableDNS { + if *input.DisableDNS { + log.Infof("disabling DNS configuration") + } else { + log.Infof("enabling DNS configuration") + } + config.DisableDNS = *input.DisableDNS + updated = true + } + + if input.DisableFirewall != nil && *input.DisableFirewall != config.DisableFirewall { + if *input.DisableFirewall { + log.Infof("disabling firewall configuration") + } else { + log.Infof("enabling firewall configuration") + } + config.DisableFirewall = *input.DisableFirewall + updated = true } if input.ClientCertKeyPath != "" { diff --git a/client/internal/connect.go b/client/internal/connect.go index 5cbf54f75..afd1f4454 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -415,6 +415,11 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe RosenpassPermissive: config.RosenpassPermissive, ServerSSHAllowed: util.ReturnBoolWithDefaultTrue(config.ServerSSHAllowed), DNSRouteInterval: config.DNSRouteInterval, + + DisableClientRoutes: config.DisableClientRoutes, + DisableServerRoutes: config.DisableServerRoutes, + DisableDNS: config.DisableDNS, + DisableFirewall: config.DisableFirewall, } if config.PreSharedKey != "" { diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index e2b5f699a..fbe8c4dbb 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -102,3 +102,17 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD return config } + +type noopHostConfigurator struct{} + +func (n noopHostConfigurator) applyDNSConfig(HostDNSConfig, *statemanager.Manager) error { + return nil +} + +func (n noopHostConfigurator) restoreHostDNS() error { + return nil +} + +func (n noopHostConfigurator) supportCustomPort() bool { + return true +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 5a9cb50d0..bb097c4cb 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -47,6 +47,7 @@ type registeredHandlerMap map[string]handlerWithStop type DefaultServer struct { ctx context.Context ctxCancel context.CancelFunc + disableSys bool mux sync.Mutex service service dnsMuxMap registeredHandlerMap @@ -84,7 +85,14 @@ type muxUpdate struct { } // NewDefaultServer returns a new dns server -func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress string, statusRecorder *peer.Status, stateManager *statemanager.Manager) (*DefaultServer, error) { +func NewDefaultServer( + ctx context.Context, + wgInterface WGIface, + customAddress string, + statusRecorder *peer.Status, + stateManager *statemanager.Manager, + disableSys bool, +) (*DefaultServer, error) { var addrPort *netip.AddrPort if customAddress != "" { parsedAddrPort, err := netip.ParseAddrPort(customAddress) @@ -101,7 +109,7 @@ func NewDefaultServer(ctx context.Context, wgInterface WGIface, customAddress st dnsService = newServiceViaListener(wgInterface, addrPort) } - return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager), nil + return newDefaultServer(ctx, wgInterface, dnsService, statusRecorder, stateManager, disableSys), nil } // NewDefaultServerPermanentUpstream returns a new dns server. It optimized for mobile systems @@ -112,9 +120,10 @@ func NewDefaultServerPermanentUpstream( config nbdns.Config, listener listener.NetworkChangeListener, statusRecorder *peer.Status, + disableSys bool, ) *DefaultServer { log.Debugf("host dns address list is: %v", hostsDnsList) - ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds.hostsDNSHolder.set(hostsDnsList) ds.permanent = true ds.addHostRootZone() @@ -131,17 +140,26 @@ func NewDefaultServerIos( wgInterface WGIface, iosDnsManager IosDnsManager, statusRecorder *peer.Status, + disableSys bool, ) *DefaultServer { - ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil) + ds := newDefaultServer(ctx, wgInterface, NewServiceViaMemory(wgInterface), statusRecorder, nil, disableSys) ds.iosDnsManager = iosDnsManager return ds } -func newDefaultServer(ctx context.Context, wgInterface WGIface, dnsService service, statusRecorder *peer.Status, stateManager *statemanager.Manager) *DefaultServer { +func newDefaultServer( + ctx context.Context, + wgInterface WGIface, + dnsService service, + statusRecorder *peer.Status, + stateManager *statemanager.Manager, + disableSys bool, +) *DefaultServer { ctx, stop := context.WithCancel(ctx) defaultServer := &DefaultServer{ ctx: ctx, ctxCancel: stop, + disableSys: disableSys, service: dnsService, handlerChain: NewHandlerChain(), dnsMuxMap: make(registeredHandlerMap), @@ -220,6 +238,13 @@ func (s *DefaultServer) Initialize() (err error) { } s.stateManager.RegisterState(&ShutdownState{}) + + if s.disableSys { + log.Info("system DNS is disabled, not setting up host manager") + s.hostManager = &noopHostConfigurator{} + return nil + } + s.hostManager, err = s.initialize() if err != nil { return fmt.Errorf("initialize: %w", err) @@ -268,47 +293,47 @@ func (s *DefaultServer) OnUpdatedHostDNSServer(hostsDnsList []string) { // UpdateDNSServer processes an update received from the management service func (s *DefaultServer) UpdateDNSServer(serial uint64, update nbdns.Config) error { - select { - case <-s.ctx.Done(): + if s.ctx.Err() != nil { log.Infof("not updating DNS server as context is closed") return s.ctx.Err() - default: - if serial < s.updateSerial { - return fmt.Errorf("not applying dns update, error: "+ - "network update is %d behind the last applied update", s.updateSerial-serial) - } - s.mux.Lock() - defer s.mux.Unlock() + } - if s.hostManager == nil { - return fmt.Errorf("dns service is not initialized yet") - } + if serial < s.updateSerial { + return fmt.Errorf("not applying dns update, error: "+ + "network update is %d behind the last applied update", s.updateSerial-serial) + } - hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{ - ZeroNil: true, - IgnoreZeroValue: true, - SlicesAsSets: true, - UseStringer: true, - }) - if err != nil { - log.Errorf("unable to hash the dns configuration update, got error: %s", err) - } + s.mux.Lock() + defer s.mux.Unlock() - if s.previousConfigHash == hash { - log.Debugf("not applying the dns configuration update as there is nothing new") - s.updateSerial = serial - return nil - } + if s.hostManager == nil { + return fmt.Errorf("dns service is not initialized yet") + } - if err := s.applyConfiguration(update); err != nil { - return fmt.Errorf("apply configuration: %w", err) - } + hash, err := hashstructure.Hash(update, hashstructure.FormatV2, &hashstructure.HashOptions{ + ZeroNil: true, + IgnoreZeroValue: true, + SlicesAsSets: true, + UseStringer: true, + }) + if err != nil { + log.Errorf("unable to hash the dns configuration update, got error: %s", err) + } + if s.previousConfigHash == hash { + log.Debugf("not applying the dns configuration update as there is nothing new") s.updateSerial = serial - s.previousConfigHash = hash - return nil } + + if err := s.applyConfiguration(update); err != nil { + return fmt.Errorf("apply configuration: %w", err) + } + + s.updateSerial = serial + s.previousConfigHash = hash + + return nil } func (s *DefaultServer) SearchDomains() []string { @@ -627,8 +652,11 @@ func (s *DefaultServer) upstreamCallbacks( s.currentConfig.RouteAll = true s.registerHandler([]string{nbdns.RootZone}, handler, PriorityDefault) } - if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { - l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") + + if s.hostManager != nil { + if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil { + l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply") + } } s.updateNSState(nsGroup, nil, true) diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 44d20c6f3..c166820c4 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -294,7 +294,7 @@ func TestUpdateDNSServer(t *testing.T) { t.Log(err) } }() - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false) if err != nil { t.Fatal(err) } @@ -403,7 +403,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) { return } - dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil) + dnsServer, err := NewDefaultServer(context.Background(), wgIface, "", &peer.Status{}, nil, false) if err != nil { t.Errorf("create DNS server: %v", err) return @@ -498,7 +498,7 @@ func TestDNSServerStartStop(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { - dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil) + dnsServer, err := NewDefaultServer(context.Background(), &mocWGIface{}, testCase.addrPort, &peer.Status{}, nil, false) if err != nil { t.Fatalf("%v", err) } @@ -633,7 +633,7 @@ func TestDNSPermanent_updateHostDNS_emptyUpstream(t *testing.T) { var dnsList []string dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{}) + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, dnsList, dnsConfig, nil, &peer.Status{}, false) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -657,7 +657,7 @@ func TestDNSPermanent_updateUpstream(t *testing.T) { } defer wgIFace.Close() dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}) + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) @@ -749,7 +749,7 @@ func TestDNSPermanent_matchOnly(t *testing.T) { } defer wgIFace.Close() dnsConfig := nbdns.Config{} - dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}) + dnsServer := NewDefaultServerPermanentUpstream(context.Background(), wgIFace, []string{"8.8.8.8"}, dnsConfig, nil, &peer.Status{}, false) err = dnsServer.Initialize() if err != nil { t.Errorf("failed to initialize DNS server: %v", err) diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index f876bda30..e6dfd278e 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -83,6 +83,11 @@ func (h *Manager) allowDNSFirewall() error { IsRange: false, Values: []int{ListenPort}, } + + if h.firewall == nil { + return nil + } + dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.RuleDirectionIN, firewall.ActionAccept, "", "") if err != nil { log.Errorf("failed to add allow DNS router rules, err: %v", err) diff --git a/client/internal/engine.go b/client/internal/engine.go index 86f3332a8..7b6f269df 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -108,6 +108,11 @@ type EngineConfig struct { ServerSSHAllowed bool DNSRouteInterval time.Duration + + DisableClientRoutes bool + DisableServerRoutes bool + DisableDNS bool + DisableFirewall bool } // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. @@ -374,18 +379,20 @@ func (e *Engine) Start() error { } e.dnsServer = dnsServer - e.routeManager = routemanager.NewManager( - e.ctx, - e.config.WgPrivateKey.PublicKey().String(), - e.config.DNSRouteInterval, - e.wgInterface, - e.statusRecorder, - e.relayManager, - initialRoutes, - e.stateManager, - dnsServer, - e.peerStore, - ) + e.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ + Context: e.ctx, + PublicKey: e.config.WgPrivateKey.PublicKey().String(), + DNSRouteInterval: e.config.DNSRouteInterval, + WGInterface: e.wgInterface, + StatusRecorder: e.statusRecorder, + RelayManager: e.relayManager, + InitialRoutes: initialRoutes, + StateManager: e.stateManager, + DNSServer: dnsServer, + PeerStore: e.peerStore, + DisableClientRoutes: e.config.DisableClientRoutes, + DisableServerRoutes: e.config.DisableServerRoutes, + }) beforePeerHook, afterPeerHook, err := e.routeManager.Init() if err != nil { log.Errorf("Failed to initialize route manager: %s", err) @@ -403,13 +410,8 @@ func (e *Engine) Start() error { return fmt.Errorf("create wg interface: %w", err) } - e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) - if err != nil { - log.Errorf("failed creating firewall manager: %s", err) - } else if e.firewall != nil { - if err := e.initFirewall(err); err != nil { - return err - } + if err := e.createFirewall(); err != nil { + return err } e.udpMux, err = e.wgInterface.Up() @@ -451,7 +453,27 @@ func (e *Engine) Start() error { return nil } -func (e *Engine) initFirewall(error) error { +func (e *Engine) createFirewall() error { + if e.config.DisableFirewall { + log.Infof("firewall is disabled") + return nil + } + + var err error + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) + if err != nil || e.firewall == nil { + log.Errorf("failed creating firewall manager: %s", err) + return nil + } + + if err := e.initFirewall(); err != nil { + return err + } + + return nil +} + +func (e *Engine) initFirewall() error { if e.firewall.IsServerRouteSupported() { if err := e.routeManager.EnableServerRouter(e.firewall); err != nil { e.close() @@ -1348,6 +1370,7 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { if e.dnsServer != nil { return nil, e.dnsServer, nil } + switch runtime.GOOS { case "android": routes, dnsConfig, err := e.readInitialSettings() @@ -1361,14 +1384,17 @@ func (e *Engine) newDnsServer() ([]*route.Route, dns.Server, error) { *dnsConfig, e.mobileDep.NetworkChangeListener, e.statusRecorder, + e.config.DisableDNS, ) go e.mobileDep.DnsReadyListener.OnReady() return routes, dnsServer, nil + case "ios": - dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder) + dnsServer := dns.NewDefaultServerIos(e.ctx, e.wgInterface, e.mobileDep.DnsManager, e.statusRecorder, e.config.DisableDNS) return nil, dnsServer, nil + default: - dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager) + dnsServer, err := dns.NewDefaultServer(e.ctx, e.wgInterface, e.config.CustomDNSAddress, e.statusRecorder, e.stateManager, e.config.DisableDNS) if err != nil { return nil, nil, err } diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index b81d8bd3f..1deea1cb8 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -253,7 +253,14 @@ func TestEngine_UpdateNetworkMap(t *testing.T) { }, } engine.wgInterface = wgIface - engine.routeManager = routemanager.NewManager(ctx, key.PublicKey().String(), time.Minute, engine.wgInterface, engine.statusRecorder, relayMgr, nil, nil, nil, nil) + engine.routeManager = routemanager.NewManager(routemanager.ManagerConfig{ + Context: ctx, + PublicKey: key.PublicKey().String(), + DNSRouteInterval: time.Minute, + WGInterface: engine.wgInterface, + StatusRecorder: engine.statusRecorder, + RelayManager: relayMgr, + }) _, _, err = engine.routeManager.Init() require.NoError(t, err) engine.dnsServer = &dns.MockServer{ diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index e25192398..6f73fb166 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -48,6 +48,21 @@ type Manager interface { Stop(stateManager *statemanager.Manager) } +type ManagerConfig struct { + Context context.Context + PublicKey string + DNSRouteInterval time.Duration + WGInterface iface.IWGIface + StatusRecorder *peer.Status + RelayManager *relayClient.Manager + InitialRoutes []*route.Route + StateManager *statemanager.Manager + DNSServer dns.Server + PeerStore *peerstore.Store + DisableClientRoutes bool + DisableServerRoutes bool +} + // DefaultManager is the default instance of a route manager type DefaultManager struct { ctx context.Context @@ -55,7 +70,7 @@ type DefaultManager struct { mux sync.Mutex clientNetworks map[route.HAUniqueID]*clientNetwork routeSelector *routeselector.RouteSelector - serverRouter serverRouter + serverRouter *serverRouter sysOps *systemops.SysOps statusRecorder *peer.Status relayMgr *relayClient.Manager @@ -67,48 +82,46 @@ type DefaultManager struct { dnsRouteInterval time.Duration stateManager *statemanager.Manager // clientRoutes is the most recent list of clientRoutes received from the Management Service - clientRoutes route.HAMap - dnsServer dns.Server - peerStore *peerstore.Store - useNewDNSRoute bool + clientRoutes route.HAMap + dnsServer dns.Server + peerStore *peerstore.Store + useNewDNSRoute bool + disableClientRoutes bool + disableServerRoutes bool } -func NewManager( - ctx context.Context, - pubKey string, - dnsRouteInterval time.Duration, - wgInterface iface.IWGIface, - statusRecorder *peer.Status, - relayMgr *relayClient.Manager, - initialRoutes []*route.Route, - stateManager *statemanager.Manager, - dnsServer dns.Server, - peerStore *peerstore.Store, -) *DefaultManager { - mCTX, cancel := context.WithCancel(ctx) +func NewManager(config ManagerConfig) *DefaultManager { + mCTX, cancel := context.WithCancel(config.Context) notifier := notifier.NewNotifier() - sysOps := systemops.NewSysOps(wgInterface, notifier) + sysOps := systemops.NewSysOps(config.WGInterface, notifier) dm := &DefaultManager{ - ctx: mCTX, - stop: cancel, - dnsRouteInterval: dnsRouteInterval, - clientNetworks: make(map[route.HAUniqueID]*clientNetwork), - relayMgr: relayMgr, - sysOps: sysOps, - statusRecorder: statusRecorder, - wgInterface: wgInterface, - pubKey: pubKey, - notifier: notifier, - stateManager: stateManager, - dnsServer: dnsServer, - peerStore: peerStore, + ctx: mCTX, + stop: cancel, + dnsRouteInterval: config.DNSRouteInterval, + clientNetworks: make(map[route.HAUniqueID]*clientNetwork), + relayMgr: config.RelayManager, + sysOps: sysOps, + statusRecorder: config.StatusRecorder, + wgInterface: config.WGInterface, + pubKey: config.PublicKey, + notifier: notifier, + stateManager: config.StateManager, + dnsServer: config.DNSServer, + peerStore: config.PeerStore, + disableClientRoutes: config.DisableClientRoutes, + disableServerRoutes: config.DisableServerRoutes, + } + + // don't proceed with client routes if it is disabled + if config.DisableClientRoutes { + return dm } dm.setupRefCounters() if runtime.GOOS == "android" { - cr := dm.initialClientRoutes(initialRoutes) + cr := dm.initialClientRoutes(config.InitialRoutes) dm.notifier.SetInitialClientRoutes(cr) } return dm @@ -156,7 +169,7 @@ func (m *DefaultManager) setupRefCounters() { func (m *DefaultManager) Init() (nbnet.AddHookFunc, nbnet.RemoveHookFunc, error) { m.routeSelector = m.initSelector() - if nbnet.CustomRoutingDisabled() { + if nbnet.CustomRoutingDisabled() || m.disableClientRoutes { return nil, nil, nil } @@ -202,6 +215,15 @@ func (m *DefaultManager) initSelector() *routeselector.RouteSelector { } func (m *DefaultManager) EnableServerRouter(firewall firewall.Manager) error { + if m.disableServerRoutes { + log.Info("server routes are disabled") + return nil + } + + if firewall == nil { + return errors.New("firewall manager is not set") + } + var err error m.serverRouter, err = newServerRouter(m.ctx, m.wgInterface, firewall, m.statusRecorder) if err != nil { @@ -228,7 +250,7 @@ func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { } } - if !nbnet.CustomRoutingDisabled() { + if !nbnet.CustomRoutingDisabled() && !m.disableClientRoutes { if err := m.sysOps.CleanupRouting(stateManager); err != nil { log.Errorf("Error cleaning up routing: %v", err) } else { @@ -258,9 +280,11 @@ func (m *DefaultManager) UpdateRoutes(updateSerial uint64, newRoutes []*route.Ro newServerRoutesMap, newClientRoutesIDMap := m.classifyRoutes(newRoutes) - filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) - m.updateClientNetworks(updateSerial, filteredClientRoutes) - m.notifier.OnNewRoutes(filteredClientRoutes) + if !m.disableClientRoutes { + filteredClientRoutes := m.routeSelector.FilterSelected(newClientRoutesIDMap) + m.updateClientNetworks(updateSerial, filteredClientRoutes) + m.notifier.OnNewRoutes(filteredClientRoutes) + } if m.serverRouter != nil { err := m.serverRouter.updateRoutes(newServerRoutesMap) diff --git a/client/internal/routemanager/manager_test.go b/client/internal/routemanager/manager_test.go index 4b7c984e5..318ef5ae5 100644 --- a/client/internal/routemanager/manager_test.go +++ b/client/internal/routemanager/manager_test.go @@ -424,7 +424,12 @@ func TestManagerUpdateRoutes(t *testing.T) { statusRecorder := peer.NewRecorder("https://mgm") ctx := context.TODO() - routeManager := NewManager(ctx, localPeerKey, 0, wgInterface, statusRecorder, nil, nil, nil, nil, nil) + routeManager := NewManager(ManagerConfig{ + Context: ctx, + PublicKey: localPeerKey, + WGInterface: wgInterface, + StatusRecorder: statusRecorder, + }) _, _, err = routeManager.Init() @@ -450,8 +455,7 @@ func TestManagerUpdateRoutes(t *testing.T) { require.Len(t, routeManager.clientNetworks, expectedWatchers, "client networks size should match") if runtime.GOOS == "linux" && routeManager.serverRouter != nil { - sr := routeManager.serverRouter.(*defaultServerRouter) - require.Len(t, sr.routes, testCase.serverRoutesExpected, "server networks size should match") + require.Len(t, routeManager.serverRouter.routes, testCase.serverRoutesExpected, "server networks size should match") } }) } diff --git a/client/internal/routemanager/server.go b/client/internal/routemanager/server.go deleted file mode 100644 index 368421eb7..000000000 --- a/client/internal/routemanager/server.go +++ /dev/null @@ -1,9 +0,0 @@ -package routemanager - -import "github.com/netbirdio/netbird/route" - -type serverRouter interface { - updateRoutes(map[route.ID]*route.Route) error - removeFromServerNetwork(*route.Route) error - cleanUp() -} diff --git a/client/internal/routemanager/server_android.go b/client/internal/routemanager/server_android.go index c75a0a7f2..e9cfa0826 100644 --- a/client/internal/routemanager/server_android.go +++ b/client/internal/routemanager/server_android.go @@ -9,8 +9,19 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/route" ) -func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (serverRouter, error) { +type serverRouter struct { +} + +func (r serverRouter) cleanUp() { +} + +func (r serverRouter) updateRoutes(map[route.ID]*route.Route) error { + return nil +} + +func newServerRouter(context.Context, iface.IWGIface, firewall.Manager, *peer.Status) (*serverRouter, error) { return nil, fmt.Errorf("server route not supported on this os") } diff --git a/client/internal/routemanager/server_nonandroid.go b/client/internal/routemanager/server_nonandroid.go index ef38d5707..b60cb318e 100644 --- a/client/internal/routemanager/server_nonandroid.go +++ b/client/internal/routemanager/server_nonandroid.go @@ -17,7 +17,7 @@ import ( "github.com/netbirdio/netbird/route" ) -type defaultServerRouter struct { +type serverRouter struct { mux sync.Mutex ctx context.Context routes map[route.ID]*route.Route @@ -26,8 +26,8 @@ type defaultServerRouter struct { statusRecorder *peer.Status } -func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (serverRouter, error) { - return &defaultServerRouter{ +func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall firewall.Manager, statusRecorder *peer.Status) (*serverRouter, error) { + return &serverRouter{ ctx: ctx, routes: make(map[route.ID]*route.Route), firewall: firewall, @@ -36,7 +36,7 @@ func newServerRouter(ctx context.Context, wgInterface iface.IWGIface, firewall f }, nil } -func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { +func (m *serverRouter) updateRoutes(routesMap map[route.ID]*route.Route) error { serverRoutesToRemove := make([]route.ID, 0) for routeID := range m.routes { @@ -80,74 +80,72 @@ func (m *defaultServerRouter) updateRoutes(routesMap map[route.ID]*route.Route) return nil } -func (m *defaultServerRouter) removeFromServerNetwork(route *route.Route) error { - select { - case <-m.ctx.Done(): +func (m *serverRouter) removeFromServerNetwork(route *route.Route) error { + if m.ctx.Err() != nil { log.Infof("Not removing from server network because context is done") return m.ctx.Err() - default: - m.mux.Lock() - defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(route) - if err != nil { - return fmt.Errorf("parse prefix: %w", err) - } - - err = m.firewall.RemoveNatRule(routerPair) - if err != nil { - return fmt.Errorf("remove routing rules: %w", err) - } - - delete(m.routes, route.ID) - - state := m.statusRecorder.GetLocalPeerState() - delete(state.Routes, route.Network.String()) - m.statusRecorder.UpdateLocalPeerState(state) - - return nil } + + m.mux.Lock() + defer m.mux.Unlock() + + routerPair, err := routeToRouterPair(route) + if err != nil { + return fmt.Errorf("parse prefix: %w", err) + } + + err = m.firewall.RemoveNatRule(routerPair) + if err != nil { + return fmt.Errorf("remove routing rules: %w", err) + } + + delete(m.routes, route.ID) + + state := m.statusRecorder.GetLocalPeerState() + delete(state.Routes, route.Network.String()) + m.statusRecorder.UpdateLocalPeerState(state) + + return nil } -func (m *defaultServerRouter) addToServerNetwork(route *route.Route) error { - select { - case <-m.ctx.Done(): +func (m *serverRouter) addToServerNetwork(route *route.Route) error { + if m.ctx.Err() != nil { log.Infof("Not adding to server network because context is done") return m.ctx.Err() - default: - m.mux.Lock() - defer m.mux.Unlock() - - routerPair, err := routeToRouterPair(route) - if err != nil { - return fmt.Errorf("parse prefix: %w", err) - } - - err = m.firewall.AddNatRule(routerPair) - if err != nil { - return fmt.Errorf("insert routing rules: %w", err) - } - - m.routes[route.ID] = route - - state := m.statusRecorder.GetLocalPeerState() - if state.Routes == nil { - state.Routes = map[string]struct{}{} - } - - routeStr := route.Network.String() - if route.IsDynamic() { - routeStr = route.Domains.SafeString() - } - state.Routes[routeStr] = struct{}{} - - m.statusRecorder.UpdateLocalPeerState(state) - - return nil } + + m.mux.Lock() + defer m.mux.Unlock() + + routerPair, err := routeToRouterPair(route) + if err != nil { + return fmt.Errorf("parse prefix: %w", err) + } + + err = m.firewall.AddNatRule(routerPair) + if err != nil { + return fmt.Errorf("insert routing rules: %w", err) + } + + m.routes[route.ID] = route + + state := m.statusRecorder.GetLocalPeerState() + if state.Routes == nil { + state.Routes = map[string]struct{}{} + } + + routeStr := route.Network.String() + if route.IsDynamic() { + routeStr = route.Domains.SafeString() + } + state.Routes[routeStr] = struct{}{} + + m.statusRecorder.UpdateLocalPeerState(state) + + return nil } -func (m *defaultServerRouter) cleanUp() { +func (m *serverRouter) cleanUp() { m.mux.Lock() defer m.mux.Unlock() for _, r := range m.routes { diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index f0d3021e9..659277570 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 -// protoc v3.21.9 +// protoc v4.23.4 // source: daemon.proto package proto @@ -122,6 +122,10 @@ type LoginRequest struct { ExtraIFaceBlacklist []string `protobuf:"bytes,17,rep,name=extraIFaceBlacklist,proto3" json:"extraIFaceBlacklist,omitempty"` NetworkMonitor *bool `protobuf:"varint,18,opt,name=networkMonitor,proto3,oneof" json:"networkMonitor,omitempty"` DnsRouteInterval *durationpb.Duration `protobuf:"bytes,19,opt,name=dnsRouteInterval,proto3,oneof" json:"dnsRouteInterval,omitempty"` + DisableClientRoutes *bool `protobuf:"varint,20,opt,name=disable_client_routes,json=disableClientRoutes,proto3,oneof" json:"disable_client_routes,omitempty"` + DisableServerRoutes *bool `protobuf:"varint,21,opt,name=disable_server_routes,json=disableServerRoutes,proto3,oneof" json:"disable_server_routes,omitempty"` + DisableDns *bool `protobuf:"varint,22,opt,name=disable_dns,json=disableDns,proto3,oneof" json:"disable_dns,omitempty"` + DisableFirewall *bool `protobuf:"varint,23,opt,name=disable_firewall,json=disableFirewall,proto3,oneof" json:"disable_firewall,omitempty"` } func (x *LoginRequest) Reset() { @@ -290,6 +294,34 @@ func (x *LoginRequest) GetDnsRouteInterval() *durationpb.Duration { return nil } +func (x *LoginRequest) GetDisableClientRoutes() bool { + if x != nil && x.DisableClientRoutes != nil { + return *x.DisableClientRoutes + } + return false +} + +func (x *LoginRequest) GetDisableServerRoutes() bool { + if x != nil && x.DisableServerRoutes != nil { + return *x.DisableServerRoutes + } + return false +} + +func (x *LoginRequest) GetDisableDns() bool { + if x != nil && x.DisableDns != nil { + return *x.DisableDns + } + return false +} + +func (x *LoginRequest) GetDisableFirewall() bool { + if x != nil && x.DisableFirewall != nil { + return *x.DisableFirewall + } + return false +} + type LoginResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -2541,7 +2573,7 @@ var file_daemon_proto_rawDesc = []byte{ 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xb0, 0x08, 0x0a, 0x0c, 0x4c, 0x6f, + 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xd1, 0x0a, 0x0a, 0x0c, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x65, 0x74, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x12, 0x26, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, @@ -2596,349 +2628,367 @@ var file_daemon_proto_rawDesc = []byte{ 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x13, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x48, 0x08, 0x52, 0x10, 0x64, 0x6e, 0x73, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x88, 0x01, 0x01, 0x42, 0x13, - 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, - 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, - 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, 0x0e, 0x5f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, - 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, 0x17, 0x0a, 0x15, 0x5f, 0x6f, 0x70, 0x74, 0x69, - 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, - 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, - 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x73, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x42, 0x16, 0x0a, 0x14, - 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, - 0x73, 0x69, 0x76, 0x65, 0x42, 0x11, 0x0a, 0x0f, 0x5f, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x64, 0x6e, 0x73, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x22, 0xb5, 0x01, 0x0a, - 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, - 0x0a, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, - 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, - 0x12, 0x28, 0x0a, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x55, 0x52, 0x49, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, - 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x12, 0x38, 0x0a, 0x17, 0x76, 0x65, + 0x75, 0x74, 0x65, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x88, 0x01, 0x01, 0x12, 0x37, + 0x0a, 0x15, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, + 0x5f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x14, 0x20, 0x01, 0x28, 0x08, 0x48, 0x09, 0x52, + 0x13, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x52, 0x6f, + 0x75, 0x74, 0x65, 0x73, 0x88, 0x01, 0x01, 0x12, 0x37, 0x0a, 0x15, 0x64, 0x69, 0x73, 0x61, 0x62, + 0x6c, 0x65, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, + 0x18, 0x15, 0x20, 0x01, 0x28, 0x08, 0x48, 0x0a, 0x52, 0x13, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, + 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x88, 0x01, 0x01, + 0x12, 0x24, 0x0a, 0x0b, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x64, 0x6e, 0x73, 0x18, + 0x16, 0x20, 0x01, 0x28, 0x08, 0x48, 0x0b, 0x52, 0x0a, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, + 0x44, 0x6e, 0x73, 0x88, 0x01, 0x01, 0x12, 0x2e, 0x0a, 0x10, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, + 0x65, 0x5f, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x18, 0x17, 0x20, 0x01, 0x28, 0x08, + 0x48, 0x0c, 0x52, 0x0f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x46, 0x69, 0x72, 0x65, 0x77, + 0x61, 0x6c, 0x6c, 0x88, 0x01, 0x01, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, + 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x42, 0x10, 0x0a, 0x0e, 0x5f, + 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x10, 0x0a, + 0x0e, 0x5f, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x42, + 0x17, 0x0a, 0x15, 0x5f, 0x6f, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x61, 0x6c, 0x50, 0x72, 0x65, 0x53, + 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x64, 0x69, 0x73, + 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x42, + 0x13, 0x0a, 0x11, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, + 0x6f, 0x77, 0x65, 0x64, 0x42, 0x16, 0x0a, 0x14, 0x5f, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, + 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x42, 0x11, 0x0a, 0x0f, + 0x5f, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x6f, 0x6e, 0x69, 0x74, 0x6f, 0x72, 0x42, + 0x13, 0x0a, 0x11, 0x5f, 0x64, 0x6e, 0x73, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x6e, 0x74, 0x65, + 0x72, 0x76, 0x61, 0x6c, 0x42, 0x18, 0x0a, 0x16, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, + 0x5f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x42, 0x18, + 0x0a, 0x16, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x5f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x42, 0x0e, 0x0a, 0x0c, 0x5f, 0x64, 0x69, 0x73, + 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x64, 0x6e, 0x73, 0x42, 0x13, 0x0a, 0x11, 0x5f, 0x64, 0x69, 0x73, + 0x61, 0x62, 0x6c, 0x65, 0x5f, 0x66, 0x69, 0x72, 0x65, 0x77, 0x61, 0x6c, 0x6c, 0x22, 0xb5, 0x01, + 0x0a, 0x0d, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x24, 0x0a, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0d, 0x6e, 0x65, 0x65, 0x64, 0x73, 0x53, 0x53, 0x4f, + 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, + 0x65, 0x12, 0x28, 0x0a, 0x0f, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, + 0x6e, 0x55, 0x52, 0x49, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x76, 0x65, 0x72, 0x69, + 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x12, 0x38, 0x0a, 0x17, 0x76, + 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, + 0x6d, 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x17, 0x76, 0x65, 0x72, 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, - 0x70, 0x6c, 0x65, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x17, 0x76, 0x65, 0x72, - 0x69, 0x66, 0x69, 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x55, 0x52, 0x49, 0x43, 0x6f, 0x6d, 0x70, - 0x6c, 0x65, 0x74, 0x65, 0x22, 0x4d, 0x0a, 0x13, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, - 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, - 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, - 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, - 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, 0x6e, - 0x61, 0x6d, 0x65, 0x22, 0x16, 0x0a, 0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, - 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, 0x55, - 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, - 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, - 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, - 0x12, 0x32, 0x0a, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, 0x75, - 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, - 0x61, 0x74, 0x75, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, - 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, 0x6f, - 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, 0x77, - 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, - 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb9, 0x03, - 0x0a, 0x11, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, - 0x74, 0x55, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, 0x61, - 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, - 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, 0x67, - 0x46, 0x69, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, 0x46, - 0x69, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, 0x64, - 0x4b, 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, - 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, - 0x55, 0x52, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, 0x6e, - 0x55, 0x52, 0x4c, 0x12, 0x24, 0x0a, 0x0d, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, - 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x69, 0x6e, 0x74, 0x65, - 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x77, 0x69, 0x72, - 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x0d, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x12, - 0x2e, 0x0a, 0x12, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, - 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x64, 0x69, 0x73, - 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x12, - 0x2a, 0x0a, 0x10, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, - 0x77, 0x65, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x73, 0x65, 0x72, 0x76, 0x65, - 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x10, 0x72, - 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, - 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, - 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x0c, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, - 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xde, 0x05, 0x0a, 0x09, 0x50, 0x65, - 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, - 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, - 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, - 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, 0x64, - 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, - 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, - 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, - 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x65, - 0x64, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, - 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, - 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, - 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, - 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, - 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, - 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, - 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, + 0x70, 0x6c, 0x65, 0x74, 0x65, 0x22, 0x4d, 0x0a, 0x13, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, + 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, + 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x75, 0x73, 0x65, 0x72, 0x43, 0x6f, 0x64, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x68, 0x6f, 0x73, 0x74, + 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x68, 0x6f, 0x73, 0x74, + 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x16, 0x0a, 0x14, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, + 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x0b, 0x0a, 0x09, + 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0c, 0x0a, 0x0a, 0x55, 0x70, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x3d, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2c, 0x0a, 0x11, 0x67, 0x65, 0x74, 0x46, + 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x11, 0x67, 0x65, 0x74, 0x46, 0x75, 0x6c, 0x6c, 0x50, 0x65, 0x65, 0x72, + 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x82, 0x01, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x12, 0x32, 0x0a, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x46, + 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x0a, 0x66, 0x75, 0x6c, 0x6c, 0x53, + 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x24, 0x0a, 0x0d, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x56, + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x0d, 0x0a, 0x0b, 0x44, + 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x0e, 0x0a, 0x0c, 0x44, 0x6f, + 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, + 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0xb9, + 0x03, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, + 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x6d, 0x61, 0x6e, + 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x55, 0x72, 0x6c, 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, + 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x6f, + 0x67, 0x46, 0x69, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x6f, 0x67, + 0x46, 0x69, 0x6c, 0x65, 0x12, 0x22, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x53, 0x68, 0x61, 0x72, 0x65, + 0x64, 0x4b, 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x70, 0x72, 0x65, 0x53, + 0x68, 0x61, 0x72, 0x65, 0x64, 0x4b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x61, 0x64, 0x6d, 0x69, + 0x6e, 0x55, 0x52, 0x4c, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x61, 0x64, 0x6d, 0x69, + 0x6e, 0x55, 0x52, 0x4c, 0x12, 0x24, 0x0a, 0x0d, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, + 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x69, 0x6e, 0x74, + 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x24, 0x0a, 0x0d, 0x77, 0x69, + 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x07, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x0d, 0x77, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x50, 0x6f, 0x72, 0x74, + 0x12, 0x2e, 0x0a, 0x12, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, + 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x08, 0x52, 0x12, 0x64, 0x69, + 0x73, 0x61, 0x62, 0x6c, 0x65, 0x41, 0x75, 0x74, 0x6f, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, + 0x12, 0x2a, 0x0a, 0x10, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, + 0x6f, 0x77, 0x65, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x73, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x53, 0x53, 0x48, 0x41, 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x12, 0x2a, 0x0a, 0x10, + 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, + 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, + 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, + 0x0c, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, + 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x22, 0xde, 0x05, 0x0a, 0x09, 0x50, + 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, + 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, + 0x12, 0x1e, 0x0a, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, + 0x12, 0x46, 0x0a, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x55, 0x70, + 0x64, 0x61, 0x74, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, + 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, + 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x10, 0x63, 0x6f, 0x6e, 0x6e, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x72, 0x65, 0x6c, 0x61, + 0x79, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x72, 0x65, 0x6c, 0x61, 0x79, + 0x65, 0x64, 0x12, 0x34, 0x0a, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, + 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x15, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, + 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x36, 0x0a, 0x16, 0x72, 0x65, 0x6d, 0x6f, + 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, + 0x70, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x16, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, + 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x54, 0x79, 0x70, 0x65, + 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x09, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x12, 0x3c, 0x0a, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, + 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, + 0x74, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, + 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, - 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x19, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x49, 0x63, 0x65, - 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, - 0x74, 0x12, 0x3e, 0x0a, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, 0x43, 0x61, - 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x18, - 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, 0x65, - 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, - 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, - 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, 0x6c, - 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, - 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, - 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, 0x12, - 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x03, - 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, - 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, - 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x73, 0x12, 0x33, 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, 0x6c, - 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x41, - 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, 0x65, - 0x6c, 0x61, 0x79, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0xf0, 0x01, 0x0a, 0x0e, 0x4c, - 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, 0x0a, - 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, 0x0a, - 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, - 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, - 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, - 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x12, - 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x66, - 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, - 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, - 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, - 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, - 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, 0x6f, - 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, - 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x18, 0x07, 0x20, - 0x03, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x22, 0x53, 0x0a, - 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, - 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, - 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, - 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, - 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, - 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, - 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, 0x52, - 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x49, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, 0x61, - 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, - 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, - 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, - 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, - 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, - 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, - 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, 0x0a, - 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, - 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, - 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, - 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, 0x0e, - 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x03, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, - 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, 0x6f, - 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, 0x05, - 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x05, - 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, 0x18, - 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x52, - 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, - 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, - 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, 0x6e, - 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x22, 0x15, 0x0a, 0x13, 0x4c, 0x69, 0x73, 0x74, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, - 0x3f, 0x0a, 0x14, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x27, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, - 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x22, 0x61, 0x0a, 0x15, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, - 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x6e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x6e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, 0x70, - 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, 0x6e, - 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, - 0x61, 0x6c, 0x6c, 0x22, 0x18, 0x0a, 0x16, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x0a, - 0x06, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, 0x01, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x07, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x14, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x73, - 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x73, - 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, - 0x6e, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, 0x6e, - 0x73, 0x12, 0x42, 0x0a, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, - 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, - 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, - 0x65, 0x64, 0x49, 0x50, 0x73, 0x1a, 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, - 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, 0x76, - 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, - 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, - 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, 0x61, - 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, - 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, - 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, - 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, 0x68, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, 0x12, - 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, - 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, - 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, - 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, 0x22, - 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, - 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, - 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, - 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, - 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, - 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0d, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x73, - 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, - 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, - 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, 0x43, - 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, - 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, 0x6e, - 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, - 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, - 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, - 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, - 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, - 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, - 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, 0x0a, - 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, - 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, - 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, - 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x62, - 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, 0x4e, - 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, 0x43, - 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, 0x0a, - 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, 0x4e, - 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, 0x05, - 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, 0x45, - 0x10, 0x07, 0x32, 0x93, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, - 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, - 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, - 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, 0x12, - 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, - 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x43, - 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, - 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, - 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, 0x0c, - 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, - 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, 0x6c, - 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x09, 0x52, 0x1a, 0x72, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x49, 0x63, + 0x65, 0x43, 0x61, 0x6e, 0x64, 0x69, 0x64, 0x61, 0x74, 0x65, 0x45, 0x6e, 0x64, 0x70, 0x6f, 0x69, + 0x6e, 0x74, 0x12, 0x52, 0x0a, 0x16, 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, + 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x18, 0x0c, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x16, + 0x6c, 0x61, 0x73, 0x74, 0x57, 0x69, 0x72, 0x65, 0x67, 0x75, 0x61, 0x72, 0x64, 0x48, 0x61, 0x6e, + 0x64, 0x73, 0x68, 0x61, 0x6b, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, + 0x78, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x52, 0x78, + 0x12, 0x18, 0x0a, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x18, 0x0e, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x07, 0x62, 0x79, 0x74, 0x65, 0x73, 0x54, 0x78, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, + 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x0f, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, + 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x18, 0x10, 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x12, 0x33, 0x0a, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x18, 0x11, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, 0x07, + 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, 0x12, 0x22, 0x0a, 0x0c, 0x72, 0x65, 0x6c, 0x61, 0x79, + 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x12, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0c, 0x72, + 0x65, 0x6c, 0x61, 0x79, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x22, 0xf0, 0x01, 0x0a, 0x0e, + 0x4c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x0e, + 0x0a, 0x02, 0x49, 0x50, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x49, 0x50, 0x12, 0x16, + 0x0a, 0x06, 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, + 0x70, 0x75, 0x62, 0x4b, 0x65, 0x79, 0x12, 0x28, 0x0a, 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, + 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x0f, 0x6b, 0x65, 0x72, 0x6e, 0x65, 0x6c, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x66, 0x61, 0x63, 0x65, + 0x12, 0x12, 0x0a, 0x04, 0x66, 0x71, 0x64, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x66, 0x71, 0x64, 0x6e, 0x12, 0x2a, 0x0a, 0x10, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, + 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, + 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, + 0x12, 0x30, 0x0a, 0x13, 0x72, 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, + 0x6d, 0x69, 0x73, 0x73, 0x69, 0x76, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x13, 0x72, + 0x6f, 0x73, 0x65, 0x6e, 0x70, 0x61, 0x73, 0x73, 0x50, 0x65, 0x72, 0x6d, 0x69, 0x73, 0x73, 0x69, + 0x76, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x18, 0x07, + 0x20, 0x03, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x22, 0x53, + 0x0a, 0x0b, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, + 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, + 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, + 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, + 0x72, 0x6f, 0x72, 0x22, 0x57, 0x0a, 0x0f, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x4c, 0x12, 0x1c, 0x0a, 0x09, 0x63, 0x6f, 0x6e, 0x6e, + 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x63, 0x6f, 0x6e, + 0x6e, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, 0x52, 0x0a, 0x0a, + 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x55, 0x52, + 0x49, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x55, 0x52, 0x49, 0x12, 0x1c, 0x0a, 0x09, + 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x09, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, + 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, + 0x22, 0x72, 0x0a, 0x0c, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, + 0x6d, 0x61, 0x69, 0x6e, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, + 0x61, 0x69, 0x6e, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x14, + 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, + 0x72, 0x72, 0x6f, 0x72, 0x22, 0xd2, 0x02, 0x0a, 0x0a, 0x46, 0x75, 0x6c, 0x6c, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x12, 0x41, 0x0a, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, 0x74, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0f, 0x6d, 0x61, 0x6e, 0x61, 0x67, 0x65, 0x6d, 0x65, 0x6e, + 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x35, 0x0a, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x52, 0x0b, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x6c, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x3e, 0x0a, + 0x0e, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x16, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, + 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0e, 0x6c, + 0x6f, 0x63, 0x61, 0x6c, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x27, 0x0a, + 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, + 0x05, 0x70, 0x65, 0x65, 0x72, 0x73, 0x12, 0x2a, 0x0a, 0x06, 0x72, 0x65, 0x6c, 0x61, 0x79, 0x73, + 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x52, 0x65, 0x6c, 0x61, 0x79, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, 0x72, 0x65, 0x6c, 0x61, + 0x79, 0x73, 0x12, 0x35, 0x0a, 0x0b, 0x64, 0x6e, 0x73, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, + 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4e, 0x53, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x0a, 0x64, + 0x6e, 0x73, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x22, 0x15, 0x0a, 0x13, 0x4c, 0x69, 0x73, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x22, 0x3f, 0x0a, 0x14, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x27, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, + 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, + 0x73, 0x22, 0x61, 0x0a, 0x15, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x6e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x49, 0x44, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, + 0x6e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x49, 0x44, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x70, + 0x70, 0x65, 0x6e, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x06, 0x61, 0x70, 0x70, 0x65, + 0x6e, 0x64, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x18, 0x0a, 0x16, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, + 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, + 0x0a, 0x06, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x69, 0x70, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x03, 0x69, 0x70, 0x73, 0x22, 0xf9, 0x01, 0x0a, 0x07, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x12, 0x0e, 0x0a, 0x02, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x02, 0x49, 0x44, 0x12, 0x14, 0x0a, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x12, 0x1a, 0x0a, 0x08, + 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, + 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x65, 0x64, 0x12, 0x18, 0x0a, 0x07, 0x64, 0x6f, 0x6d, 0x61, + 0x69, 0x6e, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x64, 0x6f, 0x6d, 0x61, 0x69, + 0x6e, 0x73, 0x12, 0x42, 0x0a, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, 0x64, 0x49, 0x50, + 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x2e, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, 0x65, + 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0b, 0x72, 0x65, 0x73, 0x6f, 0x6c, + 0x76, 0x65, 0x64, 0x49, 0x50, 0x73, 0x1a, 0x4e, 0x0a, 0x10, 0x52, 0x65, 0x73, 0x6f, 0x6c, 0x76, + 0x65, 0x64, 0x49, 0x50, 0x73, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, + 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x24, 0x0a, 0x05, + 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x49, 0x50, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x05, 0x76, 0x61, 0x6c, + 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x6a, 0x0a, 0x12, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, + 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1c, 0x0a, 0x09, + 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, + 0x09, 0x61, 0x6e, 0x6f, 0x6e, 0x79, 0x6d, 0x69, 0x7a, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x12, 0x1e, 0x0a, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, 0x66, 0x6f, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x73, 0x79, 0x73, 0x74, 0x65, 0x6d, 0x49, 0x6e, + 0x66, 0x6f, 0x22, 0x29, 0x0a, 0x13, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, + 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x70, 0x61, 0x74, + 0x68, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x70, 0x61, 0x74, 0x68, 0x22, 0x14, 0x0a, + 0x12, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, + 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, + 0x76, 0x65, 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, + 0x65, 0x6c, 0x22, 0x3c, 0x0a, 0x12, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, + 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, 0x0a, 0x05, 0x6c, 0x65, 0x76, 0x65, + 0x6c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x05, 0x6c, 0x65, 0x76, 0x65, 0x6c, + 0x22, 0x15, 0x0a, 0x13, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1b, 0x0a, 0x05, 0x53, 0x74, 0x61, 0x74, 0x65, + 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x13, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x3b, 0x0a, 0x12, 0x4c, 0x69, 0x73, + 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, + 0x25, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x0d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x06, + 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x44, 0x0a, 0x11, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1d, 0x0a, 0x0a, 0x73, + 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x6c, + 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, 0x22, 0x3b, 0x0a, 0x12, + 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x6c, 0x65, 0x61, 0x6e, 0x65, 0x64, 0x5f, 0x73, 0x74, + 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x0d, 0x63, 0x6c, 0x65, 0x61, + 0x6e, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x45, 0x0a, 0x12, 0x44, 0x65, 0x6c, + 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, 0x74, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x09, 0x73, 0x74, 0x61, 0x74, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x10, + 0x0a, 0x03, 0x61, 0x6c, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x61, 0x6c, 0x6c, + 0x22, 0x3c, 0x0a, 0x13, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x6c, 0x65, 0x74, + 0x65, 0x64, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, + 0x0d, 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x22, 0x3b, + 0x0a, 0x1f, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, + 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, + 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, + 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, + 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, 0x55, + 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, 0x49, + 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, 0x09, + 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, 0x52, + 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, 0x0a, + 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, 0x43, + 0x45, 0x10, 0x07, 0x32, 0x93, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, 0x65, + 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x14, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, + 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, + 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, 0x70, + 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, 0x74, + 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, + 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, 0x0a, + 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, 0x65, + 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, - 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, - 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, 0x10, - 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, - 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, - 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, - 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, - 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x47, - 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, - 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, - 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, - 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, - 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, - 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, - 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, - 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, - 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, - 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, - 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, - 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, 0x64, - 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, 0x0a, + 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, + 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, + 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, + 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, + 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, + 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, + 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, + 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, + 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, 0x19, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, + 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, 0x6e, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, + 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, + 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, + 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, + 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, 0x4e, + 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, + 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, + 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index cddf78242..ad3a4bc1a 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -107,6 +107,11 @@ message LoginRequest { optional bool networkMonitor = 18; optional google.protobuf.Duration dnsRouteInterval = 19; + + optional bool disable_client_routes = 20; + optional bool disable_server_routes = 21; + optional bool disable_dns = 22; + optional bool disable_firewall = 23; } message LoginResponse { diff --git a/client/server/server.go b/client/server/server.go index cba5a1148..70d19bfab 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -399,6 +399,23 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.latestConfigInput.DNSRouteInterval = &duration } + if msg.DisableClientRoutes != nil { + inputConfig.DisableClientRoutes = msg.DisableClientRoutes + s.latestConfigInput.DisableClientRoutes = msg.DisableClientRoutes + } + if msg.DisableServerRoutes != nil { + inputConfig.DisableServerRoutes = msg.DisableServerRoutes + s.latestConfigInput.DisableServerRoutes = msg.DisableServerRoutes + } + if msg.DisableDns != nil { + inputConfig.DisableDNS = msg.DisableDns + s.latestConfigInput.DisableDNS = msg.DisableDns + } + if msg.DisableFirewall != nil { + inputConfig.DisableFirewall = msg.DisableFirewall + s.latestConfigInput.DisableFirewall = msg.DisableFirewall + } + s.mutex.Unlock() if msg.OptionalPreSharedKey != nil { From 9e6e34b42d5125f3a7287eb83f31e950d8de678a Mon Sep 17 00:00:00 2001 From: Simon Smith Date: Wed, 8 Jan 2025 10:48:10 +0000 Subject: [PATCH 116/159] [misc] Upgrade go to 1.23 inn devcontainer (#3160) --- .devcontainer/Dockerfile | 2 +- .devcontainer/devcontainer.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 447164a9b..4697acf20 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.21-bullseye +FROM golang:1.23-bullseye RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ && apt-get -y install --no-install-recommends\ diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index ef883d31a..97aad75ad 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -7,7 +7,7 @@ "features": { "ghcr.io/devcontainers/features/docker-in-docker:2": {}, "ghcr.io/devcontainers/features/go:1": { - "version": "1.21" + "version": "1.23" } }, "workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}", From 409003b4f96909c618cfe3660fe2da0656ab4c8a Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Wed, 8 Jan 2025 19:35:57 +0300 Subject: [PATCH 117/159] [management] Add support for disabling resources and routing peers in networks (#3154) * sync openapi changes Signed-off-by: bcmmbaga * add option to disable network resource(s) Signed-off-by: bcmmbaga * add network resource enabled state from api Signed-off-by: bcmmbaga * fix tests Signed-off-by: bcmmbaga * add option to disable network router(s) Signed-off-by: bcmmbaga * fix tests Signed-off-by: bcmmbaga * Add tests Signed-off-by: bcmmbaga * migrate old network resources and routers Signed-off-by: bcmmbaga --------- Signed-off-by: bcmmbaga --- management/server/http/api/openapi.yml | 19 +++++-- management/server/http/api/types.gen.go | 15 ++++++ management/server/migration/migration.go | 50 +++++++++++++++++++ .../server/networks/resources/manager.go | 2 +- .../networks/resources/types/resource.go | 9 +++- .../server/networks/routers/manager_test.go | 8 +-- .../server/networks/routers/types/router.go | 7 ++- .../networks/routers/types/router_test.go | 11 +++- management/server/route_test.go | 12 +++++ management/server/store/sql_store_test.go | 4 +- management/server/store/store.go | 6 +++ management/server/types/account.go | 12 +++++ management/server/types/account_test.go | 47 +++++++++++++++++ 13 files changed, 188 insertions(+), 14 deletions(-) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 6c1d6b424..f53092415 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -800,15 +800,18 @@ components: items: type: string example: "ch8i4ug6lnn4g9hqv797" + sourceResource: + description: Policy rule source resource that the rule is applied to + $ref: '#/components/schemas/Resource' destinations: description: Policy rule destination group IDs type: array items: type: string example: "ch8i4ug6lnn4g9h7v7m0" - required: - - sources - - destinations + destinationResource: + description: Policy rule destination resource that the rule is applied to + $ref: '#/components/schemas/Resource' PolicyRuleCreate: allOf: @@ -1325,9 +1328,14 @@ components: description: Network resource address (either a direct host like 1.1.1.1 or 1.1.1.1/32, or a subnet like 192.168.178.0/24, or domains like example.com and *.example.com) type: string example: "1.1.1.1" + enabled: + description: Network resource status + type: boolean + example: true required: - name - address + - enabled NetworkResourceRequest: allOf: - $ref: '#/components/schemas/NetworkResourceMinimum' @@ -1390,12 +1398,17 @@ components: description: Indicate if peer should masquerade traffic to this route's prefix type: boolean example: true + enabled: + description: Network router status + type: boolean + example: true required: # Only one property has to be set #- peer #- peer_groups - metric - masquerade + - enabled NetworkRouter: allOf: - type: object diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index 83226587f..943d1b327 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -560,6 +560,9 @@ type NetworkResource struct { // Description Network resource description Description *string `json:"description,omitempty"` + // Enabled Network resource status + Enabled bool `json:"enabled"` + // Groups Groups that the resource belongs to Groups []GroupMinimum `json:"groups"` @@ -581,6 +584,9 @@ type NetworkResourceMinimum struct { // Description Network resource description Description *string `json:"description,omitempty"` + // Enabled Network resource status + Enabled bool `json:"enabled"` + // Name Network resource name Name string `json:"name"` } @@ -593,6 +599,9 @@ type NetworkResourceRequest struct { // Description Network resource description Description *string `json:"description,omitempty"` + // Enabled Network resource status + Enabled bool `json:"enabled"` + // Groups Group IDs containing the resource Groups []string `json:"groups"` @@ -605,6 +614,9 @@ type NetworkResourceType string // NetworkRouter defines model for NetworkRouter. type NetworkRouter struct { + // Enabled Network router status + Enabled bool `json:"enabled"` + // Id Network Router Id Id string `json:"id"` @@ -623,6 +635,9 @@ type NetworkRouter struct { // NetworkRouterRequest defines model for NetworkRouterRequest. type NetworkRouterRequest struct { + // Enabled Network router status + Enabled bool `json:"enabled"` + // Masquerade Indicate if peer should masquerade traffic to this route's prefix Masquerade bool `json:"masquerade"` diff --git a/management/server/migration/migration.go b/management/server/migration/migration.go index 0a6951736..8986d77b5 100644 --- a/management/server/migration/migration.go +++ b/management/server/migration/migration.go @@ -305,3 +305,53 @@ func hiddenKey(key string, length int) string { } return prefix + strings.Repeat("*", length) } + +func MigrateNewField[T any](ctx context.Context, db *gorm.DB, columnName string, defaultValue any) error { + var model T + + if !db.Migrator().HasTable(&model) { + log.WithContext(ctx).Debugf("Table for %T does not exist, no migration needed", model) + return nil + } + + stmt := &gorm.Statement{DB: db} + err := stmt.Parse(&model) + if err != nil { + return fmt.Errorf("parse model: %w", err) + } + tableName := stmt.Schema.Table + + if err := db.Transaction(func(tx *gorm.DB) error { + if !tx.Migrator().HasColumn(&model, columnName) { + log.WithContext(ctx).Infof("Column %s does not exist in table %s, adding it", columnName, tableName) + if err := tx.Migrator().AddColumn(&model, columnName); err != nil { + return fmt.Errorf("add column %s: %w", columnName, err) + } + } + + var rows []map[string]any + if err := tx.Table(tableName). + Select("id", columnName). + Where(columnName + " IS NULL OR " + columnName + " = ''"). + Find(&rows).Error; err != nil { + return fmt.Errorf("failed to find rows with empty %s: %w", columnName, err) + } + + if len(rows) == 0 { + log.WithContext(ctx).Infof("No rows with empty %s found in table %s, no migration needed", columnName, tableName) + return nil + } + + for _, row := range rows { + if err := tx.Table(tableName).Where("id = ?", row["id"]).Update(columnName, defaultValue).Error; err != nil { + return fmt.Errorf("failed to update row with id %v: %w", row["id"], err) + } + } + return nil + }); err != nil { + return err + } + + log.WithContext(ctx).Infof("Migration of empty %s to default value in table %s completed", columnName, tableName) + return nil +} diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 02b462947..725d15496 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -101,7 +101,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc return nil, status.NewPermissionDeniedError() } - resource, err = types.NewNetworkResource(resource.AccountID, resource.NetworkID, resource.Name, resource.Description, resource.Address, resource.GroupIDs) + resource, err = types.NewNetworkResource(resource.AccountID, resource.NetworkID, resource.Name, resource.Description, resource.Address, resource.GroupIDs, resource.Enabled) if err != nil { return nil, fmt.Errorf("failed to create new network resource: %w", err) } diff --git a/management/server/networks/resources/types/resource.go b/management/server/networks/resources/types/resource.go index 162f90378..0df6727c3 100644 --- a/management/server/networks/resources/types/resource.go +++ b/management/server/networks/resources/types/resource.go @@ -40,9 +40,10 @@ type NetworkResource struct { GroupIDs []string `gorm:"-"` Domain string Prefix netip.Prefix `gorm:"serializer:json"` + Enabled bool } -func NewNetworkResource(accountID, networkID, name, description, address string, groupIDs []string) (*NetworkResource, error) { +func NewNetworkResource(accountID, networkID, name, description, address string, groupIDs []string, enabled bool) (*NetworkResource, error) { resourceType, domain, prefix, err := GetResourceType(address) if err != nil { return nil, fmt.Errorf("invalid address: %w", err) @@ -59,6 +60,7 @@ func NewNetworkResource(accountID, networkID, name, description, address string, Domain: domain, Prefix: prefix, GroupIDs: groupIDs, + Enabled: enabled, }, nil } @@ -75,6 +77,7 @@ func (n *NetworkResource) ToAPIResponse(groups []api.GroupMinimum) *api.NetworkR Type: api.NetworkResourceType(n.Type.String()), Address: addr, Groups: groups, + Enabled: n.Enabled, } } @@ -86,6 +89,7 @@ func (n *NetworkResource) FromAPIRequest(req *api.NetworkResourceRequest) { } n.Address = req.Address n.GroupIDs = req.Groups + n.Enabled = req.Enabled } func (n *NetworkResource) Copy() *NetworkResource { @@ -100,6 +104,7 @@ func (n *NetworkResource) Copy() *NetworkResource { Domain: n.Domain, Prefix: n.Prefix, GroupIDs: n.GroupIDs, + Enabled: n.Enabled, } } @@ -115,7 +120,7 @@ func (n *NetworkResource) ToRoute(peer *nbpeer.Peer, router *routerTypes.Network PeerGroups: nil, Masquerade: router.Masquerade, Metric: router.Metric, - Enabled: true, + Enabled: n.Enabled, Groups: nil, AccessControlGroups: nil, } diff --git a/management/server/networks/routers/manager_test.go b/management/server/networks/routers/manager_test.go index e650074cc..47f5ad7e3 100644 --- a/management/server/networks/routers/manager_test.go +++ b/management/server/networks/routers/manager_test.go @@ -101,7 +101,7 @@ func Test_GetRouterReturnsPermissionDenied(t *testing.T) { func Test_CreateRouterSuccessfully(t *testing.T) { ctx := context.Background() userID := "allowedUser" - router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999) + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999, true) if err != nil { require.NoError(t, err) } @@ -127,7 +127,7 @@ func Test_CreateRouterSuccessfully(t *testing.T) { func Test_CreateRouterFailsWithPermissionDenied(t *testing.T) { ctx := context.Background() userID := "invalidUser" - router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999) + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 9999, true) if err != nil { require.NoError(t, err) } @@ -191,7 +191,7 @@ func Test_DeleteRouterFailsWithPermissionDenied(t *testing.T) { func Test_UpdateRouterSuccessfully(t *testing.T) { ctx := context.Background() userID := "allowedUser" - router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1) + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1, true) if err != nil { require.NoError(t, err) } @@ -213,7 +213,7 @@ func Test_UpdateRouterSuccessfully(t *testing.T) { func Test_UpdateRouterFailsWithPermissionDenied(t *testing.T) { ctx := context.Background() userID := "invalidUser" - router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1) + router, err := types.NewNetworkRouter("testAccountId", "testNetworkId", "testPeerId", []string{}, false, 1, true) if err != nil { require.NoError(t, err) } diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go index f37ae0861..5158ebb12 100644 --- a/management/server/networks/routers/types/router.go +++ b/management/server/networks/routers/types/router.go @@ -17,9 +17,10 @@ type NetworkRouter struct { PeerGroups []string `gorm:"serializer:json"` Masquerade bool Metric int + Enabled bool } -func NewNetworkRouter(accountID string, networkID string, peer string, peerGroups []string, masquerade bool, metric int) (*NetworkRouter, error) { +func NewNetworkRouter(accountID string, networkID string, peer string, peerGroups []string, masquerade bool, metric int, enabled bool) (*NetworkRouter, error) { if peer != "" && len(peerGroups) > 0 { return nil, errors.New("peer and peerGroups cannot be set at the same time") } @@ -32,6 +33,7 @@ func NewNetworkRouter(accountID string, networkID string, peer string, peerGroup PeerGroups: peerGroups, Masquerade: masquerade, Metric: metric, + Enabled: enabled, }, nil } @@ -42,6 +44,7 @@ func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter { PeerGroups: &n.PeerGroups, Masquerade: n.Masquerade, Metric: n.Metric, + Enabled: n.Enabled, } } @@ -56,6 +59,7 @@ func (n *NetworkRouter) FromAPIRequest(req *api.NetworkRouterRequest) { n.Masquerade = req.Masquerade n.Metric = req.Metric + n.Enabled = req.Enabled } func (n *NetworkRouter) Copy() *NetworkRouter { @@ -67,6 +71,7 @@ func (n *NetworkRouter) Copy() *NetworkRouter { PeerGroups: n.PeerGroups, Masquerade: n.Masquerade, Metric: n.Metric, + Enabled: n.Enabled, } } diff --git a/management/server/networks/routers/types/router_test.go b/management/server/networks/routers/types/router_test.go index 3335f7c89..5801e3bfa 100644 --- a/management/server/networks/routers/types/router_test.go +++ b/management/server/networks/routers/types/router_test.go @@ -11,6 +11,7 @@ func TestNewNetworkRouter(t *testing.T) { peerGroups []string masquerade bool metric int + enabled bool expectedError bool }{ // Valid cases @@ -22,6 +23,7 @@ func TestNewNetworkRouter(t *testing.T) { peerGroups: nil, masquerade: true, metric: 100, + enabled: true, expectedError: false, }, { @@ -32,6 +34,7 @@ func TestNewNetworkRouter(t *testing.T) { peerGroups: []string{"group-1", "group-2"}, masquerade: false, metric: 200, + enabled: false, expectedError: false, }, { @@ -42,6 +45,7 @@ func TestNewNetworkRouter(t *testing.T) { peerGroups: nil, masquerade: true, metric: 300, + enabled: true, expectedError: false, }, @@ -54,13 +58,14 @@ func TestNewNetworkRouter(t *testing.T) { peerGroups: []string{"group-3"}, masquerade: false, metric: 400, + enabled: false, expectedError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - router, err := NewNetworkRouter(tt.accountID, tt.networkID, tt.peer, tt.peerGroups, tt.masquerade, tt.metric) + router, err := NewNetworkRouter(tt.accountID, tt.networkID, tt.peer, tt.peerGroups, tt.masquerade, tt.metric, tt.enabled) if tt.expectedError && err == nil { t.Fatalf("Expected an error, got nil") @@ -94,6 +99,10 @@ func TestNewNetworkRouter(t *testing.T) { if router.Metric != tt.metric { t.Errorf("Expected Metric %d, got %d", tt.metric, router.Metric) } + + if router.Enabled != tt.enabled { + t.Errorf("Expected Enabled %v, got %v", tt.enabled, router.Enabled) + } } }) } diff --git a/management/server/route_test.go b/management/server/route_test.go index 48dbd11ed..1c5c56f60 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -2388,6 +2388,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { PeerGroups: nil, Masquerade: false, Metric: 9999, + Enabled: true, }, { ID: "router2", @@ -2395,12 +2396,14 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { PeerGroups: []string{"router1", "router2"}, Masquerade: false, Metric: 9999, + Enabled: true, }, { ID: "router3", NetworkID: "network3", Peer: "peerE", PeerGroups: []string{}, + Enabled: true, }, { ID: "router4", @@ -2408,6 +2411,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { PeerGroups: []string{"router1"}, Masquerade: false, Metric: 9999, + Enabled: true, }, { ID: "router5", @@ -2415,6 +2419,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Peer: "peerL", Masquerade: false, Metric: 9999, + Enabled: true, }, { ID: "router6", @@ -2422,6 +2427,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Peer: "peerN", Masquerade: false, Metric: 9999, + Enabled: true, }, }, NetworkResources: []*resourceTypes.NetworkResource{ @@ -2431,6 +2437,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Name: "Resource 1", Type: "subnet", Prefix: netip.MustParsePrefix("10.10.10.0/24"), + Enabled: true, }, { ID: "resource2", @@ -2438,6 +2445,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Name: "Resource 2", Type: "subnet", Prefix: netip.MustParsePrefix("192.168.0.0/16"), + Enabled: true, }, { ID: "resource3", @@ -2445,6 +2453,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Name: "Resource 3", Type: "domain", Domain: "example.com", + Enabled: true, }, { ID: "resource4", @@ -2452,6 +2461,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Name: "Resource 4", Type: "domain", Domain: "example.com", + Enabled: true, }, { ID: "resource5", @@ -2459,6 +2469,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Name: "Resource 5", Type: "host", Prefix: netip.MustParsePrefix("10.12.12.1/32"), + Enabled: true, }, { ID: "resource6", @@ -2466,6 +2477,7 @@ func TestAccount_GetPeerNetworkResourceFirewallRules(t *testing.T) { Name: "Resource 6", Type: "domain", Domain: "*.google.com", + Enabled: true, }, }, Policies: []*types.Policy{ diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 845bc8fd4..5928b45ba 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -2377,7 +2377,7 @@ func TestSqlStore_SaveNetworkRouter(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" networkID := "ct286bi7qv930dsrrug0" - netRouter, err := routerTypes.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0) + netRouter, err := routerTypes.NewNetworkRouter(accountID, networkID, "", []string{"net-router-grp"}, true, 0, true) require.NoError(t, err) err = store.SaveNetworkRouter(context.Background(), LockingStrengthUpdate, netRouter) @@ -2494,7 +2494,7 @@ func TestSqlStore_SaveNetworkResource(t *testing.T) { accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" networkID := "ct286bi7qv930dsrrug0" - netResource, err := resourceTypes.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com", []string{}) + netResource, err := resourceTypes.NewNetworkResource(accountID, networkID, "resource-name", "", "example.com", []string{}, true) require.NoError(t, err) err = store.SaveNetworkResource(context.Background(), LockingStrengthUpdate, netResource) diff --git a/management/server/store/store.go b/management/server/store/store.go index e1a6937e7..91ae93c7c 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -288,6 +288,12 @@ func getMigrations(ctx context.Context) []migrationFunc { func(db *gorm.DB) error { return migration.MigrateSetupKeyToHashedSetupKey[types.SetupKey](ctx, db) }, + func(db *gorm.DB) error { + return migration.MigrateNewField[resourceTypes.NetworkResource](ctx, db, "enabled", true) + }, + func(db *gorm.DB) error { + return migration.MigrateNewField[routerTypes.NetworkRouter](ctx, db, "enabled", true) + }, } } diff --git a/management/server/types/account.go b/management/server/types/account.go index 8fa61d700..f74d38cb6 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -1288,6 +1288,10 @@ func (a *Account) getNetworkResourceGroups(resourceID string) []*Group { func (a *Account) GetResourcePoliciesMap() map[string][]*Policy { resourcePolicies := make(map[string][]*Policy) for _, resource := range a.NetworkResources { + if !resource.Enabled { + continue + } + resourceAppliedPolicies := a.GetPoliciesForNetworkResource(resource.ID) resourcePolicies[resource.ID] = resourceAppliedPolicies } @@ -1301,6 +1305,10 @@ func (a *Account) GetNetworkResourcesRoutesToSync(ctx context.Context, peerID st allSourcePeers := make(map[string]struct{}, len(a.Peers)) for _, resource := range a.NetworkResources { + if !resource.Enabled { + continue + } + var addSourcePeers bool networkRoutingPeers, exists := routers[resource.NetworkID] @@ -1455,6 +1463,10 @@ func (a *Account) GetResourceRoutersMap() map[string]map[string]*routerTypes.Net routers := make(map[string]map[string]*routerTypes.NetworkRouter) for _, router := range a.NetworkRouters { + if !router.Enabled { + continue + } + if routers[router.NetworkID] == nil { routers[router.NetworkID] = make(map[string]*routerTypes.NetworkRouter) } diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index 389ab17f6..f8ab1d627 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -139,6 +139,11 @@ func setupTestAccount() *Account { AccountID: "accountID", Name: "network2", }, + { + ID: "network3ID", + AccountID: "accountID", + Name: "network3", + }, }, NetworkRouters: []*routerTypes.NetworkRouter{ { @@ -149,6 +154,7 @@ func setupTestAccount() *Account { PeerGroups: []string{}, Masquerade: false, Metric: 100, + Enabled: true, }, { ID: "router2ID", @@ -158,6 +164,7 @@ func setupTestAccount() *Account { PeerGroups: []string{}, Masquerade: false, Metric: 100, + Enabled: true, }, { ID: "router3ID", @@ -167,6 +174,7 @@ func setupTestAccount() *Account { PeerGroups: []string{}, Masquerade: false, Metric: 100, + Enabled: true, }, { ID: "router4ID", @@ -176,6 +184,7 @@ func setupTestAccount() *Account { PeerGroups: []string{"group1"}, Masquerade: false, Metric: 100, + Enabled: true, }, { ID: "router5ID", @@ -185,6 +194,7 @@ func setupTestAccount() *Account { PeerGroups: []string{"group2", "group3"}, Masquerade: false, Metric: 100, + Enabled: true, }, { ID: "router6ID", @@ -194,6 +204,17 @@ func setupTestAccount() *Account { PeerGroups: []string{"group4"}, Masquerade: false, Metric: 100, + Enabled: true, + }, + { + ID: "router6ID", + NetworkID: "network3ID", + AccountID: "accountID", + Peer: "", + PeerGroups: []string{"group6"}, + Masquerade: false, + Metric: 100, + Enabled: false, }, }, NetworkResources: []*resourceTypes.NetworkResource{ @@ -201,21 +222,31 @@ func setupTestAccount() *Account { ID: "resource1ID", AccountID: "accountID", NetworkID: "network1ID", + Enabled: true, }, { ID: "resource2ID", AccountID: "accountID", NetworkID: "network2ID", + Enabled: true, }, { ID: "resource3ID", AccountID: "accountID", NetworkID: "network1ID", + Enabled: true, }, { ID: "resource4ID", AccountID: "accountID", NetworkID: "network1ID", + Enabled: true, + }, + { + ID: "resource5ID", + AccountID: "accountID", + NetworkID: "network3ID", + Enabled: false, }, }, Policies: []*Policy{ @@ -281,6 +312,17 @@ func setupTestAccount() *Account { }, }, }, + { + ID: "policy6ID", + AccountID: "accountID", + Enabled: true, + Rules: []*PolicyRule{ + { + ID: "rule6ID", + Enabled: true, + }, + }, + }, }, } } @@ -302,6 +344,8 @@ func Test_GetResourceRoutersMap(t *testing.T) { require.Equal(t, 2, len(routers["network2ID"])) require.NotNil(t, routers["network2ID"]["peer2"]) require.NotNil(t, routers["network2ID"]["peer41"]) + + require.Equal(t, 0, len(routers["network3ID"])) } func Test_GetResourcePoliciesMap(t *testing.T) { @@ -312,6 +356,7 @@ func Test_GetResourcePoliciesMap(t *testing.T) { require.Equal(t, 1, len(policies["resource2ID"])) require.Equal(t, 2, len(policies["resource3ID"])) require.Equal(t, 1, len(policies["resource4ID"])) + require.Equal(t, 0, len(policies["resource5ID"])) } func Test_AddNetworksRoutingPeersAddsMissingPeers(t *testing.T) { @@ -476,6 +521,7 @@ func getBasicAccountsWithResource() *Account { PeerGroups: []string{}, Masquerade: false, Metric: 100, + Enabled: true, }, }, NetworkResources: []*resourceTypes.NetworkResource{ @@ -486,6 +532,7 @@ func getBasicAccountsWithResource() *Account { Address: "10.10.10.0/24", Prefix: netip.MustParsePrefix("10.10.10.0/24"), Type: resourceTypes.NetworkResourceType("subnet"), + Enabled: true, }, }, Policies: []*Policy{ From baf211203afe8e38413c9c7417a6f05d77f9bb37 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 8 Jan 2025 23:17:54 +0300 Subject: [PATCH 118/159] fix merge Signed-off-by: bcmmbaga --- management/server/account.go | 10 ++-- management/server/ephemeral.go | 3 +- management/server/ephemeral_test.go | 2 +- .../http/handlers/peers/peers_handler.go | 51 +++++++----------- .../http/handlers/peers/peers_handler_test.go | 8 +-- management/server/integrated_validator.go | 10 ++-- management/server/mock_server/account_mock.go | 4 +- management/server/peer.go | 52 +++++++++---------- management/server/store/sql_store_test.go | 5 +- .../testdata/store_with_expired_peers.sql | 2 +- management/server/user.go | 2 +- 11 files changed, 67 insertions(+), 82 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index a5b1a7070..19ff96cfb 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -105,7 +105,7 @@ type AccountManager interface { DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error - GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) + GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) SavePolicy(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) DeletePolicy(ctx context.Context, accountID, policyID, userID string) error @@ -438,7 +438,7 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con return nil } -func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration @@ -663,7 +663,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI return "", status.Errorf(status.NotFound, "no valid userID provided") } - accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) + accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) @@ -1448,7 +1448,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", err } - userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err @@ -1495,7 +1495,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont } func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { - userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, claims.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 41c1d5577..3d6d01434 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -10,7 +10,6 @@ import ( "github.com/netbirdio/netbird/management/server/activity" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" - "github.com/netbirdio/netbird/management/server/types" ) const ( @@ -122,7 +121,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. } func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { - peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare) + peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare) if err != nil { log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err) return diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index ce852fdc7..df8fe98c3 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -16,7 +16,7 @@ type MockStore struct { account *types.Account } -func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ LockingStrength) ([]*nbpeer.Peer, error) { +func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStrength) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer for _, v := range s.account.Peers { if v.Ephemeral { diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 18557cca0..09d63ea6f 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -72,8 +72,13 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } dnsDomain := h.accountManager.GetDNSDomain() - allGroups, _ := h.accountManager.GetAllGroups(ctx, accountID, userID) - groupsInfo := groups.ToGroupsInfo(allGroups, peerID) + groupsMap := map[string]*types.Group{} + grps, _ := h.accountManager.GetAllGroups(ctx, accountID, userID) + for _, group := range grps { + groupsMap[group.ID] = group + } + + groupsInfo := groups.ToGroupsInfo(groupsMap, peerID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -122,7 +127,13 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri util.WriteError(ctx, err, w) return } - groupMinimumInfo := groups.ToGroupsInfo(peerGroups, peer.ID) + + groupsMap := map[string]*types.Group{} + for _, group := range peerGroups { + groupsMap[group.ID] = group + } + + groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -193,7 +204,11 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - allGroups, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + groupsMap := map[string]*types.Group{} + grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + for _, group := range grps { + groupsMap[group.ID] = group + } respBody := make([]*api.PeerBatch, 0, len(peers)) for _, peer := range peers { @@ -202,7 +217,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := groups.ToGroupsInfo(allGroups, peer.ID) + groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } @@ -314,32 +329,6 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toGroupsInfo(groups []*nbgroup.Group, peerID string) []api.GroupMinimum { - groupsInfo := []api.GroupMinimum{} - groupsChecked := make(map[string]struct{}) - - for _, group := range groups { - _, ok := groupsChecked[group.ID] - if ok { - continue - } - - groupsChecked[group.ID] = struct{}{} - for _, pk := range group.Peers { - if pk == peerID { - info := api.GroupMinimum{ - Id: group.ID, - Name: group.Name, - PeersCount: len(group.Peers), - } - groupsInfo = append(groupsInfo, info) - break - } - } - } - return groupsInfo -} - func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index baab39ea9..16065a677 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -128,12 +128,12 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { GetPeersFunc: func(_ context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { return peers, nil }, - GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { + GetPeerGroupsFunc: func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { peersID := make([]string, len(peers)) for _, peer := range peers { peersID = append(peersID, peer.ID) } - return []*nbgroup.Group{ + return []*types.Group{ { ID: "group1", AccountID: accountID, @@ -149,10 +149,10 @@ func initTestMetaData(peers ...*nbpeer.Peer) *Handler { GetAccountIDFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { return claims.AccountId, claims.UserId, nil }, - GetAccountFunc: func(ctx context.Context, accountID string) (*server.Account, error) { + GetAccountFunc: func(ctx context.Context, accountID string) (*types.Account, error) { return account, nil }, - GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*server.Account, error) { + GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { return account, nil }, HasConnectedChannelFunc: func(peerID string) bool { diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 56c4e3e6f..9dad6fcd7 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -4,8 +4,6 @@ import ( "context" "errors" - nbgroup "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" @@ -80,11 +78,11 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { var err error - var groups []*nbgroup.Group + var groups []*types.Group var peers []*nbpeer.Peer - var settings *Settings + var settings *types.Settings - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return err @@ -102,7 +100,7 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI return nil, err } - groupsMap := make(map[string]*nbgroup.Group, len(groups)) + groupsMap := make(map[string]*types.Group, len(groups)) for _, group := range groups { groupsMap[group.ID] = group } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 641254e4f..c8e42d20a 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -47,7 +47,7 @@ type MockAccountManager struct { DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error - GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error) + GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy) (*types.Policy, error) @@ -843,7 +843,7 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string) } // GetPeerGroups mocks GetPeerGroups of the AccountManager interface -func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) { +func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { if am.GetPeerGroupsFunc != nil { return am.GetPeerGroupsFunc(ctx, accountID, peerID) } diff --git a/management/server/peer.go b/management/server/peer.go index 16c824ccf..c09e3d051 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -12,8 +12,6 @@ import ( "time" "github.com/netbirdio/netbird/management/server/geolocation" - nbgroup "github.com/netbirdio/netbird/management/server/group" - "github.com/netbirdio/netbird/management/server/util" "github.com/rs/xid" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" @@ -122,11 +120,11 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID // MarkPeerConnected marks peer as connected (true) or disconnected (false) func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { var peer *nbpeer.Peer - var settings *Settings + var settings *types.Settings var expired bool var err error - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey) if err != nil { return err @@ -163,7 +161,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return nil } -func updatePeerStatusAndLocation(ctx context.Context, geo *geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { +func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus newStatus.LastSeen = time.Now().UTC() @@ -215,7 +213,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } var peer *nbpeer.Peer - var settings *Settings + var settings *types.Settings var peerGroupList []string var requiresPeerUpdates bool var peerLabelChanged bool @@ -223,7 +221,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user var loginExpirationChanged bool var inactivityExpirationChanged bool - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, update.ID) if err != nil { return err @@ -348,7 +346,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer var updateAccountPeers bool var eventsToStore []func() - err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByID(ctx, store.LockingStrengthUpdate, accountID, peerID) if err != nil { return err @@ -575,7 +573,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } } - err = transaction.AddPeerToAccount(ctx, LockingStrengthUpdate, newPeer) + err = transaction.AddPeerToAccount(ctx, store.LockingStrengthUpdate, newPeer) if err != nil { return fmt.Errorf("failed to add peer to account: %w", err) } @@ -840,14 +838,14 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) unlockPeer = nil if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { - am.updateAccountPeers(ctx, accountID) + am.UpdateAccountPeers(ctx, accountID) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) } // getPeerPostureChecks returns the posture checks for the peer. -func getPeerPostureChecks(ctx context.Context, transaction Store, accountID, peerID string) ([]*posture.Checks, error) { +func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err @@ -872,7 +870,7 @@ func getPeerPostureChecks(ctx context.Context, transaction Store, accountID, pee peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...) } - peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, peerPostureChecksIDs) + peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, peerPostureChecksIDs) if err != nil { return nil, err } @@ -881,7 +879,7 @@ func getPeerPostureChecks(ctx context.Context, transaction Store, accountID, pee } // processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks. -func processPeerPostureChecks(ctx context.Context, transaction Store, policy *Policy, accountID, peerID string) ([]string, error) { +func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue @@ -980,7 +978,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact return err } - err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.peer.GetLastLogin()) + err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.GetLastLogin()) if err != nil { return err } @@ -1142,7 +1140,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account // If there is no peer that expires this function returns false and a duration of 0. // This function only considers peers that haven't been expired yet and that are connected. func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { - peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) return 0, false @@ -1186,7 +1184,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco // If there is no peer that expires this function returns false and a duration of 0. // This function only considers peers that haven't been expired yet and that are not connected. func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { - peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID) + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) return 0, false @@ -1227,7 +1225,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte // getExpiredPeers returns peers that have been expired. func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { - peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -1250,7 +1248,7 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID // getInactivePeers returns peers that have been expired by inactivity func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { - peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID) + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } @@ -1272,18 +1270,18 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID } // GetPeerGroups returns groups that the peer is part of. -func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { +func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { return getPeerGroups(ctx, am.Store, accountID, peerID) } // getPeerGroups returns the IDs of the groups that the peer is part of. -func getPeerGroups(ctx context.Context, transaction Store, accountID, peerID string) ([]*nbgroup.Group, error) { +func getPeerGroups(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*types.Group, error) { groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } - peerGroups := make([]*nbgroup.Group, 0) + peerGroups := make([]*types.Group, 0) for _, group := range groups { if slices.Contains(group.Peers, peerID) { peerGroups = append(peerGroups, group) @@ -1294,7 +1292,7 @@ func getPeerGroups(ctx context.Context, transaction Store, accountID, peerID str } // getPeerGroupIDs returns the IDs of the groups that the peer is part of. -func getPeerGroupIDs(ctx context.Context, transaction Store, accountID string, peerID string) ([]string, error) { +func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { groups, err := getPeerGroups(ctx, transaction, accountID, peerID) if err != nil { return nil, err @@ -1308,13 +1306,13 @@ func getPeerGroupIDs(ctx context.Context, transaction Store, accountID string, p return groupIDs, err } -func getPeerDNSLabels(ctx context.Context, transaction Store, accountID string) (lookupMap, error) { +func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) { dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) if err != nil { return nil, err } - existingLabels := make(lookupMap) + existingLabels := make(types.LookupMap) for _, label := range dnsLabels { existingLabels[label] = struct{}{} } @@ -1323,7 +1321,7 @@ func getPeerDNSLabels(ctx context.Context, transaction Store, accountID string) // IsPeerInActiveGroup checks if the given peer is part of a group that is used // in an active DNS, route, or ACL configuration. -func isPeerInActiveGroup(ctx context.Context, transaction Store, accountID, peerID string) (bool, error) { +func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID, peerID string) (bool, error) { peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID) if err != nil { return false, err @@ -1333,7 +1331,7 @@ func isPeerInActiveGroup(ctx context.Context, transaction Store, accountID, peer // deletePeers deletes all specified peers and sends updates to the remote peers. // Returns a slice of functions to save events after successful peer deletion. -func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { +func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { var peerDeletedEvents []func() for _, peer := range peers { @@ -1362,7 +1360,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction Sto FirewallRulesIsEmpty: true, }, }, - NetworkMap: &NetworkMap{}, + NetworkMap: &types.NetworkMap{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 1145e928f..630d4e426 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -14,6 +14,7 @@ import ( "time" "github.com/google/uuid" + "github.com/netbirdio/netbird/management/server/util" "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" @@ -2602,7 +2603,7 @@ func TestSqlStore_AddPeerToAccount(t *testing.T) { SSHEnabled: false, LoginExpirationEnabled: true, InactivityExpirationEnabled: false, - LastLogin: time.Now().UTC(), + LastLogin: util.ToPtr(time.Now().UTC()), CreatedAt: time.Now().UTC(), Ephemeral: true, } @@ -2623,7 +2624,7 @@ func TestSqlStore_AddPeerToAccount(t *testing.T) { assert.Equal(t, peer.SSHEnabled, storedPeer.SSHEnabled) assert.Equal(t, peer.LoginExpirationEnabled, storedPeer.LoginExpirationEnabled) assert.Equal(t, peer.InactivityExpirationEnabled, storedPeer.InactivityExpirationEnabled) - assert.WithinDurationf(t, peer.LastLogin, storedPeer.LastLogin.UTC(), time.Millisecond, "LastLogin should be equal") + assert.WithinDurationf(t, peer.GetLastLogin(), storedPeer.GetLastLogin().UTC(), time.Millisecond, "LastLogin should be equal") assert.WithinDurationf(t, peer.CreatedAt, storedPeer.CreatedAt.UTC(), time.Millisecond, "CreatedAt should be equal") assert.Equal(t, peer.Ephemeral, storedPeer.Ephemeral) assert.Equal(t, peer.Status.Connected, storedPeer.Status.Connected) diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index 9e0a605fe..5990a0625 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -1,6 +1,6 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`inactivity_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); diff --git a/management/server/user.go b/management/server/user.go index b189b2c85..17770a423 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -979,7 +979,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou peerIDs = append(peerIDs, peer.ID) peer.MarkLoginExpired(true) - if err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil { + if err := am.Store.SavePeerStatus(ctx, store.LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil { return err } am.StoreEvent( From fa1eaa0aec69710e7d872c90b07bb49fbc0d3559 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 9 Jan 2025 13:38:09 +0300 Subject: [PATCH 119/159] fix store tests Signed-off-by: bcmmbaga --- management/server/store/sql_store.go | 2 +- management/server/store/sql_store_test.go | 40 ++++++++++++----------- management/server/testdata/storev1.sql | 10 +++--- 3 files changed, 27 insertions(+), 25 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 25a33d8dd..0344538e2 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -406,7 +406,7 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr return result.Error } - if result.RowsAffected == 0 && s.storeEngine != MysqlStoreEngine { + if result.RowsAffected == 0 { return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID) } diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 630d4e426..34578eca8 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -424,7 +424,7 @@ func TestSqlite_GetAccount(t *testing.T) { } func TestSqlStore_SavePeer(t *testing.T) { - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -433,12 +433,13 @@ func TestSqlStore_SavePeer(t *testing.T) { // save status of non-existing peer peer := &nbpeer.Peer{ - Key: "peerkey", - ID: "testpeer", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + Key: "peerkey", + ID: "testpeer", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{Hostname: "testingpeer"}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + CreatedAt: time.Now().UTC(), } ctx := context.Background() err = store.SavePeer(ctx, LockingStrengthUpdate, account.Id, peer) @@ -472,7 +473,7 @@ func TestSqlStore_SavePeer(t *testing.T) { } func TestSqlStore_SavePeerStatus(t *testing.T) { - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -528,7 +529,7 @@ func TestSqlStore_SavePeerStatus(t *testing.T) { } func TestSqlStore_SavePeerLocation(t *testing.T) { - store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) t.Cleanup(cleanUp) assert.NoError(t, err) @@ -544,7 +545,8 @@ func TestSqlStore_SavePeerLocation(t *testing.T) { CityName: "City", GeoNameID: 1, }, - Meta: nbpeer.PeerSystemMeta{}, + CreatedAt: time.Now().UTC(), + Meta: nbpeer.PeerSystemMeta{}, } // error is expected as peer is not in store yet err = store.SavePeerLocation(context.Background(), LockingStrengthUpdate, account.Id, peer) @@ -2517,7 +2519,7 @@ func TestSqlStore_AddAndRemoveResourceFromGroup(t *testing.T) { } func TestSqlStore_AddPeerToGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2539,7 +2541,7 @@ func TestSqlStore_AddPeerToGroup(t *testing.T) { } func TestSqlStore_AddPeerToAllGroup(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2570,7 +2572,7 @@ func TestSqlStore_AddPeerToAllGroup(t *testing.T) { } func TestSqlStore_AddPeerToAccount(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_policy_migrate.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2634,7 +2636,7 @@ func TestSqlStore_AddPeerToAccount(t *testing.T) { } func TestSqlStore_GetAccountPeers(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2671,7 +2673,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) { } func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2707,7 +2709,7 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { } func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2743,7 +2745,7 @@ func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { } func TestSqlStore_GetAllEphemeralPeers(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/storev1.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/storev1.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2754,7 +2756,7 @@ func TestSqlStore_GetAllEphemeralPeers(t *testing.T) { } func TestSqlStore_GetUserPeers(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) @@ -2806,7 +2808,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) { } func TestSqlStore_DeletePeer(t *testing.T) { - store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store_with_expired_peers.sql", t.TempDir()) + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) t.Cleanup(cleanup) require.NoError(t, err) diff --git a/management/server/testdata/storev1.sql b/management/server/testdata/storev1.sql index 281fdac8a..cda333d4f 100644 --- a/management/server/testdata/storev1.sql +++ b/management/server/testdata/storev1.sql @@ -1,6 +1,6 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); -CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); +CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime DEFAULT NULL,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `personal_access_tokens` (`id` text,`user_id` text,`name` text,`hashed_token` text,`expiration_date` datetime,`created_by` text,`created_at` datetime,`last_used` datetime,PRIMARY KEY (`id`),CONSTRAINT `fk_users_pa_ts_g` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)); CREATE TABLE `groups` (`id` text,`account_id` text,`name` text,`issued` text,`peers` text,`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_groups_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -31,9 +31,9 @@ INSERT INTO setup_keys VALUES('831727121','auth0|61bf82ddeab084006aa1bccd','1B2B INSERT INTO setup_keys VALUES('1769568301','auth0|61bf82ddeab084006aa1bccd','EB51E9EB-A11F-4F6E-8E49-C982891B405A','Default key','reusable','2021-12-24 16:09:45.926073628+01:00','2022-01-23 16:09:45.926073628+01:00','2021-12-24 16:09:45.926073628+01:00',0,1,'2021-12-24 16:13:06.236748538+01:00','[]',0,0); INSERT INTO setup_keys VALUES('2485964613','google-oauth2|103201118415301331038','5AFB60DB-61F2-4251-8E11-494847EE88E9','Default key','reusable','2021-12-24 16:10:02.238476+01:00','2022-01-23 16:10:02.238476+01:00','2021-12-24 16:10:02.238476+01:00',0,1,'2021-12-24 16:12:05.994307717+01:00','[]',0,0); INSERT INTO setup_keys VALUES('3504804807','google-oauth2|103201118415301331038','A72E4DC2-00DE-4542-8A24-62945438104E','One-off key','one-off','2021-12-24 16:10:02.238478209+01:00','2022-01-23 16:10:02.238478209+01:00','2021-12-24 16:10:02.238478209+01:00',0,1,'2021-12-24 16:11:27.015741738+01:00','[]',0,0); -INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); -INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.182618+02:00',0,'""','','',0); -INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',0,'""','','',0); -INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:54.228182+02:00',1,'""','','',0); +INSERT INTO peers VALUES('oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','auth0|61bf82ddeab084006aa1bccd','oMNaI8qWi0CyclSuwGR++SurxJyM3pQEiPEHwX8IREo=','EB51E9EB-A11F-4F6E-8E49-C982891B405A','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:13:11.244342541+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','auth0|61bf82ddeab084006aa1bccd','xlx9/9D8+ibnRiIIB8nHGMxGOzxV17r8ShPHgi4aYSM=','1B2B50B0-B3E8-4B0C-A426-525EDB8481BD','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:12:49.089339333+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.182618+02:00',0,'""','','',0); +INSERT INTO peers VALUES('6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','google-oauth2|103201118415301331038','6kjbmVq1hmucVzvBXo5OucY5OYv+jSsB1jUTLq291Dw=','5AFB60DB-61F2-4251-8E11-494847EE88E9','"100.64.0.2"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini','2021-12-24 16:12:05.994305438+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.228182+02:00',0,'""','','',0); +INSERT INTO peers VALUES('Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','google-oauth2|103201118415301331038','Ok+5QMdt/UjoktNOvicGYj+IX2g98p+0N2PJ3vJ45RI=','A72E4DC2-00DE-4542-8A24-62945438104E','"100.64.0.1"','braginini','linux','Linux','21.04','x86_64','Ubuntu','','','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'braginini','braginini-1','2021-12-24 16:11:27.015739803+01:00',0,0,0,'','',0,0,NULL,'2024-10-02 17:00:54.228182+02:00',1,'""','','',0); INSERT INTO installations VALUES(1,''); From 649bfb236bb820a2c67940029a0f677d155b6cfb Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Fri, 10 Jan 2025 09:44:02 +0100 Subject: [PATCH 120/159] [management] Send relay credentials with turn updates (#3164) send relay credentials when sending turn credentials update to avoid removing servers from clients --- management/server/peer_test.go | 5 +++-- management/server/token_mgr.go | 17 ++++++++++++++--- management/server/token_mgr_test.go | 13 ++++++++----- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 5f500c226..0c751e6c4 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -13,13 +13,14 @@ import ( "testing" "time" - "github.com/netbirdio/netbird/management/server/util" "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/management/server/util" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -937,7 +938,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { {"Small single", 50, 10, 90, 120, 90, 120}, {"Medium single", 500, 10, 110, 170, 120, 200}, {"Large 5", 5000, 15, 1300, 2100, 4900, 7000}, - {"Extra Large", 2000, 2000, 1300, 2400, 4000, 6400}, + {"Extra Large", 2000, 2000, 1300, 2400, 3900, 6400}, } log.SetOutput(io.Discard) diff --git a/management/server/token_mgr.go b/management/server/token_mgr.go index ef8276b59..fd67fa3e3 100644 --- a/management/server/token_mgr.go +++ b/management/server/token_mgr.go @@ -158,7 +158,7 @@ func (m *TimeBasedAuthSecretsManager) refreshTURNTokens(ctx context.Context, pee log.WithContext(ctx).Debugf("stopping TURN refresh for %s", peerID) return case <-ticker.C: - m.pushNewTURNTokens(ctx, peerID) + m.pushNewTURNAndRelayTokens(ctx, peerID) } } } @@ -178,7 +178,7 @@ func (m *TimeBasedAuthSecretsManager) refreshRelayTokens(ctx context.Context, pe } } -func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, peerID string) { +func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Context, peerID string) { turnToken, err := m.turnHmacToken.GenerateToken(sha1.New) if err != nil { log.Errorf("failed to generate token for peer '%s': %s", peerID, err) @@ -201,10 +201,21 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNTokens(ctx context.Context, pee update := &proto.SyncResponse{ WiretrusteeConfig: &proto.WiretrusteeConfig{ Turns: turns, - // omit Relay to avoid updates there }, } + // workaround for the case when client is unable to handle turn and relay updates at different time + if m.relayCfg != nil { + token, err := m.GenerateRelayToken() + if err == nil { + update.WiretrusteeConfig.Relay = &proto.RelayConfig{ + Urls: m.relayCfg.Addresses, + TokenPayload: token.Payload, + TokenSignature: token.Signature, + } + } + } + log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) m.updateManager.SendUpdate(ctx, peerID, &UpdateMessage{Update: update}) } diff --git a/management/server/token_mgr_test.go b/management/server/token_mgr_test.go index 3e63346c2..2aafb9f68 100644 --- a/management/server/token_mgr_test.go +++ b/management/server/token_mgr_test.go @@ -133,11 +133,14 @@ loop: } } if relay := update.Update.GetWiretrusteeConfig().GetRelay(); relay != nil { - relayUpdates++ - if relayUpdates == 1 { - firstRelayUpdate = relay - } else { - secondRelayUpdate = relay + // avoid updating on turn updates since they also send relay credentials + if update.Update.GetWiretrusteeConfig().GetTurns() == nil { + relayUpdates++ + if relayUpdates == 1 { + firstRelayUpdate = relay + } else { + secondRelayUpdate = relay + } } } } From 93f3e1b14b6f2b70b535b547a2c4fd308f98b209 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 10 Jan 2025 11:02:05 +0100 Subject: [PATCH 121/159] [client] Prevent local routes in status from being overridden by updates (#3166) --- client/internal/engine.go | 13 +++++++------ client/internal/peer/status.go | 8 +++++++- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/client/internal/engine.go b/client/internal/engine.go index 7b6f269df..b50532b7d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -763,12 +763,13 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { } } - e.statusRecorder.UpdateLocalPeerState(peer.LocalPeerState{ - IP: e.config.WgAddr, - PubKey: e.config.WgPrivateKey.PublicKey().String(), - KernelInterface: device.WireGuardModuleIsLoaded(), - FQDN: conf.GetFqdn(), - }) + state := e.statusRecorder.GetLocalPeerState() + state.IP = e.config.WgAddr + state.PubKey = e.config.WgPrivateKey.PublicKey().String() + state.KernelInterface = device.WireGuardModuleIsLoaded() + state.FQDN = conf.GetFqdn() + + e.statusRecorder.UpdateLocalPeerState(state) return nil } diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index dc461257a..0df2a2e81 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -84,6 +84,12 @@ type LocalPeerState struct { Routes map[string]struct{} } +// Clone returns a copy of the LocalPeerState +func (l LocalPeerState) Clone() LocalPeerState { + l.Routes = maps.Clone(l.Routes) + return l +} + // SignalState contains the latest state of a signal connection type SignalState struct { URL string @@ -501,7 +507,7 @@ func (d *Status) GetPeerStateChangeNotifier(peer string) <-chan struct{} { func (d *Status) GetLocalPeerState() LocalPeerState { d.mux.Lock() defer d.mux.Unlock() - return d.localPeer + return d.localPeer.Clone() } // UpdateLocalPeerState updates local peer status From 2e596fbf1a29545129f4cbf4d8f2972f04ac5e58 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Fri, 10 Jan 2025 13:37:38 +0300 Subject: [PATCH 122/159] use account object to get validated peers Signed-off-by: bcmmbaga --- management/server/peer.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index c09e3d051..9a61b25ce 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1095,9 +1095,9 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account peers := account.GetPeers() - approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err) + log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err) return } From f1ed8599fc62910391e12db04c3db31beebed8ec Mon Sep 17 00:00:00 2001 From: Simon Smith Date: Fri, 10 Jan 2025 17:16:11 +0000 Subject: [PATCH 123/159] [misc] add missing relay to docker-compose.yml.tmpl.traefik (#3163) --- .../docker-compose.yml.tmpl.traefik | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/infrastructure_files/docker-compose.yml.tmpl.traefik b/infrastructure_files/docker-compose.yml.tmpl.traefik index 7d51c4ffb..71471c3ef 100644 --- a/infrastructure_files/docker-compose.yml.tmpl.traefik +++ b/infrastructure_files/docker-compose.yml.tmpl.traefik @@ -50,6 +50,24 @@ services: - traefik.http.services.netbird-signal.loadbalancer.server.port=80 - traefik.http.services.netbird-signal.loadbalancer.server.scheme=h2c + # Relay + relay: + image: netbirdio/relay:$NETBIRD_RELAY_TAG + restart: unless-stopped + environment: + - NB_LOG_LEVEL=info + - NB_LISTEN_ADDRESS=:$NETBIRD_RELAY_PORT + - NB_EXPOSED_ADDRESS=$NETBIRD_RELAY_DOMAIN:$NETBIRD_RELAY_PORT + # todo: change to a secure secret + - NB_AUTH_SECRET=$NETBIRD_RELAY_AUTH_SECRET + ports: + - $NETBIRD_RELAY_PORT:$NETBIRD_RELAY_PORT + logging: + driver: "json-file" + options: + max-size: "500m" + max-file: "2" + # Management management: image: netbirdio/management:$NETBIRD_MANAGEMENT_TAG From f48e33b395d003f8535c1573a80ea556d95076a6 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 10 Jan 2025 18:16:21 +0100 Subject: [PATCH 124/159] [client] Don't fail on v6 ops when disabled via kernel params (#3165) --- .../routemanager/systemops/systemops_linux.go | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index 9041cbf2d..1da92cc80 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -266,7 +266,7 @@ func addRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error { return fmt.Errorf("add gateway and device: %w", err) } - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) { return fmt.Errorf("netlink add route: %w", err) } @@ -289,7 +289,7 @@ func addUnreachableRoute(prefix netip.Prefix, tableID int) error { Dst: ipNet, } - if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RouteAdd(route); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) { return fmt.Errorf("netlink add unreachable route: %w", err) } @@ -312,7 +312,7 @@ func removeUnreachableRoute(prefix netip.Prefix, tableID int) error { if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.ENOENT) && - !errors.Is(err, syscall.EAFNOSUPPORT) { + !isOpErr(err) { return fmt.Errorf("netlink remove unreachable route: %w", err) } @@ -338,7 +338,7 @@ func removeRoute(prefix netip.Prefix, nexthop Nexthop, tableID int) error { return fmt.Errorf("add gateway and device: %w", err) } - if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RouteDel(route); err != nil && !errors.Is(err, syscall.ESRCH) && !isOpErr(err) { return fmt.Errorf("netlink remove route: %w", err) } @@ -362,7 +362,7 @@ func flushRoutes(tableID, family int) error { routes[i].Dst = &net.IPNet{IP: net.IPv6zero, Mask: net.CIDRMask(0, 128)} } } - if err := netlink.RouteDel(&routes[i]); err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RouteDel(&routes[i]); err != nil && !isOpErr(err) { result = multierror.Append(result, fmt.Errorf("failed to delete route %v from table %d: %w", routes[i], tableID, err)) } } @@ -450,7 +450,7 @@ func addRule(params ruleParams) error { rule.Invert = params.invert rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleAdd(rule); err != nil && !errors.Is(err, syscall.EEXIST) && !isOpErr(err) { return fmt.Errorf("add routing rule: %w", err) } @@ -467,7 +467,7 @@ func removeRule(params ruleParams) error { rule.Priority = params.priority rule.SuppressPrefixlen = params.suppressPrefix - if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !errors.Is(err, syscall.EAFNOSUPPORT) { + if err := netlink.RuleDel(rule); err != nil && !errors.Is(err, syscall.ENOENT) && !isOpErr(err) { return fmt.Errorf("remove routing rule: %w", err) } @@ -509,3 +509,13 @@ func hasSeparateRouting() ([]netip.Prefix, error) { } return nil, ErrRoutingIsSeparate } + +func isOpErr(err error) bool { + // EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported + if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) { + log.Debugf("route operation not supported: %v", err) + return true + } + + return false +} From 168ea9560e0f4bf50a78b68c962575fe17dfd626 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Sat, 11 Jan 2025 15:19:30 +0300 Subject: [PATCH 125/159] [Management] Send peer network map when SSH status is toggled (#3172) --- management/server/peer.go | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/management/server/peer.go b/management/server/peer.go index 6c4c2d100..bfa20bae2 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -202,7 +202,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user return nil, err } - if peer.SSHEnabled != update.SSHEnabled { + sshEnabledUpdated := peer.SSHEnabled != update.SSHEnabled + if sshEnabledUpdated { peer.SSHEnabled = update.SSHEnabled event := activity.PeerSSHEnabled if !update.SSHEnabled { @@ -275,6 +276,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if peerLabelUpdated || requiresPeerUpdates { am.UpdateAccountPeers(ctx, accountID) + } else if sshEnabledUpdated { + am.UpdateAccountPeer(ctx, account, peer) } return peer, nil @@ -1064,6 +1067,36 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account wg.Wait() } +// UpdateAccountPeer updates a single peer that belongs to an account. +// Should be called when changes need to be synced to a specific peer only. +func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, account *types.Account, peer *nbpeer.Peer) { + if !am.peersUpdateManager.HasChannel(peer.ID) { + log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) + return + } + + approvedPeersMap, err := am.GetValidatedPeers(account) + if err != nil { + log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peer.ID, err) + return + } + + dnsCache := &DNSConfigCache{} + customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + postureChecks, err := am.getPeerPostureChecks(account, peer.ID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peer.ID, err) + return + } + + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled) + am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) +} + func ConvertSliceToMap(existingLabels []string) map[string]struct{} { labelMap := make(map[string]struct{}, len(existingLabels)) for _, label := range existingLabels { From 1cc88a21901a2ecc1da695b5a1bb81b81bdebcbd Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Sat, 11 Jan 2025 14:08:13 +0100 Subject: [PATCH 126/159] [management] adjust benchmark (#3168) --- management/server/account_test.go | 25 +++++++++++++------------ management/server/peer_test.go | 2 +- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index d8ceef0e7..e4f079507 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -16,6 +16,7 @@ import ( "time" "github.com/golang-jwt/jwt" + "github.com/netbirdio/netbird/management/server/util" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -3021,12 +3022,12 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 1, 3, 3, 14}, - {"Medium", 500, 100, 7, 13, 10, 80}, - {"Large", 5000, 200, 65, 80, 60, 220}, - {"Small single", 50, 10, 1, 3, 3, 70}, - {"Medium single", 500, 10, 7, 13, 10, 32}, - {"Large 5", 5000, 15, 65, 80, 60, 200}, + {"Small", 50, 5, 1, 3, 3, 19}, + {"Medium", 500, 100, 7, 13, 10, 90}, + {"Large", 5000, 200, 65, 80, 60, 240}, + {"Small single", 50, 10, 1, 3, 3, 80}, + {"Medium single", 500, 10, 7, 13, 10, 37}, + {"Large 5", 5000, 15, 65, 80, 60, 220}, } log.SetOutput(io.Discard) @@ -3088,12 +3089,12 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 102, 110, 102, 120}, - {"Medium", 500, 100, 105, 140, 105, 170}, - {"Large", 5000, 200, 160, 200, 160, 300}, - {"Small single", 50, 10, 102, 110, 102, 120}, - {"Medium single", 500, 10, 105, 140, 105, 170}, - {"Large 5", 5000, 15, 160, 200, 160, 270}, + {"Small", 50, 5, 102, 110, 102, 130}, + {"Medium", 500, 100, 105, 140, 105, 190}, + {"Large", 5000, 200, 160, 200, 160, 320}, + {"Small single", 50, 10, 102, 110, 102, 130}, + {"Medium single", 500, 10, 105, 140, 105, 190}, + {"Large 5", 5000, 15, 160, 200, 160, 290}, } log.SetOutput(io.Discard) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 0c751e6c4..2f5d0e047 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -938,7 +938,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { {"Small single", 50, 10, 90, 120, 90, 120}, {"Medium single", 500, 10, 110, 170, 120, 200}, {"Large 5", 5000, 15, 1300, 2100, 4900, 7000}, - {"Extra Large", 2000, 2000, 1300, 2400, 3900, 6400}, + {"Extra Large", 2000, 2000, 1300, 2400, 3800, 6400}, } log.SetOutput(io.Discard) From 3fce8485bb03fdd46df69c70fb8541979a25d096 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Sat, 11 Jan 2025 22:09:29 +0300 Subject: [PATCH 127/159] Enabled new network resource and router by default (#3174) Signed-off-by: bcmmbaga --- management/server/http/handlers/networks/resources_handler.go | 1 + management/server/http/handlers/networks/routers_handler.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go index a0dc9a10d..6499bd652 100644 --- a/management/server/http/handlers/networks/resources_handler.go +++ b/management/server/http/handlers/networks/resources_handler.go @@ -123,6 +123,7 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) resource.NetworkID = mux.Vars(r)["networkId"] resource.AccountID = accountID + resource.Enabled = true resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource) if err != nil { util.WriteError(r.Context(), err, w) diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go index 2cf39a132..7ca95d902 100644 --- a/management/server/http/handlers/networks/routers_handler.go +++ b/management/server/http/handlers/networks/routers_handler.go @@ -85,7 +85,7 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { router.NetworkID = networkID router.AccountID = accountID - + router.Enabled = true router, err = h.routersManager.CreateRouter(r.Context(), userID, router) if err != nil { util.WriteError(r.Context(), err, w) From e161a9289803d523384b83736febbc6522b5f2a6 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Sun, 12 Jan 2025 16:29:25 +0100 Subject: [PATCH 128/159] [client] Update fyne dependency (#3155) --- go.mod | 10 +++++----- go.sum | 24 ++++++++++++------------ 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/go.mod b/go.mod index 9d7cb1fe4..0c6d6be99 100644 --- a/go.mod +++ b/go.mod @@ -30,7 +30,7 @@ require ( ) require ( - fyne.io/fyne/v2 v2.5.0 + fyne.io/fyne/v2 v2.5.3 fyne.io/systray v1.11.0 github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/c-robinson/iplib v1.0.3 @@ -146,7 +146,7 @@ require ( github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fredbi/uri v1.1.0 // indirect github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe // indirect - github.com/fyne-io/glfw-js v0.0.0-20240101223322-6e1efdc71b7a // indirect + github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0 // indirect github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 // indirect github.com/go-gl/gl v0.0.0-20211210172815-726fda9656d6 // indirect github.com/go-gl/glfw/v3.3/glfw v0.0.0-20240506104042-037f3cc74f2a // indirect @@ -155,8 +155,8 @@ require ( github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-redis/redis/v8 v8.11.5 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect - github.com/go-text/render v0.1.0 // indirect - github.com/go-text/typesetting v0.1.0 // indirect + github.com/go-text/render v0.2.0 // indirect + github.com/go-text/typesetting v0.2.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/btree v1.1.2 // indirect @@ -206,7 +206,7 @@ require ( github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.53.0 // indirect github.com/prometheus/procfs v0.15.0 // indirect - github.com/rymdport/portal v0.2.2 // indirect + github.com/rymdport/portal v0.3.0 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/spf13/cast v1.5.0 // indirect github.com/srwiley/oksvg v0.0.0-20221011165216-be6e8873101c // indirect diff --git a/go.sum b/go.sum index 8383475a4..f8b6c208b 100644 --- a/go.sum +++ b/go.sum @@ -50,8 +50,8 @@ dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= -fyne.io/fyne/v2 v2.5.0 h1:lEjEIso0Vi4sJXYngIMoXOM6aUjqnPjK7pBpxRxG9aI= -fyne.io/fyne/v2 v2.5.0/go.mod h1:9D4oT3NWeG+MLi/lP7ItZZyujHC/qqMJpoGTAYX5Uqc= +fyne.io/fyne/v2 v2.5.3 h1:k6LjZx6EzRZhClsuzy6vucLZBstdH2USDGHSGWq8ly8= +fyne.io/fyne/v2 v2.5.3/go.mod h1:0GOXKqyvNwk3DLmsFu9v0oYM0ZcD1ysGnlHCerKoAmo= fyne.io/systray v1.11.0 h1:D9HISlxSkx+jHSniMBR6fCFOUjk1x/OOOJLa9lJYAKg= fyne.io/systray v1.11.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= @@ -204,8 +204,8 @@ github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nos github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe h1:A/wiwvQ0CAjPkuJytaD+SsXkPU0asQ+guQEIg1BJGX4= github.com/fyne-io/gl-js v0.0.0-20220119005834-d2da28d9ccfe/go.mod h1:d4clgH0/GrRwWjRzJJQXxT/h1TyuNSfF/X64zb/3Ggg= -github.com/fyne-io/glfw-js v0.0.0-20240101223322-6e1efdc71b7a h1:ybgRdYvAHTn93HW79bLiBiJwVL4jVeyGQRZMgImoeWs= -github.com/fyne-io/glfw-js v0.0.0-20240101223322-6e1efdc71b7a/go.mod h1:gsGA2dotD4v0SR6PmPCYvS9JuOeMwAtmfvDE7mbYXMY= +github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0 h1:/1YRWFv9bAWkoo3SuxpFfzpXH0D/bQnTjNXyF4ih7Os= +github.com/fyne-io/glfw-js v0.0.0-20241126112943-313d8a0fe1d0/go.mod h1:gsGA2dotD4v0SR6PmPCYvS9JuOeMwAtmfvDE7mbYXMY= github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2 h1:hnLq+55b7Zh7/2IRzWCpiTcAvjv/P8ERF+N7+xXbZhk= github.com/fyne-io/image v0.0.0-20220602074514-4956b0afb3d2/go.mod h1:eO7W361vmlPOrykIg+Rsh1SZ3tQBaOsfzZhsIOb/Lm0= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= @@ -246,12 +246,12 @@ github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqw github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= -github.com/go-text/render v0.1.0 h1:osrmVDZNHuP1RSu3pNG7Z77Sd2xSbcb/xWytAj9kyVs= -github.com/go-text/render v0.1.0/go.mod h1:jqEuNMenrmj6QRnkdpeaP0oKGFLDNhDkVKwGjsWWYU4= -github.com/go-text/typesetting v0.1.0 h1:vioSaLPYcHwPEPLT7gsjCGDCoYSbljxoHJzMnKwVvHw= -github.com/go-text/typesetting v0.1.0/go.mod h1:d22AnmeKq/on0HNv73UFriMKc4Ez6EqZAofLhAzpSzI= -github.com/go-text/typesetting-utils v0.0.0-20240329101916-eee87fb235a3 h1:levTnuLLUmpavLGbJYLJA7fQnKeS7P1eCdAlM+vReXk= -github.com/go-text/typesetting-utils v0.0.0-20240329101916-eee87fb235a3/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o= +github.com/go-text/render v0.2.0 h1:LBYoTmp5jYiJ4NPqDc2pz17MLmA3wHw1dZSVGcOdeAc= +github.com/go-text/render v0.2.0/go.mod h1:CkiqfukRGKJA5vZZISkjSYrcdtgKQWRa2HIzvwNN5SU= +github.com/go-text/typesetting v0.2.0 h1:fbzsgbmk04KiWtE+c3ZD4W2nmCRzBqrqQOvYlwAOdho= +github.com/go-text/typesetting v0.2.0/go.mod h1:2+owI/sxa73XA581LAzVuEBZ3WEEV2pXeDswCH/3i1I= +github.com/go-text/typesetting-utils v0.0.0-20240317173224-1986cbe96c66 h1:GUrm65PQPlhFSKjLPGOZNPNxLCybjzjYBzjfoBGaDUY= +github.com/go-text/typesetting-utils v0.0.0-20240317173224-1986cbe96c66/go.mod h1:DDxDdQEnB70R8owOx3LVpEFvpMK9eeH1o2r0yZhFI9o= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/godbus/dbus/v5 v5.1.0 h1:4KLkAxT3aOY8Li4FRJe/KvhoNFFxo0m6fNuFUO8QJUk= github.com/godbus/dbus/v5 v5.1.0/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= @@ -621,8 +621,8 @@ github.com/rs/xid v1.3.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts= -github.com/rymdport/portal v0.2.2 h1:P2Q/4k673zxdFAsbD8EESZ7psfuO6/4jNu6EDrDICkM= -github.com/rymdport/portal v0.2.2/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= +github.com/rymdport/portal v0.3.0 h1:QRHcwKwx3kY5JTQcsVhmhC3TGqGQb9LFghVNUy8AdB8= +github.com/rymdport/portal v0.3.0/go.mod h1:kFF4jslnJ8pD5uCi17brj/ODlfIidOxlgUDTO5ncnC4= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= github.com/shirou/gopsutil/v3 v3.24.4 h1:dEHgzZXt4LMNm+oYELpzl9YCqV65Yr/6SfrvgRBtXeU= github.com/shirou/gopsutil/v3 v3.24.4/go.mod h1:lTd2mdiOspcqLgAnr9/nGi71NkeMpWKdmhuxm9GusH8= From 8154069e77896c90e9db6b863c4565aebb00f46a Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 13 Jan 2025 10:11:54 +0100 Subject: [PATCH 129/159] [misc] Skip docker step when fork PR (#3175) --- .github/workflows/golang-test-linux.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index da1db5c03..5f7d7b4a3 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -13,7 +13,7 @@ concurrency: jobs: build-cache: runs-on: ubuntu-22.04 - steps: + steps: - name: Checkout code uses: actions/checkout@v4 @@ -183,7 +183,7 @@ jobs: run: git --no-pager diff --exit-code - name: Login to Docker hub - if: matrix.store == 'mysql' + if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref) uses: docker/login-action@v1 with: username: ${{ secrets.DOCKER_USER }} @@ -243,7 +243,7 @@ jobs: run: git --no-pager diff --exit-code - name: Login to Docker hub - if: matrix.store == 'mysql' + if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref) uses: docker/login-action@v1 with: username: ${{ secrets.DOCKER_USER }} @@ -303,7 +303,7 @@ jobs: run: git --no-pager diff --exit-code - name: Login to Docker hub - if: matrix.store == 'mysql' + if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref) uses: docker/login-action@v1 with: username: ${{ secrets.DOCKER_USER }} From 522dd44bfa3f32dcb39ef97aa33a073564b03e02 Mon Sep 17 00:00:00 2001 From: "Krzysztof Nazarewski (kdn)" Date: Mon, 13 Jan 2025 10:15:01 +0100 Subject: [PATCH 130/159] [client] make /var/lib/netbird paths configurable (#3084) - NB_STATE_DIR - NB_UNCLEAN_SHUTDOWN_RESOLV_FILE - NB_DNS_STATE_FILE --- client/configs/configs.go | 24 ++++++++++++++++++++++++ client/internal/dns/consts.go | 18 ++++++++++++++++++ client/internal/dns/consts_freebsd.go | 5 ----- client/internal/dns/consts_linux.go | 7 ------- client/internal/statemanager/path.go | 15 ++++----------- 5 files changed, 46 insertions(+), 23 deletions(-) create mode 100644 client/configs/configs.go create mode 100644 client/internal/dns/consts.go delete mode 100644 client/internal/dns/consts_freebsd.go delete mode 100644 client/internal/dns/consts_linux.go diff --git a/client/configs/configs.go b/client/configs/configs.go new file mode 100644 index 000000000..8f9c3ba28 --- /dev/null +++ b/client/configs/configs.go @@ -0,0 +1,24 @@ +package configs + +import ( + "os" + "path/filepath" + "runtime" +) + +var StateDir string + +func init() { + StateDir = os.Getenv("NB_STATE_DIR") + if StateDir != "" { + return + } + switch runtime.GOOS { + case "windows": + StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird") + case "darwin", "linux": + StateDir = "/var/lib/netbird" + case "freebsd", "openbsd", "netbsd", "dragonfly": + StateDir = "/var/db/netbird" + } +} diff --git a/client/internal/dns/consts.go b/client/internal/dns/consts.go new file mode 100644 index 000000000..b333d0808 --- /dev/null +++ b/client/internal/dns/consts.go @@ -0,0 +1,18 @@ +//go:build !android + +package dns + +import ( + "github.com/netbirdio/netbird/client/configs" + "os" + "path/filepath" +) + +var fileUncleanShutdownResolvConfLocation string + +func init() { + fileUncleanShutdownResolvConfLocation = os.Getenv("NB_UNCLEAN_SHUTDOWN_RESOLV_FILE") + if fileUncleanShutdownResolvConfLocation == "" { + fileUncleanShutdownResolvConfLocation = filepath.Join(configs.StateDir, "resolv.conf") + } +} diff --git a/client/internal/dns/consts_freebsd.go b/client/internal/dns/consts_freebsd.go deleted file mode 100644 index 64c8fe5eb..000000000 --- a/client/internal/dns/consts_freebsd.go +++ /dev/null @@ -1,5 +0,0 @@ -package dns - -const ( - fileUncleanShutdownResolvConfLocation = "/var/db/netbird/resolv.conf" -) diff --git a/client/internal/dns/consts_linux.go b/client/internal/dns/consts_linux.go deleted file mode 100644 index 15614b0c5..000000000 --- a/client/internal/dns/consts_linux.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build !android - -package dns - -const ( - fileUncleanShutdownResolvConfLocation = "/var/lib/netbird/resolv.conf" -) diff --git a/client/internal/statemanager/path.go b/client/internal/statemanager/path.go index 6cfd79a12..d232e5f0c 100644 --- a/client/internal/statemanager/path.go +++ b/client/internal/statemanager/path.go @@ -1,23 +1,16 @@ package statemanager import ( + "github.com/netbirdio/netbird/client/configs" "os" "path/filepath" - "runtime" ) // GetDefaultStatePath returns the path to the state file based on the operating system // It returns an empty string if the path cannot be determined. func GetDefaultStatePath() string { - switch runtime.GOOS { - case "windows": - return filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird", "state.json") - case "darwin", "linux": - return "/var/lib/netbird/state.json" - case "freebsd", "openbsd", "netbsd", "dragonfly": - return "/var/db/netbird/state.json" + if path := os.Getenv("NB_DNS_STATE_FILE"); path != "" { + return path } - - return "" - + return filepath.Join(configs.StateDir, "state.json") } From d1e5d584f7cdf0ce6062a27827770a3243fa91d3 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 13 Jan 2025 16:12:34 +0300 Subject: [PATCH 131/159] Fix merge Signed-off-by: bcmmbaga --- management/server/peer.go | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 280f96a8f..5dfdda857 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -324,7 +324,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user if peerLabelChanged || requiresPeerUpdates { am.UpdateAccountPeers(ctx, accountID) } else if sshChanged { - am.UpdateAccountPeer(ctx, account, peer) + am.UpdateAccountPeer(ctx, accountID, peer.ID) } return peer, nil @@ -1095,8 +1095,6 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account } }() - peers := account.GetPeers() - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err) @@ -1111,7 +1109,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - for _, peer := range peers { + for _, peer := range account.Peers { if !am.peersUpdateManager.HasChannel(peer.ID) { log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) continue @@ -1140,15 +1138,27 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account // UpdateAccountPeer updates a single peer that belongs to an account. // Should be called when changes need to be synced to a specific peer only. -func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, account *types.Account, peer *nbpeer.Peer) { - if !am.peersUpdateManager.HasChannel(peer.ID) { - log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) +func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err) return } - approvedPeersMap, err := am.GetValidatedPeers(account) + peer := account.GetPeer(peerId) + if peer == nil { + log.WithContext(ctx).Tracef("peer %s doesn't exists in account %s", peerId, accountId) + return + } + + if !am.peersUpdateManager.HasChannel(peerId) { + log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId) + return + } + + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) if err != nil { - log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peer.ID, err) + log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err) return } @@ -1157,13 +1167,13 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, account resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - postureChecks, err := am.getPeerPostureChecks(account, peer.ID) + postureChecks, err := am.getPeerPostureChecks(account, peerId) if err != nil { - log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peer.ID, err) + log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to get posture checks: %v", peerId, err) return } - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache, account.Settings.RoutingPeerDNSResolutionEnabled) am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) } From 3cc6d3862d0c41fd9e1a568d98d333519914a28b Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 13 Jan 2025 17:52:39 +0300 Subject: [PATCH 132/159] Improve peer performance Signed-off-by: bcmmbaga --- management/server/peer.go | 23 +++-------------------- management/server/store/sql_store.go | 13 +++++++++++++ management/server/store/store.go | 1 + 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 5dfdda857..b5a77a674 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -953,7 +953,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) if err != nil { return nil, nil, nil, err } @@ -1313,29 +1313,12 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID // GetPeerGroups returns groups that the peer is part of. func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { - return getPeerGroups(ctx, am.Store, accountID, peerID) -} - -// getPeerGroups returns the IDs of the groups that the peer is part of. -func getPeerGroups(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*types.Group, error) { - groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) - if err != nil { - return nil, err - } - - peerGroups := make([]*types.Group, 0) - for _, group := range groups { - if slices.Contains(group.Peers, peerID) { - peerGroups = append(peerGroups, group) - } - } - - return peerGroups, nil + return am.Store.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID) } // getPeerGroupIDs returns the IDs of the groups that the peer is part of. func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { - groups, err := getPeerGroups(ctx, transaction, accountID, peerID) + groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID) if err != nil { return nil, err } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 0344538e2..82934dd11 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -1218,6 +1218,19 @@ func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string return nil } +// GetPeerGroups retrieves all groups assigned to a specific peer in a given account. +func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) { + var groups []*types.Group + query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId)) + + if query.Error != nil { + return nil, query.Error + } + + return groups, nil +} + // GetAccountPeers retrieves peers for an account. func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { var peers []*nbpeer.Peer diff --git a/management/server/store/store.go b/management/server/store/store.go index d7799be41..245df1c3e 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -100,6 +100,7 @@ type Store interface { GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string, groupID string) error + GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) AddResourceToGroup(ctx context.Context, accountId string, groupID string, resource *types.Resource) error RemoveResourceFromGroup(ctx context.Context, accountId string, groupID string, resourceID string) error AddPeerToAccount(ctx context.Context, lockStrength LockingStrength, peer *nbpeer.Peer) error From 48af90c7705696b48ed3b5d935581ca62bd9cdfb Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 13 Jan 2025 19:18:18 +0300 Subject: [PATCH 133/159] Get account direct from store without buffer Signed-off-by: bcmmbaga --- management/server/peer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/peer.go b/management/server/peer.go index b5a77a674..78bb1dd8c 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -948,7 +948,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return peer, emptyMap, nil, nil } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, nil, nil, err } From c603c40a5322c8dbb775b552db811ab7063eb30c Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 13 Jan 2025 20:56:51 +0300 Subject: [PATCH 134/159] Add get peer groups tests Signed-off-by: bcmmbaga --- management/server/store/sql_store_test.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 34578eca8..cb51dab51 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -2635,6 +2635,27 @@ func TestSqlStore_AddPeerToAccount(t *testing.T) { assert.WithinDurationf(t, peer.Status.LastSeen, storedPeer.Status.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal") } +func TestSqlStore_GetPeerGroups(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_policy_migrate.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + peerID := "cfefqs706sqkneg59g4g" + + groups, err := store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID) + require.NoError(t, err) + assert.Len(t, groups, 1) + assert.Equal(t, groups[0].Name, "All") + + err = store.AddPeerToGroup(context.Background(), LockingStrengthUpdate, accountID, peerID, "cfefqs706sqkneg59g4h") + require.NoError(t, err) + + groups, err = store.GetPeerGroups(context.Background(), LockingStrengthShare, accountID, peerID) + require.NoError(t, err) + assert.Len(t, groups, 2) +} + func TestSqlStore_GetAccountPeers(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) t.Cleanup(cleanup) From 7a9c75db91ff6a3ee0d5c4a41972e186b67b38d5 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Mon, 13 Jan 2025 23:19:30 +0300 Subject: [PATCH 135/159] Adjust benchmarks Signed-off-by: bcmmbaga --- management/server/account_test.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index ce87987f3..6c51f896c 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3082,12 +3082,12 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 102, 110, 102, 130}, - {"Medium", 500, 100, 105, 140, 105, 190}, - {"Large", 5000, 200, 160, 200, 160, 320}, - {"Small single", 50, 10, 102, 110, 102, 130}, - {"Medium single", 500, 10, 105, 140, 105, 190}, - {"Large 5", 5000, 15, 160, 200, 160, 290}, + {"Small", 50, 5, 102, 110, 3, 20}, + {"Medium", 500, 100, 105, 140, 20, 90}, + {"Large", 5000, 200, 160, 200, 120, 260}, + {"Small single", 50, 10, 102, 110, 5, 40}, + {"Medium single", 500, 10, 105, 140, 10, 60}, + {"Large 5", 5000, 15, 160, 200, 60, 140}, } log.SetOutput(io.Discard) @@ -3156,12 +3156,12 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 107, 120, 107, 160}, - {"Medium", 500, 100, 105, 140, 105, 220}, - {"Large", 5000, 200, 180, 220, 180, 395}, - {"Small single", 50, 10, 107, 120, 105, 160}, - {"Medium single", 500, 10, 105, 140, 105, 170}, - {"Large 5", 5000, 15, 180, 220, 180, 340}, + {"Small", 50, 5, 107, 120, 10, 50}, + {"Medium", 500, 100, 105, 140, 30, 140}, + {"Large", 5000, 200, 180, 220, 150, 300}, + {"Small single", 50, 10, 107, 120, 10, 60}, + {"Medium single", 500, 10, 105, 140, 20, 60}, + {"Large 5", 5000, 15, 180, 220, 80, 200}, } log.SetOutput(io.Discard) From eb062c07ec76c87e3d948be131559b3e4420c782 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 14 Jan 2025 12:09:52 +0300 Subject: [PATCH 136/159] Adjust benchmarks Signed-off-by: bcmmbaga --- management/server/account_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index 6c51f896c..1fe20bbf5 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3083,11 +3083,11 @@ func BenchmarkLoginPeer_ExistingPeer(b *testing.B) { maxMsPerOpCICD float64 }{ {"Small", 50, 5, 102, 110, 3, 20}, - {"Medium", 500, 100, 105, 140, 20, 90}, + {"Medium", 500, 100, 105, 140, 20, 110}, {"Large", 5000, 200, 160, 200, 120, 260}, {"Small single", 50, 10, 102, 110, 5, 40}, {"Medium single", 500, 10, 105, 140, 10, 60}, - {"Large 5", 5000, 15, 160, 200, 60, 140}, + {"Large 5", 5000, 15, 160, 200, 60, 180}, } log.SetOutput(io.Discard) @@ -3156,10 +3156,10 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { minMsPerOpCICD float64 maxMsPerOpCICD float64 }{ - {"Small", 50, 5, 107, 120, 10, 50}, + {"Small", 50, 5, 107, 120, 10, 80}, {"Medium", 500, 100, 105, 140, 30, 140}, {"Large", 5000, 200, 180, 220, 150, 300}, - {"Small single", 50, 10, 107, 120, 10, 60}, + {"Small single", 50, 10, 107, 120, 10, 80}, {"Medium single", 500, 10, 105, 140, 20, 60}, {"Large 5", 5000, 15, 180, 220, 80, 200}, } From aa0480c5e673eae6b40c1fdf9720f94edc31043f Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 14 Jan 2025 15:14:56 +0100 Subject: [PATCH 137/159] [management] Update benchmark workflow (#3181) --- .github/workflows/golang-test-linux.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 5f7d7b4a3..3c37ced94 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -314,7 +314,7 @@ jobs: run: docker pull mlsmaycon/warmed-mysql:8 - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=benchmark -run=^$ -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list ./... | grep /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list -tags=benchmark ./... | grep /management) api_integration_test: needs: [ build-cache ] @@ -363,7 +363,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -tags=integration $(go list ./... | grep /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list -tags=integration ./... | grep /management) test_client_on_docker: needs: [ build-cache ] From 47a18db18601552d327f1d756a4231968f75ded5 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 14 Jan 2025 18:04:20 +0100 Subject: [PATCH 138/159] update local benchmark expectations --- .github/workflows/golang-test-linux.yml | 2 +- .../peers_handler_benchmark_test.go | 56 ++++++++-------- .../setupkeys_handler_benchmark_test.go | 64 +++++++++---------- .../users_handler_benchmark_test.go | 8 +-- 4 files changed, 65 insertions(+), 65 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 3c37ced94..ab3e8d652 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -363,7 +363,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list -tags=integration ./... | grep /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 15m $(go list -tags=integration ./... | grep /management) test_client_on_docker: needs: [ build-cache ] diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index e76370426..141a759cb 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -35,14 +35,14 @@ var benchCasesPeers = map[string]testing_tools.BenchmarkCase{ func BenchmarkUpdatePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 1300, MaxMsPerOpLocal: 1700, MinMsPerOpCICD: 2200, MaxMsPerOpCICD: 13000}, + "Peers - XS": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 2200, MaxMsPerOpCICD: 13000}, "Peers - S": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 200}, - "Peers - M": {MinMsPerOpLocal: 160, MaxMsPerOpLocal: 190, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 500}, - "Peers - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 430, MinMsPerOpCICD: 450, MaxMsPerOpCICD: 1400}, - "Groups - L": {MinMsPerOpLocal: 1200, MaxMsPerOpLocal: 1500, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 13000}, - "Users - L": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 800, MaxMsPerOpCICD: 2800}, - "Setup Keys - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 700, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 1300}, - "Peers - XL": {MinMsPerOpLocal: 1400, MaxMsPerOpLocal: 1900, MinMsPerOpCICD: 2200, MaxMsPerOpCICD: 5000}, + "Peers - M": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 500}, + "Peers - L": {MinMsPerOpLocal: 230, MaxMsPerOpLocal: 270, MinMsPerOpCICD: 450, MaxMsPerOpCICD: 1400}, + "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 13000}, + "Users - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 800, MaxMsPerOpCICD: 2800}, + "Setup Keys - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 1300}, + "Peers - XL": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 2200, MaxMsPerOpCICD: 5000}, } log.SetOutput(io.Discard) @@ -77,14 +77,14 @@ func BenchmarkUpdatePeer(b *testing.B) { func BenchmarkGetOnePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 900, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 7000}, - "Peers - S": {MinMsPerOpLocal: 3, MaxMsPerOpLocal: 7, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30}, - "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 35, MaxMsPerOpCICD: 80}, - "Peers - L": {MinMsPerOpLocal: 120, MaxMsPerOpLocal: 160, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, - "Groups - L": {MinMsPerOpLocal: 500, MaxMsPerOpLocal: 750, MinMsPerOpCICD: 900, MaxMsPerOpCICD: 6500}, - "Users - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 300, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 600}, - "Setup Keys - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 300, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 600}, - "Peers - XL": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 1500}, + "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 60, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 7000}, + "Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30}, + "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 35, MaxMsPerOpCICD: 80}, + "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, + "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 900, MaxMsPerOpCICD: 6500}, + "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 600}, + "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 600}, + "Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 1500}, } log.SetOutput(io.Discard) @@ -111,13 +111,13 @@ func BenchmarkGetOnePeer(b *testing.B) { func BenchmarkGetAllPeers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 900, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 6000}, - "Peers - S": {MinMsPerOpLocal: 4, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 7, MaxMsPerOpCICD: 30}, + "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 6000}, + "Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 7, MaxMsPerOpCICD: 30}, "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 90}, "Peers - L": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 350}, - "Groups - L": {MinMsPerOpLocal: 5000, MaxMsPerOpLocal: 5500, MinMsPerOpCICD: 7000, MaxMsPerOpCICD: 15000}, - "Users - L": {MinMsPerOpLocal: 250, MaxMsPerOpLocal: 300, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 700}, - "Setup Keys - L": {MinMsPerOpLocal: 250, MaxMsPerOpLocal: 350, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 700}, + "Groups - L": {MinMsPerOpLocal: 4800, MaxMsPerOpLocal: 5300, MinMsPerOpCICD: 7000, MaxMsPerOpCICD: 15000}, + "Users - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 700}, + "Setup Keys - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 700}, "Peers - XL": {MinMsPerOpLocal: 900, MaxMsPerOpLocal: 1300, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 2200}, } @@ -145,14 +145,14 @@ func BenchmarkGetAllPeers(b *testing.B) { func BenchmarkDeletePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 7000}, - "Peers - S": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 210}, - "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 230}, - "Peers - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 210}, - "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 700, MaxMsPerOpCICD: 5500}, - "Users - L": {MinMsPerOpLocal: 170, MaxMsPerOpLocal: 210, MinMsPerOpCICD: 290, MaxMsPerOpCICD: 1700}, - "Setup Keys - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 125, MinMsPerOpCICD: 55, MaxMsPerOpCICD: 280}, - "Peers - XL": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 250}, + "Peers - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 7000}, + "Peers - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 210}, + "Peers - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 230}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 210}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 700, MaxMsPerOpCICD: 5500}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 290, MaxMsPerOpCICD: 1700}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 55, MaxMsPerOpCICD: 280}, + "Peers - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 250}, } log.SetOutput(io.Discard) diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go index bbdb4250b..304499018 100644 --- a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -35,14 +35,14 @@ var benchCasesSetupKeys = map[string]testing_tools.BenchmarkCase{ func BenchmarkCreateSetupKey(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, } log.SetOutput(io.Discard) @@ -81,14 +81,14 @@ func BenchmarkCreateSetupKey(b *testing.B) { func BenchmarkUpdateSetupKey(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, } log.SetOutput(io.Discard) @@ -128,14 +128,14 @@ func BenchmarkUpdateSetupKey(b *testing.B) { func BenchmarkGetOneSetupKey(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, } log.SetOutput(io.Discard) @@ -196,14 +196,14 @@ func BenchmarkGetAllSetupKeys(b *testing.B) { func BenchmarkDeleteSetupKey(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - M": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - XL": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, } log.SetOutput(io.Discard) diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go index b62341995..9ddc23c8c 100644 --- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -119,11 +119,11 @@ func BenchmarkGetOneUser(b *testing.B) { func BenchmarkGetAllUsers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ "Users - XS": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 180}, - "Users - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, + "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, "Users - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 12, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, - "Users - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, - "Peers - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, - "Groups - L": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 140, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 200}, "Users - XL": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 90}, } From 9ff56eae646d85f209133dfdca9b2a8666ffd102 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 14 Jan 2025 18:43:50 +0100 Subject: [PATCH 139/159] update cloud expectations --- .../peers_handler_benchmark_test.go | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index 141a759cb..0ac187aa6 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -35,14 +35,14 @@ var benchCasesPeers = map[string]testing_tools.BenchmarkCase{ func BenchmarkUpdatePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 2200, MaxMsPerOpCICD: 13000}, + "Peers - XS": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 5000}, "Peers - S": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 200}, "Peers - M": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 500}, - "Peers - L": {MinMsPerOpLocal: 230, MaxMsPerOpLocal: 270, MinMsPerOpCICD: 450, MaxMsPerOpCICD: 1400}, - "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 13000}, - "Users - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 800, MaxMsPerOpCICD: 2800}, - "Setup Keys - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 1300}, - "Peers - XL": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 2200, MaxMsPerOpCICD: 5000}, + "Peers - L": {MinMsPerOpLocal: 230, MaxMsPerOpLocal: 270, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 1000}, + "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 700, MaxMsPerOpCICD: 4000}, + "Users - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 350, MaxMsPerOpCICD: 1500}, + "Setup Keys - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 350, MaxMsPerOpCICD: 1100}, + "Peers - XL": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 3000}, } log.SetOutput(io.Discard) @@ -77,13 +77,13 @@ func BenchmarkUpdatePeer(b *testing.B) { func BenchmarkGetOnePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 60, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 7000}, + "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 60, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 300}, "Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30}, - "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 35, MaxMsPerOpCICD: 80}, + "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 50}, "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, - "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 900, MaxMsPerOpCICD: 6500}, - "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 600}, - "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 600}, + "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 800}, + "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, + "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, "Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 1500}, } @@ -111,8 +111,8 @@ func BenchmarkGetOnePeer(b *testing.B) { func BenchmarkGetAllPeers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 6000}, - "Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 7, MaxMsPerOpCICD: 30}, + "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 70, MaxMsPerOpCICD: 2000}, + "Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 30}, "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 90}, "Peers - L": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 350}, "Groups - L": {MinMsPerOpLocal: 4800, MaxMsPerOpLocal: 5300, MinMsPerOpCICD: 7000, MaxMsPerOpCICD: 15000}, @@ -145,14 +145,14 @@ func BenchmarkGetAllPeers(b *testing.B) { func BenchmarkDeletePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 7000}, - "Peers - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 210}, - "Peers - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 230}, - "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 210}, - "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 700, MaxMsPerOpCICD: 5500}, - "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 290, MaxMsPerOpCICD: 1700}, - "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 55, MaxMsPerOpCICD: 280}, - "Peers - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 250}, + "Peers - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, + "Peers - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, + "Peers - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, + "Peers - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, } log.SetOutput(io.Discard) From cd15c85d5b875473aac25b56ce8a8513701bad0a Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 14 Jan 2025 20:54:44 +0300 Subject: [PATCH 140/159] Add status error for generic result error Signed-off-by: bcmmbaga --- management/server/store/sql_store.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 82934dd11..46e896b55 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -332,7 +332,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, a result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to save peer to store: %v", result.Error) } return nil @@ -358,7 +358,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID Where(idQueryCondition, accountID). Updates(&accountCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to update account domain attributes to store: %v", result.Error) } if result.RowsAffected == 0 { @@ -381,7 +381,7 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStren Where(accountAndIDQueryCondition, accountID, peerID). Updates(&peerCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to save peer status to store: %v", result.Error) } if result.RowsAffected == 0 { @@ -403,7 +403,7 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr Updates(peerCopy) if result.Error != nil { - return result.Error + return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error) } if result.RowsAffected == 0 { From ce7385119ab75665f5a39a010ed61a9fc09612c1 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 14 Jan 2025 20:58:43 +0300 Subject: [PATCH 141/159] Use integrated validator direct Signed-off-by: bcmmbaga --- management/server/peer.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 78bb1dd8c..d9a58ed7f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -101,7 +101,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, err } - approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.Groups, account.Peers, account.Settings.Extra) if err != nil { return nil, err } @@ -1057,12 +1057,12 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, err } - approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { return nil, err } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.Groups, account.Peers, account.Settings.Extra) if err != nil { return nil, err } From 34831399031bd97220fe51d434f773babcc39e6f Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 14 Jan 2025 19:17:48 +0100 Subject: [PATCH 142/159] update expectations --- .../peers_handler_benchmark_test.go | 54 +++++++++---------- .../setupkeys_handler_benchmark_test.go | 48 ++++++++--------- .../users_handler_benchmark_test.go | 20 +++---- 3 files changed, 61 insertions(+), 61 deletions(-) diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index 0ac187aa6..aaf8f203a 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -35,14 +35,14 @@ var benchCasesPeers = map[string]testing_tools.BenchmarkCase{ func BenchmarkUpdatePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 5000}, + "Peers - XS": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 2000}, "Peers - S": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 200}, - "Peers - M": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 500}, - "Peers - L": {MinMsPerOpLocal: 230, MaxMsPerOpLocal: 270, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 1000}, - "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 700, MaxMsPerOpCICD: 4000}, - "Users - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 350, MaxMsPerOpCICD: 1500}, - "Setup Keys - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 350, MaxMsPerOpCICD: 1100}, - "Peers - XL": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 3000}, + "Peers - M": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, + "Peers - L": {MinMsPerOpLocal: 230, MaxMsPerOpLocal: 270, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 500}, + "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 100}, + "Users - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 350, MaxMsPerOpCICD: 600}, + "Setup Keys - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 350, MaxMsPerOpCICD: 600}, + "Peers - XL": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 2000}, } log.SetOutput(io.Discard) @@ -77,14 +77,14 @@ func BenchmarkUpdatePeer(b *testing.B) { func BenchmarkGetOnePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 60, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 300}, + "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 60, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 200}, "Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30}, "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 50}, - "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, - "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 800}, - "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, - "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, - "Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 1500}, + "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, + "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 400}, + "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, + "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, + "Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 600}, } log.SetOutput(io.Discard) @@ -111,14 +111,14 @@ func BenchmarkGetOnePeer(b *testing.B) { func BenchmarkGetAllPeers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 70, MaxMsPerOpCICD: 2000}, + "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 150}, "Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 30}, - "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 90}, - "Peers - L": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 350}, + "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 70}, + "Peers - L": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 200}, "Groups - L": {MinMsPerOpLocal: 4800, MaxMsPerOpLocal: 5300, MinMsPerOpCICD: 7000, MaxMsPerOpCICD: 15000}, - "Users - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 700}, - "Setup Keys - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 700}, - "Peers - XL": {MinMsPerOpLocal: 900, MaxMsPerOpLocal: 1300, MinMsPerOpCICD: 1100, MaxMsPerOpCICD: 2200}, + "Users - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400}, + "Setup Keys - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400}, + "Peers - XL": {MinMsPerOpLocal: 900, MaxMsPerOpLocal: 1300, MinMsPerOpCICD: 800, MaxMsPerOpCICD: 1500}, } log.SetOutput(io.Discard) @@ -145,14 +145,14 @@ func BenchmarkGetAllPeers(b *testing.B) { func BenchmarkDeletePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, - "Peers - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, - "Peers - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, - "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, - "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, - "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, - "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, - "Peers - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 30, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 8}, + "Peers - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Peers - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Peers - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, + "Peers - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 4, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 15}, } log.SetOutput(io.Discard) diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go index 304499018..6d3d55141 100644 --- a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -35,14 +35,14 @@ var benchCasesSetupKeys = map[string]testing_tools.BenchmarkCase{ func BenchmarkCreateSetupKey(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, - "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 17}, + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 17}, } log.SetOutput(io.Discard) @@ -81,14 +81,14 @@ func BenchmarkCreateSetupKey(b *testing.B) { func BenchmarkUpdateSetupKey(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, - "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 19}, + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 3, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 19}, } log.SetOutput(io.Discard) @@ -128,14 +128,14 @@ func BenchmarkUpdateSetupKey(b *testing.B) { func BenchmarkGetOneSetupKey(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, - "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 16}, + "Setup Keys - XS": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Setup Keys - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Setup Keys - M": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Setup Keys - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, + "Setup Keys - XL": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 16}, } log.SetOutput(io.Discard) diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go index 9ddc23c8c..573c2200e 100644 --- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -36,13 +36,13 @@ var benchCasesUsers = map[string]testing_tools.BenchmarkCase{ func BenchmarkUpdateUser(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ "Users - XS": {MinMsPerOpLocal: 700, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 7000}, - "Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 6, MaxMsPerOpCICD: 40}, + "Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 4, MaxMsPerOpCICD: 40}, "Users - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200}, - "Users - L": {MinMsPerOpLocal: 60, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 700}, + "Users - L": {MinMsPerOpLocal: 60, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 90, MaxMsPerOpCICD: 700}, "Peers - L": {MinMsPerOpLocal: 300, MaxMsPerOpLocal: 500, MinMsPerOpCICD: 550, MaxMsPerOpCICD: 2000}, "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 750, MaxMsPerOpCICD: 5000}, - "Setup Keys - L": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 1000}, - "Users - XL": {MinMsPerOpLocal: 350, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 700, MaxMsPerOpCICD: 3500}, + "Setup Keys - L": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 1000}, + "Users - XL": {MinMsPerOpLocal: 350, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 3500}, } log.SetOutput(io.Discard) @@ -119,11 +119,11 @@ func BenchmarkGetOneUser(b *testing.B) { func BenchmarkGetAllUsers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ "Users - XS": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 180}, - "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, - "Users - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 12, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, - "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, - "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, - "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 30}, + "Users - S": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, + "Users - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 12, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, + "Users - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, + "Peers - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, + "Groups - L": {MinMsPerOpLocal: 0, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 30}, "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 140, MinMsPerOpCICD: 60, MaxMsPerOpCICD: 200}, "Users - XL": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 90}, } @@ -158,7 +158,7 @@ func BenchmarkDeleteUsers(b *testing.B) { "Users - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 45, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 190}, "Peers - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 1800}, "Groups - L": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 800, MinMsPerOpCICD: 1200, MaxMsPerOpCICD: 7500}, - "Setup Keys - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 55, MaxMsPerOpCICD: 600}, + "Setup Keys - L": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 40, MaxMsPerOpCICD: 600}, "Users - XL": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 400}, } From 2706ede08e00adabf41e31cad52c7cbac15b7d37 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 14 Jan 2025 19:42:48 +0100 Subject: [PATCH 143/159] update expectations --- .../benchmarks/peers_handler_benchmark_test.go | 14 +++++++------- .../benchmarks/setupkeys_handler_benchmark_test.go | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index aaf8f203a..413fbb741 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -35,13 +35,13 @@ var benchCasesPeers = map[string]testing_tools.BenchmarkCase{ func BenchmarkUpdatePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 2000}, + "Peers - XS": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 3500}, "Peers - S": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 80, MaxMsPerOpCICD: 200}, "Peers - M": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 150, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, - "Peers - L": {MinMsPerOpLocal: 230, MaxMsPerOpLocal: 270, MinMsPerOpCICD: 250, MaxMsPerOpCICD: 500}, - "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 100}, - "Users - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 350, MaxMsPerOpCICD: 600}, - "Setup Keys - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 350, MaxMsPerOpCICD: 600}, + "Peers - L": {MinMsPerOpLocal: 230, MaxMsPerOpLocal: 270, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 500}, + "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 3500}, + "Users - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 300, MaxMsPerOpCICD: 600}, + "Setup Keys - L": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 300, MaxMsPerOpCICD: 600}, "Peers - XL": {MinMsPerOpLocal: 600, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 600, MaxMsPerOpCICD: 2000}, } @@ -81,7 +81,7 @@ func BenchmarkGetOnePeer(b *testing.B) { "Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30}, "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 50}, "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, - "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 150, MaxMsPerOpCICD: 400}, + "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, "Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 600}, @@ -114,7 +114,7 @@ func BenchmarkGetAllPeers(b *testing.B) { "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 150}, "Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 30}, "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 70}, - "Peers - L": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 200}, + "Peers - L": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, "Groups - L": {MinMsPerOpLocal: 4800, MaxMsPerOpLocal: 5300, MinMsPerOpCICD: 7000, MaxMsPerOpCICD: 15000}, "Users - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400}, "Setup Keys - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400}, diff --git a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go index 6d3d55141..ed643f75e 100644 --- a/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/setupkeys_handler_benchmark_test.go @@ -162,8 +162,8 @@ func BenchmarkGetOneSetupKey(b *testing.B) { func BenchmarkGetAllSetupKeys(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 12}, - "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 1, MaxMsPerOpCICD: 15}, + "Setup Keys - XS": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 12}, + "Setup Keys - S": {MinMsPerOpLocal: 0.5, MaxMsPerOpLocal: 2, MinMsPerOpCICD: 0, MaxMsPerOpCICD: 15}, "Setup Keys - M": {MinMsPerOpLocal: 5, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 40}, "Setup Keys - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, "Peers - L": {MinMsPerOpLocal: 30, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, From 8b7766e34dab911c57b01fbe9246cbf5c9b5f36b Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 14 Jan 2025 20:10:04 +0100 Subject: [PATCH 144/159] update expectations --- .github/workflows/golang-test-linux.yml | 2 +- .../http/testing/benchmarks/peers_handler_benchmark_test.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index ab3e8d652..04ec59975 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -314,7 +314,7 @@ jobs: run: docker pull mlsmaycon/warmed-mysql:8 - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m $(go list -tags=benchmark ./... | grep /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 15m $(go list -tags=benchmark ./... | grep /management) api_integration_test: needs: [ build-cache ] diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index 413fbb741..5caa0d4f9 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -84,7 +84,7 @@ func BenchmarkGetOnePeer(b *testing.B) { "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, - "Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 600}, + "Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 750}, } log.SetOutput(io.Discard) @@ -118,7 +118,7 @@ func BenchmarkGetAllPeers(b *testing.B) { "Groups - L": {MinMsPerOpLocal: 4800, MaxMsPerOpLocal: 5300, MinMsPerOpCICD: 7000, MaxMsPerOpCICD: 15000}, "Users - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400}, "Setup Keys - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400}, - "Peers - XL": {MinMsPerOpLocal: 900, MaxMsPerOpLocal: 1300, MinMsPerOpCICD: 800, MaxMsPerOpCICD: 1500}, + "Peers - XL": {MinMsPerOpLocal: 900, MaxMsPerOpLocal: 1300, MinMsPerOpCICD: 800, MaxMsPerOpCICD: 2300}, } log.SetOutput(io.Discard) From 84aea32118662027e2cb5daed1f0fe29b3ad81ef Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 14 Jan 2025 22:13:11 +0300 Subject: [PATCH 145/159] Refactor peer scheduler to retry every 3 seconds on errors Signed-off-by: bcmmbaga --- management/server/account.go | 9 +++++---- management/server/peer.go | 27 ++++++++++++++++++--------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 19ff96cfb..9a1ca9866 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -45,6 +45,7 @@ import ( const ( CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days + peerSchedulerRetryInterval = 3 * time.Second emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) @@ -469,7 +470,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc expiredPeers, err := am.getExpiredPeers(ctx, accountID) if err != nil { - return 0, false + return peerSchedulerRetryInterval, true } var peerIDs []string @@ -481,7 +482,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil { log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID) - return 0, false + return peerSchedulerRetryInterval, true } return am.getNextPeerExpiration(ctx, accountID) @@ -504,7 +505,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context inactivePeers, err := am.getInactivePeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID) - return 0, false + return peerSchedulerRetryInterval, true } var peerIDs []string @@ -516,7 +517,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil { log.Errorf("failed updating account peers while expiring peers for account %s", accountID) - return 0, false + return peerSchedulerRetryInterval, true } return am.getNextInactivePeerExpiration(ctx, accountID) diff --git a/management/server/peer.go b/management/server/peer.go index d9a58ed7f..9921c0c9b 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -335,6 +335,15 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() + user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) + if err != nil { + return err + } + + if user.AccountID != accountID { + return status.NewUserNotPartOfAccountError() + } + peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) if err != nil { return err @@ -1139,6 +1148,11 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account // UpdateAccountPeer updates a single peer that belongs to an account. // Should be called when changes need to be synced to a specific peer only. func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) { + if !am.peersUpdateManager.HasChannel(peerId) { + log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId) + return + } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err) @@ -1151,11 +1165,6 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - if !am.peersUpdateManager.HasChannel(peerId) { - log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId) - return - } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) if err != nil { log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err) @@ -1185,7 +1194,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) - return 0, false + return peerSchedulerRetryInterval, true } if len(peersWithExpiry) == 0 { @@ -1195,7 +1204,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get account settings: %v", err) - return 0, false + return peerSchedulerRetryInterval, true } var nextExpiry *time.Duration @@ -1229,7 +1238,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) - return 0, false + return peerSchedulerRetryInterval, true } if len(peersWithInactivity) == 0 { @@ -1239,7 +1248,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get account settings: %v", err) - return 0, false + return peerSchedulerRetryInterval, true } var nextExpiry *time.Duration From 29ea44b874d63cb897d8e79d4c722154b7484eb9 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 14 Jan 2025 20:44:06 +0100 Subject: [PATCH 146/159] update expectations --- management/server/account_test.go | 2 +- .../benchmarks/users_handler_benchmark_test.go | 10 +++++----- management/server/integrated_validator.go | 14 ++------------ management/server/peer_test.go | 2 +- 4 files changed, 9 insertions(+), 19 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index 1fe20bbf5..57bc0c757 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3158,7 +3158,7 @@ func BenchmarkLoginPeer_NewPeer(b *testing.B) { }{ {"Small", 50, 5, 107, 120, 10, 80}, {"Medium", 500, 100, 105, 140, 30, 140}, - {"Large", 5000, 200, 180, 220, 150, 300}, + {"Large", 5000, 200, 180, 220, 140, 300}, {"Small single", 50, 10, 107, 120, 10, 80}, {"Medium single", 500, 10, 105, 140, 20, 60}, {"Large 5", 5000, 15, 180, 220, 80, 200}, diff --git a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go index 573c2200e..549a51c0e 100644 --- a/management/server/http/testing/benchmarks/users_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/users_handler_benchmark_test.go @@ -35,11 +35,11 @@ var benchCasesUsers = map[string]testing_tools.BenchmarkCase{ func BenchmarkUpdateUser(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Users - XS": {MinMsPerOpLocal: 700, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 7000}, - "Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 4, MaxMsPerOpCICD: 40}, - "Users - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200}, + "Users - XS": {MinMsPerOpLocal: 700, MaxMsPerOpLocal: 1000, MinMsPerOpCICD: 1300, MaxMsPerOpCICD: 8000}, + "Users - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 4, MaxMsPerOpCICD: 50}, + "Users - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 250}, "Users - L": {MinMsPerOpLocal: 60, MaxMsPerOpLocal: 100, MinMsPerOpCICD: 90, MaxMsPerOpCICD: 700}, - "Peers - L": {MinMsPerOpLocal: 300, MaxMsPerOpLocal: 500, MinMsPerOpCICD: 550, MaxMsPerOpCICD: 2000}, + "Peers - L": {MinMsPerOpLocal: 300, MaxMsPerOpLocal: 500, MinMsPerOpCICD: 550, MaxMsPerOpCICD: 2400}, "Groups - L": {MinMsPerOpLocal: 400, MaxMsPerOpLocal: 600, MinMsPerOpCICD: 750, MaxMsPerOpCICD: 5000}, "Setup Keys - L": {MinMsPerOpLocal: 50, MaxMsPerOpLocal: 200, MinMsPerOpCICD: 130, MaxMsPerOpCICD: 1000}, "Users - XL": {MinMsPerOpLocal: 350, MaxMsPerOpLocal: 550, MinMsPerOpCICD: 650, MaxMsPerOpCICD: 3500}, @@ -152,7 +152,7 @@ func BenchmarkGetAllUsers(b *testing.B) { func BenchmarkDeleteUsers(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Users - XS": {MinMsPerOpLocal: 1000, MaxMsPerOpLocal: 1600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 10000}, + "Users - XS": {MinMsPerOpLocal: 1000, MaxMsPerOpLocal: 1600, MinMsPerOpCICD: 1900, MaxMsPerOpCICD: 11000}, "Users - S": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 40, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200}, "Users - M": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 70, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 230}, "Users - L": {MinMsPerOpLocal: 15, MaxMsPerOpLocal: 45, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 190}, diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 9dad6fcd7..79063d533 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -100,17 +100,7 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI return nil, err } - groupsMap := make(map[string]*types.Group, len(groups)) - for _, group := range groups { - groupsMap[group.ID] = group - } - - peersMap := make(map[string]*nbpeer.Peer, len(peers)) - for _, peer := range peers { - peersMap[peer.ID] = peer - } - - return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra) + return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra) } type MocIntegratedValidator struct { @@ -127,7 +117,7 @@ func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.P } return update, false, nil } -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for _, peer := range peers { validatedPeers[peer.ID] = struct{}{} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 2f5d0e047..bf712f38a 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -938,7 +938,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { {"Small single", 50, 10, 90, 120, 90, 120}, {"Medium single", 500, 10, 110, 170, 120, 200}, {"Large 5", 5000, 15, 1300, 2100, 4900, 7000}, - {"Extra Large", 2000, 2000, 1300, 2400, 3800, 6400}, + {"Extra Large", 2000, 2000, 1300, 2400, 3000, 6400}, } log.SetOutput(io.Discard) From 61b38e56e4fae8b885c0e9d32f33b25d24058af5 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 14 Jan 2025 20:51:07 +0100 Subject: [PATCH 147/159] fix validator --- management/server/integrated_validator.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 79063d533..79f0b698d 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -100,7 +100,18 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI return nil, err } - return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra) + groupsMap := make(map[string]*types.Group, len(groups)) + + for _, group := range groups { + groupsMap[group.ID] = group + } + + peersMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra) } type MocIntegratedValidator struct { From a5731fe5096f5a48e17f1df34a5256eac991c753 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 14 Jan 2025 20:56:19 +0100 Subject: [PATCH 148/159] fix validator --- management/server/integrated_validator.go | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 79f0b698d..9e347ddac 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -100,18 +100,7 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI return nil, err } - groupsMap := make(map[string]*types.Group, len(groups)) - - for _, group := range groups { - groupsMap[group.ID] = group - } - - peersMap := make(map[string]*nbpeer.Peer, len(peers)) - for _, peer := range peers { - peersMap[peer.ID] = peer - } - - return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra) + return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra) } type MocIntegratedValidator struct { @@ -128,7 +117,7 @@ func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.P } return update, false, nil } -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for _, peer := range peers { validatedPeers[peer.ID] = struct{}{} From 44f69c70e0a112d53a964f800752852791b23161 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 14 Jan 2025 21:01:15 +0100 Subject: [PATCH 149/159] fix validator --- management/server/integrated_validator.go | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 9e347ddac..a7bc8523b 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -100,7 +100,18 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI return nil, err } - return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra) + groupsMap := make(map[string]*types.Group, len(groups)) + + for _, group := range groups { + groupsMap[group.ID] = group + } + + peersMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra) } type MocIntegratedValidator struct { @@ -117,6 +128,7 @@ func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.P } return update, false, nil } + func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for _, peer := range peers { From b1e8ed889efecbff22671ce7b1e4f07944b8181e Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 14 Jan 2025 21:28:48 +0100 Subject: [PATCH 150/159] update timeouts --- .github/workflows/golang-test-linux.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 04ec59975..f1e7c299d 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -314,7 +314,7 @@ jobs: run: docker pull mlsmaycon/warmed-mysql:8 - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 15m $(go list -tags=benchmark ./... | grep /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -run=^$ -tags=benchmark -bench=. -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m $(go list -tags=benchmark ./... | grep /management) api_integration_test: needs: [ build-cache ] @@ -363,7 +363,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 15m $(go list -tags=integration ./... | grep /management) + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=integration -p 1 -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 20m $(go list -tags=integration ./... | grep /management) test_client_on_docker: needs: [ build-cache ] From b15ee5c07c775f3e587cca04c292a08aaa772dca Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 15 Jan 2025 00:09:10 +0300 Subject: [PATCH 151/159] Refactor ToGroupsInfo to process slices of groups Signed-off-by: bcmmbaga --- management/server/groups/manager.go | 22 +++++++++++++++---- .../server/http/handlers/networks/handler.go | 4 ++-- .../http/handlers/peers/peers_handler.go | 22 ++++--------------- 3 files changed, 24 insertions(+), 24 deletions(-) diff --git a/management/server/groups/manager.go b/management/server/groups/manager.go index f5abb212e..02b669e41 100644 --- a/management/server/groups/manager.go +++ b/management/server/groups/manager.go @@ -13,7 +13,8 @@ import ( ) type Manager interface { - GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) + GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error) @@ -37,7 +38,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, accou } } -func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { +func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Read) if err != nil { return nil, err @@ -51,6 +52,15 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string return nil, fmt.Errorf("error getting account groups: %w", err) } + return groups, nil +} + +func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { + groups, err := m.GetAllGroups(ctx, accountID, userID) + if err != nil { + return nil, err + } + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group @@ -130,7 +140,7 @@ func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transa return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID) } -func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum { +func ToGroupsInfo(groups []*types.Group, id string) []api.GroupMinimum { groupsInfo := []api.GroupMinimum{} groupsChecked := make(map[string]struct{}) for _, group := range groups { @@ -167,7 +177,11 @@ func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum return groupsInfo } -func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { +func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + return []*types.Group{}, nil +} + +func (m *mockManager) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { return map[string]*types.Group{}, nil } diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index 6b36a8fce..316b93611 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -82,7 +82,7 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { return } - groups, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -267,7 +267,7 @@ func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, ne return nil, nil, 0, fmt.Errorf("failed to get routers in network: %w", err) } - groups, err := h.groupsManager.GetAllGroups(ctx, accountID, userID) + groups, err := h.groupsManager.GetAllGroupsMap(ctx, accountID, userID) if err != nil { return nil, nil, 0, fmt.Errorf("failed to get groups: %w", err) } diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 09d63ea6f..76a0149c6 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -72,13 +72,8 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } dnsDomain := h.accountManager.GetDNSDomain() - groupsMap := map[string]*types.Group{} - grps, _ := h.accountManager.GetAllGroups(ctx, accountID, userID) - for _, group := range grps { - groupsMap[group.ID] = group - } - - groupsInfo := groups.ToGroupsInfo(groupsMap, peerID) + grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) + groupsInfo := groups.ToGroupsInfo(grps, peerID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -128,12 +123,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri return } - groupsMap := map[string]*types.Group{} - for _, group := range peerGroups { - groupsMap[group.ID] = group - } - - groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID) + groupMinimumInfo := groups.ToGroupsInfo(peerGroups, peer.ID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -204,11 +194,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - groupsMap := map[string]*types.Group{} grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) - for _, group := range grps { - groupsMap[group.ID] = group - } respBody := make([]*api.PeerBatch, 0, len(peers)) for _, peer := range peers { @@ -217,7 +203,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID) + groupMinimumInfo := groups.ToGroupsInfo(grps, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) } From 30b023d1262a311e70ed10c4422b81157357b862 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 14 Jan 2025 23:28:06 +0100 Subject: [PATCH 152/159] update expectations --- .../http/testing/benchmarks/peers_handler_benchmark_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index 5caa0d4f9..45a8966da 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -77,11 +77,11 @@ func BenchmarkUpdatePeer(b *testing.B) { func BenchmarkGetOnePeer(b *testing.B) { var expectedMetrics = map[string]testing_tools.PerformanceMetrics{ - "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 60, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 200}, + "Peers - XS": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 60, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 70}, "Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30}, "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 50}, "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, - "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, + "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 100}, "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, "Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 750}, @@ -115,7 +115,7 @@ func BenchmarkGetAllPeers(b *testing.B) { "Peers - S": {MinMsPerOpLocal: 2, MaxMsPerOpLocal: 10, MinMsPerOpCICD: 5, MaxMsPerOpCICD: 30}, "Peers - M": {MinMsPerOpLocal: 20, MaxMsPerOpLocal: 50, MinMsPerOpCICD: 20, MaxMsPerOpCICD: 70}, "Peers - L": {MinMsPerOpLocal: 130, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 300}, - "Groups - L": {MinMsPerOpLocal: 4800, MaxMsPerOpLocal: 5300, MinMsPerOpCICD: 7000, MaxMsPerOpCICD: 15000}, + "Groups - L": {MinMsPerOpLocal: 4800, MaxMsPerOpLocal: 5300, MinMsPerOpCICD: 5000, MaxMsPerOpCICD: 8000}, "Users - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400}, "Setup Keys - L": {MinMsPerOpLocal: 100, MaxMsPerOpLocal: 170, MinMsPerOpCICD: 100, MaxMsPerOpCICD: 400}, "Peers - XL": {MinMsPerOpLocal: 900, MaxMsPerOpLocal: 1300, MinMsPerOpCICD: 800, MaxMsPerOpCICD: 2300}, From 167c80da4035d60340a0eebe431e50d2d73bf99e Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Tue, 14 Jan 2025 23:38:23 +0100 Subject: [PATCH 153/159] update expectations --- .../http/testing/benchmarks/peers_handler_benchmark_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index 45a8966da..901fe43ab 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -81,7 +81,7 @@ func BenchmarkGetOnePeer(b *testing.B) { "Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30}, "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 50}, "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, - "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 100}, + "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, "Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 750}, From e27db948aee482d64ddceef78e861f4e0890cfbf Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Wed, 15 Jan 2025 00:19:42 +0100 Subject: [PATCH 154/159] update expectations --- .../http/testing/benchmarks/peers_handler_benchmark_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go index 901fe43ab..a4098f5d4 100644 --- a/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go +++ b/management/server/http/testing/benchmarks/peers_handler_benchmark_test.go @@ -81,7 +81,7 @@ func BenchmarkGetOnePeer(b *testing.B) { "Peers - S": {MinMsPerOpLocal: 1, MaxMsPerOpLocal: 5, MinMsPerOpCICD: 2, MaxMsPerOpCICD: 30}, "Peers - M": {MinMsPerOpLocal: 9, MaxMsPerOpLocal: 18, MinMsPerOpCICD: 15, MaxMsPerOpCICD: 50}, "Peers - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, - "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 150}, + "Groups - L": {MinMsPerOpLocal: 80, MaxMsPerOpLocal: 130, MinMsPerOpCICD: 30, MaxMsPerOpCICD: 200}, "Users - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, "Setup Keys - L": {MinMsPerOpLocal: 40, MaxMsPerOpLocal: 90, MinMsPerOpCICD: 50, MaxMsPerOpCICD: 130}, "Peers - XL": {MinMsPerOpLocal: 200, MaxMsPerOpLocal: 400, MinMsPerOpCICD: 200, MaxMsPerOpCICD: 750}, From 6a81ca2e52627ec8859c46731372a197ce28918d Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 15 Jan 2025 11:44:03 +0300 Subject: [PATCH 155/159] Bump integrations version Signed-off-by: bcmmbaga --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 0c6d6be99..74b160b50 100644 --- a/go.mod +++ b/go.mod @@ -60,7 +60,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 + github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index f8b6c208b..77c6b4895 100644 --- a/go.sum +++ b/go.sum @@ -528,6 +528,8 @@ github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6R github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 h1:M+UPn/o+plVE7ZehgL6/1dftptsO1tyTPssgImgi+28= github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480/go.mod h1:RC0PnyATSBPrRWKQgb+7KcC1tMta9eYyzuA414RG9wQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6 h1:I/ODkZ8rSDOzlJbhEjD2luSI71zl+s5JgNvFHY0+mBU= +github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6/go.mod h1:izUUs1NT7ja+PwSX3kJ7ox8Kkn478tboBJSjL4kU6J0= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= From 2fb399aa7da0a29ffcaf826b4eece07905491a13 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 15 Jan 2025 11:49:40 +0300 Subject: [PATCH 156/159] Refactor GetValidatedPeers Signed-off-by: bcmmbaga --- management/server/integrated_validator.go | 15 ++------------- .../server/integrated_validator/interface.go | 2 +- management/server/peer.go | 12 ++++++------ management/server/store/sql_store.go | 4 ++-- management/server/types/account.go | 4 ++-- 5 files changed, 13 insertions(+), 24 deletions(-) diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index a7bc8523b..dcb9400f6 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -100,18 +100,7 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI return nil, err } - groupsMap := make(map[string]*types.Group, len(groups)) - - for _, group := range groups { - groupsMap[group.ID] = group - } - - peersMap := make(map[string]*nbpeer.Peer, len(peers)) - for _, peer := range peers { - peersMap[peer.ID] = peer - } - - return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra) + return am.integratedPeerValidator.GetValidatedPeers(accountID, groups, peers, settings.Extra) } type MocIntegratedValidator struct { @@ -129,7 +118,7 @@ func (a MocIntegratedValidator) ValidatePeer(_ context.Context, update *nbpeer.P return update, false, nil } -func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { +func (a MocIntegratedValidator) GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) { validatedPeers := make(map[string]struct{}) for _, peer := range peers { validatedPeers[peer.ID] = struct{}{} diff --git a/management/server/integrated_validator/interface.go b/management/server/integrated_validator/interface.go index 22b8026aa..ff179e3c0 100644 --- a/management/server/integrated_validator/interface.go +++ b/management/server/integrated_validator/interface.go @@ -14,7 +14,7 @@ type IntegratedValidator interface { ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *account.ExtraSettings) (*nbpeer.Peer, bool, error) PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *account.ExtraSettings) (bool, bool, error) - GetValidatedPeers(accountID string, groups map[string]*types.Group, peers map[string]*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) + GetValidatedPeers(accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *account.ExtraSettings) (map[string]struct{}, error) PeerDeleted(ctx context.Context, accountID, peerID string) error SetPeerInvalidationListener(fn func(accountID string)) Stop(ctx context.Context) diff --git a/management/server/peer.go b/management/server/peer.go index 9921c0c9b..c7f1830de 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -101,7 +101,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, err } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.Groups, account.Peers, account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.GroupsG, account.PeersG, account.Settings.Extra) if err != nil { return nil, err } @@ -404,7 +404,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin groups[groupID] = group.Peers } - validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) + validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.GroupsG, account.PeersG, account.Settings.Extra) if err != nil { return nil, err } @@ -962,7 +962,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.GroupsG, account.PeersG, account.Settings.Extra) if err != nil { return nil, nil, nil, err } @@ -1071,7 +1071,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, err } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.Groups, account.Peers, account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.GroupsG, account.PeersG, account.Settings.Extra) if err != nil { return nil, err } @@ -1104,7 +1104,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account } }() - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.GroupsG, account.PeersG, account.Settings.Extra) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err) return @@ -1165,7 +1165,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.GroupsG, account.PeersG, account.Settings.Extra) if err != nil { log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err) return diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 46e896b55..29f1cde66 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -221,7 +221,7 @@ func generateAccountSQLTypes(account *types.Account) { for id, peer := range account.Peers { peer.ID = id - account.PeersG = append(account.PeersG, *peer) + account.PeersG = append(account.PeersG, peer) } for id, user := range account.Users { @@ -235,7 +235,7 @@ func generateAccountSQLTypes(account *types.Account) { for id, group := range account.Groups { group.ID = id - account.GroupsG = append(account.GroupsG, *group) + account.GroupsG = append(account.GroupsG, group) } for id, route := range account.Routes { diff --git a/management/server/types/account.go b/management/server/types/account.go index f74d38cb6..8e5d303ce 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -54,11 +54,11 @@ type Account struct { SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` Network *Network `gorm:"embedded;embeddedPrefix:network_"` Peers map[string]*nbpeer.Peer `gorm:"-"` - PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` + PeersG []*nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` Users map[string]*User `gorm:"-"` UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` Groups map[string]*Group `gorm:"-"` - GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Routes map[route.ID]*route.Route `gorm:"-"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"` From 22798ada2533eca6c2afc97604721bc5f7f095e0 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 15 Jan 2025 11:51:08 +0300 Subject: [PATCH 157/159] Fix tests Signed-off-by: bcmmbaga --- management/server/migration/migration_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index a645ae325..4b4aa1184 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -147,7 +147,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { err = db.Save(&types.Account{ Id: "1234", - PeersG: []nbpeer.Peer{ + PeersG: []*nbpeer.Peer{ {Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}}, }}, ).Error From 6367fe4f7402168f349ca9bc2f36dfd13d552e88 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 15 Jan 2025 11:52:50 +0300 Subject: [PATCH 158/159] go mod tidy Signed-off-by: bcmmbaga --- go.sum | 2 -- 1 file changed, 2 deletions(-) diff --git a/go.sum b/go.sum index 77c6b4895..f6d4590ee 100644 --- a/go.sum +++ b/go.sum @@ -526,8 +526,6 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e h1:PURA50S8u4mF6RrkYYCAvvPCixhqqEiEy3Ej6avh04c= github.com/netbirdio/ice/v3 v3.0.0-20240315174635-e72a50fcb64e/go.mod h1:YMLU7qbKfVjmEv7EoZPIVEI+kNYxWCdPK3VS0BU+U4Q= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480 h1:M+UPn/o+plVE7ZehgL6/1dftptsO1tyTPssgImgi+28= -github.com/netbirdio/management-integrations/integrations v0.0.0-20241211172827-ba0a446be480/go.mod h1:RC0PnyATSBPrRWKQgb+7KcC1tMta9eYyzuA414RG9wQ= github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6 h1:I/ODkZ8rSDOzlJbhEjD2luSI71zl+s5JgNvFHY0+mBU= github.com/netbirdio/management-integrations/integrations v0.0.0-20250115083837-a09722b8d2a6/go.mod h1:izUUs1NT7ja+PwSX3kJ7ox8Kkn478tboBJSjL4kU6J0= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= From 4afaabb33c60ce9c1442cb2aacaf3a955af9cf67 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Wed, 15 Jan 2025 12:50:25 +0300 Subject: [PATCH 159/159] Use peers and groups map for peers validation Signed-off-by: bcmmbaga --- management/server/migration/migration_test.go | 2 +- management/server/peer.go | 12 ++++++------ management/server/store/sql_store.go | 4 ++-- management/server/types/account.go | 4 ++-- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/management/server/migration/migration_test.go b/management/server/migration/migration_test.go index 4b4aa1184..a645ae325 100644 --- a/management/server/migration/migration_test.go +++ b/management/server/migration/migration_test.go @@ -147,7 +147,7 @@ func TestMigrateNetIPFieldFromBlobToJSON_WithJSONData(t *testing.T) { err = db.Save(&types.Account{ Id: "1234", - PeersG: []*nbpeer.Peer{ + PeersG: []nbpeer.Peer{ {Location: nbpeer.Location{ConnectionIP: net.IP{10, 0, 0, 1}}}, }}, ).Error diff --git a/management/server/peer.go b/management/server/peer.go index c7f1830de..3ddb0a22d 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -101,7 +101,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return nil, err } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.GroupsG, account.PeersG, account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { return nil, err } @@ -404,7 +404,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin groups[groupID] = group.Peers } - validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.GroupsG, account.PeersG, account.Settings.Extra) + validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { return nil, err } @@ -962,7 +962,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.GroupsG, account.PeersG, account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { return nil, nil, nil, err } @@ -1071,7 +1071,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, return nil, err } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.GroupsG, account.PeersG, account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { return nil, err } @@ -1104,7 +1104,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account } }() - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.GroupsG, account.PeersG, account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get validate peers: %v", err) return @@ -1165,7 +1165,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.GroupsG, account.PeersG, account.Settings.Extra) + approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) if err != nil { log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err) return diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 29f1cde66..46e896b55 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -221,7 +221,7 @@ func generateAccountSQLTypes(account *types.Account) { for id, peer := range account.Peers { peer.ID = id - account.PeersG = append(account.PeersG, peer) + account.PeersG = append(account.PeersG, *peer) } for id, user := range account.Users { @@ -235,7 +235,7 @@ func generateAccountSQLTypes(account *types.Account) { for id, group := range account.Groups { group.ID = id - account.GroupsG = append(account.GroupsG, group) + account.GroupsG = append(account.GroupsG, *group) } for id, route := range account.Routes { diff --git a/management/server/types/account.go b/management/server/types/account.go index 8e5d303ce..f74d38cb6 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -54,11 +54,11 @@ type Account struct { SetupKeysG []SetupKey `json:"-" gorm:"foreignKey:AccountID;references:id"` Network *Network `gorm:"embedded;embeddedPrefix:network_"` Peers map[string]*nbpeer.Peer `gorm:"-"` - PeersG []*nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` + PeersG []nbpeer.Peer `json:"-" gorm:"foreignKey:AccountID;references:id"` Users map[string]*User `gorm:"-"` UsersG []User `json:"-" gorm:"foreignKey:AccountID;references:id"` Groups map[string]*Group `gorm:"-"` - GroupsG []*Group `json:"-" gorm:"foreignKey:AccountID;references:id"` + GroupsG []Group `json:"-" gorm:"foreignKey:AccountID;references:id"` Policies []*Policy `gorm:"foreignKey:AccountID;references:id"` Routes map[route.ID]*route.Route `gorm:"-"` RoutesG []route.Route `json:"-" gorm:"foreignKey:AccountID;references:id"`