Refactor peer to use store methods

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-11-14 19:33:57 +03:00
parent 7d849a92c0
commit c557c98390
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
6 changed files with 555 additions and 339 deletions

View File

@ -92,7 +92,7 @@ type AccountManager interface {
GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error)
ListUsers(ctx context.Context, accountID string) ([]*User, error) ListUsers(ctx context.Context, accountID string) ([]*User, error)
GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *Account) error MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error)
GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error) GetNetworkMap(ctx context.Context, peerID string) (*NetworkMap, error)
@ -112,6 +112,7 @@ type AccountManager interface {
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error)
GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error)
SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error) SavePolicy(ctx context.Context, accountID, userID string, policy *Policy) (*Policy, error)
DeletePolicy(ctx context.Context, accountID, policyID, userID string) error DeletePolicy(ctx context.Context, accountID, policyID, userID string) error
@ -134,7 +135,7 @@ type AccountManager interface {
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error)
LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error) GetAllConnectedPeers() (map[string]struct{}, error)
HasConnectedChannel(peerID string) bool HasConnectedChannel(peerID string) bool
GetExternalCacheManager() ExternalCacheManager GetExternalCacheManager() ExternalCacheManager
@ -145,7 +146,7 @@ type AccountManager interface {
GetIdpManager() idp.Manager GetIdpManager() idp.Manager
UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error UpdateIntegratedValidatorGroups(ctx context.Context, accountID string, userID string, groups []string) error
GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error)
GetValidatedPeers(account *Account) (map[string]struct{}, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
@ -1160,17 +1161,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
event = activity.AccountPeerLoginExpirationDisabled event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel(ctx, []string{accountID}) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
} else { } else {
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
} }
am.StoreEvent(ctx, userID, accountID, accountID, event, nil) am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
} }
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
} }
err = am.handleInactivityExpirationSettings(ctx, account, oldSettings, newSettings, userID, accountID) err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1185,21 +1186,21 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return updatedAccount, nil return updatedAccount, nil
} }
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, account *Account, oldSettings, newSettings *Settings, userID, accountID string) error { func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error {
if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled {
event := activity.AccountPeerInactivityExpirationEnabled event := activity.AccountPeerInactivityExpirationEnabled
if !newSettings.PeerInactivityExpirationEnabled { if !newSettings.PeerInactivityExpirationEnabled {
event = activity.AccountPeerInactivityExpirationDisabled event = activity.AccountPeerInactivityExpirationDisabled
am.peerInactivityExpiry.Cancel(ctx, []string{accountID}) am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
} else { } else {
am.checkAndSchedulePeerInactivityExpiration(ctx, account) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
} }
am.StoreEvent(ctx, userID, accountID, accountID, event, nil) am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
} }
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, account) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
} }
return nil return nil
@ -1207,73 +1208,64 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) { return func() (time.Duration, bool) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) expiredPeers, err := am.getExpiredPeers(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed getting account %s expiring peers", accountID) return 0, false
return account.GetNextPeerExpiration()
} }
expiredPeers := account.GetExpiredPeers()
var peerIDs []string var peerIDs []string
for _, peer := range expiredPeers { for _, peer := range expiredPeers {
peerIDs = append(peerIDs, peer.ID) peerIDs = append(peerIDs, peer.ID)
} }
log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) log.WithContext(ctx).Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID)
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil {
log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", account.Id) log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID)
return account.GetNextPeerExpiration() return 0, false
} }
return account.GetNextPeerExpiration() return am.getNextPeerExpiration(ctx, accountID)
} }
} }
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, account *Account) { func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) {
am.peerLoginExpiry.Cancel(ctx, []string{account.Id}) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
if nextRun, ok := account.GetNextPeerExpiration(); ok { if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, account.Id, am.peerLoginExpirationJob(ctx, account.Id)) go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
} }
} }
// peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found // peerInactivityExpirationJob marks login expired for all inactive peers and returns the minimum duration in which the next peer of the account will expire by inactivity if found
func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) { return func() (time.Duration, bool) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) inactivePeers, err := am.getInactivePeers(ctx, accountID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
log.Errorf("failed getting account %s expiring peers", accountID) log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID)
return account.GetNextInactivePeerExpiration() return 0, false
} }
expiredPeers := account.GetInactivePeers()
var peerIDs []string var peerIDs []string
for _, peer := range expiredPeers { for _, peer := range inactivePeers {
peerIDs = append(peerIDs, peer.ID) peerIDs = append(peerIDs, peer.ID)
} }
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id) log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), accountID)
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id) log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
return account.GetNextInactivePeerExpiration() return 0, false
} }
return account.GetNextInactivePeerExpiration() return am.getNextInactivePeerExpiration(ctx, accountID)
} }
} }
// checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions // checkAndSchedulePeerInactivityExpiration periodically checks for inactive peers to end their sessions
func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, account *Account) { func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx context.Context, accountID string) {
am.peerInactivityExpiry.Cancel(ctx, []string{account.Id}) am.peerInactivityExpiry.Cancel(ctx, []string{accountID})
if nextRun, ok := account.GetNextInactivePeerExpiration(); ok { if nextRun, ok := am.getNextInactivePeerExpiration(ctx, accountID); ok {
go am.peerInactivityExpiry.Schedule(ctx, nextRun, account.Id, am.peerInactivityExpirationJob(ctx, account.Id)) go am.peerInactivityExpiry.Schedule(ctx, nextRun, accountID, am.peerInactivityExpirationJob(ctx, accountID))
} }
} }
@ -1409,7 +1401,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "no valid userID provided") return "", status.Errorf(status.NotFound, "no valid userID provided")
} }
accountID, err := am.Store.GetAccountIDByUserID(userID) accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
@ -2188,7 +2180,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", err return "", err
} }
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err return "", err
@ -2235,7 +2227,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
} }
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err return "", err
@ -2292,17 +2284,12 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock() defer peerUnlock()
account, err := am.Store.GetAccount(ctx, accountID) peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, status.NewGetAccountError(err)
}
peer, netMap, postureChecks, err := am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, account)
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
} }
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, account) err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
} }
@ -2316,12 +2303,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey) peerUnlock := am.Store.AcquireWriteLockByUID(ctx, peerPubKey)
defer peerUnlock() defer peerUnlock()
account, err := am.Store.GetAccount(ctx, accountID) err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID)
if err != nil {
return status.NewGetAccountError(err)
}
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, account)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
} }
@ -2339,12 +2321,7 @@ func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey st
unlock := am.Store.AcquireReadLockByUID(ctx, accountID) unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) _, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
if err != nil {
return err
}
_, _, _, err = am.SyncPeer(ctx, PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, account)
if err != nil { if err != nil {
return mapError(ctx, err) return mapError(ctx, err)
} }

View File

@ -4,6 +4,8 @@ import (
"context" "context"
"errors" "errors"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
@ -73,6 +75,39 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
return true, nil return true, nil
} }
func (am *DefaultAccountManager) GetValidatedPeers(account *Account) (map[string]struct{}, error) { func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
return am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra) var err error
var groups []*nbgroup.Group
var peers []*nbpeer.Peer
var settings *Settings
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
groups, err = transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
peers, err = transaction.GetAccountPeers(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
return err
})
if err != nil {
return nil, err
}
groupsMap := make(map[string]*nbgroup.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
return am.integratedPeerValidator.GetValidatedPeers(accountID, groupsMap, peersMap, settings.Extra)
} }

View File

@ -47,6 +47,7 @@ type MockAccountManager struct {
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*group.Group, error)
DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error
GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error) GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error) SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *server.Policy) (*server.Policy, error)
@ -90,7 +91,7 @@ type MockAccountManager struct {
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error)
LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
GetAllConnectedPeersFunc func() (map[string]struct{}, error) GetAllConnectedPeersFunc func() (map[string]struct{}, error)
HasConnectedChannelFunc func(peerID string) bool HasConnectedChannelFunc func(peerID string) bool
@ -130,7 +131,12 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st
panic("implement me") panic("implement me")
} }
func (am *MockAccountManager) GetValidatedPeers(account *server.Account) (map[string]struct{}, error) { func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) {
account, err := am.GetAccountFunc(ctx, accountID)
if err != nil {
return nil, err
}
approvedPeers := make(map[string]struct{}) approvedPeers := make(map[string]struct{})
for id := range account.Peers { for id := range account.Peers {
approvedPeers[id] = struct{}{} approvedPeers[id] = struct{}{}
@ -221,7 +227,7 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userId,
} }
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, account *server.Account) error { func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error {
if am.MarkPeerConnectedFunc != nil { if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP)
} }
@ -682,9 +688,9 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login server.PeerLo
} }
// SyncPeer mocks SyncPeer of the AccountManager interface // SyncPeer mocks SyncPeer of the AccountManager interface
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, account *server.Account) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) { func (am *MockAccountManager) SyncPeer(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) {
if am.SyncPeerFunc != nil { if am.SyncPeerFunc != nil {
return am.SyncPeerFunc(ctx, sync, account) return am.SyncPeerFunc(ctx, sync, accountID)
} }
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") return nil, nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented")
} }
@ -831,3 +837,11 @@ func (am *MockAccountManager) GetAccount(ctx context.Context, accountID string)
} }
return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetAccount is not implemented")
} }
// GetPeerGroups mocks GetPeerGroups of the AccountManager interface
func (am *MockAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*group.Group, error) {
if am.GetPeerGroupsFunc != nil {
return am.GetPeerGroupsFunc(ctx, accountID, peerID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetPeerGroups is not implemented")
}

View File

@ -11,8 +11,10 @@ import (
"sync" "sync"
"time" "time"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/posture"
@ -53,43 +55,55 @@ type PeerLogin struct {
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin. // the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID string) ([]*nbpeer.Peer, error) {
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
approvedPeersMap, err := am.GetValidatedPeers(account) if user.IsRegularUser() && settings.RegularUsersViewBlocked {
return []*nbpeer.Peer{}, nil
}
accountPeers, err := am.Store.GetAccountPeers(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
peers := make([]*nbpeer.Peer, 0) peers := make([]*nbpeer.Peer, 0)
peersMap := make(map[string]*nbpeer.Peer) peersMap := make(map[string]*nbpeer.Peer)
regularUser := !user.HasAdminPower() && !user.IsServiceUser for _, peer := range accountPeers {
if user.IsRegularUser() && user.Id != peer.UserID {
if regularUser && account.Settings.RegularUsersViewBlocked {
return peers, nil
}
for _, peer := range account.Peers {
if regularUser && user.Id != peer.UserID {
// only display peers that belong to the current user if the current user is not an admin // only display peers that belong to the current user if the current user is not an admin
continue continue
} }
p := peer.Copy() peers = append(peers, peer)
peers = append(peers, p) peersMap[peer.ID] = peer
peersMap[peer.ID] = p
} }
if !regularUser { if user.IsAdminOrServiceUser() {
return peers, nil return peers, nil
} }
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, err
}
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil {
return nil, err
}
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
for _, peer := range peers { for _, peer := range peers {
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap) aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
@ -98,48 +112,46 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
} }
} }
peers = make([]*nbpeer.Peer, 0, len(peersMap)) return maps.Values(peersMap), nil
for _, peer := range peersMap {
peers = append(peers, peer)
}
return peers, nil
} }
// MarkPeerConnected marks peer as connected (true) or disconnected (false) // MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, account *Account) error { func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
peer, err := account.FindPeerByPubKey(peerPubKey) peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peerPubKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to find peer by pub key: %w", err) return err
} }
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, account) settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to update peer status and location: %w", err) return err
} }
log.WithContext(ctx).Debugf("mark peer %s connected: %t", peer.ID, connected) expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, accountID)
if err != nil {
return err
}
if peer.AddedWithSSOLogin() { if peer.AddedWithSSOLogin() {
if peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
} }
if peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
} }
} }
if expired { if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from // we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again. // the expired one. Here we notify them that connection is now allowed again.
am.updateAccountPeers(ctx, account.Id) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, account *Account) (bool, error) { func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
oldStatus := peer.Status.Copy() oldStatus := peer.Status.Copy()
newStatus := oldStatus newStatus := oldStatus
newStatus.LastSeen = time.Now().UTC() newStatus.LastSeen = time.Now().UTC()
@ -159,18 +171,16 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
peer.Location.CountryCode = location.Country.ISOCode peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID peer.Location.GeoNameID = location.City.GeonameID
err = am.Store.SavePeerLocation(account.Id, peer) err = am.Store.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
} }
} }
} }
account.UpdatePeer(peer) err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus)
err := am.Store.SavePeerStatus(account.Id, peer.ID, *newStatus)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to save peer status: %w", err) return false, err
} }
return oldStatus.LoginExpired, nil return oldStatus.LoginExpired, nil
@ -181,37 +191,51 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
peer := account.GetPeer(update.ID) if user.AccountID != accountID {
if peer == nil { return nil, status.NewUserNotPartOfAccountError()
return nil, status.Errorf(status.NotFound, "peer %s not found", update.ID) }
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, update.ID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, update.ID)
if err != nil {
return nil, err
} }
var requiresPeerUpdates bool var requiresPeerUpdates bool
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), account.GetPeerGroupsList(peer.ID), account.Settings.Extra) update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var sshChanged, peerLabelChanged, loginExpirationChanged, inactivityExpirationChanged bool
if peer.SSHEnabled != update.SSHEnabled { if peer.SSHEnabled != update.SSHEnabled {
peer.SSHEnabled = update.SSHEnabled peer.SSHEnabled = update.SSHEnabled
event := activity.PeerSSHEnabled sshChanged = true
if !update.SSHEnabled {
event = activity.PeerSSHDisabled
}
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
} }
peerLabelUpdated := peer.Name != update.Name if peer.Name != update.Name {
if peerLabelUpdated {
peer.Name = update.Name peer.Name = update.Name
peerLabelChanged = true
existingLabels := account.getPeerDNSLabels() existingLabels, err := am.getPeerDNSLabels(ctx, accountID)
if err != nil {
return nil, err
}
newLabel, err := getPeerHostLabel(peer.Name, existingLabels) newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
if err != nil { if err != nil {
@ -219,133 +243,100 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
} }
peer.DNSLabel = newLabel peer.DNSLabel = newLabel
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain()))
} }
if peer.LoginExpirationEnabled != update.LoginExpirationEnabled { if peer.LoginExpirationEnabled != update.LoginExpirationEnabled {
if !peer.AddedWithSSOLogin() { if !peer.AddedWithSSOLogin() {
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated") return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
} }
peer.LoginExpirationEnabled = update.LoginExpirationEnabled peer.LoginExpirationEnabled = update.LoginExpirationEnabled
loginExpirationChanged = true
}
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled {
if !peer.AddedWithSSOLogin() {
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the inactivity expiration can't be updated")
}
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
inactivityExpirationChanged = true
}
if err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil {
return nil, err
}
if sshChanged {
event := activity.PeerSSHEnabled
if !peer.SSHEnabled {
event = activity.PeerSSHDisabled
}
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
}
if peerLabelChanged {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRenamed, peer.EventMeta(am.GetDNSDomain()))
}
if loginExpirationChanged {
event := activity.PeerLoginExpirationEnabled event := activity.PeerLoginExpirationEnabled
if !update.LoginExpirationEnabled { if !peer.LoginExpirationEnabled {
event = activity.PeerLoginExpirationDisabled event = activity.PeerLoginExpirationDisabled
} }
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && account.Settings.PeerLoginExpirationEnabled { if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, account) am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
} }
} }
if peer.InactivityExpirationEnabled != update.InactivityExpirationEnabled { if inactivityExpirationChanged {
if !peer.AddedWithSSOLogin() {
return nil, status.Errorf(status.PreconditionFailed, "this peer hasn't been added with the SSO login, therefore the login expiration can't be updated")
}
peer.InactivityExpirationEnabled = update.InactivityExpirationEnabled
event := activity.PeerInactivityExpirationEnabled event := activity.PeerInactivityExpirationEnabled
if !update.InactivityExpirationEnabled { if !peer.InactivityExpirationEnabled {
event = activity.PeerInactivityExpirationDisabled event = activity.PeerInactivityExpirationDisabled
} }
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain())) am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(am.GetDNSDomain()))
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && account.Settings.PeerInactivityExpirationEnabled { if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, account) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
} }
} }
account.UpdatePeer(peer) if peerLabelChanged || requiresPeerUpdates {
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
if peerLabelUpdated || requiresPeerUpdates {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
return peer, nil return peer, nil
} }
// deletePeers will delete all specified peers and send updates to the remote peers. Don't call without acquiring account lock
func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Account, peerIDs []string, userID string) error {
// the first loop is needed to ensure all peers present under the account before modifying, otherwise
// we might have some inconsistencies
peers := make([]*nbpeer.Peer, 0, len(peerIDs))
for _, peerID := range peerIDs {
peer := account.GetPeer(peerID)
if peer == nil {
return status.Errorf(status.NotFound, "peer %s not found", peerID)
}
peers = append(peers, peer)
}
// the 2nd loop performs the actual modification
for _, peer := range peers {
err := am.integratedPeerValidator.PeerDeleted(ctx, account.Id, peer.ID)
if err != nil {
return err
}
account.DeletePeer(peer.ID)
am.peersUpdateManager.SendUpdate(ctx, peer.ID,
&UpdateMessage{
Update: &proto.SyncResponse{
// fill those field for backward compatibility
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
// new field
NetworkMap: &proto.NetworkMap{
Serial: account.Network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
},
},
NetworkMap: &NetworkMap{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
}
return nil
}
// DeletePeer removes peer from the account by its IP // DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID) updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, peerID)
if err != nil { if err != nil {
return err return err
} }
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, account, peerID) var peer *nbpeer.Peer
if err != nil { var addPeerRemovedEvents []func()
return err
}
err = am.deletePeers(ctx, account, []string{peerID}, userID) err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
if err != nil { peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, peerID)
return err if err != nil {
} return err
}
err = am.Store.SaveAccount(ctx, account) addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
if err != nil { if err != nil {
return err return err
}
return transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
})
for _, addPeerRemovedEvent := range addPeerRemovedEvents {
addPeerRemovedEvent()
} }
if updateAccountPeers { if updateAccountPeers {
@ -411,7 +402,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
addedByUser := false addedByUser := false
if len(userID) > 0 { if len(userID) > 0 {
addedByUser = true addedByUser = true
accountID, err = am.Store.GetAccountIDByUserID(userID) accountID, err = am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID)
} else { } else {
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
} }
@ -442,12 +433,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
var newPeer *nbpeer.Peer var newPeer *nbpeer.Peer
var groupsToAdd []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
var setupKeyID string var setupKeyID string
var setupKeyName string var setupKeyName string
var ephemeral bool var ephemeral bool
var groupsToAdd []string
if addedByUser { if addedByUser {
user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID) user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID)
if err != nil { if err != nil {
@ -590,39 +581,16 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
unlock() unlock()
unlock = nil unlock = nil
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, newPeer.ID)
if err != nil {
return nil, nil, nil, status.NewGetAccountError(err)
}
allGroup, err := account.GetGroupAll()
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
}
groupsToAdd = append(groupsToAdd, allGroup.ID)
newGroupsAffectsPeers, err := areGroupChangesAffectPeers(ctx, am.Store, accountID, groupsToAdd)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
if newGroupsAffectsPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
approvedPeersMap, err := am.GetValidatedPeers(account) return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
if err != nil {
return nil, nil, nil, err
}
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID)
if err != nil {
return nil, nil, nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
} }
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) { func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
@ -645,16 +613,16 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc
} }
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, sync.WireGuardPubKey)
if err != nil { if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError() return nil, nil, nil, status.NewPeerNotRegisteredError()
} }
if peer.UserID != "" { if peer.UserID != "" {
user, err := account.FindUser(peer.UserID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get user: %w", err) return nil, nil, nil, err
} }
err = checkIfPeerOwnerIsBlocked(peer, user) err = checkIfPeerOwnerIsBlocked(peer, user)
@ -663,52 +631,38 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
} }
} }
if peerLoginExpired(ctx, peer, account.Settings) { settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
return nil, nil, nil, status.NewPeerLoginExpiredError()
}
updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
err = am.Store.SavePeer(ctx, account.Id, peer)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to save peer: %w", err)
}
if sync.UpdateAccountPeers {
am.updateAccountPeers(ctx, account.Id)
}
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to validate peer: %w", err)
}
var postureChecks []*posture.Checks
if peerNotValid {
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
}
return peer, emptyMap, postureChecks, nil
}
if isStatusChanged {
am.updateAccountPeers(ctx, account.Id)
}
validPeersMap, err := am.GetValidatedPeers(account)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err)
}
postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) if peerLoginExpired(ctx, peer, settings) {
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil return nil, nil, nil, status.NewPeerLoginExpiredError()
}
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, peer.ID)
if err != nil {
return nil, nil, nil, err
}
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupList, settings.Extra)
if err != nil {
return nil, nil, nil, err
}
updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil {
return nil, nil, nil, err
}
}
if isStatusChanged || (updated && sync.UpdateAccountPeers) {
am.updateAccountPeers(ctx, accountID)
}
return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
} }
// LoginPeer logs in or registers a peer. // LoginPeer logs in or registers a peer.
@ -814,7 +768,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
if shouldStorePeer { if shouldStorePeer {
err = am.Store.SavePeer(ctx, accountID, peer) err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -823,16 +777,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
unlockPeer() unlockPeer()
unlockPeer = nil unlockPeer = nil
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if updateRemotePeers || isStatusChanged { if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
} }
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO // checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
@ -864,22 +813,30 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil return nil
} }
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
var postureChecks []*posture.Checks
if isRequiresApproval { if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
emptyMap := &NetworkMap{ emptyMap := &NetworkMap{
Network: account.Network.Copy(), Network: network.Copy(),
} }
return peer, emptyMap, nil, nil return peer, emptyMap, nil, nil
} }
approvedPeersMap, err := am.GetValidatedPeers(account) account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
if err != nil {
return nil, nil, nil, err
}
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, peer.ID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@ -896,7 +853,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
// If peer was expired before and if it reached this point, it is re-authenticated. // If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer. // UserID is present, meaning that JWT validation passed successfully in the API layer.
peer = peer.UpdateLastLogin() peer = peer.UpdateLastLogin()
err = am.Store.SavePeer(ctx, peer.AccountID, peer) err = am.Store.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer)
if err != nil { if err != nil {
return err return err
} }
@ -943,41 +900,47 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
// GetPeer for a given accountID, peerID and userID error if not found. // GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
user, err := account.FindUser(userID) if user.AccountID != accountID {
return nil, status.NewUserNotPartOfAccountError()
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked { if user.IsRegularUser() && settings.RegularUsersViewBlocked {
return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID) return nil, status.Errorf(status.Internal, "user %s has no access to his own peer %s under account %s", userID, peerID, accountID)
} }
peer := account.GetPeer(peerID) peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil {
return nil, status.Errorf(status.NotFound, "peer with %s not found under account %s", peerID, accountID) return nil, err
} }
// if admin or user owns this peer, return peer // if admin or user owns this peer, return peer
if user.HasAdminPower() || user.IsServiceUser || peer.UserID == userID { if user.IsAdminOrServiceUser() || peer.UserID == userID {
return peer, nil return peer, nil
} }
// it is also possible that user doesn't own the peer but some of his peers have access to it, // it is also possible that user doesn't own the peer but some of his peers have access to it,
// this is a valid case, show the peer as well. // this is a valid case, show the peer as well.
userPeers, err := account.FindUserPeers(userID) userPeers, err := am.Store.GetUserPeers(ctx, LockingStrengthShare, accountID, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil {
return nil, err
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1006,12 +969,13 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers: %v", err) log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
return return
} }
peers := account.GetPeers() peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(account) approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err) log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
return return
@ -1037,7 +1001,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID) postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err) log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
return return
} }
@ -1050,6 +1014,236 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
wg.Wait() wg.Wait()
} }
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected.
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
return 0, false
}
if len(peersWithExpiry) == 0 {
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithExpiry {
// consider only connected peers because others will require login on connecting to the management server
if peer.Status.LoginExpired || !peer.Status.Connected {
continue
}
_, duration := peer.LoginExpired(settings.PeerLoginExpiration)
if nextExpiry == nil || duration < *nextExpiry {
// if expiration is below 1s return 1s duration
// this avoids issues with ticker that can't be set to < 0
if duration < time.Second {
return time.Second, true
}
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// GetNextInactivePeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
// If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are not connected.
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
return 0, false
}
if len(peersWithInactivity) == 0 {
return 0, false
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return 0, false
}
var nextExpiry *time.Duration
for _, peer := range peersWithInactivity {
if peer.Status.LoginExpired || peer.Status.Connected {
continue
}
_, duration := peer.SessionExpired(settings.PeerInactivityExpiration)
if nextExpiry == nil || duration < *nextExpiry {
// if expiration is below 1s return 1s duration
// this avoids issues with ticker that can't be set to < 0
if duration < time.Second {
return time.Second, true
}
nextExpiry = &duration
}
}
if nextExpiry == nil {
return 0, false
}
return *nextExpiry, true
}
// getExpiredPeers returns peers that have been expired.
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
var peers []*nbpeer.Peer
for _, peer := range peersWithExpiry {
expired, _ := peer.LoginExpired(settings.PeerLoginExpiration)
if expired {
peers = append(peers, peer)
}
}
return peers, nil
}
// getInactivePeers returns peers that have been expired by inactivity
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
var peers []*nbpeer.Peer
for _, inactivePeer := range peersWithInactivity {
inactive, _ := inactivePeer.SessionExpired(settings.PeerInactivityExpiration)
if inactive {
peers = append(peers, inactivePeer)
}
}
return peers, nil
}
// GetPeerGroups returns groups that the peer is part of.
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerGroups := make([]*nbgroup.Group, 0)
for _, group := range groups {
if slices.Contains(group.Peers, peerID) {
peerGroups = append(peerGroups, group)
}
}
return peerGroups, nil
}
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID string, peerID string) ([]string, error) {
groups, err := am.GetPeerGroups(ctx, accountID, peerID)
if err != nil {
return nil, err
}
groupIDs := make([]string, 0, len(groups))
for _, group := range groups {
groupIDs = append(groupIDs, group.ID)
}
return groupIDs, err
}
func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID string) (lookupMap, error) {
dnsLabels, err := am.Store.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
existingLabels := make(lookupMap)
for _, label := range dnsLabels {
existingLabels[label] = struct{}{}
}
return existingLabels, nil
}
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, accountID, peerID string) (bool, error) {
peerGroupIDs, err := am.getPeerGroupIDs(ctx, accountID, peerID)
if err != nil {
return false, err
}
return areGroupChangesAffectPeers(ctx, am.Store, accountID, peerGroupIDs) // TODO: use transaction
}
// deletePeers deletes all specified peers and sends updates to the remote peers.
// Returns a slice of functions to save events after successful peer deletion.
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
var peerDeletedEvents []func()
for _, peer := range peers {
if err := am.integratedPeerValidator.PeerDeleted(ctx, accountID, peer.ID); err != nil {
return nil, err
}
network, err := transaction.GetAccountNetwork(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
if err = transaction.DeletePeer(ctx, LockingStrengthUpdate, accountID, peer.ID); err != nil {
return nil, err
}
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
},
},
NetworkMap: &NetworkMap{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
peerDeletedEvents = append(peerDeletedEvents, func() {
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
})
}
return peerDeletedEvents, nil
}
func ConvertSliceToMap(existingLabels []string) map[string]struct{} { func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
labelMap := make(map[string]struct{}, len(existingLabels)) labelMap := make(map[string]struct{}, len(existingLabels))
for _, label := range existingLabels { for _, label := range existingLabels {
@ -1057,15 +1251,3 @@ func ConvertSliceToMap(existingLabels []string) map[string]struct{} {
} }
return labelMap return labelMap
} }
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration.
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, account *Account, peerID string) (bool, error) {
peerGroupIDs := make([]string, 0)
for _, group := range account.Groups {
if slices.Contains(group.Peers, peerID) {
peerGroupIDs = append(peerGroupIDs, group.ID)
}
}
return areGroupChangesAffectPeers(ctx, am.Store, account.Id, peerGroupIDs)
}

View File

@ -44,7 +44,7 @@ type Peer struct {
// CreatedAt records the time the peer was created // CreatedAt records the time the peer was created
CreatedAt time.Time CreatedAt time.Time
// Indicate ephemeral peer attribute // Indicate ephemeral peer attribute
Ephemeral bool Ephemeral bool `gorm:"index"`
// Geo location based on connection IP // Geo location based on connection IP
Location Location `gorm:"embedded;embeddedPrefix:location_"` Location Location `gorm:"embedded;embeddedPrefix:location_"`
} }

View File

@ -487,6 +487,10 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
} }
delete(account.Users, targetUserID) delete(account.Users, targetUserID)
if updateAccountPeers {
account.Network.IncSerial()
}
err = am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
if err != nil { if err != nil {
return err return err
@ -511,12 +515,16 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
return false, nil return false, nil
} }
peerIDs := make([]string, 0, len(peers)) eventsToStore, err := deletePeers(ctx, am, am.Store, account.Id, initiatorUserID, peers)
for _, peer := range peers { if err != nil {
peerIDs = append(peerIDs, peer.ID) return false, err
} }
return hadPeers, am.deletePeers(ctx, account, peerIDs, initiatorUserID) for _, storeEvent := range eventsToStore {
storeEvent()
}
return hadPeers, nil
} }
// InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period.
@ -823,7 +831,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
} }
if len(expiredPeers) > 0 { if len(expiredPeers) > 0 {
if err := am.expireAndUpdatePeers(ctx, account, expiredPeers); err != nil { if err := am.expireAndUpdatePeers(ctx, account.Id, expiredPeers); err != nil {
log.WithContext(ctx).Errorf("failed update expired peers: %s", err) log.WithContext(ctx).Errorf("failed update expired peers: %s", err)
return nil, err return nil, err
} }
@ -1104,7 +1112,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
} }
// expireAndUpdatePeers expires all peers of the given user and updates them in the account // expireAndUpdatePeers expires all peers of the given user and updates them in the account
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, account *Account, peers []*nbpeer.Peer) error { func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error {
var peerIDs []string var peerIDs []string
for _, peer := range peers { for _, peer := range peers {
// nolint:staticcheck // nolint:staticcheck
@ -1115,16 +1123,13 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
} }
peerIDs = append(peerIDs, peer.ID) peerIDs = append(peerIDs, peer.ID)
peer.MarkLoginExpired(true) peer.MarkLoginExpired(true)
account.UpdatePeer(peer)
if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil { if err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *peer.Status); err != nil {
return fmt.Errorf("failed saving peer status for peer %s: %s", peer.ID, err) return err
} }
log.WithContext(ctx).Tracef("mark peer %s login expired", peer.ID)
am.StoreEvent( am.StoreEvent(
ctx, ctx,
peer.UserID, peer.ID, account.Id, peer.UserID, peer.ID, accountID,
activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()), activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()),
) )
} }
@ -1132,7 +1137,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
if len(peerIDs) != 0 { if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service // this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(ctx, peerIDs) am.peersUpdateManager.CloseChannels(ctx, peerIDs)
am.updateAccountPeers(ctx, account.Id) am.updateAccountPeers(ctx, accountID)
} }
return nil return nil
} }
@ -1234,6 +1239,9 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
deletedUsersMeta[targetUserID] = meta deletedUsersMeta[targetUserID] = meta
} }
if updateAccountPeers {
account.Network.IncSerial()
}
err = am.Store.SaveAccount(ctx, account) err = am.Store.SaveAccount(ctx, account)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete users: %w", err) return fmt.Errorf("failed to delete users: %w", err)