mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-16 01:58:16 +02:00
@ -45,15 +45,15 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
PublicCategory = "public"
|
||||
PrivateCategory = "private"
|
||||
UnknownCategory = "unknown"
|
||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||
DefaultPeerInactivityExpiration = 10 * time.Minute
|
||||
emptyUserID = "empty user ID in claims"
|
||||
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
|
||||
emptyUserID = "empty user ID in claims"
|
||||
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
|
||||
)
|
||||
|
||||
type userLoggedInOnce bool
|
||||
@ -134,14 +134,14 @@ type AccountManager interface {
|
||||
GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error)
|
||||
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error
|
||||
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error)
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error)
|
||||
LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||
HasConnectedChannel(peerID string) bool
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error
|
||||
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||
GetIdpManager() idp.Manager
|
||||
@ -1122,7 +1122,16 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
|
||||
// Only users with role UserRoleAdmin can update the account.
|
||||
// User that performs the update has to belong to the account.
|
||||
// Returns an updated Account
|
||||
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) {
|
||||
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
|
||||
}
|
||||
|
||||
halfYearLimit := 180 * 24 * time.Hour
|
||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||
@ -1132,78 +1141,89 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
var oldSettings *Settings
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
oldSettings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get account settings: %w", err)
|
||||
}
|
||||
|
||||
if err = am.validateExtraSettings(ctx, newSettings, oldSettings, userID, accountID); err != nil {
|
||||
return fmt.Errorf("failed to validate extra settings: %w", err)
|
||||
}
|
||||
|
||||
if err = am.Store.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil {
|
||||
return fmt.Errorf("failed to update account settings: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := account.FindUser(userID)
|
||||
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("error getting account: %w", err)
|
||||
}
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
if !user.HasAdminPower() {
|
||||
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
|
||||
}
|
||||
return newSettings, nil
|
||||
}
|
||||
|
||||
err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID)
|
||||
// validateExtraSettings validates the extra settings of the account.
|
||||
func (am *DefaultAccountManager) validateExtraSettings(ctx context.Context, newSettings, oldSettings *Settings, userID, accountID string) error {
|
||||
peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
oldSettings := account.Settings
|
||||
peerMap := make(map[string]*nbpeer.Peer, len(peers))
|
||||
for _, peer := range peers {
|
||||
peerMap[peer.ID] = peer
|
||||
}
|
||||
|
||||
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peerMap, userID, accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) {
|
||||
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
|
||||
event := activity.AccountPeerLoginExpirationEnabled
|
||||
if !newSettings.PeerLoginExpirationEnabled {
|
||||
event = activity.AccountPeerLoginExpirationDisabled
|
||||
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
|
||||
} else {
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
|
||||
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
|
||||
}
|
||||
|
||||
err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updatedAccount := account.UpdateSettings(newSettings)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return updatedAccount, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) {
|
||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||
event := activity.AccountPeerInactivityExpirationEnabled
|
||||
if !newSettings.PeerInactivityExpirationEnabled {
|
||||
event = activity.AccountPeerInactivityExpirationDisabled
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||
} else {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||
}
|
||||
|
||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||
@ -1234,10 +1254,10 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) {
|
||||
am.peerLoginExpiry.Cancel(ctx, []string{account.Id})
|
||||
if nextRun, ok := account.GetNextPeerExpiration(); ok {
|
||||
go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id))
|
||||
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) {
|
||||
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
|
||||
if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
|
||||
go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
|
||||
}
|
||||
}
|
||||
|
||||
@ -1271,10 +1291,10 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
|
||||
}
|
||||
|
||||
// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
|
||||
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) {
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{account.Id})
|
||||
if nextRun, ok := account.GetNextInactivePeerExpiration(); ok {
|
||||
go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id))
|
||||
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, accountID string) {
|
||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||
if nextRun, ok := am.getNextInactivePeerExpiration(ctx, accountID); ok {
|
||||
go am.peerInactivityExpiry.Schedule(ctx, nextRun, accountID, am.peerInactivityExpirationJob(ctx, accountID))
|
||||
}
|
||||
}
|
||||
|
||||
@ -1435,12 +1455,12 @@ func isNil(i idp.Manager) bool {
|
||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||
if !isNil(am.idpManager) {
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
|
||||
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cachedAccount := &Account{
|
||||
Id: accountID,
|
||||
Id: accountID,
|
||||
Users: make(map[string]*User),
|
||||
}
|
||||
for _, user := range accountUsers {
|
||||
@ -2083,7 +2103,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
return fmt.Errorf("error saving groups: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||
}
|
||||
}
|
||||
@ -2101,7 +2121,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
for _, g := range addNewGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
} else {
|
||||
@ -2114,7 +2134,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
||||
}
|
||||
|
||||
for _, g := range removeOldGroups {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||
} else {
|
||||
|
@ -2553,7 +2553,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 0)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
@ -2573,7 +2573,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
||||
assert.NoError(t, err, "unable to get user")
|
||||
assert.Len(t, user.AutoGroups, 1)
|
||||
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||
assert.NoError(t, err, "unable to get group")
|
||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||
})
|
||||
|
@ -50,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
}
|
||||
|
||||
// GetAllGroups returns all groups in an account
|
||||
@ -64,7 +64,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
|
||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
|
||||
}
|
||||
|
||||
// SaveGroup object of the peers
|
||||
@ -94,7 +94,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||
existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, newGroup.Name, accountID)
|
||||
existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
|
||||
if err != nil {
|
||||
s, ok := status.FromError(err)
|
||||
if !ok || s.ErrorType != status.NotFound {
|
||||
@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
||||
}
|
||||
|
||||
for _, peerID := range newGroup.Peers {
|
||||
if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID); err != nil {
|
||||
if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||
}
|
||||
}
|
||||
@ -158,7 +158,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
|
||||
addedPeers := make([]string, 0)
|
||||
removedPeers := make([]string, 0)
|
||||
|
||||
oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, newGroup.ID, accountID)
|
||||
oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
|
||||
if err == nil && oldGroup != nil {
|
||||
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||
@ -170,7 +170,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
|
||||
}
|
||||
|
||||
for _, peerID := range addedPeers {
|
||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID)
|
||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
|
||||
continue
|
||||
@ -187,7 +187,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
|
||||
}
|
||||
|
||||
for _, peerID := range removedPeers {
|
||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID)
|
||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
|
||||
continue
|
||||
@ -232,7 +232,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
|
||||
return status.Errorf(status.PermissionDenied, "no permission to delete group")
|
||||
}
|
||||
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -288,7 +288,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
)
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@ -307,7 +307,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, groupIDsToDelete, accountID); err != nil {
|
||||
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil {
|
||||
return fmt.Errorf("failed to delete group: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||
}
|
||||
|
||||
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
|
||||
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
|
||||
}
|
||||
|
||||
// CreateNameServerGroup creates and saves a new nameserver group
|
||||
@ -103,7 +103,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account")
|
||||
}
|
||||
|
||||
_, err = am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupToSave.ID, accountID)
|
||||
_, err = am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -150,7 +150,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account")
|
||||
}
|
||||
|
||||
nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
|
||||
nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -160,7 +160,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, nsGroupID, accountID); err != nil {
|
||||
if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID); err != nil {
|
||||
return fmt.Errorf("failed to delete nameserver group: %w", err)
|
||||
}
|
||||
|
||||
|
@ -117,11 +117,11 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
|
||||
if peer.AddedWithSSOLogin() {
|
||||
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account.Id)
|
||||
}
|
||||
|
||||
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account.Id)
|
||||
}
|
||||
}
|
||||
|
||||
@ -230,7 +230,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
||||
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
@ -249,7 +249,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
@ -537,7 +537,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
||||
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
@ -1041,6 +1041,139 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||
// If there is no peer that expires this function returns false and a duration of 0.
|
||||
// This function only considers peers that haven't been expired yet and that are connected.
|
||||
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
|
||||
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if len(peersWithExpiry) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var nextExpiry *time.Duration
|
||||
for _, peer := range peersWithExpiry {
|
||||
// consider only connected peers because others will require login on connecting to the management server
|
||||
if peer.Status.LoginExpired || !peer.Status.Connected {
|
||||
continue
|
||||
}
|
||||
_, duration := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||
if nextExpiry == nil || duration < *nextExpiry {
|
||||
// if expiration is below 1s return 1s duration
|
||||
// this avoids issues with ticker that can't be set to < 0
|
||||
if duration < time.Second {
|
||||
return time.Second, true
|
||||
}
|
||||
nextExpiry = &duration
|
||||
}
|
||||
}
|
||||
|
||||
if nextExpiry == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return *nextExpiry, true
|
||||
}
|
||||
|
||||
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||
// If there is no peer that expires this function returns false and a duration of 0.
|
||||
// This function only considers peers that haven't been expired yet and that are not connected.
|
||||
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
|
||||
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if len(peersWithInactivity) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var nextExpiry *time.Duration
|
||||
for _, peer := range peersWithInactivity {
|
||||
if peer.Status.LoginExpired || peer.Status.Connected {
|
||||
continue
|
||||
}
|
||||
_, duration := peer.SessionExpired(settings.PeerInactivityExpiration)
|
||||
if nextExpiry == nil || duration < *nextExpiry {
|
||||
// if expiration is below 1s return 1s duration
|
||||
// this avoids issues with ticker that can't be set to < 0
|
||||
if duration < time.Second {
|
||||
return time.Second, true
|
||||
}
|
||||
nextExpiry = &duration
|
||||
}
|
||||
}
|
||||
|
||||
if nextExpiry == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return *nextExpiry, true
|
||||
}
|
||||
|
||||
// getExpiredPeers returns peers that have been expired.
|
||||
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
|
||||
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
for _, peer := range peersWithExpiry {
|
||||
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||
if expired {
|
||||
peers = append(peers, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// getInactivePeers returns peers that have been expired by inactivity
|
||||
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
|
||||
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var peers []*nbpeer.Peer
|
||||
for _, inactivePeer := range peersWithInactivity {
|
||||
inactive, _ := inactivePeer.SessionExpired(settings.PeerInactivityExpiration)
|
||||
if inactive {
|
||||
peers = append(peers, inactivePeer)
|
||||
}
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||
labelMap := make(map[string]struct{}, len(existingLabels))
|
||||
for _, label := range existingLabels {
|
||||
|
@ -8,9 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
b "github.com/hashicorp/go-secure-stdlib/base62"
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/base62"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -335,7 +335,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||
}
|
||||
|
||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
|
||||
}
|
||||
|
||||
// SavePolicy in the store
|
||||
|
@ -25,7 +25,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||
}
|
||||
|
||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||
}
|
||||
|
||||
// SavePostureChecks saves a posture check.
|
||||
@ -49,7 +49,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
if isUpdate {
|
||||
action = activity.PostureCheckUpdated
|
||||
|
||||
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecks.ID, accountID); err != nil {
|
||||
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
|
||||
return fmt.Errorf("failed to get posture checks: %w", err)
|
||||
}
|
||||
|
||||
@ -114,7 +114,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
return err
|
||||
}
|
||||
|
||||
postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
||||
postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -124,7 +124,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, postureChecksID, accountID); err != nil {
|
||||
if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID); err != nil {
|
||||
return fmt.Errorf("failed to delete posture checks: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
@ -56,13 +56,35 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
|
||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||
}
|
||||
|
||||
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
|
||||
return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
|
||||
}
|
||||
|
||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||
func (am *DefaultAccountManager) GetRoutesByPrefixOrDomains(accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
|
||||
accountRoutes, err := am.Store.GetAccountRoutes(context.Background(), LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routes := make([]*route.Route, 0)
|
||||
for _, r := range accountRoutes {
|
||||
dynamic := r.IsDynamic()
|
||||
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
|
||||
!dynamic && r.Network.String() == prefix.String() {
|
||||
routes = append(routes, r)
|
||||
}
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
|
||||
// routes can have both peer and peer_groups
|
||||
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
|
||||
routesWithPrefix, err := am.GetRoutesByPrefixOrDomains(account.Id, prefix, domains)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// lets remember all the peers and the peer groups from routesWithPrefix
|
||||
seenPeers := make(map[string]bool)
|
||||
@ -81,8 +103,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
for _, groupID := range prefixRoute.PeerGroups {
|
||||
seenPeerGroups[groupID] = true
|
||||
|
||||
group := account.GetGroup(groupID)
|
||||
if group == nil {
|
||||
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, account.Id, groupID)
|
||||
if err != nil || group == nil {
|
||||
return status.Errorf(
|
||||
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
|
||||
getRouteDescriptor(prefix, domains), groupID,
|
||||
@ -97,10 +119,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
|
||||
if peerID != "" {
|
||||
// check that peerID exists and is not in any route as single peer or part of the group
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, account.Id, peerID)
|
||||
if err != nil || peer == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
if _, ok := seenPeers[peerID]; ok {
|
||||
return status.Errorf(status.AlreadyExists,
|
||||
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
|
||||
@ -109,7 +132,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
|
||||
// check that peerGroupIDs are not in any route peerGroups list
|
||||
for _, groupID := range peerGroupIDs {
|
||||
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again.
|
||||
// we validated the group existence before entering this function, no need to check again.
|
||||
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, groupID, account.Id)
|
||||
if err != nil || group == nil {
|
||||
return status.Errorf(status.InvalidArgument, "group with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
if _, ok := seenPeerGroups[groupID]; ok {
|
||||
return status.Errorf(
|
||||
@ -120,10 +147,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
|
||||
for _, id := range group.Peers {
|
||||
if _, ok := seenPeers[id]; ok {
|
||||
peer := account.GetPeer(id)
|
||||
if peer == nil {
|
||||
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, peerID, account.Id)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
return status.Errorf(status.AlreadyExists,
|
||||
"failed to add route with %s - peer %s from the group %s already has this route",
|
||||
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
|
||||
@ -146,6 +174,15 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -181,17 +218,17 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
newRoute.ID = route.ID(xid.New().String())
|
||||
|
||||
if len(peerGroupIDs) > 0 {
|
||||
err = validateGroups(peerGroupIDs, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//err = validateGroups(peerGroupIDs, account.Groups)
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
}
|
||||
|
||||
if len(accessControlGroupIDs) > 0 {
|
||||
err = validateGroups(accessControlGroupIDs, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//err = validateGroups(accessControlGroupIDs, account.Groups)
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
|
||||
@ -207,10 +244,10 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
err = validateGroups(groups, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//err = validateGroups(groups, account.Groups)
|
||||
//if err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
|
||||
newRoute.Peer = peerID
|
||||
newRoute.PeerGroups = peerGroupIDs
|
||||
@ -290,17 +327,17 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
}
|
||||
|
||||
if len(routeToSave.PeerGroups) > 0 {
|
||||
err = validateGroups(routeToSave.PeerGroups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//err = validateGroups(routeToSave.PeerGroups, account.Groups)
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
}
|
||||
|
||||
if len(routeToSave.AccessControlGroups) > 0 {
|
||||
err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
||||
@ -308,10 +345,10 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return err
|
||||
}
|
||||
|
||||
err = validateGroups(routeToSave.Groups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//err = validateGroups(routeToSave.Groups, account.Groups)
|
||||
//if err != nil {
|
||||
// return err
|
||||
//}
|
||||
|
||||
account.Routes[routeToSave.ID] = routeToSave
|
||||
|
||||
|
@ -287,7 +287,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyToSave.Id, accountID)
|
||||
oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -382,7 +382,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
||||
}
|
||||
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -541,9 +541,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
|
||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
|
||||
var users []*User
|
||||
result := s.db.Find(&users, accountIDCondition, accountID)
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
@ -828,7 +828,9 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
|
||||
|
||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
|
||||
var accountSettings AccountSettings
|
||||
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
First(&accountSettings, idQueryCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "settings not found")
|
||||
}
|
||||
@ -837,6 +839,21 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
||||
return accountSettings.Settings, nil
|
||||
}
|
||||
|
||||
// SaveAccountSettings stores the account settings in DB.
|
||||
func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error {
|
||||
result := s.db.WithContext(ctx).Debug().Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Select("*").Where(idQueryCondition, accountID).Updates(&AccountSettings{Settings: settings})
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save account settings to store: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
var user User
|
||||
@ -1054,9 +1071,72 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountPeers retrieves peers for an account.
|
||||
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
||||
var peers []*nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get peers from store")
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// GetUserPeers retrieves peers for a user.
|
||||
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
|
||||
var peers []*nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, "account_id = ? AND user_id = ?", accountID, userID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get peers from store")
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user.
|
||||
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
||||
var peers []*nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
|
||||
Find(&peers, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get peers with expiration from store")
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user.
|
||||
func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
||||
var peers []*nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
|
||||
Find(&peers, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with inactivity from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get peers with inactivity from store")
|
||||
}
|
||||
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// GetPeerByID retrieves a peer by its ID and account ID.
|
||||
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
|
||||
var peer *nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&peer, accountAndIDQueryCondition, accountID, peerID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get peer from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get peer from store")
|
||||
}
|
||||
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
@ -1067,8 +1147,9 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
|
||||
}
|
||||
@ -1113,6 +1194,19 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
||||
return &accountDNSSettings.DNSSettings, nil
|
||||
}
|
||||
|
||||
// SaveDNSSettings saves the DNS settings to the store.
|
||||
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "failed to save dns settings to store: %v", result.Error)
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "account not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AccountExists checks whether an account exists by the given ID.
|
||||
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||
var accountID string
|
||||
@ -1146,16 +1240,24 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
|
||||
}
|
||||
|
||||
// GetGroupByID retrieves a group by ID and account ID.
|
||||
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
|
||||
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
|
||||
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) {
|
||||
var group *nbgroup.Group
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Find(&group, accountAndIDQueryCondition, accountID, groupID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get group from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get group from store")
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// GetGroupByName retrieves a group by name and account ID.
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) {
|
||||
var group nbgroup.Group
|
||||
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
||||
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
|
||||
Order("json_array_length(peers) DESC").First(&group, "account_id = ? AND name = ?", accountID, groupName)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "group not found")
|
||||
@ -1174,69 +1276,335 @@ func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength,
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroup deletes a group from the database.
|
||||
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete group from the store: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete group from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "group not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroups deletes groups from the database.
|
||||
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(strength)}).
|
||||
Delete(&nbgroup.Group{}, " account_id = ? AND id IN ?", accountID, groupIDs)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete groups from the store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountPolicies retrieves policies for an account.
|
||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
||||
var policies []*Policy
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get policies from store")
|
||||
}
|
||||
|
||||
return policies, nil
|
||||
}
|
||||
|
||||
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
|
||||
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
|
||||
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) {
|
||||
var policy *Policy
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Preload(clause.Associations).Find(&policy, accountAndIDQueryCondition, accountID, policyID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get policy from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get policy from store")
|
||||
}
|
||||
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
// SavePolicy saves a policy to the database.
|
||||
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
|
||||
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save policy to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, policyID, accountID string) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete policy from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete policy from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "policy not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
||||
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
var postureChecks []*posture.Checks
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get posture checks from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get posture checks from store")
|
||||
}
|
||||
|
||||
return postureChecks, nil
|
||||
}
|
||||
|
||||
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
|
||||
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
|
||||
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) {
|
||||
var postureCheck *posture.Checks
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&postureCheck, accountAndIDQueryCondition, accountID, postureCheckID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "posture check not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get posture check from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get posture check from store")
|
||||
}
|
||||
|
||||
return postureCheck, nil
|
||||
}
|
||||
|
||||
// SavePostureChecks saves a posture checks to the database.
|
||||
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||
return status.Errorf(status.InvalidArgument, "name should be unique")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to save posture checks to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save posture checks to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePostureChecks deletes a posture checks from the database.
|
||||
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete posture checks from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete posture checks from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "posture checks not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountRoutes retrieves network routes for an account.
|
||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
||||
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
var routes []*route.Route
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Find(&routes, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get routes from store")
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
// GetRouteByID retrieves a route by its ID and account ID.
|
||||
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
|
||||
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
|
||||
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
|
||||
var route *route.Route
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&route, accountAndIDQueryCondition, accountID, routeID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "route not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get route from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get route from store")
|
||||
}
|
||||
|
||||
return route, nil
|
||||
}
|
||||
|
||||
// SaveRoute saves a route to the database.
|
||||
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save route to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRoute deletes a route from the database.
|
||||
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete route from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "route not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
var setupKeys []*SetupKey
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Find(&setupKeys, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get setup keys from store")
|
||||
}
|
||||
|
||||
return setupKeys, nil
|
||||
}
|
||||
|
||||
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
|
||||
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
|
||||
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) {
|
||||
var setupKey *SetupKey
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get setup key from store")
|
||||
}
|
||||
|
||||
return setupKey, nil
|
||||
}
|
||||
|
||||
// SaveSetupKey saves a setup key to the database.
|
||||
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error {
|
||||
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
|
||||
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save name server group to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
|
||||
var nsGroups []*nbdns.NameServerGroup
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get name server groups from store")
|
||||
}
|
||||
|
||||
return nsGroups, nil
|
||||
}
|
||||
|
||||
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
||||
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
|
||||
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
|
||||
}
|
||||
|
||||
// getRecords retrieves records from the database based on the account ID.
|
||||
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
||||
var record []T
|
||||
|
||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
|
||||
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
var nsGroup *nbdns.NameServerGroup
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
||||
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "name server group not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get name server group from store")
|
||||
}
|
||||
|
||||
return record, nil
|
||||
return nsGroup, nil
|
||||
}
|
||||
|
||||
// SaveNameServerGroup saves a name server group to the database.
|
||||
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save name server group to store")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteNameServerGroup deletes a name server group from the database.
|
||||
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nameServerGroupID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete name server group from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "name server group not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPATByID retrieves a personal access token by its ID and user ID.
|
||||
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, patID string, userID string) (*PersonalAccessToken, error) {
|
||||
var pat PersonalAccessToken
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&pat, "id = ? AND user_id = ?", patID, userID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "PAT not found")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get PAT from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get PAT from store")
|
||||
}
|
||||
|
||||
return &pat, nil
|
||||
}
|
||||
|
||||
// SavePAT saves a personal access token to the database.
|
||||
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save PAT to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save PAT to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePAT deletes a personal access token from the database.
|
||||
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete PAT from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete PAT from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.Errorf(status.NotFound, "PAT not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRecordByID retrieves a record by its ID and account ID from the database.
|
||||
|
@ -1181,7 +1181,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
||||
t.Fatal("failed to save group")
|
||||
return err
|
||||
}
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
|
||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
|
||||
if err != nil {
|
||||
t.Fatal("failed to get group")
|
||||
return err
|
||||
|
@ -56,13 +56,15 @@ type Store interface {
|
||||
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
|
||||
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
DeleteAccount(ctx context.Context, account *Account) error
|
||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
|
||||
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||
@ -71,24 +73,34 @@ type Store interface {
|
||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||
|
||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error)
|
||||
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
||||
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
||||
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
||||
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
|
||||
|
||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
|
||||
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
|
||||
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
||||
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||
|
||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||
@ -96,16 +108,25 @@ type Store interface {
|
||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
|
||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
|
||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
|
||||
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
|
||||
|
||||
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
|
||||
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
|
||||
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
|
||||
|
||||
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) (*dns.NameServerGroup, error)
|
||||
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
|
||||
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
|
||||
|
||||
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
|
||||
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
|
||||
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
|
||||
|
||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
||||
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||
|
||||
GetInstallationID() string
|
||||
|
@ -546,9 +546,6 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
|
||||
|
||||
// CreatePAT creates a new PAT for the given user
|
||||
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if tokenName == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "token name can't be empty")
|
||||
}
|
||||
@ -557,35 +554,28 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
||||
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUser, ok := account.Users[targetUserID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
executingUser, ok := account.Users[initiatorUserID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
|
||||
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) ||
|
||||
executingUser.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
||||
}
|
||||
|
||||
pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id)
|
||||
pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id, executingUser.Id)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
|
||||
}
|
||||
|
||||
targetUser.PATs[pat.ID] = &pat.PersonalAccessToken
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to save account: %v", err)
|
||||
if err = am.Store.SavePAT(ctx, LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil {
|
||||
return nil, fmt.Errorf("failed to save PAT: %w", err)
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||
@ -596,51 +586,33 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
||||
|
||||
// DeletePAT deletes a specific PAT from a user
|
||||
func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "account not found: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
targetUser, ok := account.Users[targetUserID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
executingUser, ok := account.Users[initiatorUserID]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
|
||||
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
|
||||
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) ||
|
||||
executingUser.AccountID != accountID {
|
||||
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
||||
}
|
||||
|
||||
pat := targetUser.PATs[tokenID]
|
||||
if pat == nil {
|
||||
return status.Errorf(status.NotFound, "PAT not found")
|
||||
pat, err := am.Store.GetPATByID(ctx, LockingStrengthShare, tokenID, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.Store.DeleteTokenID2UserIDIndex(pat.ID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "Failed to delete token id index: %s", err)
|
||||
}
|
||||
err = am.Store.DeleteHashedPAT2TokenIDIndex(pat.HashedToken)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err)
|
||||
if err = am.Store.DeletePAT(ctx, LockingStrengthUpdate, tokenID, targetUserID); err != nil {
|
||||
return fmt.Errorf("failed to delete PAT: %w", err)
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
||||
|
||||
delete(targetUser.PATs, tokenID)
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "Failed to save account: %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -651,22 +623,11 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
|
||||
return nil, err
|
||||
}
|
||||
|
||||
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
|
||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||
}
|
||||
|
||||
for _, pat := range targetUser.PATsG {
|
||||
if pat.ID == tokenID {
|
||||
return pat.Copy(), nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, status.Errorf(status.NotFound, "PAT not found")
|
||||
return am.Store.GetPATByID(ctx, LockingStrengthShare, targetUserID, tokenID)
|
||||
}
|
||||
|
||||
// GetAllPATs returns all PATs for a user
|
||||
|
Reference in New Issue
Block a user