fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-15 18:17:47 +03:00
parent a8c8b77df8
commit 1123729c1c
14 changed files with 766 additions and 227 deletions

View File

@ -45,15 +45,15 @@ import (
) )
const ( const (
PublicCategory = "public" PublicCategory = "public"
PrivateCategory = "private" PrivateCategory = "private"
UnknownCategory = "unknown" UnknownCategory = "unknown"
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
DefaultPeerLoginExpiration = 24 * time.Hour DefaultPeerLoginExpiration = 24 * time.Hour
DefaultPeerInactivityExpiration = 10 * time.Minute DefaultPeerInactivityExpiration = 10 * time.Minute
emptyUserID = "empty user ID in claims" emptyUserID = "empty user ID in claims"
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
) )
type userLoggedInOnce bool type userLoggedInOnce bool
@ -134,14 +134,14 @@ type AccountManager interface {
GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error)
LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error) GetAllConnectedPeers() (map[string]struct{}, error)
HasConnectedChannel(peerID string) bool HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManager() idp.Manager GetIdpManager() idp.Manager
@ -1122,7 +1122,16 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
// Only users with role UserRoleAdmin can update the account. // Only users with role UserRoleAdmin can update the account.
// User that performs the update has to belong to the account. // User that performs the update has to belong to the account.
// Returns an updated Account // Returns an updated Account
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) { func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
}
halfYearLimit := 180 * 24 * time.Hour halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit { if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@ -1132,78 +1141,89 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
} }
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) var oldSettings *Settings
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
oldSettings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
if err = am.validateExtraSettings(ctx, newSettings, oldSettings, userID, accountID); err != nil {
return fmt.Errorf("failed to validate extra settings: %w", err)
}
if err = am.Store.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil {
return fmt.Errorf("failed to update account settings: %w", err)
}
return nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("error getting account: %w", err)
} }
am.updateAccountPeers(ctx, account)
if !user.HasAdminPower() { return newSettings, nil
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") }
}
err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) // validateExtraSettings validates the extra settings of the account.
func (am *DefaultAccountManager) validateExtraSettings(ctx context.Context, newSettings, oldSettings *Settings, userID, accountID string) error {
peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return err
} }
oldSettings := account.Settings peerMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peerMap[peer.ID] = peer
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peerMap, userID, accountID)
}
func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) {
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled { if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
event := activity.AccountPeerLoginExpirationEnabled event := activity.AccountPeerLoginExpirationEnabled
if !newSettings.PeerLoginExpirationEnabled { if !newSettings.PeerLoginExpirationEnabled {
event = activity.AccountPeerLoginExpirationDisabled event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel(ctx, []string{accountID}) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
} else { } else {
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
} }
am.StoreEvent(ctx, userID, accountID, accountID, event, nil) am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
} }
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
} }
err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, err
}
updatedAccount := account.UpdateSettings(newSettings)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
return updatedAccount, nil
} }
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled event := activity.AccountPeerInactivityExpirationEnabled
if !newSettings.PeerInactivityExpirationEnabled { if !newSettings.PeerInactivityExpirationEnabled {
event = activity.AccountPeerInactivityExpirationDisabled event = activity.AccountPeerInactivityExpirationDisabled
am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else { } else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
} }
am.StoreEvent(ctx, userID, accountID, accountID, event, nil) am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
} }
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
} }
return nil
} }
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
@ -1234,10 +1254,10 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
} }
} }
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) { func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) {
am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
if nextRun, ok := account.GetNextPeerExpiration(); ok { if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
} }
} }
@ -1271,10 +1291,10 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
} }
// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions // checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) { func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, accountID string) {
am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { if nextRun, ok := am.getNextInactivePeerExpiration(ctx, accountID); ok {
go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) go am.peerInactivityExpiry.Schedule(ctx, nextRun, accountID, am.peerInactivityExpirationJob(ctx, accountID))
} }
} }
@ -1435,12 +1455,12 @@ func isNil(i idp.Manager) bool {
// addAccountIDToIDPAppMeta update user's app metadata in idp manager // addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) { if !isNil(am.idpManager) {
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID) accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }
cachedAccount := &Account{ cachedAccount := &Account{
Id: accountID, Id: accountID,
Users: make(map[string]*User), Users: make(map[string]*User),
} }
for _, user := range accountUsers { for _, user := range accountUsers {
@ -2083,7 +2103,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error saving groups: %w", err) return fmt.Errorf("error saving groups: %w", err)
} }
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil { if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err) return fmt.Errorf("error incrementing network serial: %w", err)
} }
} }
@ -2101,7 +2121,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
for _, g := range addNewGroups { for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else { } else {
@ -2114,7 +2134,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
} }
for _, g := range removeOldGroups { for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else { } else {

View File

@ -2553,7 +2553,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0) assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })
@ -2573,7 +2573,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1) assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID") group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
}) })

View File

@ -50,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
return nil, err return nil, err
} }
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
@ -64,7 +64,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
} }
// SaveGroup object of the peers // SaveGroup object of the peers
@ -94,7 +94,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
} }
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, newGroup.Name, accountID) existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil { if err != nil {
s, ok := status.FromError(err) s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound { if !ok || s.ErrorType != status.NotFound {
@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
} }
for _, peerID := range newGroup.Peers { for _, peerID := range newGroup.Peers {
if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID); err != nil { if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
} }
} }
@ -158,7 +158,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, newGroup.ID, accountID) oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil { if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers) addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers) removedPeers = difference(oldGroup.Peers, newGroup.Peers)
@ -170,7 +170,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
} }
for _, peerID := range addedPeers { for _, peerID := range addedPeers {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID) peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue continue
@ -187,7 +187,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
} }
for _, peerID := range removedPeers { for _, peerID := range removedPeers {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID) peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID) log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue continue
@ -232,7 +232,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
return status.Errorf(status.PermissionDenied, "no permission to delete group") return status.Errorf(status.PermissionDenied, "no permission to delete group")
} }
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil { if err != nil {
return err return err
} }
@ -288,7 +288,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
) )
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID) group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil { if err != nil {
continue continue
} }
@ -307,7 +307,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, groupIDsToDelete, accountID); err != nil { if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil {
return fmt.Errorf("failed to delete group: %w", err) return fmt.Errorf("failed to delete group: %w", err)
} }
return nil return nil

View File

@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups") return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
} }
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
} }
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
@ -103,7 +103,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account") return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account")
} }
_, err = am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupToSave.ID, accountID) _, err = am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
if err != nil { if err != nil {
return err return err
} }
@ -150,7 +150,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account") return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account")
} }
nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID) nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
if err != nil { if err != nil {
return err return err
} }
@ -160,7 +160,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, nsGroupID, accountID); err != nil { if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID); err != nil {
return fmt.Errorf("failed to delete nameserver group: %w", err) return fmt.Errorf("failed to delete nameserver group: %w", err)
} }

View File

@ -117,11 +117,11 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
if peer.AddedWithSSOLogin() { if peer.AddedWithSSOLogin() {
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, account.Id)
} }
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account) am.checkAndSchedulePeerInactivityExpiration(ctx, account.Id)
} }
} }
@ -230,7 +230,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
} }
} }
@ -249,7 +249,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
} }
} }
@ -537,7 +537,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed to add peer to account: %w", err) return fmt.Errorf("failed to add peer to account: %w", err)
} }
err = transaction.IncrementNetworkSerial(ctx, accountID) err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
@ -1041,6 +1041,139 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
wg.Wait() wg.Wait()
} }
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected.
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
return 0, false
}
if len(peersWithExpiry) == 0 {
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithExpiry {
// consider only connected peers because others will require login on connecting to the management server
if peer.Status.LoginExpired || !peer.Status.Connected {
continue
}
_, duration := peer.LoginExpired(settings.PeerLoginExpiration)
if nextExpiry == nil || duration < *nextExpiry {
// if expiration is below 1s return 1s duration
// this avoids issues with ticker that can't be set to < 0
if duration < time.Second {
return time.Second, true
}
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are not connected.
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
return 0, false
}
if len(peersWithInactivity) == 0 {
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithInactivity {
if peer.Status.LoginExpired || peer.Status.Connected {
continue
}
_, duration := peer.SessionExpired(settings.PeerInactivityExpiration)
if nextExpiry == nil || duration < *nextExpiry {
// if expiration is below 1s return 1s duration
// this avoids issues with ticker that can't be set to < 0
if duration < time.Second {
return time.Second, true
}
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// getExpiredPeers returns peers that have been expired.
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
var peers []*nbpeer.Peer
for _, peer := range peersWithExpiry {
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
if expired {
peers = append(peers, peer)
}
}
return peers, nil
}
// getInactivePeers returns peers that have been expired by inactivity
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
var peers []*nbpeer.Peer
for _, inactivePeer := range peersWithInactivity {
inactive, _ := inactivePeer.SessionExpired(settings.PeerInactivityExpiration)
if inactive {
peers = append(peers, inactivePeer)
}
}
return peers, nil
}
func ConvertSliceToMap(existingLabels []string) map[string]struct{} { func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels)) labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels { for _, label := range existingLabels {

View File

@ -8,9 +8,8 @@ import (
"time" "time"
b "github.com/hashicorp/go-secure-stdlib/base62" b "github.com/hashicorp/go-secure-stdlib/base62"
"github.com/rs/xid"
"github.com/netbirdio/netbird/base62" "github.com/netbirdio/netbird/base62"
"github.com/rs/xid"
) )
const ( const (

View File

@ -335,7 +335,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
} }
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
} }
// SavePolicy in the store // SavePolicy in the store

View File

@ -25,7 +25,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
} }
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
} }
// SavePostureChecks saves a posture check. // SavePostureChecks saves a posture check.
@ -49,7 +49,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
if isUpdate { if isUpdate {
action = activity.PostureCheckUpdated action = activity.PostureCheckUpdated
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecks.ID, accountID); err != nil { if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
return fmt.Errorf("failed to get posture checks: %w", err) return fmt.Errorf("failed to get posture checks: %w", err)
} }
@ -114,7 +114,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
return err return err
} }
postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
if err != nil { if err != nil {
return err return err
} }
@ -124,7 +124,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
return fmt.Errorf("failed to increment network serial: %w", err) return fmt.Errorf("failed to increment network serial: %w", err)
} }
if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, postureChecksID, accountID); err != nil { if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID); err != nil {
return fmt.Errorf("failed to delete posture checks: %w", err) return fmt.Errorf("failed to delete posture checks: %w", err)
} }
return nil return nil

View File

@ -56,13 +56,35 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes") return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
} }
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID) return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func (am *DefaultAccountManager) GetRoutesByPrefixOrDomains(accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := am.Store.GetAccountRoutes(context.Background(), LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
routes := make([]*route.Route, 0)
for _, r := range accountRoutes {
dynamic := r.IsDynamic()
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
!dynamic && r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes, nil
} }
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
// routes can have both peer and peer_groups // routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) routesWithPrefix, err := am.GetRoutesByPrefixOrDomains(account.Id, prefix, domains)
if err != nil {
return err
}
// lets remember all the peers and the peer groups from routesWithPrefix // lets remember all the peers and the peer groups from routesWithPrefix
seenPeers := make(map[string]bool) seenPeers := make(map[string]bool)
@ -81,8 +103,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
for _, groupID := range prefixRoute.PeerGroups { for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true seenPeerGroups[groupID] = true
group := account.GetGroup(groupID) group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, account.Id, groupID)
if group == nil { if err != nil || group == nil {
return status.Errorf( return status.Errorf(
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
getRouteDescriptor(prefix, domains), groupID, getRouteDescriptor(prefix, domains), groupID,
@ -97,10 +119,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
if peerID != "" { if peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group // check that peerID exists and is not in any route as single peer or part of the group
peer := account.GetPeer(peerID) peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, account.Id, peerID)
if peer == nil { if err != nil || peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
if _, ok := seenPeers[peerID]; ok { if _, ok := seenPeers[peerID]; ok {
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
@ -109,7 +132,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
// check that peerGroupIDs are not in any route peerGroups list // check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs { for _, groupID := range peerGroupIDs {
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. // we validated the group existence before entering this function, no need to check again.
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, groupID, account.Id)
if err != nil || group == nil {
return status.Errorf(status.InvalidArgument, "group with ID %s not found", peerID)
}
if _, ok := seenPeerGroups[groupID]; ok { if _, ok := seenPeerGroups[groupID]; ok {
return status.Errorf( return status.Errorf(
@ -120,10 +147,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
for _, id := range group.Peers { for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok { if _, ok := seenPeers[id]; ok {
peer := account.GetPeer(id) peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, peerID, account.Id)
if peer == nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s from the group %s already has this route", "failed to add route with %s - peer %s from the group %s already has this route",
getRouteDescriptor(prefix, domains), peer.Name, group.Name) getRouteDescriptor(prefix, domains), peer.Name, group.Name)
@ -146,6 +174,15 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}
account, err := am.Store.GetAccount(ctx, accountID) account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -181,17 +218,17 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
newRoute.ID = route.ID(xid.New().String()) newRoute.ID = route.ID(xid.New().String())
if len(peerGroupIDs) > 0 { if len(peerGroupIDs) > 0 {
err = validateGroups(peerGroupIDs, account.Groups) //err = validateGroups(peerGroupIDs, account.Groups)
if err != nil { //if err != nil {
return nil, err // return nil, err
} //}
} }
if len(accessControlGroupIDs) > 0 { if len(accessControlGroupIDs) > 0 {
err = validateGroups(accessControlGroupIDs, account.Groups) //err = validateGroups(accessControlGroupIDs, account.Groups)
if err != nil { //if err != nil {
return nil, err // return nil, err
} //}
} }
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
@ -207,10 +244,10 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
err = validateGroups(groups, account.Groups) //err = validateGroups(groups, account.Groups)
if err != nil { //if err != nil {
return nil, err // return nil, err
} //}
newRoute.Peer = peerID newRoute.Peer = peerID
newRoute.PeerGroups = peerGroupIDs newRoute.PeerGroups = peerGroupIDs
@ -290,17 +327,17 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
} }
if len(routeToSave.PeerGroups) > 0 { if len(routeToSave.PeerGroups) > 0 {
err = validateGroups(routeToSave.PeerGroups, account.Groups) //err = validateGroups(routeToSave.PeerGroups, account.Groups)
if err != nil { //if err != nil {
return err // return err
} //}
} }
if len(routeToSave.AccessControlGroups) > 0 { if len(routeToSave.AccessControlGroups) > 0 {
err = validateGroups(routeToSave.AccessControlGroups, account.Groups) //err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
if err != nil { //if err != nil {
return err // return err
} //}
} }
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
@ -308,10 +345,10 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err return err
} }
err = validateGroups(routeToSave.Groups, account.Groups) //err = validateGroups(routeToSave.Groups, account.Groups)
if err != nil { //if err != nil {
return err // return err
} //}
account.Routes[routeToSave.ID] = routeToSave account.Routes[routeToSave.ID] = routeToSave

View File

@ -287,7 +287,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return nil, err return nil, err
} }
oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyToSave.Id, accountID) oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -382,7 +382,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys") return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
} }
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID) setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -541,9 +541,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil return &user, nil
} }
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) { func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
var users []*User var users []*User
result := s.db.Find(&users, accountIDCondition, accountID) result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed") return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@ -828,7 +828,9 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) { func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
var accountSettings AccountSettings var accountSettings AccountSettings
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
First(&accountSettings, idQueryCondition, accountID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found") return nil, status.Errorf(status.NotFound, "settings not found")
} }
@ -837,6 +839,21 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
return accountSettings.Settings, nil return accountSettings.Settings, nil
} }
// SaveAccountSettings stores the account settings in DB.
func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error {
result := s.db.WithContext(ctx).Debug().Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select("*").Where(idQueryCondition, accountID).Updates(&AccountSettings{Settings: settings})
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save account settings to store: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "account not found")
}
return nil
}
// SaveUserLastLogin stores the last login time for a user in DB. // SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User var user User
@ -1054,9 +1071,72 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
return nil return nil
} }
// GetAccountPeers retrieves peers for an account.
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get peers from store")
}
return peers, nil
}
// GetUserPeers retrieves peers for a user. // GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID) var peers []*nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, "account_id = ? AND user_id = ?", accountID, userID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get peers from store")
}
return peers, nil
}
// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
Find(&peers, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peers with expiration from store")
}
return peers, nil
}
// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
Find(&peers, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peers with inactivity from store")
}
return peers, nil
}
// GetPeerByID retrieves a peer by its ID and account ID.
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
var peer *nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&peer, accountAndIDQueryCondition, accountID, peerID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
log.WithContext(ctx).Errorf("failed to get peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peer from store")
}
return peer, nil
} }
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
@ -1067,8 +1147,9 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
return nil return nil
} }
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil { if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error) return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
} }
@ -1113,6 +1194,19 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
return &accountDNSSettings.DNSSettings, nil return &accountDNSSettings.DNSSettings, nil
} }
// SaveDNSSettings saves the DNS settings to the store.
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save dns settings to store: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "account not found")
}
return nil
}
// AccountExists checks whether an account exists by the given ID. // AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
var accountID string var accountID string
@ -1146,16 +1240,24 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
} }
// GetGroupByID retrieves a group by ID and account ID. // GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) { func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) {
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID) var group *nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get group from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get group from store")
}
return group, nil
} }
// GetGroupByName retrieves a group by name and account ID. // GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) {
var group nbgroup.Group var group nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations). result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID) Order("json_array_length(peers) DESC").First(&group, "account_id = ? AND name = ?", accountID, groupName)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found") return nil, status.Errorf(status.NotFound, "group not found")
@ -1174,69 +1276,335 @@ func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength,
return nil return nil
} }
// DeleteGroup deletes a group from the database.
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group from the store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete group from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "group not found")
}
return nil
}
// DeleteGroups deletes groups from the database.
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(strength)}).
Delete(&nbgroup.Group{}, " account_id = ? AND id IN ?", accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete groups from the store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error)
}
return nil
}
// GetAccountPolicies retrieves policies for an account. // GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) { func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID) var policies []*Policy
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get policies from store")
}
return policies, nil
} }
// GetPolicyByID retrieves a policy by its ID and account ID. // GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) {
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID) var policy *Policy
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).Find(&policy, accountAndIDQueryCondition, accountID, policyID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get policy from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get policy from store")
}
return policy, nil
}
// SavePolicy saves a policy to the database.
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
return status.Errorf(status.Internal, "failed to save policy to store")
}
return nil
}
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, policyID, accountID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete policy from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete policy from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "policy not found")
}
return nil
} }
// GetAccountPostureChecks retrieves posture checks for an account. // GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID) var postureChecks []*posture.Checks
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get posture checks from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get posture checks from store")
}
return postureChecks, nil
} }
// GetPostureChecksByID retrieves posture checks by their ID and account ID. // GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) {
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID) var postureCheck *posture.Checks
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&postureCheck, accountAndIDQueryCondition, accountID, postureCheckID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "posture check not found")
}
log.WithContext(ctx).Errorf("failed to get posture check from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get posture check from store")
}
return postureCheck, nil
}
// SavePostureChecks saves a posture checks to the database.
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return status.Errorf(status.InvalidArgument, "name should be unique")
}
log.WithContext(ctx).Errorf("failed to save posture checks to the store: %s", err)
return status.Errorf(status.Internal, "failed to save posture checks to store")
}
return nil
}
// DeletePostureChecks deletes a posture checks from the database.
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete posture checks from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete posture checks from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "posture checks not found")
}
return nil
} }
// GetAccountRoutes retrieves network routes for an account. // GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID) var routes []*route.Route
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&routes, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get routes from store")
}
return routes, nil
} }
// GetRouteByID retrieves a route by its ID and account ID. // GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID) var route *route.Route
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&route, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "route not found")
}
log.WithContext(ctx).Errorf("failed to get route from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get route from store")
}
return route, nil
}
// SaveRoute saves a route to the database.
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
return status.Errorf(status.Internal, "failed to save route to store")
}
return nil
}
// DeleteRoute deletes a route from the database.
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete route from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "route not found")
}
return nil
} }
// GetAccountSetupKeys retrieves setup keys for an account. // GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) { func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID) var setupKeys []*SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&setupKeys, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get setup keys from store")
}
return setupKeys, nil
} }
// GetSetupKeyByID retrieves a setup key by its ID and account ID. // GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) { func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) {
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID) var setupKey *SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get setup key from store")
}
return setupKey, nil
}
// SaveSetupKey saves a setup key to the database.
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error {
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
return status.Errorf(status.Internal, "failed to save name server group to store")
}
return nil
} }
// GetAccountNameServerGroups retrieves name server groups for an account. // GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID) var nsGroups []*nbdns.NameServerGroup
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get name server groups from store")
}
return nsGroups, nil
} }
// GetNameServerGroupByID retrieves a name server group by its ID and account ID. // GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) { func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID) var nsGroup *nbdns.NameServerGroup
} result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
var record []T
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".") if errors.Is(err, gorm.ErrRecordNotFound) {
recordType := parts[len(parts)-1] return nil, status.Errorf(status.NotFound, "name server group not found")
}
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err) log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get name server group from store")
} }
return record, nil return nsGroup, nil
}
// SaveNameServerGroup saves a name server group to the database.
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
return status.Errorf(status.Internal, "failed to save name server group to store")
}
return nil
}
// DeleteNameServerGroup deletes a name server group from the database.
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nameServerGroupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete name server group from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "name server group not found")
}
return nil
}
// GetPATByID retrieves a personal access token by its ID and user ID.
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, patID string, userID string) (*PersonalAccessToken, error) {
var pat PersonalAccessToken
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&pat, "id = ? AND user_id = ?", patID, userID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "PAT not found")
}
log.WithContext(ctx).Errorf("failed to get PAT from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get PAT from store")
}
return &pat, nil
}
// SavePAT saves a personal access token to the database.
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save PAT to the store: %s", err)
return status.Errorf(status.Internal, "failed to save PAT to store")
}
return nil
}
// DeletePAT deletes a personal access token from the database.
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete PAT from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete PAT from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "PAT not found")
}
return nil
} }
// getRecordByID retrieves a record by its ID and account ID from the database. // getRecordByID retrieves a record by its ID and account ID from the database.

View File

@ -1181,7 +1181,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Fatal("failed to save group") t.Fatal("failed to save group")
return err return err
} }
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID) group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
if err != nil { if err != nil {
t.Fatal("failed to get group") t.Fatal("failed to get group")
return err return err

View File

@ -56,13 +56,15 @@ type Store interface {
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error
SaveAccount(ctx context.Context, account *Account) error SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
@ -71,24 +73,34 @@ type Store interface {
DeleteTokenID2UserIDIndex(tokenID string) error DeleteTokenID2UserIDIndex(tokenID string) error
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error
@ -96,16 +108,25 @@ type Store interface {
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) (*dns.NameServerGroup, error)
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string GetInstallationID() string

View File

@ -546,9 +546,6 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
// CreatePAT creates a new PAT for the given user // CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) { func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if tokenName == "" { if tokenName == "" {
return nil, status.Errorf(status.InvalidArgument, "token name can't be empty") return nil, status.Errorf(status.InvalidArgument, "token name can't be empty")
} }
@ -557,35 +554,28 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365") return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
} }
account, err := am.Store.GetAccount(ctx, accountID) executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser, ok := account.Users[targetUserID] targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
if !ok { if err != nil {
return nil, status.Errorf(status.NotFound, "user not found") return nil, err
} }
executingUser, ok := account.Users[initiatorUserID] if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) ||
if !ok { executingUser.AccountID != accountID {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user") return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
} }
pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id) pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id, executingUser.Id)
if err != nil { if err != nil {
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err) return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
} }
targetUser.PATs[pat.ID] = &pat.PersonalAccessToken if err = am.Store.SavePAT(ctx, LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil {
return nil, fmt.Errorf("failed to save PAT: %w", err)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, status.Errorf(status.Internal, "failed to save account: %v", err)
} }
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
@ -596,51 +586,33 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
// DeletePAT deletes a specific PAT from a user // DeletePAT deletes a specific PAT from a user
func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return status.Errorf(status.NotFound, "account not found: %s", err) return err
} }
targetUser, ok := account.Users[targetUserID] targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
if !ok { if err != nil {
return status.Errorf(status.NotFound, "user not found") return err
} }
executingUser, ok := account.Users[initiatorUserID] if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) ||
if !ok { executingUser.AccountID != accountID {
return status.Errorf(status.NotFound, "user not found")
}
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user") return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
} }
pat := targetUser.PATs[tokenID] pat, err := am.Store.GetPATByID(ctx, LockingStrengthShare, tokenID, targetUserID)
if pat == nil { if err != nil {
return status.Errorf(status.NotFound, "PAT not found") return err
} }
err = am.Store.DeleteTokenID2UserIDIndex(pat.ID) if err = am.Store.DeletePAT(ctx, LockingStrengthUpdate, tokenID, targetUserID); err != nil {
if err != nil { return fmt.Errorf("failed to delete PAT: %w", err)
return status.Errorf(status.Internal, "Failed to delete token id index: %s", err)
}
err = am.Store.DeleteHashedPAT2TokenIDIndex(pat.HashedToken)
if err != nil {
return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err)
} }
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName} meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta) am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
delete(targetUser.PATs, tokenID)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return status.Errorf(status.Internal, "Failed to save account: %s", err)
}
return nil return nil
} }
@ -651,22 +623,11 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, err return nil, err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
if err != nil {
return nil, err
}
if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID { if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user") return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
} }
for _, pat := range targetUser.PATsG { return am.Store.GetPATByID(ctx, LockingStrengthShare, targetUserID, tokenID)
if pat.ID == tokenID {
return pat.Copy(), nil
}
}
return nil, status.Errorf(status.NotFound, "PAT not found")
} }
// GetAllPATs returns all PATs for a user // GetAllPATs returns all PATs for a user