From 1123729c1c6bf78e79a29bade1962c93766281e1 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Tue, 15 Oct 2024 18:17:47 +0300 Subject: [PATCH] fix merge Signed-off-by: bcmmbaga --- management/server/account.go | 132 +++--- management/server/account_test.go | 4 +- management/server/group.go | 20 +- management/server/nameserver.go | 8 +- management/server/peer.go | 143 ++++++- management/server/personal_access_token.go | 3 +- management/server/policy.go | 2 +- management/server/posture_checks.go | 8 +- management/server/route.go | 103 +++-- management/server/setupkey.go | 4 +- management/server/sql_store.go | 442 +++++++++++++++++++-- management/server/sql_store_test.go | 2 +- management/server/store.go | 39 +- management/server/user.go | 83 +--- 14 files changed, 766 insertions(+), 227 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 7c84ad1ca..33e7fc11c 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -45,15 +45,15 @@ import ( ) const ( - PublicCategory = "public" - PrivateCategory = "private" - UnknownCategory = "unknown" - CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days - CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days - DefaultPeerLoginExpiration = 24 * time.Hour + PublicCategory = "public" + PrivateCategory = "private" + UnknownCategory = "unknown" + CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days + CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days + DefaultPeerLoginExpiration = 24 * time.Hour DefaultPeerInactivityExpiration = 10 * time.Minute - emptyUserID = "empty user ID in claims" - errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" + emptyUserID = "empty user ID in claims" + errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" ) type userLoggedInOnce bool @@ -134,14 +134,14 @@ type AccountManager interface { GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) + UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) HasConnectedChannel(peerID string) bool GetExternalCacheManager() ExternalCacheManager GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error + SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) GetIdpManager() idp.Manager @@ -1122,7 +1122,16 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. // Returns an updated Account -func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) { +func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if !user.HasAdminPower() || user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") + } + halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -1132,78 +1141,89 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() + var oldSettings *Settings - account, err := am.Store.GetAccount(ctx, accountID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + oldSettings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return fmt.Errorf("failed to get account settings: %w", err) + } + + if err = am.validateExtraSettings(ctx, newSettings, oldSettings, userID, accountID); err != nil { + return fmt.Errorf("failed to validate extra settings: %w", err) + } + + if err = am.Store.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil { + return fmt.Errorf("failed to update account settings: %w", err) + } + + return nil + }) if err != nil { return nil, err } - user, err := account.FindUser(userID) + am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) + am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) + + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) if err != nil { - return nil, err + return nil, fmt.Errorf("error getting account: %w", err) } + am.updateAccountPeers(ctx, account) - if !user.HasAdminPower() { - return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") - } + return newSettings, nil +} - err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) +// validateExtraSettings validates the extra settings of the account. +func (am *DefaultAccountManager) validateExtraSettings(ctx context.Context, newSettings, oldSettings *Settings, userID, accountID string) error { + peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID) if err != nil { - return nil, err + return err } - oldSettings := account.Settings + peerMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, peer := range peers { + peerMap[peer.ID] = peer + } + + return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peerMap, userID, accountID) +} + +func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) { if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled { event := activity.AccountPeerLoginExpirationEnabled if !newSettings.PeerLoginExpirationEnabled { event = activity.AccountPeerLoginExpirationDisabled am.peerLoginExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerLoginExpiration(ctx, account) + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) - am.checkAndSchedulePeerLoginExpiration(ctx, account) + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } - - err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID) - if err != nil { - return nil, err - } - - updatedAccount := account.UpdateSettings(newSettings) - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - - return updatedAccount, nil } -func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) { if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { event := activity.AccountPeerInactivityExpirationEnabled if !newSettings.PeerInactivityExpirationEnabled { event = activity.AccountPeerInactivityExpirationDisabled am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) } else { - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } am.StoreEvent(ctx, userID, accountID, accountID, event, nil) } if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } - - return nil } func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { @@ -1234,10 +1254,10 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc } } -func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) { - am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) - if nextRun, ok := account.GetNextPeerExpiration(); ok { - go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) +func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) { + am.peerLoginExpiry.Cancel(ctx, []string{accountID}) + if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok { + go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID)) } } @@ -1271,10 +1291,10 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context } // checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions -func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) { - am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) - if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { - go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) +func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, accountID string) { + am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) + if nextRun, ok := am.getNextInactivePeerExpiration(ctx, accountID); ok { + go am.peerInactivityExpiry.Schedule(ctx, nextRun, accountID, am.peerInactivityExpirationJob(ctx, accountID)) } } @@ -1435,12 +1455,12 @@ func isNil(i idp.Manager) bool { // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) { - accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) + accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID) if err != nil { return err } cachedAccount := &Account{ - Id: accountID, + Id: accountID, Users: make(map[string]*User), } for _, user := range accountUsers { @@ -2083,7 +2103,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st return fmt.Errorf("error saving groups: %w", err) } - if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { + if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil { return fmt.Errorf("error incrementing network serial: %w", err) } } @@ -2101,7 +2121,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range addNewGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { @@ -2114,7 +2134,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } for _, g := range removeOldGroups { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g) if err != nil { log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) } else { diff --git a/management/server/account_test.go b/management/server/account_test.go index f10cc0b29..aaef8fe8f 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2553,7 +2553,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 0) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) @@ -2573,7 +2573,7 @@ func TestAccount_SetJWTGroups(t *testing.T) { assert.NoError(t, err, "unable to get user") assert.Len(t, user.AutoGroups, 1) - group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") + group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1") assert.NoError(t, err, "unable to get group") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") }) diff --git a/management/server/group.go b/management/server/group.go index 7c6ece6d7..74b7f977f 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -50,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI return nil, err } - return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) } // GetAllGroups returns all groups in an account @@ -64,7 +64,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) + return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName) } // SaveGroup object of the peers @@ -94,7 +94,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { - existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, newGroup.Name, accountID) + existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name) if err != nil { s, ok := status.FromError(err) if !ok || s.ErrorType != status.NotFound { @@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } for _, peerID := range newGroup.Peers { - if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID); err != nil { + if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil { return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) } } @@ -158,7 +158,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID addedPeers := make([]string, 0) removedPeers := make([]string, 0) - oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, newGroup.ID, accountID) + oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID) if err == nil && oldGroup != nil { addedPeers = difference(newGroup.Peers, oldGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers) @@ -170,7 +170,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID } for _, peerID := range addedPeers { - peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID) + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) if err != nil { log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) continue @@ -187,7 +187,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID } for _, peerID := range removedPeers { - peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID) + peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID) if err != nil { log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) continue @@ -232,7 +232,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use return status.Errorf(status.PermissionDenied, "no permission to delete group") } - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { return err } @@ -288,7 +288,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us ) for _, groupID := range groupIDs { - group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) + group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID) if err != nil { continue } @@ -307,7 +307,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us return fmt.Errorf("failed to increment network serial: %w", err) } - if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, groupIDsToDelete, accountID); err != nil { + if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil { return fmt.Errorf("failed to delete group: %w", err) } return nil diff --git a/management/server/nameserver.go b/management/server/nameserver.go index eecfb3355..90c1eefa2 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") } - return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) + return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID) } // CreateNameServerGroup creates and saves a new nameserver group @@ -103,7 +103,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account") } - _, err = am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupToSave.ID, accountID) + _, err = am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID) if err != nil { return err } @@ -150,7 +150,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account") } - nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) + nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID) if err != nil { return err } @@ -160,7 +160,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco return fmt.Errorf("failed to increment network serial: %w", err) } - if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, nsGroupID, accountID); err != nil { + if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID); err != nil { return fmt.Errorf("failed to delete nameserver group: %w", err) } diff --git a/management/server/peer.go b/management/server/peer.go index a4c7e1266..461b9e310 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -117,11 +117,11 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK if peer.AddedWithSSOLogin() { if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, account) + am.checkAndSchedulePeerLoginExpiration(ctx, account.Id) } if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + am.checkAndSchedulePeerInactivityExpiration(ctx, account.Id) } } @@ -230,7 +230,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { - am.checkAndSchedulePeerLoginExpiration(ctx, account) + am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } } @@ -249,7 +249,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { - am.checkAndSchedulePeerInactivityExpiration(ctx, account) + am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } } @@ -537,7 +537,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return fmt.Errorf("failed to add peer to account: %w", err) } - err = transaction.IncrementNetworkSerial(ctx, accountID) + err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) if err != nil { return fmt.Errorf("failed to increment network serial: %w", err) } @@ -1041,6 +1041,139 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account wg.Wait() } +// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are connected. +func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) + return 0, false + } + + if len(peersWithExpiry) == 0 { + return 0, false + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account settings: %v", err) + return 0, false + } + + var nextExpiry *time.Duration + for _, peer := range peersWithExpiry { + // consider only connected peers because others will require login on connecting to the management server + if peer.Status.LoginExpired || !peer.Status.Connected { + continue + } + _, duration := peer.LoginExpired(settings.PeerLoginExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. +// If there is no peer that expires this function returns false and a duration of 0. +// This function only considers peers that haven't been expired yet and that are not connected. +func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) + return 0, false + } + + if len(peersWithInactivity) == 0 { + return 0, false + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account settings: %v", err) + return 0, false + } + + var nextExpiry *time.Duration + for _, peer := range peersWithInactivity { + if peer.Status.LoginExpired || peer.Status.Connected { + continue + } + _, duration := peer.SessionExpired(settings.PeerInactivityExpiration) + if nextExpiry == nil || duration < *nextExpiry { + // if expiration is below 1s return 1s duration + // this avoids issues with ticker that can't be set to < 0 + if duration < time.Second { + return time.Second, true + } + nextExpiry = &duration + } + } + + if nextExpiry == nil { + return 0, false + } + + return *nextExpiry, true +} + +// getExpiredPeers returns peers that have been expired. +func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { + peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + var peers []*nbpeer.Peer + for _, peer := range peersWithExpiry { + expired, _ := peer.LoginExpired(settings.PeerLoginExpiration) + if expired { + peers = append(peers, peer) + } + } + + return peers, nil +} + +// getInactivePeers returns peers that have been expired by inactivity +func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { + peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + var peers []*nbpeer.Peer + for _, inactivePeer := range peersWithInactivity { + inactive, _ := inactivePeer.SessionExpired(settings.PeerInactivityExpiration) + if inactive { + peers = append(peers, inactivePeer) + } + } + + return peers, nil +} + func ConvertSliceToMap(existingLabels []string) map[string]struct{} { labelMap := make(map[string]struct{}, len(existingLabels)) for _, label := range existingLabels { diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index 47a5e13d7..583073d97 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -8,9 +8,8 @@ import ( "time" b "github.com/hashicorp/go-secure-stdlib/base62" - "github.com/rs/xid" - "github.com/netbirdio/netbird/base62" + "github.com/rs/xid" ) const ( diff --git a/management/server/policy.go b/management/server/policy.go index 75647de44..0b7fce48f 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -335,7 +335,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") } - return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) + return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID) } // SavePolicy in the store diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index d5840d2be..bce3ea05e 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -25,7 +25,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } - return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) + return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) } // SavePostureChecks saves a posture check. @@ -49,7 +49,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI if isUpdate { action = activity.PostureCheckUpdated - if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecks.ID, accountID); err != nil { + if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil { return fmt.Errorf("failed to get posture checks: %w", err) } @@ -114,7 +114,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return err } - postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) + postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID) if err != nil { return err } @@ -124,7 +124,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun return fmt.Errorf("failed to increment network serial: %w", err) } - if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, postureChecksID, accountID); err != nil { + if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID); err != nil { return fmt.Errorf("failed to delete posture checks: %w", err) } return nil diff --git a/management/server/route.go b/management/server/route.go index 39ee6170c..048b425ae 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -56,13 +56,35 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string, return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") } - return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) + return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID)) +} + +// GetRoutesByPrefixOrDomains return list of routes by account and route prefix +func (am *DefaultAccountManager) GetRoutesByPrefixOrDomains(accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) { + accountRoutes, err := am.Store.GetAccountRoutes(context.Background(), LockingStrengthShare, accountID) + if err != nil { + return nil, err + } + + routes := make([]*route.Route, 0) + for _, r := range accountRoutes { + dynamic := r.IsDynamic() + if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() || + !dynamic && r.Network.String() == prefix.String() { + routes = append(routes, r) + } + } + + return routes, nil } // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { // routes can have both peer and peer_groups - routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) + routesWithPrefix, err := am.GetRoutesByPrefixOrDomains(account.Id, prefix, domains) + if err != nil { + return err + } // lets remember all the peers and the peer groups from routesWithPrefix seenPeers := make(map[string]bool) @@ -81,8 +103,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account for _, groupID := range prefixRoute.PeerGroups { seenPeerGroups[groupID] = true - group := account.GetGroup(groupID) - if group == nil { + group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, account.Id, groupID) + if err != nil || group == nil { return status.Errorf( status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", getRouteDescriptor(prefix, domains), groupID, @@ -97,10 +119,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account if peerID != "" { // check that peerID exists and is not in any route as single peer or part of the group - peer := account.GetPeer(peerID) - if peer == nil { + peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, account.Id, peerID) + if err != nil || peer == nil { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } + if _, ok := seenPeers[peerID]; ok { return status.Errorf(status.AlreadyExists, "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) @@ -109,7 +132,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account // check that peerGroupIDs are not in any route peerGroups list for _, groupID := range peerGroupIDs { - group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. + // we validated the group existence before entering this function, no need to check again. + group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, groupID, account.Id) + if err != nil || group == nil { + return status.Errorf(status.InvalidArgument, "group with ID %s not found", peerID) + } if _, ok := seenPeerGroups[groupID]; ok { return status.Errorf( @@ -120,10 +147,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix for _, id := range group.Peers { if _, ok := seenPeers[id]; ok { - peer := account.GetPeer(id) - if peer == nil { + peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, peerID, account.Id) + if err != nil { return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) } + return status.Errorf(status.AlreadyExists, "failed to add route with %s - peer %s from the group %s already has this route", getRouteDescriptor(prefix, domains), peer.Name, group.Name) @@ -146,6 +174,15 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + } + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, err @@ -181,17 +218,17 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri newRoute.ID = route.ID(xid.New().String()) if len(peerGroupIDs) > 0 { - err = validateGroups(peerGroupIDs, account.Groups) - if err != nil { - return nil, err - } + //err = validateGroups(peerGroupIDs, account.Groups) + //if err != nil { + // return nil, err + //} } if len(accessControlGroupIDs) > 0 { - err = validateGroups(accessControlGroupIDs, account.Groups) - if err != nil { - return nil, err - } + //err = validateGroups(accessControlGroupIDs, account.Groups) + //if err != nil { + // return nil, err + //} } err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) @@ -207,10 +244,10 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) } - err = validateGroups(groups, account.Groups) - if err != nil { - return nil, err - } + //err = validateGroups(groups, account.Groups) + //if err != nil { + // return nil, err + //} newRoute.Peer = peerID newRoute.PeerGroups = peerGroupIDs @@ -290,17 +327,17 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI } if len(routeToSave.PeerGroups) > 0 { - err = validateGroups(routeToSave.PeerGroups, account.Groups) - if err != nil { - return err - } + //err = validateGroups(routeToSave.PeerGroups, account.Groups) + //if err != nil { + // return err + //} } if len(routeToSave.AccessControlGroups) > 0 { - err = validateGroups(routeToSave.AccessControlGroups, account.Groups) - if err != nil { - return err - } + //err = validateGroups(routeToSave.AccessControlGroups, account.Groups) + //if err != nil { + // return err + //} } err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) @@ -308,10 +345,10 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI return err } - err = validateGroups(routeToSave.Groups, account.Groups) - if err != nil { - return err - } + //err = validateGroups(routeToSave.Groups, account.Groups) + //if err != nil { + // return err + //} account.Routes[routeToSave.ID] = routeToSave diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 41cf894ec..d6bdf74f0 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -287,7 +287,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str return nil, err } - oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyToSave.Id, accountID) + oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id) if err != nil { return nil, err } @@ -382,7 +382,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") } - setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) + setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID) if err != nil { return nil, err } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index de3dfa945..dde15b265 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -541,9 +541,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre return &user, nil } -func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { +func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) { var users []*User - result := s.db.Find(&users, accountIDCondition, accountID) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") @@ -828,7 +828,9 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) { var accountSettings AccountSettings - if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + First(&accountSettings, idQueryCondition, accountID) + if err := result.Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } @@ -837,6 +839,21 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS return accountSettings.Settings, nil } +// SaveAccountSettings stores the account settings in DB. +func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error { + result := s.db.WithContext(ctx).Debug().Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Select("*").Where(idQueryCondition, accountID).Updates(&AccountSettings{Settings: settings}) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save account settings to store: %v", result.Error) + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "account not found") + } + + return nil +} + // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { var user User @@ -1054,9 +1071,72 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId return nil } +// GetAccountPeers retrieves peers for an account. +func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get peers from store") + } + + return peers, nil +} + // GetUserPeers retrieves peers for a user. func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { - return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID) + var peers []*nbpeer.Peer + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, "account_id = ? AND user_id = ?", accountID, userID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get peers from store") + } + + return peers, nil +} + +// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user. +func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). + Find(&peers, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers with expiration from store") + } + + return peers, nil +} + +// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user. +func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { + var peers []*nbpeer.Peer + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). + Find(&peers, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get peers with inactivity from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peers with inactivity from store") + } + + return peers, nil +} + +// GetPeerByID retrieves a peer by its ID and account ID. +func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) { + var peer *nbpeer.Peer + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&peer, accountAndIDQueryCondition, accountID, peerID) + if err := result.Error; err != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "peer not found") + } + log.WithContext(ctx).Errorf("failed to get peer from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get peer from store") + } + + return peer, nil } func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { @@ -1067,8 +1147,9 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro return nil } -func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) +func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) } @@ -1113,6 +1194,19 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki return &accountDNSSettings.DNSSettings, nil } +// SaveDNSSettings saves the DNS settings to the store. +func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings}) + if result.Error != nil { + return status.Errorf(status.Internal, "failed to save dns settings to store: %v", result.Error) + } + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "account not found") + } + return nil +} + // AccountExists checks whether an account exists by the given ID. func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { var accountID string @@ -1146,16 +1240,24 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength } // GetGroupByID retrieves a group by ID and account ID. -func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { - return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID) +func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) { + var group *nbgroup.Group + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&group, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get group from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get group from store") + } + + return group, nil } // GetGroupByName retrieves a group by name and account ID. -func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) { var group nbgroup.Group result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). - Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID) + Order("json_array_length(peers) DESC").First(&group, "account_id = ? AND name = ?", accountID, groupName) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "group not found") @@ -1174,69 +1276,335 @@ func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength, return nil } +// DeleteGroup deletes a group from the database. +func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete group from the store: %s", result.Error) + return status.Errorf(status.Internal, "failed to delete group from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "group not found") + } + + return nil +} + +// DeleteGroups deletes groups from the database. +func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(strength)}). + Delete(&nbgroup.Group{}, " account_id = ? AND id IN ?", accountID, groupIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete groups from the store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error) + } + + return nil +} + // GetAccountPolicies retrieves policies for an account. func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { - return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) + var policies []*Policy + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get policies from store") + } + + return policies, nil } // GetPolicyByID retrieves a policy by its ID and account ID. -func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { - return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID) +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) { + var policy *Policy + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Preload(clause.Associations).Find(&policy, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get policy from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get policy from store") + } + + return policy, nil +} + +// SavePolicy saves a policy to the database. +func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error { + result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err) + return status.Errorf(status.Internal, "failed to save policy to store") + } + return nil +} + +func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, policyID, accountID string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete policy from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete policy from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "policy not found") + } + + return nil } // GetAccountPostureChecks retrieves posture checks for an account. func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { - return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID) + var postureChecks []*posture.Checks + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get posture checks from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get posture checks from store") + } + + return postureChecks, nil } // GetPostureChecksByID retrieves posture checks by their ID and account ID. -func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { - return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID) +func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) { + var postureCheck *posture.Checks + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&postureCheck, accountAndIDQueryCondition, accountID, postureCheckID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "posture check not found") + } + log.WithContext(ctx).Errorf("failed to get posture check from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get posture check from store") + } + + return postureCheck, nil +} + +// SavePostureChecks saves a posture checks to the database. +func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrDuplicatedKey) { + return status.Errorf(status.InvalidArgument, "name should be unique") + } + log.WithContext(ctx).Errorf("failed to save posture checks to the store: %s", err) + return status.Errorf(status.Internal, "failed to save posture checks to store") + } + + return nil +} + +// DeletePostureChecks deletes a posture checks from the database. +func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete posture checks from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete posture checks from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "posture checks not found") + } + + return nil } // GetAccountRoutes retrieves network routes for an account. func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { - return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID) + var routes []*route.Route + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&routes, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get routes from store") + } + + return routes, nil } // GetRouteByID retrieves a route by its ID and account ID. -func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { - return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID) +func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) { + var route *route.Route + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&route, accountAndIDQueryCondition, accountID, routeID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "route not found") + } + log.WithContext(ctx).Errorf("failed to get route from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get route from store") + } + + return route, nil +} + +// SaveRoute saves a route to the database. +func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save route to the store: %s", err) + return status.Errorf(status.Internal, "failed to save route to store") + } + + return nil +} + +// DeleteRoute deletes a route from the database. +func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete route from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "route not found") + } + + return nil } // GetAccountSetupKeys retrieves setup keys for an account. func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { - return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID) + var setupKeys []*SetupKey + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Find(&setupKeys, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get setup keys from store") + } + + return setupKeys, nil } // GetSetupKeyByID retrieves a setup key by its ID and account ID. -func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) { - return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID) +func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) { + var setupKey *SetupKey + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID) + if err := result.Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "setup key not found") + } + log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get setup key from store") + } + + return setupKey, nil +} + +// SaveSetupKey saves a setup key to the database. +func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error { + result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}). + Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err) + return status.Errorf(status.Internal, "failed to save name server group to store") + } + + return nil } // GetAccountNameServerGroups retrieves name server groups for an account. func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { - return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) + var nsGroups []*nbdns.NameServerGroup + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get name server groups from store") + } + + return nsGroups, nil } // GetNameServerGroupByID retrieves a name server group by its ID and account ID. -func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) { - return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) -} - -// getRecords retrieves records from the database based on the account ID. -func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) { - var record []T - - result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID) +func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { + var nsGroup *nbdns.NameServerGroup + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID) if err := result.Error; err != nil { - parts := strings.Split(fmt.Sprintf("%T", record), ".") - recordType := parts[len(parts)-1] - - return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err) + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "name server group not found") + } + log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get name server group from store") } - return record, nil + return nsGroup, nil +} + +// SaveNameServerGroup saves a name server group to the database. +func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err) + return status.Errorf(status.Internal, "failed to save name server group to store") + } + return nil +} + +// DeleteNameServerGroup deletes a name server group from the database. +func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nameServerGroupID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete name server group from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "name server group not found") + } + + return nil +} + +// GetPATByID retrieves a personal access token by its ID and user ID. +func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, patID string, userID string) (*PersonalAccessToken, error) { + var pat PersonalAccessToken + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&pat, "id = ? AND user_id = ?", patID, userID) + if err := result.Error; err != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "PAT not found") + } + log.WithContext(ctx).Errorf("failed to get PAT from the store: %s", err) + return nil, status.Errorf(status.Internal, "failed to get PAT from store") + } + + return &pat, nil +} + +// SavePAT saves a personal access token to the database. +func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to save PAT to the store: %s", err) + return status.Errorf(status.Internal, "failed to save PAT to store") + } + + return nil +} + +// DeletePAT deletes a personal access token from the database. +func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error { + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID) + if err := result.Error; err != nil { + log.WithContext(ctx).Errorf("failed to delete PAT from the store: %s", err) + return status.Errorf(status.Internal, "failed to delete PAT from store") + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "PAT not found") + } + + return nil } // getRecordByID retrieves a record by its ID and account ID from the database. diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 20e812ea7..8feec2b77 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1181,7 +1181,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) { t.Fatal("failed to save group") return err } - group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) + group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID) if err != nil { t.Fatal("failed to get group") return err diff --git a/management/server/store.go b/management/server/store.go index 131fd8aaa..deb55ea3f 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -56,13 +56,15 @@ type Store interface { GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) + SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error + SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) - GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) + GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) SaveUsers(accountID string, users map[string]*User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error @@ -71,24 +73,34 @@ type Store interface { DeleteTokenID2UserIDIndex(tokenID string) error GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) - GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) - GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) + GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error + DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error + DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) - GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) + SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error + DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) - GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) + GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) + SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error + DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) + GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) + GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) + GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error @@ -96,16 +108,25 @@ type Store interface { GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) - GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) + GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) + SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) - GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) + GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error) + SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error + DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) - GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) + GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) (*dns.NameServerGroup, error) + SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error + DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error + + GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error) + SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error + DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) - IncrementNetworkSerial(ctx context.Context, accountId string) error + IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetInstallationID() string diff --git a/management/server/user.go b/management/server/user.go index 71608ef20..17c5065ef 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -546,9 +546,6 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin // CreatePAT creates a new PAT for the given user func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - if tokenName == "" { return nil, status.Errorf(status.InvalidArgument, "token name can't be empty") } @@ -557,35 +554,28 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") } - account, err := am.Store.GetAccount(ctx, accountID) + executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) if err != nil { return nil, err } - targetUser, ok := account.Users[targetUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") + targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + if err != nil { + return nil, err } - executingUser, ok := account.Users[initiatorUserID] - if !ok { - return nil, status.Errorf(status.NotFound, "user not found") - } - - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { + if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) || + executingUser.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user") } - pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id) + pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id, executingUser.Id) if err != nil { return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) } - targetUser.PATs[pat.ID] = &pat.PersonalAccessToken - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, status.Errorf(status.Internal, "failed to save account: %v", err) + if err = am.Store.SavePAT(ctx, LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil { + return nil, fmt.Errorf("failed to save PAT: %w", err) } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} @@ -596,51 +586,33 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string // DeletePAT deletes a specific PAT from a user func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID) if err != nil { - return status.Errorf(status.NotFound, "account not found: %s", err) + return err } - targetUser, ok := account.Users[targetUserID] - if !ok { - return status.Errorf(status.NotFound, "user not found") + targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) + if err != nil { + return err } - executingUser, ok := account.Users[initiatorUserID] - if !ok { - return status.Errorf(status.NotFound, "user not found") - } - - if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) { + if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) || + executingUser.AccountID != accountID { return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user") } - pat := targetUser.PATs[tokenID] - if pat == nil { - return status.Errorf(status.NotFound, "PAT not found") + pat, err := am.Store.GetPATByID(ctx, LockingStrengthShare, tokenID, targetUserID) + if err != nil { + return err } - err = am.Store.DeleteTokenID2UserIDIndex(pat.ID) - if err != nil { - return status.Errorf(status.Internal, "Failed to delete token id index: %s", err) - } - err = am.Store.DeleteHashedPAT2TokenIDIndex(pat.HashedToken) - if err != nil { - return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err) + if err = am.Store.DeletePAT(ctx, LockingStrengthUpdate, tokenID, targetUserID); err != nil { + return fmt.Errorf("failed to delete PAT: %w", err) } meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) - delete(targetUser.PATs, tokenID) - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return status.Errorf(status.Internal, "Failed to save account: %s", err) - } return nil } @@ -651,22 +623,11 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i return nil, err } - targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID) - if err != nil { - return nil, err - } - if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") } - for _, pat := range targetUser.PATsG { - if pat.ID == tokenID { - return pat.Copy(), nil - } - } - - return nil, status.Errorf(status.NotFound, "PAT not found") + return am.Store.GetPATByID(ctx, LockingStrengthShare, targetUserID, tokenID) } // GetAllPATs returns all PATs for a user