mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-03 01:11:14 +01:00
run peer ops in transaction
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
f6f7260897
commit
a61e9da3e9
@ -2390,8 +2390,8 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
|
||||
return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
|
||||
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction Store, peer *nbpeer.Peer, settings *Settings) (bool, error) {
|
||||
user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -2402,7 +2402,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
|
||||
}
|
||||
|
||||
if peerLoginExpired(ctx, peer, settings) {
|
||||
err = am.handleExpiredPeer(ctx, user, peer)
|
||||
err = am.handleExpiredPeer(ctx, transaction, user, peer)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -117,17 +118,25 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var peer *nbpeer.Peer
|
||||
var settings *Settings
|
||||
var expired bool
|
||||
var err error
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, accountID)
|
||||
settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -151,7 +160,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
|
||||
func updatePeerStatusAndLocation(ctx context.Context, geo *geolocation.Geolocation, transaction Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
|
||||
oldStatus := peer.Status.Copy()
|
||||
newStatus := oldStatus
|
||||
newStatus.LastSeen = time.Now().UTC()
|
||||
@ -162,8 +171,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
|
||||
}
|
||||
peer.Status = newStatus
|
||||
|
||||
if am.geo != nil && realIP != nil {
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if geo != nil && realIP != nil {
|
||||
location, err := geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
} else {
|
||||
@ -171,14 +180,14 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
|
||||
peer.Location.CountryCode = location.Country.ISOCode
|
||||
peer.Location.CityName = location.City.Names.En
|
||||
peer.Location.GeoNameID = location.City.GeonameID
|
||||
err = am.Store.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer)
|
||||
err = transaction.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus)
|
||||
err := transaction.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -200,23 +209,49 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
return nil, status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, update.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, update.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var settings *Settings
|
||||
var peerGroupList []string
|
||||
var requiresPeerUpdates bool
|
||||
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra)
|
||||
var newLabel string
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, update.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peerGroupList, err = getPeerGroupIDs(ctx, am.Store, accountID, update.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peer.Name != update.Name {
|
||||
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newLabel, err = getPeerHostLabel(update.Name, existingLabels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peer.DNSLabel = newLabel
|
||||
}
|
||||
|
||||
return transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -231,18 +266,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
if peer.Name != update.Name {
|
||||
peer.Name = update.Name
|
||||
peerLabelChanged = true
|
||||
|
||||
existingLabels, err := am.getPeerDNSLabels(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peer.DNSLabel = newLabel
|
||||
}
|
||||
|
||||
if peer.LoginExpirationEnabled != update.LoginExpirationEnabled {
|
||||
@ -261,10 +284,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
inactivityExpirationChanged = true
|
||||
}
|
||||
|
||||
if err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sshChanged {
|
||||
event := activity.PeerSSHEnabled
|
||||
if !peer.SSHEnabled {
|
||||
@ -313,13 +332,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, peerID)
|
||||
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, LockingStrengthShare, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if peerAccountID != accountID {
|
||||
return status.NewPeerNotPartOfAccountError()
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var addPeerRemovedEvents []func()
|
||||
var updateAccountPeers bool
|
||||
var eventsToStore []func()
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, peerID)
|
||||
@ -327,16 +351,21 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
return err
|
||||
}
|
||||
|
||||
addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
|
||||
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
|
||||
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
|
||||
return err
|
||||
})
|
||||
|
||||
for _, addPeerRemovedEvent := range addPeerRemovedEvents {
|
||||
addPeerRemovedEvent()
|
||||
for _, storeEvent := range eventsToStore {
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
@ -433,6 +462,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
|
||||
var newPeer *nbpeer.Peer
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
var setupKeyID string
|
||||
@ -480,7 +510,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return fmt.Errorf("failed to get free DNS label: %w", err)
|
||||
}
|
||||
|
||||
freeIP, err := am.getFreeIP(ctx, transaction, accountID)
|
||||
freeIP, err := getFreeIP(ctx, transaction, accountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get free IP: %w", err)
|
||||
}
|
||||
@ -564,6 +594,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
}
|
||||
|
||||
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
||||
return nil
|
||||
})
|
||||
@ -581,11 +616,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
unlock()
|
||||
unlock = nil
|
||||
|
||||
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, newPeer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
am.updateAccountPeers(ctx, accountID)
|
||||
}
|
||||
@ -593,13 +623,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
|
||||
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
|
||||
func getFreeIP(ctx context.Context, transaction Store, accountID string) (net.IP, error) {
|
||||
takenIps, err := transaction.GetTakenIPs(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
|
||||
}
|
||||
|
||||
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
|
||||
network, err := transaction.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed getting network: %w", err)
|
||||
}
|
||||
@ -614,48 +644,59 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc
|
||||
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.NewPeerNotRegisteredError()
|
||||
}
|
||||
var peer *nbpeer.Peer
|
||||
var peerNotValid bool
|
||||
var isStatusChanged bool
|
||||
var updated bool
|
||||
var err error
|
||||
|
||||
if peer.UserID != "" {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, sync.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return status.NewPeerNotRegisteredError()
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, user)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
if peer.UserID != "" {
|
||||
user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = checkIfPeerOwnerIsBlocked(peer, user); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if peerLoginExpired(ctx, peer, settings) {
|
||||
return nil, nil, nil, status.NewPeerLoginExpiredError()
|
||||
}
|
||||
|
||||
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupList, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
updated := peer.UpdateMetaIfNew(sync.Meta)
|
||||
if updated {
|
||||
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
|
||||
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if peerLoginExpired(ctx, peer, settings) {
|
||||
return status.NewPeerLoginExpiredError()
|
||||
}
|
||||
|
||||
peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peerNotValid, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updated = peer.UpdateMetaIfNew(sync.Meta)
|
||||
if updated {
|
||||
err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if isStatusChanged || (updated && sync.UpdateAccountPeers) {
|
||||
@ -707,73 +748,73 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
}()
|
||||
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
var peer *nbpeer.Peer
|
||||
var updateRemotePeers bool
|
||||
var isRequiresApproval bool
|
||||
var isStatusChanged bool
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
// this flag prevents unnecessary calls to the persistent store.
|
||||
shouldStorePeer := false
|
||||
updateRemotePeers := false
|
||||
|
||||
if login.UserID != "" {
|
||||
if peer.UserID != login.UserID {
|
||||
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
|
||||
return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user")
|
||||
}
|
||||
|
||||
changed, err := am.handleUserPeer(ctx, peer, settings)
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||
peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
return err
|
||||
}
|
||||
if changed {
|
||||
shouldStorePeer = true
|
||||
updateRemotePeers = true
|
||||
|
||||
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// this flag prevents unnecessary calls to the persistent store.
|
||||
shouldStorePeer := false
|
||||
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if login.UserID != "" {
|
||||
if peer.UserID != login.UserID {
|
||||
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
|
||||
return status.Errorf(status.Unauthenticated, "invalid user")
|
||||
}
|
||||
|
||||
var grps []string
|
||||
for _, group := range groups {
|
||||
for _, id := range group.Peers {
|
||||
if id == peer.ID {
|
||||
grps = append(grps, group.ID)
|
||||
break
|
||||
changed, err := am.handleUserPeer(ctx, transaction, peer, settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if changed {
|
||||
shouldStorePeer = true
|
||||
updateRemotePeers = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra)
|
||||
peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
isRequiresApproval, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updated := peer.UpdateMetaIfNew(login.Meta)
|
||||
if updated {
|
||||
shouldStorePeer = true
|
||||
}
|
||||
|
||||
if peer.SSHKey != login.SSHKey {
|
||||
peer.SSHKey = login.SSHKey
|
||||
shouldStorePeer = true
|
||||
}
|
||||
|
||||
if shouldStorePeer {
|
||||
if err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
updated := peer.UpdateMetaIfNew(login.Meta)
|
||||
if updated {
|
||||
shouldStorePeer = true
|
||||
}
|
||||
|
||||
if peer.SSHKey != login.SSHKey {
|
||||
peer.SSHKey = login.SSHKey
|
||||
shouldStorePeer = true
|
||||
}
|
||||
|
||||
if shouldStorePeer {
|
||||
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
unlockPeer()
|
||||
unlockPeer = nil
|
||||
|
||||
@ -845,7 +886,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error {
|
||||
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction Store, user *User, peer *nbpeer.Peer) error {
|
||||
err := checkAuth(ctx, user.Id, peer)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -853,12 +894,12 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
|
||||
// If peer was expired before and if it reached this point, it is re-authenticated.
|
||||
// UserID is present, meaning that JWT validation passed successfully in the API layer.
|
||||
peer = peer.UpdateLastLogin()
|
||||
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer)
|
||||
err = transaction.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
|
||||
err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1149,7 +1190,12 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID
|
||||
|
||||
// GetPeerGroups returns groups that the peer is part of.
|
||||
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
|
||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
return getPeerGroups(ctx, am.Store, accountID, peerID)
|
||||
}
|
||||
|
||||
// getPeerGroups returns the IDs of the groups that the peer is part of.
|
||||
func getPeerGroups(ctx context.Context, transaction Store, accountID, peerID string) ([]*nbgroup.Group, error) {
|
||||
groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1165,8 +1211,8 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p
|
||||
}
|
||||
|
||||
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
|
||||
func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID string, peerID string) ([]string, error) {
|
||||
groups, err := am.GetPeerGroups(ctx, accountID, peerID)
|
||||
func getPeerGroupIDs(ctx context.Context, transaction Store, accountID string, peerID string) ([]string, error) {
|
||||
groups, err := getPeerGroups(ctx, transaction, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1179,8 +1225,8 @@ func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID
|
||||
return groupIDs, err
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID string) (lookupMap, error) {
|
||||
dnsLabels, err := am.Store.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID)
|
||||
func getPeerDNSLabels(ctx context.Context, transaction Store, accountID string) (lookupMap, error) {
|
||||
dnsLabels, err := transaction.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1194,12 +1240,12 @@ func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID
|
||||
|
||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||
// in an active DNS, route, or ACL configuration.
|
||||
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, accountID, peerID string) (bool, error) {
|
||||
peerGroupIDs, err := am.getPeerGroupIDs(ctx, accountID, peerID)
|
||||
func isPeerInActiveGroup(ctx context.Context, transaction Store, accountID, peerID string) (bool, error) {
|
||||
peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return areGroupChangesAffectPeers(ctx, am.Store, accountID, peerGroupIDs) // TODO: use transaction
|
||||
return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction
|
||||
}
|
||||
|
||||
// deletePeers deletes all specified peers and sends updates to the remote peers.
|
||||
|
@ -754,6 +754,20 @@ func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength Lockin
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
|
||||
var accountID string
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||
Select("account_id").Where(idQueryCondition, peerID).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "peer %s account not found", peerID)
|
||||
}
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
||||
var accountID string
|
||||
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
|
||||
|
@ -86,6 +86,11 @@ func NewAccountNotFoundError(accountKey string) error {
|
||||
return Errorf(NotFound, "account not found: %s", accountKey)
|
||||
}
|
||||
|
||||
// NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account
|
||||
func NewPeerNotPartOfAccountError() error {
|
||||
return Errorf(PermissionDenied, "peer is not part of this account")
|
||||
}
|
||||
|
||||
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
|
||||
func NewUserNotFoundError(userKey string) error {
|
||||
return Errorf(NotFound, "user not found: %s", userKey)
|
||||
|
@ -50,6 +50,7 @@ type Store interface {
|
||||
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
|
||||
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
||||
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||
|
@ -524,6 +524,10 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
for _, peer := range peers {
|
||||
account.DeletePeer(peer.ID)
|
||||
}
|
||||
|
||||
return hadPeers, nil
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user