fix merge

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-15 18:17:47 +03:00
parent a8c8b77df8
commit 1123729c1c
14 changed files with 766 additions and 227 deletions

View File

@ -134,14 +134,14 @@ type AccountManager interface {
GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error)
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error)
LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error)
HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
GetIdpManager() idp.Manager
@ -1122,7 +1122,16 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
// Only users with role UserRoleAdmin can update the account.
// User that performs the update has to belong to the account.
// Returns an updated Account
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) {
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if !user.HasAdminPower() || user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
}
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@ -1132,78 +1141,89 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
var oldSettings *Settings
account, err := am.Store.GetAccount(ctx, accountID)
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
oldSettings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return fmt.Errorf("failed to get account settings: %w", err)
}
if err = am.validateExtraSettings(ctx, newSettings, oldSettings, userID, accountID); err != nil {
return fmt.Errorf("failed to validate extra settings: %w", err)
}
if err = am.Store.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil {
return fmt.Errorf("failed to update account settings: %w", err)
}
return nil
})
if err != nil {
return nil, err
}
user, err := account.FindUser(userID)
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, err
return nil, fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
if !user.HasAdminPower() {
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
}
return newSettings, nil
}
err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID)
// validateExtraSettings validates the extra settings of the account.
func (am *DefaultAccountManager) validateExtraSettings(ctx context.Context, newSettings, oldSettings *Settings, userID, accountID string) error {
peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
return err
}
oldSettings := account.Settings
peerMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peerMap[peer.ID] = peer
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peerMap, userID, accountID)
}
func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) {
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
event := activity.AccountPeerLoginExpirationEnabled
if !newSettings.PeerLoginExpirationEnabled {
event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(ctx, account)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
}
err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, err
}
updatedAccount := account.UpdateSettings(newSettings)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
return updatedAccount, nil
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled
if !newSettings.PeerInactivityExpirationEnabled {
event = activity.AccountPeerInactivityExpirationDisabled
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
}
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
return nil
}
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
@ -1234,10 +1254,10 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
}
}
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) {
am.peerLoginExpiry.Cancel(ctx, []string{account.Id})
if nextRun, ok := account.GetNextPeerExpiration(); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id))
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) {
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
}
}
@ -1271,10 +1291,10 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
}
// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) {
am.peerInactivityExpiry.Cancel(ctx, []string{account.Id})
if nextRun, ok := account.GetNextInactivePeerExpiration(); ok {
go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id))
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, accountID string) {
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
if nextRun, ok := am.getNextInactivePeerExpiration(ctx, accountID); ok {
go am.peerInactivityExpiry.Schedule(ctx, nextRun, accountID, am.peerInactivityExpirationJob(ctx, accountID))
}
}
@ -1435,7 +1455,7 @@ func isNil(i idp.Manager) bool {
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) {
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
@ -2083,7 +2103,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
return fmt.Errorf("error saving groups: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return fmt.Errorf("error incrementing network serial: %w", err)
}
}
@ -2101,7 +2121,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {
@ -2114,7 +2134,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
} else {

View File

@ -2553,7 +2553,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})
@ -2573,7 +2573,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
})

View File

@ -50,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
return nil, err
}
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
}
// GetAllGroups returns all groups in an account
@ -64,7 +64,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
}
// SaveGroup object of the peers
@ -94,7 +94,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
}
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, newGroup.Name, accountID)
existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
}
for _, peerID := range newGroup.Peers {
if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID); err != nil {
if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
}
}
@ -158,7 +158,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
addedPeers := make([]string, 0)
removedPeers := make([]string, 0)
oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, newGroup.ID, accountID)
oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
if err == nil && oldGroup != nil {
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
@ -170,7 +170,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
}
for _, peerID := range addedPeers {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID)
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue
@ -187,7 +187,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
}
for _, peerID := range removedPeers {
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID)
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
continue
@ -232,7 +232,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
return status.Errorf(status.PermissionDenied, "no permission to delete group")
}
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
return err
}
@ -288,7 +288,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
)
for _, groupID := range groupIDs {
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
if err != nil {
continue
}
@ -307,7 +307,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, groupIDsToDelete, accountID); err != nil {
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil {
return fmt.Errorf("failed to delete group: %w", err)
}
return nil

View File

@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
}
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
}
// CreateNameServerGroup creates and saves a new nameserver group
@ -103,7 +103,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account")
}
_, err = am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupToSave.ID, accountID)
_, err = am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
if err != nil {
return err
}
@ -150,7 +150,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account")
}
nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
if err != nil {
return err
}
@ -160,7 +160,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, nsGroupID, accountID); err != nil {
if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID); err != nil {
return fmt.Errorf("failed to delete nameserver group: %w", err)
}

View File

@ -117,11 +117,11 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
if peer.AddedWithSSOLogin() {
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
am.checkAndSchedulePeerLoginExpiration(ctx, account.Id)
}
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
am.checkAndSchedulePeerInactivityExpiration(ctx, account.Id)
}
}
@ -230,7 +230,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
}
}
@ -249,7 +249,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
}
@ -537,7 +537,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed to add peer to account: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, accountID)
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
@ -1041,6 +1041,139 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
wg.Wait()
}
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected.
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
return 0, false
}
if len(peersWithExpiry) == 0 {
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithExpiry {
// consider only connected peers because others will require login on connecting to the management server
if peer.Status.LoginExpired || !peer.Status.Connected {
continue
}
_, duration := peer.LoginExpired(settings.PeerLoginExpiration)
if nextExpiry == nil || duration < *nextExpiry {
// if expiration is below 1s return 1s duration
// this avoids issues with ticker that can't be set to < 0
if duration < time.Second {
return time.Second, true
}
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are not connected.
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
return 0, false
}
if len(peersWithInactivity) == 0 {
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithInactivity {
if peer.Status.LoginExpired || peer.Status.Connected {
continue
}
_, duration := peer.SessionExpired(settings.PeerInactivityExpiration)
if nextExpiry == nil || duration < *nextExpiry {
// if expiration is below 1s return 1s duration
// this avoids issues with ticker that can't be set to < 0
if duration < time.Second {
return time.Second, true
}
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// getExpiredPeers returns peers that have been expired.
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
var peers []*nbpeer.Peer
for _, peer := range peersWithExpiry {
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
if expired {
peers = append(peers, peer)
}
}
return peers, nil
}
// getInactivePeers returns peers that have been expired by inactivity
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
var peers []*nbpeer.Peer
for _, inactivePeer := range peersWithInactivity {
inactive, _ := inactivePeer.SessionExpired(settings.PeerInactivityExpiration)
if inactive {
peers = append(peers, inactivePeer)
}
}
return peers, nil
}
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels {

View File

@ -8,9 +8,8 @@ import (
"time"
b "github.com/hashicorp/go-secure-stdlib/base62"
"github.com/rs/xid"
"github.com/netbirdio/netbird/base62"
"github.com/rs/xid"
)
const (

View File

@ -335,7 +335,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
}
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
}
// SavePolicy in the store

View File

@ -25,7 +25,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
}
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
}
// SavePostureChecks saves a posture check.
@ -49,7 +49,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
if isUpdate {
action = activity.PostureCheckUpdated
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecks.ID, accountID); err != nil {
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
return fmt.Errorf("failed to get posture checks: %w", err)
}
@ -114,7 +114,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
return err
}
postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
if err != nil {
return err
}
@ -124,7 +124,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
return fmt.Errorf("failed to increment network serial: %w", err)
}
if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, postureChecksID, accountID); err != nil {
if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID); err != nil {
return fmt.Errorf("failed to delete posture checks: %w", err)
}
return nil

View File

@ -56,13 +56,35 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
}
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func (am *DefaultAccountManager) GetRoutesByPrefixOrDomains(accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := am.Store.GetAccountRoutes(context.Background(), LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
routes := make([]*route.Route, 0)
for _, r := range accountRoutes {
dynamic := r.IsDynamic()
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
!dynamic && r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes, nil
}
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
// routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
routesWithPrefix, err := am.GetRoutesByPrefixOrDomains(account.Id, prefix, domains)
if err != nil {
return err
}
// lets remember all the peers and the peer groups from routesWithPrefix
seenPeers := make(map[string]bool)
@ -81,8 +103,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true
group := account.GetGroup(groupID)
if group == nil {
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, account.Id, groupID)
if err != nil || group == nil {
return status.Errorf(
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
getRouteDescriptor(prefix, domains), groupID,
@ -97,10 +119,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
if peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group
peer := account.GetPeer(peerID)
if peer == nil {
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, account.Id, peerID)
if err != nil || peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
if _, ok := seenPeers[peerID]; ok {
return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
@ -109,7 +132,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
// check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs {
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again.
// we validated the group existence before entering this function, no need to check again.
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, groupID, account.Id)
if err != nil || group == nil {
return status.Errorf(status.InvalidArgument, "group with ID %s not found", peerID)
}
if _, ok := seenPeerGroups[groupID]; ok {
return status.Errorf(
@ -120,10 +147,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok {
peer := account.GetPeer(id)
if peer == nil {
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, peerID, account.Id)
if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
}
return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s from the group %s already has this route",
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
@ -146,6 +174,15 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
return nil, err
}
if user.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
@ -181,17 +218,17 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
newRoute.ID = route.ID(xid.New().String())
if len(peerGroupIDs) > 0 {
err = validateGroups(peerGroupIDs, account.Groups)
if err != nil {
return nil, err
}
//err = validateGroups(peerGroupIDs, account.Groups)
//if err != nil {
// return nil, err
//}
}
if len(accessControlGroupIDs) > 0 {
err = validateGroups(accessControlGroupIDs, account.Groups)
if err != nil {
return nil, err
}
//err = validateGroups(accessControlGroupIDs, account.Groups)
//if err != nil {
// return nil, err
//}
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
@ -207,10 +244,10 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
err = validateGroups(groups, account.Groups)
if err != nil {
return nil, err
}
//err = validateGroups(groups, account.Groups)
//if err != nil {
// return nil, err
//}
newRoute.Peer = peerID
newRoute.PeerGroups = peerGroupIDs
@ -290,17 +327,17 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
}
if len(routeToSave.PeerGroups) > 0 {
err = validateGroups(routeToSave.PeerGroups, account.Groups)
if err != nil {
return err
}
//err = validateGroups(routeToSave.PeerGroups, account.Groups)
//if err != nil {
// return err
//}
}
if len(routeToSave.AccessControlGroups) > 0 {
err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
if err != nil {
return err
}
//err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
//if err != nil {
// return err
//}
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
@ -308,10 +345,10 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
err = validateGroups(routeToSave.Groups, account.Groups)
if err != nil {
return err
}
//err = validateGroups(routeToSave.Groups, account.Groups)
//if err != nil {
// return err
//}
account.Routes[routeToSave.ID] = routeToSave

View File

@ -287,7 +287,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return nil, err
}
oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyToSave.Id, accountID)
oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
if err != nil {
return nil, err
}
@ -382,7 +382,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
}
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
if err != nil {
return nil, err
}

View File

@ -541,9 +541,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
return &user, nil
}
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
var users []*User
result := s.db.Find(&users, accountIDCondition, accountID)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
@ -828,7 +828,9 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
var accountSettings AccountSettings
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
First(&accountSettings, idQueryCondition, accountID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
@ -837,6 +839,21 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
return accountSettings.Settings, nil
}
// SaveAccountSettings stores the account settings in DB.
func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error {
result := s.db.WithContext(ctx).Debug().Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select("*").Where(idQueryCondition, accountID).Updates(&AccountSettings{Settings: settings})
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save account settings to store: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "account not found")
}
return nil
}
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User
@ -1054,9 +1071,72 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
return nil
}
// GetAccountPeers retrieves peers for an account.
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get peers from store")
}
return peers, nil
}
// GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
var peers []*nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, "account_id = ? AND user_id = ?", accountID, userID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get peers from store")
}
return peers, nil
}
// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
Find(&peers, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peers with expiration from store")
}
return peers, nil
}
// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
Find(&peers, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peers with inactivity from store")
}
return peers, nil
}
// GetPeerByID retrieves a peer by its ID and account ID.
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
var peer *nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&peer, accountAndIDQueryCondition, accountID, peerID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
log.WithContext(ctx).Errorf("failed to get peer from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get peer from store")
}
return peer, nil
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
@ -1067,8 +1147,9 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
return nil
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
}
@ -1113,6 +1194,19 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
return &accountDNSSettings.DNSSettings, nil
}
// SaveDNSSettings saves the DNS settings to the store.
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save dns settings to store: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "account not found")
}
return nil
}
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
var accountID string
@ -1146,16 +1240,24 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
}
// GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) {
var group *nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get group from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get group from store")
}
return group, nil
}
// GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) {
var group nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
Order("json_array_length(peers) DESC").First(&group, "account_id = ? AND name = ?", accountID, groupName)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
@ -1174,69 +1276,335 @@ func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength,
return nil
}
// DeleteGroup deletes a group from the database.
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete group from the store: %s", result.Error)
return status.Errorf(status.Internal, "failed to delete group from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "group not found")
}
return nil
}
// DeleteGroups deletes groups from the database.
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(strength)}).
Delete(&nbgroup.Group{}, " account_id = ? AND id IN ?", accountID, groupIDs)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to delete groups from the store: %v", result.Error)
return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error)
}
return nil
}
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
var policies []*Policy
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "failed to get policies from store")
}
return policies, nil
}
// GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) {
var policy *Policy
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).Find(&policy, accountAndIDQueryCondition, accountID, policyID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get policy from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get policy from store")
}
return policy, nil
}
// SavePolicy saves a policy to the database.
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
return status.Errorf(status.Internal, "failed to save policy to store")
}
return nil
}
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, policyID, accountID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete policy from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete policy from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "policy not found")
}
return nil
}
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
var postureChecks []*posture.Checks
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get posture checks from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get posture checks from store")
}
return postureChecks, nil
}
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) {
var postureCheck *posture.Checks
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&postureCheck, accountAndIDQueryCondition, accountID, postureCheckID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "posture check not found")
}
log.WithContext(ctx).Errorf("failed to get posture check from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get posture check from store")
}
return postureCheck, nil
}
// SavePostureChecks saves a posture checks to the database.
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrDuplicatedKey) {
return status.Errorf(status.InvalidArgument, "name should be unique")
}
log.WithContext(ctx).Errorf("failed to save posture checks to the store: %s", err)
return status.Errorf(status.Internal, "failed to save posture checks to store")
}
return nil
}
// DeletePostureChecks deletes a posture checks from the database.
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete posture checks from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete posture checks from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "posture checks not found")
}
return nil
}
// GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
var routes []*route.Route
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&routes, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get routes from store")
}
return routes, nil
}
// GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
var route *route.Route
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&route, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "route not found")
}
log.WithContext(ctx).Errorf("failed to get route from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get route from store")
}
return route, nil
}
// SaveRoute saves a route to the database.
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
return status.Errorf(status.Internal, "failed to save route to store")
}
return nil
}
// DeleteRoute deletes a route from the database.
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete route from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "route not found")
}
return nil
}
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
var setupKeys []*SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&setupKeys, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get setup keys from store")
}
return setupKeys, nil
}
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) {
var setupKey *SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get setup key from store")
}
return setupKey, nil
}
// SaveSetupKey saves a setup key to the database.
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error {
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
return status.Errorf(status.Internal, "failed to save name server group to store")
}
return nil
}
// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
var nsGroups []*nbdns.NameServerGroup
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get name server groups from store")
}
return nsGroups, nil
}
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
}
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
var record []T
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
var nsGroup *nbdns.NameServerGroup
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "name server group not found")
}
log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get name server group from store")
}
return record, nil
return nsGroup, nil
}
// SaveNameServerGroup saves a name server group to the database.
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
return status.Errorf(status.Internal, "failed to save name server group to store")
}
return nil
}
// DeleteNameServerGroup deletes a name server group from the database.
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nameServerGroupID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete name server group from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "name server group not found")
}
return nil
}
// GetPATByID retrieves a personal access token by its ID and user ID.
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, patID string, userID string) (*PersonalAccessToken, error) {
var pat PersonalAccessToken
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&pat, "id = ? AND user_id = ?", patID, userID)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "PAT not found")
}
log.WithContext(ctx).Errorf("failed to get PAT from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get PAT from store")
}
return &pat, nil
}
// SavePAT saves a personal access token to the database.
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save PAT to the store: %s", err)
return status.Errorf(status.Internal, "failed to save PAT to store")
}
return nil
}
// DeletePAT deletes a personal access token from the database.
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete PAT from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete PAT from store")
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, "PAT not found")
}
return nil
}
// getRecordByID retrieves a record by its ID and account ID from the database.

View File

@ -1181,7 +1181,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
t.Fatal("failed to save group")
return err
}
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
if err != nil {
t.Fatal("failed to get group")
return err

View File

@ -56,13 +56,15 @@ type Store interface {
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error
SaveAccount(ctx context.Context, account *Account) error
DeleteAccount(ctx context.Context, account *Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
SaveUsers(accountID string, users map[string]*User) error
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
@ -71,24 +73,34 @@ type Store interface {
DeleteTokenID2UserIDIndex(tokenID string) error
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error)
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
@ -96,16 +108,25 @@ type Store interface {
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) (*dns.NameServerGroup, error)
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
IncrementNetworkSerial(ctx context.Context, accountId string) error
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
GetInstallationID() string

View File

@ -546,9 +546,6 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
// CreatePAT creates a new PAT for the given user
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
if tokenName == "" {
return nil, status.Errorf(status.InvalidArgument, "token name can't be empty")
}
@ -557,35 +554,28 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
}
account, err := am.Store.GetAccount(ctx, accountID)
executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
if err != nil {
return nil, err
}
targetUser, ok := account.Users[targetUserID]
if !ok {
return nil, status.Errorf(status.NotFound, "user not found")
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
if err != nil {
return nil, err
}
executingUser, ok := account.Users[initiatorUserID]
if !ok {
return nil, status.Errorf(status.NotFound, "user not found")
}
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) ||
executingUser.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
}
pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id)
pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id, executingUser.Id)
if err != nil {
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
}
targetUser.PATs[pat.ID] = &pat.PersonalAccessToken
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, status.Errorf(status.Internal, "failed to save account: %v", err)
if err = am.Store.SavePAT(ctx, LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil {
return nil, fmt.Errorf("failed to save PAT: %w", err)
}
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
@ -596,51 +586,33 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
// DeletePAT deletes a specific PAT from a user
func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
if err != nil {
return status.Errorf(status.NotFound, "account not found: %s", err)
return err
}
targetUser, ok := account.Users[targetUserID]
if !ok {
return status.Errorf(status.NotFound, "user not found")
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
if err != nil {
return err
}
executingUser, ok := account.Users[initiatorUserID]
if !ok {
return status.Errorf(status.NotFound, "user not found")
}
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) ||
executingUser.AccountID != accountID {
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
}
pat := targetUser.PATs[tokenID]
if pat == nil {
return status.Errorf(status.NotFound, "PAT not found")
pat, err := am.Store.GetPATByID(ctx, LockingStrengthShare, tokenID, targetUserID)
if err != nil {
return err
}
err = am.Store.DeleteTokenID2UserIDIndex(pat.ID)
if err != nil {
return status.Errorf(status.Internal, "Failed to delete token id index: %s", err)
}
err = am.Store.DeleteHashedPAT2TokenIDIndex(pat.HashedToken)
if err != nil {
return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err)
if err = am.Store.DeletePAT(ctx, LockingStrengthUpdate, tokenID, targetUserID); err != nil {
return fmt.Errorf("failed to delete PAT: %w", err)
}
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
delete(targetUser.PATs, tokenID)
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return status.Errorf(status.Internal, "Failed to save account: %s", err)
}
return nil
}
@ -651,22 +623,11 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, err
}
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
if err != nil {
return nil, err
}
if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
}
for _, pat := range targetUser.PATsG {
if pat.ID == tokenID {
return pat.Copy(), nil
}
}
return nil, status.Errorf(status.NotFound, "PAT not found")
return am.Store.GetPATByID(ctx, LockingStrengthShare, targetUserID, tokenID)
}
// GetAllPATs returns all PATs for a user