mirror of
https://github.com/netbirdio/netbird.git
synced 2025-08-09 23:27:58 +02:00
Refactor ephemeral peers and mark PAT as used
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
@ -1865,33 +1865,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
|
||||
|
||||
// MarkPATUsed marks a personal access token as used
|
||||
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
|
||||
|
||||
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccountByUser(ctx, user.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
|
||||
defer unlock()
|
||||
|
||||
account, err = am.Store.GetAccountByUser(ctx, user.Id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pat, ok := account.Users[user.Id].PATs[tokenID]
|
||||
if !ok {
|
||||
return fmt.Errorf("token not found")
|
||||
}
|
||||
|
||||
pat.LastUsed = time.Now().UTC()
|
||||
|
||||
return am.Store.SaveAccount(ctx, account)
|
||||
return am.Store.MarkPATUsed(ctx, LockingStrengthUpdate, tokenID)
|
||||
}
|
||||
|
||||
// GetAccount returns an account associated with this account ID.
|
||||
|
@ -20,10 +20,10 @@ var (
|
||||
)
|
||||
|
||||
type ephemeralPeer struct {
|
||||
id string
|
||||
account *Account
|
||||
deadline time.Time
|
||||
next *ephemeralPeer
|
||||
id string
|
||||
accountID string
|
||||
deadline time.Time
|
||||
next *ephemeralPeer
|
||||
}
|
||||
|
||||
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
|
||||
@ -104,12 +104,6 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
|
||||
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
|
||||
|
||||
a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
@ -117,7 +111,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
return
|
||||
}
|
||||
|
||||
e.addPeer(peer.ID, a, newDeadLine())
|
||||
e.addPeer(peer.AccountID, peer.ID, newDeadLine())
|
||||
if e.timer == nil {
|
||||
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
|
||||
e.cleanup(ctx)
|
||||
@ -126,17 +120,21 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
accounts := e.store.GetAllAccounts(context.Background())
|
||||
peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
t := newDeadLine()
|
||||
count := 0
|
||||
for _, a := range accounts {
|
||||
for id, p := range a.Peers {
|
||||
if p.Ephemeral {
|
||||
count++
|
||||
e.addPeer(id, a, t)
|
||||
}
|
||||
for _, p := range peers {
|
||||
if p.Ephemeral {
|
||||
count++
|
||||
e.addPeer(p.AccountID, p.ID, t)
|
||||
}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
|
||||
}
|
||||
|
||||
@ -170,18 +168,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
|
||||
for id, p := range deletePeers {
|
||||
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
|
||||
err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator)
|
||||
err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) {
|
||||
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
|
||||
ep := &ephemeralPeer{
|
||||
id: id,
|
||||
account: account,
|
||||
deadline: deadline,
|
||||
id: peerID,
|
||||
accountID: accountID,
|
||||
deadline: deadline,
|
||||
}
|
||||
|
||||
if e.headPeer == nil {
|
||||
|
@ -363,12 +363,12 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, accountID stri
|
||||
|
||||
// DeletePeer removes peer from the account by its IP
|
||||
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
|
||||
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
|
||||
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, LockingStrengthShare, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
if peerAccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
|
@ -44,7 +44,7 @@ type Peer struct {
|
||||
// CreatedAt records the time the peer was created
|
||||
CreatedAt time.Time `diff:"-"`
|
||||
// Indicate ephemeral peer attribute
|
||||
Ephemeral bool `diff:"-"`
|
||||
Ephemeral bool `gorm:"index" diff:"-"`
|
||||
// Geo location based on connection IP
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"`
|
||||
}
|
||||
|
@ -39,6 +39,7 @@ const (
|
||||
accountAndIDQueryCondition = "account_id = ? and id = ?"
|
||||
accountIDCondition = "account_id = ?"
|
||||
peerNotFoundFMT = "peer %s not found"
|
||||
batchSize = 500
|
||||
)
|
||||
|
||||
// SqlStore represents an account storage backed by a Sql DB persisted to disk
|
||||
@ -592,7 +593,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
return nil, status.NewAccountNotFoundError()
|
||||
}
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@ -708,7 +709,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
||||
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
return "", status.NewAccountNotFoundError()
|
||||
}
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
@ -719,6 +720,21 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
|
||||
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
|
||||
var accountID string
|
||||
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.NewAccountNotFoundError()
|
||||
}
|
||||
|
||||
return "", status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
|
||||
var accountID string
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
|
||||
Select("account_id").Where(idQueryCondition, peerID).First(&accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
|
||||
@ -798,7 +814,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
|
||||
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
|
||||
First(&accountNetwork, idQueryCondition, accountID).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
return nil, status.NewAccountNotFoundError()
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting network from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
|
||||
@ -1132,9 +1148,27 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
|
||||
return peer, nil
|
||||
}
|
||||
|
||||
// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing.
|
||||
func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) {
|
||||
var allEphemeralPeers, batchPeers []*nbpeer.Peer
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Where("ephemeral = ?", true).
|
||||
FindInBatches(&batchPeers, batchSize, func(tx *gorm.DB, batch int) error {
|
||||
allEphemeralPeers = append(allEphemeralPeers, batchPeers...)
|
||||
return nil
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to retrieve ephemeral peers: %s", result.Error)
|
||||
return nil, fmt.Errorf("failed to retrieve ephemeral peers")
|
||||
}
|
||||
|
||||
return allEphemeralPeers, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&Policy{}, accountAndIDQueryCondition, accountID, peerID)
|
||||
Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete peer from store")
|
||||
@ -1629,6 +1663,27 @@ func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength
|
||||
return pats, nil
|
||||
}
|
||||
|
||||
// MarkPATUsed marks a personal access token as used.
|
||||
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
|
||||
patCopy := PersonalAccessToken{
|
||||
LastUsed: time.Now().UTC(),
|
||||
}
|
||||
|
||||
fieldsToUpdate := []string{"last_used"}
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Select(fieldsToUpdate).Where(idQueryCondition, patID).Updates(&patCopy)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to mark PAT as used: %s", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to mark PAT as used")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewPATNotFoundError()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SavePAT saves a personal access token to the database.
|
||||
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error {
|
||||
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)
|
||||
|
@ -82,8 +82,8 @@ func NewPeerNotFoundError(peerKey string) error {
|
||||
}
|
||||
|
||||
// NewAccountNotFoundError creates a new Error with NotFound type for a missing account
|
||||
func NewAccountNotFoundError(accountKey string) error {
|
||||
return Errorf(NotFound, "account not found: %s", accountKey)
|
||||
func NewAccountNotFoundError() error {
|
||||
return Errorf(NotFound, "account not found")
|
||||
}
|
||||
|
||||
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
|
||||
@ -134,11 +134,20 @@ func NewUnauthorizedToViewSetupKeysError() error {
|
||||
return Errorf(PermissionDenied, "only users with admin power can view setup keys")
|
||||
}
|
||||
|
||||
func NewGroupNotFoundError() error {
|
||||
return Errorf(NotFound, "group not found")
|
||||
}
|
||||
|
||||
func NewUnauthorizedToViewGroupsError() error {
|
||||
return Errorf(PermissionDenied, "only users with admin power can view groups")
|
||||
}
|
||||
|
||||
func NewPATNotFoundError() error {
|
||||
return Errorf(NotFound, "PAT not found")
|
||||
}
|
||||
|
||||
func NewUnauthorizedToViewPATsError() error {
|
||||
return Errorf(PermissionDenied, "only users with admin power can view personal access tokens")
|
||||
return Errorf(PermissionDenied, "only users with admin power can view PATs")
|
||||
}
|
||||
|
||||
func NewUnauthorizedToViewPoliciesError() error {
|
||||
|
@ -50,6 +50,7 @@ type Store interface {
|
||||
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountIDByUserID(userID string) (string, error)
|
||||
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
|
||||
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
|
||||
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||
@ -99,6 +100,7 @@ type Store interface {
|
||||
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
|
||||
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
|
||||
SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
|
||||
SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error
|
||||
SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
|
||||
@ -123,6 +125,7 @@ type Store interface {
|
||||
|
||||
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
|
||||
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error)
|
||||
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
|
||||
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
|
||||
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error
|
||||
|
||||
|
Reference in New Issue
Block a user