run peer ops in transaction

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga 2024-11-18 15:06:25 +03:00
parent f6f7260897
commit a61e9da3e9
No known key found for this signature in database
GPG Key ID: 511EED5C928AD547
6 changed files with 233 additions and 163 deletions

View File

@ -2390,8 +2390,8 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey) return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey)
} }
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) { func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction Store, peer *nbpeer.Peer, settings *Settings) (bool, error) {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -2402,7 +2402,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
} }
if peerLoginExpired(ctx, peer, settings) { if peerLoginExpired(ctx, peer, settings) {
err = am.handleExpiredPeer(ctx, user, peer) err = am.handleExpiredPeer(ctx, transaction, user, peer)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@ -11,6 +11,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/netbirdio/netbird/management/server/geolocation"
nbgroup "github.com/netbirdio/netbird/management/server/group" nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid" "github.com/rs/xid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -117,17 +118,25 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
// MarkPeerConnected marks peer as connected (true) or disconnected (false) // MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peerPubKey) var peer *nbpeer.Peer
var settings *Settings
var expired bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, peerPubKey)
if err != nil { if err != nil {
return err return err
} }
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return err return err
} }
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, accountID) expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID)
return err
})
if err != nil { if err != nil {
return err return err
} }
@ -151,7 +160,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return nil return nil
} }
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { func updatePeerStatusAndLocation(ctx context.Context, geo *geolocation.Geolocation, transaction Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
oldStatus := peer.Status.Copy() oldStatus := peer.Status.Copy()
newStatus := oldStatus newStatus := oldStatus
newStatus.LastSeen = time.Now().UTC() newStatus.LastSeen = time.Now().UTC()
@ -162,8 +171,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
} }
peer.Status = newStatus peer.Status = newStatus
if am.geo != nil && realIP != nil { if geo != nil && realIP != nil {
location, err := am.geo.Lookup(realIP) location, err := geo.Lookup(realIP)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err) log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
} else { } else {
@ -171,14 +180,14 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
peer.Location.CountryCode = location.Country.ISOCode peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID peer.Location.GeoNameID = location.City.GeonameID
err = am.Store.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer) err = transaction.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
} }
} }
} }
err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus) err := transaction.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -200,23 +209,49 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return nil, status.NewUserNotPartOfAccountError() return nil, status.NewUserNotPartOfAccountError()
} }
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, update.ID) var peer *nbpeer.Peer
if err != nil { var settings *Settings
return nil, err var peerGroupList []string
}
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, update.ID)
if err != nil {
return nil, err
}
var requiresPeerUpdates bool var requiresPeerUpdates bool
var newLabel string
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, update.ID)
if err != nil {
return err
}
settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
peerGroupList, err = getPeerGroupIDs(ctx, am.Store, accountID, update.ID)
if err != nil {
return err
}
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra) update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra)
if err != nil {
return err
}
if peer.Name != update.Name {
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID)
if err != nil {
return err
}
newLabel, err = getPeerHostLabel(update.Name, existingLabels)
if err != nil {
return err
}
peer.DNSLabel = newLabel
}
return transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -231,18 +266,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
if peer.Name != update.Name { if peer.Name != update.Name {
peer.Name = update.Name peer.Name = update.Name
peerLabelChanged = true peerLabelChanged = true
existingLabels, err := am.getPeerDNSLabels(ctx, accountID)
if err != nil {
return nil, err
}
newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
if err != nil {
return nil, err
}
peer.DNSLabel = newLabel
} }
if peer.LoginExpirationEnabled != update.LoginExpirationEnabled { if peer.LoginExpirationEnabled != update.LoginExpirationEnabled {
@ -261,10 +284,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
inactivityExpirationChanged = true inactivityExpirationChanged = true
} }
if err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil {
return nil, err
}
if sshChanged { if sshChanged {
event := activity.PeerSSHEnabled event := activity.PeerSSHEnabled
if !peer.SSHEnabled { if !peer.SSHEnabled {
@ -313,13 +332,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, peerID) peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, LockingStrengthShare, peerID)
if err != nil { if err != nil {
return err return err
} }
if peerAccountID != accountID {
return status.NewPeerNotPartOfAccountError()
}
var peer *nbpeer.Peer var peer *nbpeer.Peer
var addPeerRemovedEvents []func() var updateAccountPeers bool
var eventsToStore []func()
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, peerID) peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, peerID)
@ -327,16 +351,21 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err return err
} }
addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}) updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID)
if err != nil { if err != nil {
return err return err
} }
return transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID) if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
return err
}
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
return err
}) })
for _, addPeerRemovedEvent := range addPeerRemovedEvents { for _, storeEvent := range eventsToStore {
addPeerRemovedEvent() storeEvent()
} }
if updateAccountPeers { if updateAccountPeers {
@ -433,6 +462,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
var newPeer *nbpeer.Peer var newPeer *nbpeer.Peer
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
var setupKeyID string var setupKeyID string
@ -480,7 +510,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return fmt.Errorf("failed to get free DNS label: %w", err) return fmt.Errorf("failed to get free DNS label: %w", err)
} }
freeIP, err := am.getFreeIP(ctx, transaction, accountID) freeIP, err := getFreeIP(ctx, transaction, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get free IP: %w", err) return fmt.Errorf("failed to get free IP: %w", err)
} }
@ -564,6 +594,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
} }
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID)
if err != nil {
return err
}
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
return nil return nil
}) })
@ -581,11 +616,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
unlock() unlock()
unlock = nil unlock = nil
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, newPeer.ID)
if err != nil {
return nil, nil, nil, err
}
if updateAccountPeers { if updateAccountPeers {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
} }
@ -593,13 +623,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
} }
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) { func getFreeIP(ctx context.Context, transaction Store, accountID string) (net.IP, error) {
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID) takenIps, err := transaction.GetTakenIPs(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err) return nil, fmt.Errorf("failed to get taken IPs: %w", err)
} }
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID) network, err := transaction.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed getting network: %w", err) return nil, fmt.Errorf("failed getting network: %w", err)
} }
@ -614,49 +644,60 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, sync.WireGuardPubKey) var peer *nbpeer.Peer
var peerNotValid bool
var isStatusChanged bool
var updated bool
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, sync.WireGuardPubKey)
if err != nil { if err != nil {
return nil, nil, nil, status.NewPeerNotRegisteredError() return status.NewPeerNotRegisteredError()
} }
if peer.UserID != "" { if peer.UserID != "" {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
if err != nil { if err != nil {
return nil, nil, nil, err return err
} }
err = checkIfPeerOwnerIsBlocked(peer, user) if err = checkIfPeerOwnerIsBlocked(peer, user); err != nil {
if err != nil { return err
return nil, nil, nil, err
} }
} }
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return err
} }
if peerLoginExpired(ctx, peer, settings) { if peerLoginExpired(ctx, peer, settings) {
return nil, nil, nil, status.NewPeerLoginExpiredError() return status.NewPeerLoginExpiredError()
} }
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, peer.ID) peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
if err != nil { if err != nil {
return nil, nil, nil, err return err
} }
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupList, settings.Extra) peerNotValid, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil { if err != nil {
return nil, nil, nil, err return err
} }
updated := peer.UpdateMetaIfNew(sync.Meta) updated = peer.UpdateMetaIfNew(sync.Meta)
if updated { if updated {
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
if err != nil {
return err
}
}
return nil
})
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
}
if isStatusChanged || (updated && sync.UpdateAccountPeers) { if isStatusChanged || (updated && sync.UpdateAccountPeers) {
am.updateAccountPeers(ctx, accountID) am.updateAccountPeers(ctx, accountID)
@ -707,54 +748,49 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
}() }()
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey) var peer *nbpeer.Peer
var updateRemotePeers bool
var isRequiresApproval bool
var isStatusChanged bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil { if err != nil {
return nil, nil, nil, err return err
} }
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return err
} }
// this flag prevents unnecessary calls to the persistent store. // this flag prevents unnecessary calls to the persistent store.
shouldStorePeer := false shouldStorePeer := false
updateRemotePeers := false
if login.UserID != "" { if login.UserID != "" {
if peer.UserID != login.UserID { if peer.UserID != login.UserID {
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID) log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user") return status.Errorf(status.Unauthenticated, "invalid user")
} }
changed, err := am.handleUserPeer(ctx, peer, settings) changed, err := am.handleUserPeer(ctx, transaction, peer, settings)
if err != nil { if err != nil {
return nil, nil, nil, err return err
} }
if changed { if changed {
shouldStorePeer = true shouldStorePeer = true
updateRemotePeers = true updateRemotePeers = true
} }
} }
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
if err != nil { if err != nil {
return nil, nil, nil, err return err
} }
var grps []string isRequiresApproval, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
for _, group := range groups {
for _, id := range group.Peers {
if id == peer.ID {
grps = append(grps, group.ID)
break
}
}
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra)
if err != nil { if err != nil {
return nil, nil, nil, err return err
} }
updated := peer.UpdateMetaIfNew(login.Meta) updated := peer.UpdateMetaIfNew(login.Meta)
@ -768,11 +804,16 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
} }
if shouldStorePeer { if shouldStorePeer {
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) if err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil {
return err
}
}
return nil
})
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
}
unlockPeer() unlockPeer()
unlockPeer = nil unlockPeer = nil
@ -845,7 +886,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
} }
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error { func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction Store, user *User, peer *nbpeer.Peer) error {
err := checkAuth(ctx, user.Id, peer) err := checkAuth(ctx, user.Id, peer)
if err != nil { if err != nil {
return err return err
@ -853,12 +894,12 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
// If peer was expired before and if it reached this point, it is re-authenticated. // If peer was expired before and if it reached this point, it is re-authenticated.
// UserID is present, meaning that JWT validation passed successfully in the API layer. // UserID is present, meaning that JWT validation passed successfully in the API layer.
peer = peer.UpdateLastLogin() peer = peer.UpdateLastLogin()
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer) err = transaction.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer)
if err != nil { if err != nil {
return err return err
} }
err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin) err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
if err != nil { if err != nil {
return err return err
} }
@ -1149,7 +1190,12 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID
// GetPeerGroups returns groups that the peer is part of. // GetPeerGroups returns groups that the peer is part of.
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) { func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID) return getPeerGroups(ctx, am.Store, accountID, peerID)
}
// getPeerGroups returns the IDs of the groups that the peer is part of.
func getPeerGroups(ctx context.Context, transaction Store, accountID, peerID string) ([]*nbgroup.Group, error) {
groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1165,8 +1211,8 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p
} }
// getPeerGroupIDs returns the IDs of the groups that the peer is part of. // getPeerGroupIDs returns the IDs of the groups that the peer is part of.
func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID string, peerID string) ([]string, error) { func getPeerGroupIDs(ctx context.Context, transaction Store, accountID string, peerID string) ([]string, error) {
groups, err := am.GetPeerGroups(ctx, accountID, peerID) groups, err := getPeerGroups(ctx, transaction, accountID, peerID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1179,8 +1225,8 @@ func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID
return groupIDs, err return groupIDs, err
} }
func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID string) (lookupMap, error) { func getPeerDNSLabels(ctx context.Context, transaction Store, accountID string) (lookupMap, error) {
dnsLabels, err := am.Store.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID) dnsLabels, err := transaction.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1194,12 +1240,12 @@ func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID
// IsPeerInActiveGroup checks if the given peer is part of a group that is used // IsPeerInActiveGroup checks if the given peer is part of a group that is used
// in an active DNS, route, or ACL configuration. // in an active DNS, route, or ACL configuration.
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, accountID, peerID string) (bool, error) { func isPeerInActiveGroup(ctx context.Context, transaction Store, accountID, peerID string) (bool, error) {
peerGroupIDs, err := am.getPeerGroupIDs(ctx, accountID, peerID) peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID)
if err != nil { if err != nil {
return false, err return false, err
} }
return areGroupChangesAffectPeers(ctx, am.Store, accountID, peerGroupIDs) // TODO: use transaction return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction
} }
// deletePeers deletes all specified peers and sends updates to the remote peers. // deletePeers deletes all specified peers and sends updates to the remote peers.

View File

@ -754,6 +754,20 @@ func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength Lockin
return accountID, nil return accountID, nil
} }
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
var accountID string
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Select("account_id").Where(idQueryCondition, peerID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "peer %s account not found", peerID)
}
return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
}
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
var accountID string var accountID string
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID) result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)

View File

@ -86,6 +86,11 @@ func NewAccountNotFoundError(accountKey string) error {
return Errorf(NotFound, "account not found: %s", accountKey) return Errorf(NotFound, "account not found: %s", accountKey)
} }
// NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account
func NewPeerNotPartOfAccountError() error {
return Errorf(PermissionDenied, "peer is not part of this account")
}
// NewUserNotFoundError creates a new Error with NotFound type for a missing user // NewUserNotFoundError creates a new Error with NotFound type for a missing user
func NewUserNotFoundError(userKey string) error { func NewUserNotFoundError(userKey string) error {
return Errorf(NotFound, "user not found: %s", userKey) return Errorf(NotFound, "user not found: %s", userKey)

View File

@ -50,6 +50,7 @@ type Store interface {
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)

View File

@ -524,6 +524,10 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
storeEvent() storeEvent()
} }
for _, peer := range peers {
account.DeletePeer(peer.ID)
}
return hadPeers, nil return hadPeers, nil
} }