mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-16 10:08:12 +02:00
@ -45,15 +45,15 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
PublicCategory = "public"
|
PublicCategory = "public"
|
||||||
PrivateCategory = "private"
|
PrivateCategory = "private"
|
||||||
UnknownCategory = "unknown"
|
UnknownCategory = "unknown"
|
||||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||||
DefaultPeerLoginExpiration = 24 * time.Hour
|
DefaultPeerLoginExpiration = 24 * time.Hour
|
||||||
DefaultPeerInactivityExpiration = 10 * time.Minute
|
DefaultPeerInactivityExpiration = 10 * time.Minute
|
||||||
emptyUserID = "empty user ID in claims"
|
emptyUserID = "empty user ID in claims"
|
||||||
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
|
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
|
||||||
)
|
)
|
||||||
|
|
||||||
type userLoggedInOnce bool
|
type userLoggedInOnce bool
|
||||||
@ -134,14 +134,14 @@ type AccountManager interface {
|
|||||||
GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error)
|
GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error)
|
||||||
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error
|
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error
|
||||||
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error)
|
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error)
|
||||||
LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||||
SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||||
GetAllConnectedPeers() (map[string]struct{}, error)
|
GetAllConnectedPeers() (map[string]struct{}, error)
|
||||||
HasConnectedChannel(peerID string) bool
|
HasConnectedChannel(peerID string) bool
|
||||||
GetExternalCacheManager() ExternalCacheManager
|
GetExternalCacheManager() ExternalCacheManager
|
||||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||||
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error
|
SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, isUpdate bool) error
|
||||||
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
DeletePostureChecks(ctx context.Context, accountID, postureChecksID, userID string) error
|
||||||
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error)
|
||||||
GetIdpManager() idp.Manager
|
GetIdpManager() idp.Manager
|
||||||
@ -1122,7 +1122,16 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
|
|||||||
// Only users with role UserRoleAdmin can update the account.
|
// Only users with role UserRoleAdmin can update the account.
|
||||||
// User that performs the update has to belong to the account.
|
// User that performs the update has to belong to the account.
|
||||||
// Returns an updated Account
|
// Returns an updated Account
|
||||||
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) {
|
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) {
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !user.HasAdminPower() || user.AccountID != accountID {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
|
||||||
|
}
|
||||||
|
|
||||||
halfYearLimit := 180 * 24 * time.Hour
|
halfYearLimit := 180 * 24 * time.Hour
|
||||||
if newSettings.PeerLoginExpiration > halfYearLimit {
|
if newSettings.PeerLoginExpiration > halfYearLimit {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
|
||||||
@ -1132,78 +1141,89 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
|
||||||
}
|
}
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
var oldSettings *Settings
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
oldSettings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get account settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = am.validateExtraSettings(ctx, newSettings, oldSettings, userID, accountID); err != nil {
|
||||||
|
return fmt.Errorf("failed to validate extra settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = am.Store.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings); err != nil {
|
||||||
|
return fmt.Errorf("failed to update account settings: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := account.FindUser(userID)
|
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||||
|
am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||||
|
|
||||||
|
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("error getting account: %w", err)
|
||||||
}
|
}
|
||||||
|
am.updateAccountPeers(ctx, account)
|
||||||
|
|
||||||
if !user.HasAdminPower() {
|
return newSettings, nil
|
||||||
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
|
}
|
||||||
}
|
|
||||||
|
|
||||||
err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID)
|
// validateExtraSettings validates the extra settings of the account.
|
||||||
|
func (am *DefaultAccountManager) validateExtraSettings(ctx context.Context, newSettings, oldSettings *Settings, userID, accountID string) error {
|
||||||
|
peers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
oldSettings := account.Settings
|
peerMap := make(map[string]*nbpeer.Peer, len(peers))
|
||||||
|
for _, peer := range peers {
|
||||||
|
peerMap[peer.ID] = peer
|
||||||
|
}
|
||||||
|
|
||||||
|
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peerMap, userID, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) {
|
||||||
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
|
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
|
||||||
event := activity.AccountPeerLoginExpirationEnabled
|
event := activity.AccountPeerLoginExpirationEnabled
|
||||||
if !newSettings.PeerLoginExpirationEnabled {
|
if !newSettings.PeerLoginExpirationEnabled {
|
||||||
event = activity.AccountPeerLoginExpirationDisabled
|
event = activity.AccountPeerLoginExpirationDisabled
|
||||||
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
|
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
|
||||||
} else {
|
} else {
|
||||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
|
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
|
||||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
updatedAccount := account.UpdateSettings(newSettings)
|
|
||||||
|
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return updatedAccount, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error {
|
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) {
|
||||||
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
|
||||||
event := activity.AccountPeerInactivityExpirationEnabled
|
event := activity.AccountPeerInactivityExpirationEnabled
|
||||||
if !newSettings.PeerInactivityExpirationEnabled {
|
if !newSettings.PeerInactivityExpirationEnabled {
|
||||||
event = activity.AccountPeerInactivityExpirationDisabled
|
event = activity.AccountPeerInactivityExpirationDisabled
|
||||||
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||||
} else {
|
} else {
|
||||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||||
}
|
}
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
|
||||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
|
||||||
@ -1234,10 +1254,10 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) {
|
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) {
|
||||||
am.peerLoginExpiry.Cancel(ctx, []string{account.Id})
|
am.peerLoginExpiry.Cancel(ctx, []string{accountID})
|
||||||
if nextRun, ok := account.GetNextPeerExpiration(); ok {
|
if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
|
||||||
go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id))
|
go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1271,10 +1291,10 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
|
// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
|
||||||
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) {
|
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, accountID string) {
|
||||||
am.peerInactivityExpiry.Cancel(ctx, []string{account.Id})
|
am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
|
||||||
if nextRun, ok := account.GetNextInactivePeerExpiration(); ok {
|
if nextRun, ok := am.getNextInactivePeerExpiration(ctx, accountID); ok {
|
||||||
go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id))
|
go am.peerInactivityExpiry.Schedule(ctx, nextRun, accountID, am.peerInactivityExpirationJob(ctx, accountID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1435,12 +1455,12 @@ func isNil(i idp.Manager) bool {
|
|||||||
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
|
||||||
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
|
||||||
if !isNil(am.idpManager) {
|
if !isNil(am.idpManager) {
|
||||||
accountUsers, err := am.Store.GetAccountUsers(ctx, accountID)
|
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
cachedAccount := &Account{
|
cachedAccount := &Account{
|
||||||
Id: accountID,
|
Id: accountID,
|
||||||
Users: make(map[string]*User),
|
Users: make(map[string]*User),
|
||||||
}
|
}
|
||||||
for _, user := range accountUsers {
|
for _, user := range accountUsers {
|
||||||
@ -2083,7 +2103,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
return fmt.Errorf("error saving groups: %w", err)
|
return fmt.Errorf("error saving groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.IncrementNetworkSerial(ctx, accountID); err != nil {
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
return fmt.Errorf("error incrementing network serial: %w", err)
|
return fmt.Errorf("error incrementing network serial: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2101,7 +2121,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, g := range addNewGroups {
|
for _, g := range addNewGroups {
|
||||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||||
} else {
|
} else {
|
||||||
@ -2114,7 +2134,7 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, g := range removeOldGroups {
|
for _, g := range removeOldGroups {
|
||||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, g, accountID)
|
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, g)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, accountID)
|
||||||
} else {
|
} else {
|
||||||
|
@ -2553,7 +2553,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 0)
|
assert.Len(t, user.AutoGroups, 0)
|
||||||
|
|
||||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||||
assert.NoError(t, err, "unable to get group")
|
assert.NoError(t, err, "unable to get group")
|
||||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||||
})
|
})
|
||||||
@ -2573,7 +2573,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
|
|||||||
assert.NoError(t, err, "unable to get user")
|
assert.NoError(t, err, "unable to get user")
|
||||||
assert.Len(t, user.AutoGroups, 1)
|
assert.Len(t, user.AutoGroups, 1)
|
||||||
|
|
||||||
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "group1", "accountID")
|
group1, err := manager.Store.GetGroupByID(context.Background(), LockingStrengthShare, "accountID", "group1")
|
||||||
assert.NoError(t, err, "unable to get group")
|
assert.NoError(t, err, "unable to get group")
|
||||||
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
assert.Equal(t, group1.Issued, group.GroupIssuedAPI, "group should be api issued")
|
||||||
})
|
})
|
||||||
|
@ -50,7 +50,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
return am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllGroups returns all groups in an account
|
// GetAllGroups returns all groups in an account
|
||||||
@ -64,7 +64,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
|||||||
|
|
||||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) {
|
||||||
return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID)
|
return am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, groupName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SaveGroup object of the peers
|
// SaveGroup object of the peers
|
||||||
@ -94,7 +94,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
|||||||
}
|
}
|
||||||
|
|
||||||
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
if newGroup.ID == "" && newGroup.Issued == nbgroup.GroupIssuedAPI {
|
||||||
existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, newGroup.Name, accountID)
|
existingGroup, err := am.Store.GetGroupByName(ctx, LockingStrengthShare, accountID, newGroup.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s, ok := status.FromError(err)
|
s, ok := status.FromError(err)
|
||||||
if !ok || s.ErrorType != status.NotFound {
|
if !ok || s.ErrorType != status.NotFound {
|
||||||
@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, peerID := range newGroup.Peers {
|
for _, peerID := range newGroup.Peers {
|
||||||
if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID); err != nil {
|
if _, err = am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID); err != nil {
|
||||||
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -158,7 +158,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
|
|||||||
addedPeers := make([]string, 0)
|
addedPeers := make([]string, 0)
|
||||||
removedPeers := make([]string, 0)
|
removedPeers := make([]string, 0)
|
||||||
|
|
||||||
oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, newGroup.ID, accountID)
|
oldGroup, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, newGroup.ID)
|
||||||
if err == nil && oldGroup != nil {
|
if err == nil && oldGroup != nil {
|
||||||
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
addedPeers = difference(newGroup.Peers, oldGroup.Peers)
|
||||||
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
removedPeers = difference(oldGroup.Peers, newGroup.Peers)
|
||||||
@ -170,7 +170,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, peerID := range addedPeers {
|
for _, peerID := range addedPeers {
|
||||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID)
|
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
|
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
|
||||||
continue
|
continue
|
||||||
@ -187,7 +187,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, userID
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, peerID := range removedPeers {
|
for _, peerID := range removedPeers {
|
||||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, peerID, accountID)
|
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
|
log.WithContext(ctx).Errorf("peer %s not found under account %s while saving group", peerID, accountID)
|
||||||
continue
|
continue
|
||||||
@ -232,7 +232,7 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountID, use
|
|||||||
return status.Errorf(status.PermissionDenied, "no permission to delete group")
|
return status.Errorf(status.PermissionDenied, "no permission to delete group")
|
||||||
}
|
}
|
||||||
|
|
||||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -288,7 +288,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
|||||||
)
|
)
|
||||||
|
|
||||||
for _, groupID := range groupIDs {
|
for _, groupID := range groupIDs {
|
||||||
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, groupID, accountID)
|
group, err := am.Store.GetGroupByID(ctx, LockingStrengthShare, accountID, groupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -307,7 +307,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
|
|||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, groupIDsToDelete, accountID); err != nil {
|
if err = transaction.DeleteGroups(ctx, LockingStrengthUpdate, accountID, groupIDsToDelete); err != nil {
|
||||||
return fmt.Errorf("failed to delete group: %w", err)
|
return fmt.Errorf("failed to delete group: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
|
|||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view name server groups")
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
|
return am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateNameServerGroup creates and saves a new nameserver group
|
// CreateNameServerGroup creates and saves a new nameserver group
|
||||||
@ -103,7 +103,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
|||||||
return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account")
|
return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupToSave.ID, accountID)
|
_, err = am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupToSave.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -150,7 +150,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
|||||||
return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account")
|
return status.Errorf(status.PermissionDenied, "no permission to delete nameserver for this account")
|
||||||
}
|
}
|
||||||
|
|
||||||
nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, nsGroupID, accountID)
|
nsGroup, err := am.Store.GetNameServerGroupByID(ctx, LockingStrengthShare, accountID, nsGroupID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -160,7 +160,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
|||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, nsGroupID, accountID); err != nil {
|
if err = transaction.DeleteNameServerGroup(ctx, LockingStrengthUpdate, accountID, nsGroupID); err != nil {
|
||||||
return fmt.Errorf("failed to delete nameserver group: %w", err)
|
return fmt.Errorf("failed to delete nameserver group: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -117,11 +117,11 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
|||||||
|
|
||||||
if peer.AddedWithSSOLogin() {
|
if peer.AddedWithSSOLogin() {
|
||||||
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
am.checkAndSchedulePeerLoginExpiration(ctx, account.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
am.checkAndSchedulePeerInactivityExpiration(ctx, account.Id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -230,7 +230,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||||
|
|
||||||
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled {
|
||||||
am.checkAndSchedulePeerLoginExpiration(ctx, account)
|
am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -249,7 +249,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
|
||||||
|
|
||||||
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled {
|
||||||
am.checkAndSchedulePeerInactivityExpiration(ctx, account)
|
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -537,7 +537,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return fmt.Errorf("failed to add peer to account: %w", err)
|
return fmt.Errorf("failed to add peer to account: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = transaction.IncrementNetworkSerial(ctx, accountID)
|
err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
@ -1041,6 +1041,139 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||||
|
// If there is no peer that expires this function returns false and a duration of 0.
|
||||||
|
// This function only considers peers that haven't been expired yet and that are connected.
|
||||||
|
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
|
||||||
|
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(peersWithExpiry) == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var nextExpiry *time.Duration
|
||||||
|
for _, peer := range peersWithExpiry {
|
||||||
|
// consider only connected peers because others will require login on connecting to the management server
|
||||||
|
if peer.Status.LoginExpired || !peer.Status.Connected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, duration := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||||
|
if nextExpiry == nil || duration < *nextExpiry {
|
||||||
|
// if expiration is below 1s return 1s duration
|
||||||
|
// this avoids issues with ticker that can't be set to < 0
|
||||||
|
if duration < time.Second {
|
||||||
|
return time.Second, true
|
||||||
|
}
|
||||||
|
nextExpiry = &duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextExpiry == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return *nextExpiry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||||
|
// If there is no peer that expires this function returns false and a duration of 0.
|
||||||
|
// This function only considers peers that haven't been expired yet and that are not connected.
|
||||||
|
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
|
||||||
|
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(peersWithInactivity) == 0 {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var nextExpiry *time.Duration
|
||||||
|
for _, peer := range peersWithInactivity {
|
||||||
|
if peer.Status.LoginExpired || peer.Status.Connected {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, duration := peer.SessionExpired(settings.PeerInactivityExpiration)
|
||||||
|
if nextExpiry == nil || duration < *nextExpiry {
|
||||||
|
// if expiration is below 1s return 1s duration
|
||||||
|
// this avoids issues with ticker that can't be set to < 0
|
||||||
|
if duration < time.Second {
|
||||||
|
return time.Second, true
|
||||||
|
}
|
||||||
|
nextExpiry = &duration
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if nextExpiry == nil {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return *nextExpiry, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// getExpiredPeers returns peers that have been expired.
|
||||||
|
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
|
||||||
|
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var peers []*nbpeer.Peer
|
||||||
|
for _, peer := range peersWithExpiry {
|
||||||
|
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
|
||||||
|
if expired {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getInactivePeers returns peers that have been expired by inactivity
|
||||||
|
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
|
||||||
|
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var peers []*nbpeer.Peer
|
||||||
|
for _, inactivePeer := range peersWithInactivity {
|
||||||
|
inactive, _ := inactivePeer.SessionExpired(settings.PeerInactivityExpiration)
|
||||||
|
if inactive {
|
||||||
|
peers = append(peers, inactivePeer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
|
||||||
labelMap := make(map[string]struct{}, len(existingLabels))
|
labelMap := make(map[string]struct{}, len(existingLabels))
|
||||||
for _, label := range existingLabels {
|
for _, label := range existingLabels {
|
||||||
|
@ -8,9 +8,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
b "github.com/hashicorp/go-secure-stdlib/base62"
|
b "github.com/hashicorp/go-secure-stdlib/base62"
|
||||||
"github.com/rs/xid"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/base62"
|
"github.com/netbirdio/netbird/base62"
|
||||||
|
"github.com/rs/xid"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -335,7 +335,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
|
|||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies")
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID)
|
return am.Store.GetPolicyByID(ctx, LockingStrengthShare, accountID, policyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SavePolicy in the store
|
// SavePolicy in the store
|
||||||
|
@ -25,7 +25,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
|
|||||||
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly)
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SavePostureChecks saves a posture check.
|
// SavePostureChecks saves a posture check.
|
||||||
@ -49,7 +49,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
|||||||
if isUpdate {
|
if isUpdate {
|
||||||
action = activity.PostureCheckUpdated
|
action = activity.PostureCheckUpdated
|
||||||
|
|
||||||
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecks.ID, accountID); err != nil {
|
if _, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecks.ID); err != nil {
|
||||||
return fmt.Errorf("failed to get posture checks: %w", err)
|
return fmt.Errorf("failed to get posture checks: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,7 +114,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID)
|
postureChecks, err := am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, postureChecksID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -124,7 +124,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
|
|||||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, postureChecksID, accountID); err != nil {
|
if err = transaction.DeletePostureChecks(ctx, LockingStrengthUpdate, accountID, postureChecksID); err != nil {
|
||||||
return fmt.Errorf("failed to delete posture checks: %w", err)
|
return fmt.Errorf("failed to delete posture checks: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -56,13 +56,35 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
|
|||||||
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view Network Routes")
|
||||||
}
|
}
|
||||||
|
|
||||||
return am.Store.GetRouteByID(ctx, LockingStrengthShare, string(routeID), accountID)
|
return am.Store.GetRouteByID(ctx, LockingStrengthShare, accountID, string(routeID))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||||
|
func (am *DefaultAccountManager) GetRoutesByPrefixOrDomains(accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
|
||||||
|
accountRoutes, err := am.Store.GetAccountRoutes(context.Background(), LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := make([]*route.Route, 0)
|
||||||
|
for _, r := range accountRoutes {
|
||||||
|
dynamic := r.IsDynamic()
|
||||||
|
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
|
||||||
|
!dynamic && r.Network.String() == prefix.String() {
|
||||||
|
routes = append(routes, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||||
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
|
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
|
||||||
// routes can have both peer and peer_groups
|
// routes can have both peer and peer_groups
|
||||||
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
|
routesWithPrefix, err := am.GetRoutesByPrefixOrDomains(account.Id, prefix, domains)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// lets remember all the peers and the peer groups from routesWithPrefix
|
// lets remember all the peers and the peer groups from routesWithPrefix
|
||||||
seenPeers := make(map[string]bool)
|
seenPeers := make(map[string]bool)
|
||||||
@ -81,8 +103,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
|||||||
for _, groupID := range prefixRoute.PeerGroups {
|
for _, groupID := range prefixRoute.PeerGroups {
|
||||||
seenPeerGroups[groupID] = true
|
seenPeerGroups[groupID] = true
|
||||||
|
|
||||||
group := account.GetGroup(groupID)
|
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, account.Id, groupID)
|
||||||
if group == nil {
|
if err != nil || group == nil {
|
||||||
return status.Errorf(
|
return status.Errorf(
|
||||||
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
|
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
|
||||||
getRouteDescriptor(prefix, domains), groupID,
|
getRouteDescriptor(prefix, domains), groupID,
|
||||||
@ -97,10 +119,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
|||||||
|
|
||||||
if peerID != "" {
|
if peerID != "" {
|
||||||
// check that peerID exists and is not in any route as single peer or part of the group
|
// check that peerID exists and is not in any route as single peer or part of the group
|
||||||
peer := account.GetPeer(peerID)
|
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, account.Id, peerID)
|
||||||
if peer == nil {
|
if err != nil || peer == nil {
|
||||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := seenPeers[peerID]; ok {
|
if _, ok := seenPeers[peerID]; ok {
|
||||||
return status.Errorf(status.AlreadyExists,
|
return status.Errorf(status.AlreadyExists,
|
||||||
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
|
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
|
||||||
@ -109,7 +132,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
|||||||
|
|
||||||
// check that peerGroupIDs are not in any route peerGroups list
|
// check that peerGroupIDs are not in any route peerGroups list
|
||||||
for _, groupID := range peerGroupIDs {
|
for _, groupID := range peerGroupIDs {
|
||||||
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again.
|
// we validated the group existence before entering this function, no need to check again.
|
||||||
|
group, err := am.Store.GetGroupByID(context.Background(), LockingStrengthShare, groupID, account.Id)
|
||||||
|
if err != nil || group == nil {
|
||||||
|
return status.Errorf(status.InvalidArgument, "group with ID %s not found", peerID)
|
||||||
|
}
|
||||||
|
|
||||||
if _, ok := seenPeerGroups[groupID]; ok {
|
if _, ok := seenPeerGroups[groupID]; ok {
|
||||||
return status.Errorf(
|
return status.Errorf(
|
||||||
@ -120,10 +147,11 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
|||||||
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
|
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
|
||||||
for _, id := range group.Peers {
|
for _, id := range group.Peers {
|
||||||
if _, ok := seenPeers[id]; ok {
|
if _, ok := seenPeers[id]; ok {
|
||||||
peer := account.GetPeer(id)
|
peer, err := am.Store.GetPeerByID(context.Background(), LockingStrengthShare, peerID, account.Id)
|
||||||
if peer == nil {
|
if err != nil {
|
||||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||||
}
|
}
|
||||||
|
|
||||||
return status.Errorf(status.AlreadyExists,
|
return status.Errorf(status.AlreadyExists,
|
||||||
"failed to add route with %s - peer %s from the group %s already has this route",
|
"failed to add route with %s - peer %s from the group %s already has this route",
|
||||||
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
|
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
|
||||||
@ -146,6 +174,15 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
|
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.AccountID != accountID {
|
||||||
|
return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data")
|
||||||
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
account, err := am.Store.GetAccount(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -181,17 +218,17 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
newRoute.ID = route.ID(xid.New().String())
|
newRoute.ID = route.ID(xid.New().String())
|
||||||
|
|
||||||
if len(peerGroupIDs) > 0 {
|
if len(peerGroupIDs) > 0 {
|
||||||
err = validateGroups(peerGroupIDs, account.Groups)
|
//err = validateGroups(peerGroupIDs, account.Groups)
|
||||||
if err != nil {
|
//if err != nil {
|
||||||
return nil, err
|
// return nil, err
|
||||||
}
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(accessControlGroupIDs) > 0 {
|
if len(accessControlGroupIDs) > 0 {
|
||||||
err = validateGroups(accessControlGroupIDs, account.Groups)
|
//err = validateGroups(accessControlGroupIDs, account.Groups)
|
||||||
if err != nil {
|
//if err != nil {
|
||||||
return nil, err
|
// return nil, err
|
||||||
}
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
|
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
|
||||||
@ -207,10 +244,10 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
|||||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateGroups(groups, account.Groups)
|
//err = validateGroups(groups, account.Groups)
|
||||||
if err != nil {
|
//if err != nil {
|
||||||
return nil, err
|
// return nil, err
|
||||||
}
|
//}
|
||||||
|
|
||||||
newRoute.Peer = peerID
|
newRoute.Peer = peerID
|
||||||
newRoute.PeerGroups = peerGroupIDs
|
newRoute.PeerGroups = peerGroupIDs
|
||||||
@ -290,17 +327,17 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(routeToSave.PeerGroups) > 0 {
|
if len(routeToSave.PeerGroups) > 0 {
|
||||||
err = validateGroups(routeToSave.PeerGroups, account.Groups)
|
//err = validateGroups(routeToSave.PeerGroups, account.Groups)
|
||||||
if err != nil {
|
//if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(routeToSave.AccessControlGroups) > 0 {
|
if len(routeToSave.AccessControlGroups) > 0 {
|
||||||
err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
|
//err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
|
||||||
if err != nil {
|
//if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
//}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
||||||
@ -308,10 +345,10 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = validateGroups(routeToSave.Groups, account.Groups)
|
//err = validateGroups(routeToSave.Groups, account.Groups)
|
||||||
if err != nil {
|
//if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
//}
|
||||||
|
|
||||||
account.Routes[routeToSave.ID] = routeToSave
|
account.Routes[routeToSave.ID] = routeToSave
|
||||||
|
|
||||||
|
@ -287,7 +287,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyToSave.Id, accountID)
|
oldKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyToSave.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -382,7 +382,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
|||||||
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
return nil, status.Errorf(status.Unauthorized, "only users with admin power can view setup keys")
|
||||||
}
|
}
|
||||||
|
|
||||||
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, keyID, accountID)
|
setupKey, err := am.Store.GetSetupKeyByID(ctx, LockingStrengthShare, accountID, keyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -541,9 +541,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
|
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error) {
|
||||||
var users []*User
|
var users []*User
|
||||||
result := s.db.Find(&users, accountIDCondition, accountID)
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&users, accountIDCondition, accountID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||||
@ -828,7 +828,9 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
|
|||||||
|
|
||||||
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
|
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
|
||||||
var accountSettings AccountSettings
|
var accountSettings AccountSettings
|
||||||
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||||
|
First(&accountSettings, idQueryCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "settings not found")
|
return nil, status.Errorf(status.NotFound, "settings not found")
|
||||||
}
|
}
|
||||||
@ -837,6 +839,21 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
|
|||||||
return accountSettings.Settings, nil
|
return accountSettings.Settings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveAccountSettings stores the account settings in DB.
|
||||||
|
func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error {
|
||||||
|
result := s.db.WithContext(ctx).Debug().Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||||
|
Select("*").Where(idQueryCondition, accountID).Updates(&AccountSettings{Settings: settings})
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "failed to save account settings to store: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||||
var user User
|
var user User
|
||||||
@ -1054,9 +1071,72 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetAccountPeers retrieves peers for an account.
|
||||||
|
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
||||||
|
var peers []*nbpeer.Peer
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, accountIDCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get peers from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetUserPeers retrieves peers for a user.
|
// GetUserPeers retrieves peers for a user.
|
||||||
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
|
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
|
||||||
return getRecords[*nbpeer.Peer](s.db.WithContext(ctx).Where("user_id = ?", userID), lockStrength, accountID)
|
var peers []*nbpeer.Peer
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&peers, "account_id = ? AND user_id = ?", accountID, userID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get peers from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get peers from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user.
|
||||||
|
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
||||||
|
var peers []*nbpeer.Peer
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
|
||||||
|
Find(&peers, accountIDCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get peers with expiration from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user.
|
||||||
|
func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
|
||||||
|
var peers []*nbpeer.Peer
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
|
||||||
|
Find(&peers, accountIDCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get peers with inactivity from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get peers with inactivity from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return peers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerByID retrieves a peer by its ID and account ID.
|
||||||
|
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
|
||||||
|
var peer *nbpeer.Peer
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&peer, accountAndIDQueryCondition, accountID, peerID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "peer not found")
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get peer from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get peer from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return peer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||||
@ -1067,8 +1147,9 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error {
|
||||||
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||||
|
Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
|
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
|
||||||
}
|
}
|
||||||
@ -1113,6 +1194,19 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
|
|||||||
return &accountDNSSettings.DNSSettings, nil
|
return &accountDNSSettings.DNSSettings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveDNSSettings saves the DNS settings to the store.
|
||||||
|
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||||
|
Where(idQueryCondition, accountID).Updates(&AccountDNSSettings{DNSSettings: *settings})
|
||||||
|
if result.Error != nil {
|
||||||
|
return status.Errorf(status.Internal, "failed to save dns settings to store: %v", result.Error)
|
||||||
|
}
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "account not found")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// AccountExists checks whether an account exists by the given ID.
|
// AccountExists checks whether an account exists by the given ID.
|
||||||
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
|
||||||
var accountID string
|
var accountID string
|
||||||
@ -1146,16 +1240,24 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetGroupByID retrieves a group by ID and account ID.
|
// GetGroupByID retrieves a group by ID and account ID.
|
||||||
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) {
|
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error) {
|
||||||
return getRecordByID[nbgroup.Group](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, groupID, accountID)
|
var group *nbgroup.Group
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Find(&group, accountAndIDQueryCondition, accountID, groupID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get group from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get group from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return group, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGroupByName retrieves a group by name and account ID.
|
// GetGroupByName retrieves a group by name and account ID.
|
||||||
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
|
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) {
|
||||||
var group nbgroup.Group
|
var group nbgroup.Group
|
||||||
|
|
||||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Preload(clause.Associations).
|
||||||
Order("json_array_length(peers) DESC").First(&group, "name = ? and account_id = ?", groupName, accountID)
|
Order("json_array_length(peers) DESC").First(&group, "account_id = ? AND name = ?", accountID, groupName)
|
||||||
if err := result.Error; err != nil {
|
if err := result.Error; err != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.Errorf(status.NotFound, "group not found")
|
return nil, status.Errorf(status.NotFound, "group not found")
|
||||||
@ -1174,69 +1276,335 @@ func (s *SqlStore) SaveGroup(ctx context.Context, lockStrength LockingStrength,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteGroup deletes a group from the database.
|
||||||
|
func (s *SqlStore) DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&nbgroup.Group{}, accountAndIDQueryCondition, accountID, groupID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete group from the store: %s", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete group from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "group not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteGroups deletes groups from the database.
|
||||||
|
func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(strength)}).
|
||||||
|
Delete(&nbgroup.Group{}, " account_id = ? AND id IN ?", accountID, groupIDs)
|
||||||
|
if result.Error != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete groups from the store: %v", result.Error)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete groups from store: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountPolicies retrieves policies for an account.
|
// GetAccountPolicies retrieves policies for an account.
|
||||||
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
|
||||||
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
|
var policies []*Policy
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get policies from the store: %s", result.Error)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get policies from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return policies, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPolicyByID retrieves a policy by its ID and account ID.
|
// GetPolicyByID retrieves a policy by its ID and account ID.
|
||||||
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
|
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error) {
|
||||||
return getRecordByID[Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, policyID, accountID)
|
var policy *Policy
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Preload(clause.Associations).Find(&policy, accountAndIDQueryCondition, accountID, policyID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get policy from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get policy from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return policy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SavePolicy saves a policy to the database.
|
||||||
|
func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error {
|
||||||
|
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
|
||||||
|
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(policy)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save policy to the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to save policy to store")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, policyID, accountID string) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete policy from the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete policy from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "policy not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountPostureChecks retrieves posture checks for an account.
|
// GetAccountPostureChecks retrieves posture checks for an account.
|
||||||
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
|
||||||
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
|
var postureChecks []*posture.Checks
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountIDCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get posture checks from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get posture checks from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return postureChecks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
|
||||||
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
|
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) {
|
||||||
return getRecordByID[posture.Checks](s.db.WithContext(ctx), lockStrength, postureCheckID, accountID)
|
var postureCheck *posture.Checks
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&postureCheck, accountAndIDQueryCondition, accountID, postureCheckID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "posture check not found")
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get posture check from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get posture check from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return postureCheck, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SavePostureChecks saves a posture checks to the database.
|
||||||
|
func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrDuplicatedKey) {
|
||||||
|
return status.Errorf(status.InvalidArgument, "name should be unique")
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to save posture checks to the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to save posture checks to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePostureChecks deletes a posture checks from the database.
|
||||||
|
func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&posture.Checks{}, accountAndIDQueryCondition, accountID, postureChecksID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete posture checks from the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete posture checks from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "posture checks not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountRoutes retrieves network routes for an account.
|
// GetAccountRoutes retrieves network routes for an account.
|
||||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
||||||
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
|
var routes []*route.Route
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Find(&routes, accountIDCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get routes from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRouteByID retrieves a route by its ID and account ID.
|
// GetRouteByID retrieves a route by its ID and account ID.
|
||||||
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
|
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
|
||||||
return getRecordByID[route.Route](s.db.WithContext(ctx), lockStrength, routeID, accountID)
|
var route *route.Route
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&route, accountAndIDQueryCondition, accountID, routeID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "route not found")
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get route from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get route from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return route, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveRoute saves a route to the database.
|
||||||
|
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to save route to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRoute deletes a route from the database.
|
||||||
|
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete route from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "route not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||||
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
|
||||||
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
|
var setupKeys []*SetupKey
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Find(&setupKeys, accountIDCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get setup keys from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get setup keys from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return setupKeys, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
|
||||||
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error) {
|
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error) {
|
||||||
return getRecordByID[SetupKey](s.db.WithContext(ctx), lockStrength, setupKeyID, accountID)
|
var setupKey *SetupKey
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get setup key from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get setup key from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return setupKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveSetupKey saves a setup key to the database.
|
||||||
|
func (s *SqlStore) SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error {
|
||||||
|
result := s.db.WithContext(ctx).Session(&gorm.Session{FullSaveAssociations: true}).
|
||||||
|
Clauses(clause.Locking{Strength: string(lockStrength)}).Save(setupKey)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to save name server group to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccountNameServerGroups retrieves name server groups for an account.
|
// GetAccountNameServerGroups retrieves name server groups for an account.
|
||||||
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
|
||||||
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
|
var nsGroups []*nbdns.NameServerGroup
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&nsGroups, accountIDCondition, accountID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get name server groups from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get name server groups from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nsGroups, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
|
||||||
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nsGroupID string, accountID string) (*nbdns.NameServerGroup, error) {
|
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||||
return getRecordByID[nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, nsGroupID, accountID)
|
var nsGroup *nbdns.NameServerGroup
|
||||||
}
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
|
||||||
// getRecords retrieves records from the database based on the account ID.
|
|
||||||
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
|
||||||
var record []T
|
|
||||||
|
|
||||||
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
|
|
||||||
if err := result.Error; err != nil {
|
if err := result.Error; err != nil {
|
||||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
recordType := parts[len(parts)-1]
|
return nil, status.Errorf(status.NotFound, "name server group not found")
|
||||||
|
}
|
||||||
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
|
log.WithContext(ctx).Errorf("failed to get name server group from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get name server group from store")
|
||||||
}
|
}
|
||||||
|
|
||||||
return record, nil
|
return nsGroup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveNameServerGroup saves a name server group to the database.
|
||||||
|
func (s *SqlStore) SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *nbdns.NameServerGroup) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(nameServerGroup)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save name server group to the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to save name server group to store")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteNameServerGroup deletes a name server group from the database.
|
||||||
|
func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error {
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Delete(&nbdns.NameServerGroup{}, accountAndIDQueryCondition, accountID, nameServerGroupID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete name server group from the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete name server group from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "name server group not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPATByID retrieves a personal access token by its ID and user ID.
|
||||||
|
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, patID string, userID string) (*PersonalAccessToken, error) {
|
||||||
|
var pat PersonalAccessToken
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
First(&pat, "id = ? AND user_id = ?", patID, userID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, status.Errorf(status.NotFound, "PAT not found")
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Errorf("failed to get PAT from the store: %s", err)
|
||||||
|
return nil, status.Errorf(status.Internal, "failed to get PAT from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &pat, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SavePAT saves a personal access token to the database.
|
||||||
|
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to save PAT to the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to save PAT to store")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePAT deletes a personal access token from the database.
|
||||||
|
func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength, userID, patID string) error {
|
||||||
|
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||||
|
Delete(&PersonalAccessToken{}, "user_id = ? AND id = ?", userID, patID)
|
||||||
|
if err := result.Error; err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to delete PAT from the store: %s", err)
|
||||||
|
return status.Errorf(status.Internal, "failed to delete PAT from store")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.RowsAffected == 0 {
|
||||||
|
return status.Errorf(status.NotFound, "PAT not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRecordByID retrieves a record by its ID and account ID from the database.
|
// getRecordByID retrieves a record by its ID and account ID from the database.
|
||||||
|
@ -1181,7 +1181,7 @@ func TestSqlite_CreateAndGetObjectInTransaction(t *testing.T) {
|
|||||||
t.Fatal("failed to save group")
|
t.Fatal("failed to save group")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.ID, group.AccountID)
|
group, err = transaction.GetGroupByID(context.Background(), LockingStrengthUpdate, group.AccountID, group.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal("failed to get group")
|
t.Fatal("failed to get group")
|
||||||
return err
|
return err
|
||||||
|
@ -56,13 +56,15 @@ type Store interface {
|
|||||||
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error)
|
||||||
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error)
|
||||||
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error)
|
||||||
|
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error
|
||||||
|
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error
|
||||||
SaveAccount(ctx context.Context, account *Account) error
|
SaveAccount(ctx context.Context, account *Account) error
|
||||||
DeleteAccount(ctx context.Context, account *Account) error
|
DeleteAccount(ctx context.Context, account *Account) error
|
||||||
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
|
||||||
|
|
||||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||||
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)
|
||||||
GetAccountUsers(ctx context.Context, accountID string) ([]*User, error)
|
GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*User, error)
|
||||||
SaveUsers(accountID string, users map[string]*User) error
|
SaveUsers(accountID string, users map[string]*User) error
|
||||||
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error
|
||||||
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error
|
||||||
@ -71,24 +73,34 @@ type Store interface {
|
|||||||
DeleteTokenID2UserIDIndex(tokenID string) error
|
DeleteTokenID2UserIDIndex(tokenID string) error
|
||||||
|
|
||||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error)
|
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*nbgroup.Group, error)
|
||||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error)
|
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error)
|
||||||
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error
|
||||||
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error
|
||||||
|
DeleteGroup(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) error
|
||||||
|
DeleteGroups(ctx context.Context, strength LockingStrength, accountID string, groupIDs []string) error
|
||||||
|
|
||||||
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error)
|
||||||
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error)
|
GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*Policy, error)
|
||||||
|
SavePolicy(ctx context.Context, lockStrength LockingStrength, policy *Policy) error
|
||||||
|
DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error
|
||||||
|
|
||||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||||
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error)
|
||||||
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error)
|
GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error)
|
||||||
|
SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error
|
||||||
|
DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error
|
||||||
|
|
||||||
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error)
|
||||||
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error
|
||||||
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error
|
||||||
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error
|
||||||
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error)
|
||||||
|
GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||||
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error)
|
||||||
|
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||||
|
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||||
|
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
||||||
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
|
||||||
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error
|
||||||
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
SavePeerLocation(accountID string, peer *nbpeer.Peer) error
|
||||||
@ -96,16 +108,25 @@ type Store interface {
|
|||||||
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error)
|
||||||
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error
|
||||||
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
|
GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error)
|
||||||
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, setupKeyID string, accountID string) (*SetupKey, error)
|
GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*SetupKey, error)
|
||||||
|
SaveSetupKey(ctx context.Context, lockStrength LockingStrength, setupKey *SetupKey) error
|
||||||
|
|
||||||
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
|
||||||
|
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
|
||||||
|
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
|
||||||
|
|
||||||
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) (*dns.NameServerGroup, error)
|
||||||
|
SaveNameServerGroup(ctx context.Context, lockStrength LockingStrength, nameServerGroup *dns.NameServerGroup) error
|
||||||
|
DeleteNameServerGroup(ctx context.Context, lockStrength LockingStrength, accountID, nameServerGroupID string) error
|
||||||
|
|
||||||
|
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
|
||||||
|
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
|
||||||
|
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
|
||||||
|
|
||||||
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error)
|
||||||
IncrementNetworkSerial(ctx context.Context, accountId string) error
|
IncrementNetworkSerial(ctx context.Context, lockStrength LockingStrength, accountId string) error
|
||||||
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error)
|
||||||
|
|
||||||
GetInstallationID() string
|
GetInstallationID() string
|
||||||
|
@ -546,9 +546,6 @@ func (am *DefaultAccountManager) InviteUser(ctx context.Context, accountID strin
|
|||||||
|
|
||||||
// CreatePAT creates a new PAT for the given user
|
// CreatePAT creates a new PAT for the given user
|
||||||
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenName string, expiresIn int) (*PersonalAccessTokenGenerated, error) {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
if tokenName == "" {
|
if tokenName == "" {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "token name can't be empty")
|
return nil, status.Errorf(status.InvalidArgument, "token name can't be empty")
|
||||||
}
|
}
|
||||||
@ -557,35 +554,28 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
|||||||
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
|
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser, ok := account.Users[targetUserID]
|
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser, ok := account.Users[initiatorUserID]
|
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) ||
|
||||||
if !ok {
|
executingUser.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.NotFound, "user not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
|
|
||||||
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
return nil, status.Errorf(status.PermissionDenied, "no permission to create PAT for this user")
|
||||||
}
|
}
|
||||||
|
|
||||||
pat, err := CreateNewPAT(tokenName, expiresIn, executingUser.Id)
|
pat, err := CreateNewPAT(tokenName, expiresIn, targetUser.Id, executingUser.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
|
return nil, status.Errorf(status.Internal, "failed to create PAT: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser.PATs[pat.ID] = &pat.PersonalAccessToken
|
if err = am.Store.SavePAT(ctx, LockingStrengthUpdate, &pat.PersonalAccessToken); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to save PAT: %w", err)
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(status.Internal, "failed to save account: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||||
@ -596,51 +586,33 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
|||||||
|
|
||||||
// DeletePAT deletes a specific PAT from a user
|
// DeletePAT deletes a specific PAT from a user
|
||||||
func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error {
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
executingUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, initiatorUserID)
|
||||||
defer unlock()
|
|
||||||
|
|
||||||
account, err := am.Store.GetAccount(ctx, accountID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return status.Errorf(status.NotFound, "account not found: %s", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser, ok := account.Users[targetUserID]
|
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return status.Errorf(status.NotFound, "user not found")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
executingUser, ok := account.Users[initiatorUserID]
|
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) ||
|
||||||
if !ok {
|
executingUser.AccountID != accountID {
|
||||||
return status.Errorf(status.NotFound, "user not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !(initiatorUserID == targetUserID || (executingUser.HasAdminPower() && targetUser.IsServiceUser)) {
|
|
||||||
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
return status.Errorf(status.PermissionDenied, "no permission to delete PAT for this user")
|
||||||
}
|
}
|
||||||
|
|
||||||
pat := targetUser.PATs[tokenID]
|
pat, err := am.Store.GetPATByID(ctx, LockingStrengthShare, tokenID, targetUserID)
|
||||||
if pat == nil {
|
if err != nil {
|
||||||
return status.Errorf(status.NotFound, "PAT not found")
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.Store.DeleteTokenID2UserIDIndex(pat.ID)
|
if err = am.Store.DeletePAT(ctx, LockingStrengthUpdate, tokenID, targetUserID); err != nil {
|
||||||
if err != nil {
|
return fmt.Errorf("failed to delete PAT: %w", err)
|
||||||
return status.Errorf(status.Internal, "Failed to delete token id index: %s", err)
|
|
||||||
}
|
|
||||||
err = am.Store.DeleteHashedPAT2TokenIDIndex(pat.HashedToken)
|
|
||||||
if err != nil {
|
|
||||||
return status.Errorf(status.Internal, "Failed to delete hashed token index: %s", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
meta := map[string]any{"name": pat.Name, "is_service_user": targetUser.IsServiceUser, "user_name": targetUser.ServiceUserName}
|
||||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
am.StoreEvent(ctx, initiatorUserID, targetUserID, accountID, activity.PersonalAccessTokenDeleted, meta)
|
||||||
|
|
||||||
delete(targetUser.PATs, tokenID)
|
|
||||||
|
|
||||||
err = am.Store.SaveAccount(ctx, account)
|
|
||||||
if err != nil {
|
|
||||||
return status.Errorf(status.Internal, "Failed to save account: %s", err)
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -651,22 +623,11 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
targetUser, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, targetUserID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
|
if (initiatorUserID != targetUserID && !initiatorUser.IsAdminOrServiceUser()) || initiatorUser.AccountID != accountID {
|
||||||
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
return nil, status.Errorf(status.PermissionDenied, "no permission to get PAT for this user")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, pat := range targetUser.PATsG {
|
return am.Store.GetPATByID(ctx, LockingStrengthShare, targetUserID, tokenID)
|
||||||
if pat.ID == tokenID {
|
|
||||||
return pat.Copy(), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, status.Errorf(status.NotFound, "PAT not found")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAllPATs returns all PATs for a user
|
// GetAllPATs returns all PATs for a user
|
||||||
|
Reference in New Issue
Block a user