mirror of
https://github.com/netbirdio/netbird.git
synced 2025-06-28 21:51:40 +02:00
run peer ops in transaction
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
parent
f6f7260897
commit
a61e9da3e9
@ -2390,8 +2390,8 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
|
|||||||
return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey)
|
return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
|
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction Store, peer *nbpeer.Peer, settings *Settings) (bool, error) {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
|
user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@ -2402,7 +2402,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee
|
|||||||
}
|
}
|
||||||
|
|
||||||
if peerLoginExpired(ctx, peer, settings) {
|
if peerLoginExpired(ctx, peer, settings) {
|
||||||
err = am.handleExpiredPeer(ctx, user, peer)
|
err = am.handleExpiredPeer(ctx, transaction, user, peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@ -117,17 +118,25 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
|||||||
|
|
||||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
|
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
|
||||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peerPubKey)
|
var peer *nbpeer.Peer
|
||||||
|
var settings *Settings
|
||||||
|
var expired bool
|
||||||
|
var err error
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, peerPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
expired, err := am.updatePeerStatusAndLocation(ctx, peer, connected, realIP, accountID)
|
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -151,7 +160,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
|
func updatePeerStatusAndLocation(ctx context.Context, geo *geolocation.Geolocation, transaction Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) {
|
||||||
oldStatus := peer.Status.Copy()
|
oldStatus := peer.Status.Copy()
|
||||||
newStatus := oldStatus
|
newStatus := oldStatus
|
||||||
newStatus.LastSeen = time.Now().UTC()
|
newStatus.LastSeen = time.Now().UTC()
|
||||||
@ -162,8 +171,8 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
|
|||||||
}
|
}
|
||||||
peer.Status = newStatus
|
peer.Status = newStatus
|
||||||
|
|
||||||
if am.geo != nil && realIP != nil {
|
if geo != nil && realIP != nil {
|
||||||
location, err := am.geo.Lookup(realIP)
|
location, err := geo.Lookup(realIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||||
} else {
|
} else {
|
||||||
@ -171,14 +180,14 @@ func (am *DefaultAccountManager) updatePeerStatusAndLocation(ctx context.Context
|
|||||||
peer.Location.CountryCode = location.Country.ISOCode
|
peer.Location.CountryCode = location.Country.ISOCode
|
||||||
peer.Location.CityName = location.City.Names.En
|
peer.Location.CityName = location.City.Names.En
|
||||||
peer.Location.GeoNameID = location.City.GeonameID
|
peer.Location.GeoNameID = location.City.GeonameID
|
||||||
err = am.Store.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer)
|
err = transaction.SavePeerLocation(ctx, LockingStrengthUpdate, accountID, peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err := am.Store.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus)
|
err := transaction.SavePeerStatus(ctx, LockingStrengthUpdate, accountID, peer.ID, *newStatus)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@ -200,23 +209,49 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
return nil, status.NewUserNotPartOfAccountError()
|
return nil, status.NewUserNotPartOfAccountError()
|
||||||
}
|
}
|
||||||
|
|
||||||
peer, err := am.Store.GetPeerByID(ctx, LockingStrengthShare, accountID, update.ID)
|
var peer *nbpeer.Peer
|
||||||
if err != nil {
|
var settings *Settings
|
||||||
return nil, err
|
var peerGroupList []string
|
||||||
}
|
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, update.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var requiresPeerUpdates bool
|
var requiresPeerUpdates bool
|
||||||
|
var newLabel string
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, update.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
settings, err = transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peerGroupList, err = getPeerGroupIDs(ctx, am.Store, accountID, update.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra)
|
update, requiresPeerUpdates, err = am.integratedPeerValidator.ValidatePeer(ctx, update, peer, userID, accountID, am.GetDNSDomain(), peerGroupList, settings.Extra)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.Name != update.Name {
|
||||||
|
existingLabels, err := getPeerDNSLabels(ctx, transaction, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
newLabel, err = getPeerHostLabel(update.Name, existingLabels)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.DNSLabel = newLabel
|
||||||
|
}
|
||||||
|
|
||||||
|
return transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -231,18 +266,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
if peer.Name != update.Name {
|
if peer.Name != update.Name {
|
||||||
peer.Name = update.Name
|
peer.Name = update.Name
|
||||||
peerLabelChanged = true
|
peerLabelChanged = true
|
||||||
|
|
||||||
existingLabels, err := am.getPeerDNSLabels(ctx, accountID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
newLabel, err := getPeerHostLabel(peer.Name, existingLabels)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.DNSLabel = newLabel
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.LoginExpirationEnabled != update.LoginExpirationEnabled {
|
if peer.LoginExpirationEnabled != update.LoginExpirationEnabled {
|
||||||
@ -261,10 +284,6 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
|||||||
inactivityExpirationChanged = true
|
inactivityExpirationChanged = true
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if sshChanged {
|
if sshChanged {
|
||||||
event := activity.PeerSSHEnabled
|
event := activity.PeerSSHEnabled
|
||||||
if !peer.SSHEnabled {
|
if !peer.SSHEnabled {
|
||||||
@ -313,13 +332,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
|||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||||
defer unlock()
|
defer unlock()
|
||||||
|
|
||||||
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, peerID)
|
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, LockingStrengthShare, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if peerAccountID != accountID {
|
||||||
|
return status.NewPeerNotPartOfAccountError()
|
||||||
|
}
|
||||||
|
|
||||||
var peer *nbpeer.Peer
|
var peer *nbpeer.Peer
|
||||||
var addPeerRemovedEvents []func()
|
var updateAccountPeers bool
|
||||||
|
var eventsToStore []func()
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, peerID)
|
peer, err = transaction.GetPeerByID(ctx, LockingStrengthUpdate, accountID, peerID)
|
||||||
@ -327,16 +351,21 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
|
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID)
|
if err = transaction.IncrementNetworkSerial(ctx, LockingStrengthUpdate, accountID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer})
|
||||||
|
return err
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, addPeerRemovedEvent := range addPeerRemovedEvents {
|
for _, storeEvent := range eventsToStore {
|
||||||
addPeerRemovedEvent()
|
storeEvent()
|
||||||
}
|
}
|
||||||
|
|
||||||
if updateAccountPeers {
|
if updateAccountPeers {
|
||||||
@ -433,6 +462,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
|
|
||||||
var newPeer *nbpeer.Peer
|
var newPeer *nbpeer.Peer
|
||||||
|
var updateAccountPeers bool
|
||||||
|
|
||||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
var setupKeyID string
|
var setupKeyID string
|
||||||
@ -480,7 +510,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return fmt.Errorf("failed to get free DNS label: %w", err)
|
return fmt.Errorf("failed to get free DNS label: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
freeIP, err := am.getFreeIP(ctx, transaction, accountID)
|
freeIP, err := getFreeIP(ctx, transaction, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get free IP: %w", err)
|
return fmt.Errorf("failed to get free IP: %w", err)
|
||||||
}
|
}
|
||||||
@ -564,6 +594,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
updateAccountPeers, err = isPeerInActiveGroup(ctx, transaction, accountID, newPeer.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@ -581,11 +616,6 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
unlock()
|
unlock()
|
||||||
unlock = nil
|
unlock = nil
|
||||||
|
|
||||||
updateAccountPeers, err := am.isPeerInActiveGroup(ctx, accountID, newPeer.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if updateAccountPeers {
|
if updateAccountPeers {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
@ -593,13 +623,13 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
|
func getFreeIP(ctx context.Context, transaction Store, accountID string) (net.IP, error) {
|
||||||
takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID)
|
takenIps, err := transaction.GetTakenIPs(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
|
return nil, fmt.Errorf("failed to get taken IPs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
|
network, err := transaction.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed getting network: %w", err)
|
return nil, fmt.Errorf("failed getting network: %w", err)
|
||||||
}
|
}
|
||||||
@ -614,49 +644,60 @@ func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, acc
|
|||||||
|
|
||||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
|
||||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, sync.WireGuardPubKey)
|
var peer *nbpeer.Peer
|
||||||
|
var peerNotValid bool
|
||||||
|
var isStatusChanged bool
|
||||||
|
var updated bool
|
||||||
|
var err error
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, sync.WireGuardPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, status.NewPeerNotRegisteredError()
|
return status.NewPeerNotRegisteredError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer.UserID != "" {
|
if peer.UserID != "" {
|
||||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
|
user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = checkIfPeerOwnerIsBlocked(peer, user)
|
if err = checkIfPeerOwnerIsBlocked(peer, user); err != nil {
|
||||||
if err != nil {
|
return err
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if peerLoginExpired(ctx, peer, settings) {
|
if peerLoginExpired(ctx, peer, settings) {
|
||||||
return nil, nil, nil, status.NewPeerLoginExpiredError()
|
return status.NewPeerLoginExpiredError()
|
||||||
}
|
}
|
||||||
|
|
||||||
peerGroupList, err := am.getPeerGroupIDs(ctx, accountID, peer.ID)
|
peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peer.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
peerNotValid, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupList, settings.Extra)
|
peerNotValid, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
updated := peer.UpdateMetaIfNew(sync.Meta)
|
updated = peer.UpdateMetaIfNew(sync.Meta)
|
||||||
if updated {
|
if updated {
|
||||||
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
|
err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if isStatusChanged || (updated && sync.UpdateAccountPeers) {
|
if isStatusChanged || (updated && sync.UpdateAccountPeers) {
|
||||||
am.updateAccountPeers(ctx, accountID)
|
am.updateAccountPeers(ctx, accountID)
|
||||||
@ -707,54 +748,49 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
|
var peer *nbpeer.Peer
|
||||||
|
var updateRemotePeers bool
|
||||||
|
var isRequiresApproval bool
|
||||||
|
var isStatusChanged bool
|
||||||
|
|
||||||
|
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
|
||||||
|
peer, err = transaction.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// this flag prevents unnecessary calls to the persistent store.
|
// this flag prevents unnecessary calls to the persistent store.
|
||||||
shouldStorePeer := false
|
shouldStorePeer := false
|
||||||
updateRemotePeers := false
|
|
||||||
|
|
||||||
if login.UserID != "" {
|
if login.UserID != "" {
|
||||||
if peer.UserID != login.UserID {
|
if peer.UserID != login.UserID {
|
||||||
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
|
log.Warnf("user mismatch when logging in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID)
|
||||||
return nil, nil, nil, status.Errorf(status.Unauthenticated, "invalid user")
|
return status.Errorf(status.Unauthenticated, "invalid user")
|
||||||
}
|
}
|
||||||
|
|
||||||
changed, err := am.handleUserPeer(ctx, peer, settings)
|
changed, err := am.handleUserPeer(ctx, transaction, peer, settings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if changed {
|
if changed {
|
||||||
shouldStorePeer = true
|
shouldStorePeer = true
|
||||||
updateRemotePeers = true
|
updateRemotePeers = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
peerGroupIDs, err := getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var grps []string
|
isRequiresApproval, isStatusChanged, err = am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||||
for _, group := range groups {
|
|
||||||
for _, id := range group.Peers {
|
|
||||||
if id == peer.ID {
|
|
||||||
grps = append(grps, group.ID)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
updated := peer.UpdateMetaIfNew(login.Meta)
|
updated := peer.UpdateMetaIfNew(login.Meta)
|
||||||
@ -768,11 +804,16 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if shouldStorePeer {
|
if shouldStorePeer {
|
||||||
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
|
if err = transaction.SavePeer(ctx, LockingStrengthUpdate, accountID, peer); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, err
|
return nil, nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
unlockPeer()
|
unlockPeer()
|
||||||
unlockPeer = nil
|
unlockPeer = nil
|
||||||
@ -845,7 +886,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
|||||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error {
|
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transaction Store, user *User, peer *nbpeer.Peer) error {
|
||||||
err := checkAuth(ctx, user.Id, peer)
|
err := checkAuth(ctx, user.Id, peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -853,12 +894,12 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us
|
|||||||
// If peer was expired before and if it reached this point, it is re-authenticated.
|
// If peer was expired before and if it reached this point, it is re-authenticated.
|
||||||
// UserID is present, meaning that JWT validation passed successfully in the API layer.
|
// UserID is present, meaning that JWT validation passed successfully in the API layer.
|
||||||
peer = peer.UpdateLastLogin()
|
peer = peer.UpdateLastLogin()
|
||||||
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer)
|
err = transaction.SavePeer(ctx, LockingStrengthUpdate, peer.AccountID, peer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
|
err = transaction.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -1149,7 +1190,12 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID
|
|||||||
|
|
||||||
// GetPeerGroups returns groups that the peer is part of.
|
// GetPeerGroups returns groups that the peer is part of.
|
||||||
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
|
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*nbgroup.Group, error) {
|
||||||
groups, err := am.Store.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
return getPeerGroups(ctx, am.Store, accountID, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPeerGroups returns the IDs of the groups that the peer is part of.
|
||||||
|
func getPeerGroups(ctx context.Context, transaction Store, accountID, peerID string) ([]*nbgroup.Group, error) {
|
||||||
|
groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -1165,8 +1211,8 @@ func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, p
|
|||||||
}
|
}
|
||||||
|
|
||||||
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
|
// getPeerGroupIDs returns the IDs of the groups that the peer is part of.
|
||||||
func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID string, peerID string) ([]string, error) {
|
func getPeerGroupIDs(ctx context.Context, transaction Store, accountID string, peerID string) ([]string, error) {
|
||||||
groups, err := am.GetPeerGroups(ctx, accountID, peerID)
|
groups, err := getPeerGroups(ctx, transaction, accountID, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -1179,8 +1225,8 @@ func (am *DefaultAccountManager) getPeerGroupIDs(ctx context.Context, accountID
|
|||||||
return groupIDs, err
|
return groupIDs, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID string) (lookupMap, error) {
|
func getPeerDNSLabels(ctx context.Context, transaction Store, accountID string) (lookupMap, error) {
|
||||||
dnsLabels, err := am.Store.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID)
|
dnsLabels, err := transaction.GetAccountPeerDNSLabels(ctx, LockingStrengthShare, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -1194,12 +1240,12 @@ func (am *DefaultAccountManager) getPeerDNSLabels(ctx context.Context, accountID
|
|||||||
|
|
||||||
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
// IsPeerInActiveGroup checks if the given peer is part of a group that is used
|
||||||
// in an active DNS, route, or ACL configuration.
|
// in an active DNS, route, or ACL configuration.
|
||||||
func (am *DefaultAccountManager) isPeerInActiveGroup(ctx context.Context, accountID, peerID string) (bool, error) {
|
func isPeerInActiveGroup(ctx context.Context, transaction Store, accountID, peerID string) (bool, error) {
|
||||||
peerGroupIDs, err := am.getPeerGroupIDs(ctx, accountID, peerID)
|
peerGroupIDs, err := getPeerGroupIDs(ctx, transaction, accountID, peerID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
return areGroupChangesAffectPeers(ctx, am.Store, accountID, peerGroupIDs) // TODO: use transaction
|
return areGroupChangesAffectPeers(ctx, transaction, accountID, peerGroupIDs) // TODO: use transaction
|
||||||
}
|
}
|
||||||
|
|
||||||
// deletePeers deletes all specified peers and sends updates to the remote peers.
|
// deletePeers deletes all specified peers and sends updates to the remote peers.
|
||||||
|
@ -754,6 +754,20 @@ func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength Lockin
|
|||||||
return accountID, nil
|
return accountID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
|
||||||
|
var accountID string
|
||||||
|
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||||
|
Select("account_id").Where(idQueryCondition, peerID).First(&accountID)
|
||||||
|
if result.Error != nil {
|
||||||
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
|
return "", status.Errorf(status.NotFound, "peer %s account not found", peerID)
|
||||||
|
}
|
||||||
|
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return accountID, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
|
||||||
var accountID string
|
var accountID string
|
||||||
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
|
result := s.db.Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
|
||||||
|
@ -86,6 +86,11 @@ func NewAccountNotFoundError(accountKey string) error {
|
|||||||
return Errorf(NotFound, "account not found: %s", accountKey)
|
return Errorf(NotFound, "account not found: %s", accountKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account
|
||||||
|
func NewPeerNotPartOfAccountError() error {
|
||||||
|
return Errorf(PermissionDenied, "peer is not part of this account")
|
||||||
|
}
|
||||||
|
|
||||||
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
|
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
|
||||||
func NewUserNotFoundError(userKey string) error {
|
func NewUserNotFoundError(userKey string) error {
|
||||||
return Errorf(NotFound, "user not found: %s", userKey)
|
return Errorf(NotFound, "user not found: %s", userKey)
|
||||||
|
@ -50,6 +50,7 @@ type Store interface {
|
|||||||
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
||||||
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
|
GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error)
|
||||||
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
||||||
|
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
|
||||||
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
||||||
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
||||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||||
|
@ -524,6 +524,10 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU
|
|||||||
storeEvent()
|
storeEvent()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, peer := range peers {
|
||||||
|
account.DeletePeer(peer.ID)
|
||||||
|
}
|
||||||
|
|
||||||
return hadPeers, nil
|
return hadPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user