diff --git a/management/server/account.go b/management/server/account.go index e30e30759..676659b56 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1865,33 +1865,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str // MarkPATUsed marks a personal access token as used func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error { - - user, err := am.Store.GetUserByTokenID(ctx, tokenID) - if err != nil { - return err - } - - account, err := am.Store.GetAccountByUser(ctx, user.Id) - if err != nil { - return err - } - - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) - defer unlock() - - account, err = am.Store.GetAccountByUser(ctx, user.Id) - if err != nil { - return err - } - - pat, ok := account.Users[user.Id].PATs[tokenID] - if !ok { - return fmt.Errorf("token not found") - } - - pat.LastUsed = time.Now().UTC() - - return am.Store.SaveAccount(ctx, account) + return am.Store.MarkPATUsed(ctx, LockingStrengthUpdate, tokenID) } // GetAccount returns an account associated with this account ID. diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 590b1d708..6e245ec5a 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -20,10 +20,10 @@ var ( ) type ephemeralPeer struct { - id string - account *Account - deadline time.Time - next *ephemeralPeer + id string + accountID string + deadline time.Time + next *ephemeralPeer } // todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it @@ -104,12 +104,6 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID) - a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID) - if err != nil { - log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err) - return - } - e.peersLock.Lock() defer e.peersLock.Unlock() @@ -117,7 +111,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. return } - e.addPeer(peer.ID, a, newDeadLine()) + e.addPeer(peer.AccountID, peer.ID, newDeadLine()) if e.timer == nil { e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { e.cleanup(ctx) @@ -126,17 +120,21 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer. } func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { - accounts := e.store.GetAllAccounts(context.Background()) + peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare) + if err != nil { + log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err) + return + } + t := newDeadLine() count := 0 - for _, a := range accounts { - for id, p := range a.Peers { - if p.Ephemeral { - count++ - e.addPeer(id, a, t) - } + for _, p := range peers { + if p.Ephemeral { + count++ + e.addPeer(p.AccountID, p.ID, t) } } + log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count) } @@ -170,18 +168,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { for id, p := range deletePeers { log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) - err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator) + err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator) if err != nil { log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) } } } -func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) { +func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) { ep := &ephemeralPeer{ - id: id, - account: account, - deadline: deadline, + id: peerID, + accountID: accountID, + deadline: deadline, } if e.headPeer == nil { diff --git a/management/server/peer.go b/management/server/peer.go index c58e7b225..eaa119e11 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -363,12 +363,12 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, accountID stri // DeletePeer removes peer from the account by its IP func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { - user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, LockingStrengthShare, peerID) if err != nil { return err } - if user.AccountID != accountID { + if peerAccountID != accountID { return status.NewUserNotPartOfAccountError() } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 82e0acf3a..24a4b98ea 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -44,7 +44,7 @@ type Peer struct { // CreatedAt records the time the peer was created CreatedAt time.Time `diff:"-"` // Indicate ephemeral peer attribute - Ephemeral bool `diff:"-"` + Ephemeral bool `gorm:"index" diff:"-"` // Geo location based on connection IP Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"` } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index c68c182df..a5316d72d 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -39,6 +39,7 @@ const ( accountAndIDQueryCondition = "account_id = ? and id = ?" accountIDCondition = "account_id = ?" peerNotFoundFMT = "peer %s not found" + batchSize = 500 ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -592,7 +593,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.NewAccountNotFoundError(accountID) + return nil, status.NewAccountNotFoundError() } return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -708,7 +709,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return "", status.Errorf(status.NotFound, "account not found: index lookup failed") + return "", status.NewAccountNotFoundError() } return "", status.NewGetAccountFromStoreError(result.Error) } @@ -719,6 +720,21 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { var accountID string result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return "", status.NewAccountNotFoundError() + } + + return "", status.NewGetAccountFromStoreError(result.Error) + } + + return accountID, nil +} + +func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) { + var accountID string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + Select("account_id").Where(idQueryCondition, peerID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -798,7 +814,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). First(&accountNetwork, idQueryCondition, accountID).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, status.NewAccountNotFoundError(accountID) + return nil, status.NewAccountNotFoundError() } log.WithContext(ctx).Errorf("error when getting network from the store: %s", err) return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err) @@ -1132,9 +1148,27 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength return peer, nil } +// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing. +func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) { + var allEphemeralPeers, batchPeers []*nbpeer.Peer + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Where("ephemeral = ?", true). + FindInBatches(&batchPeers, batchSize, func(tx *gorm.DB, batch int) error { + allEphemeralPeers = append(allEphemeralPeers, batchPeers...) + return nil + }) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to retrieve ephemeral peers: %s", result.Error) + return nil, fmt.Errorf("failed to retrieve ephemeral peers") + } + + return allEphemeralPeers, nil +} + func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&Policy{}, accountAndIDQueryCondition, accountID, peerID) + Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err) return status.Errorf(status.Internal, "failed to delete peer from store") @@ -1629,6 +1663,27 @@ func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength return pats, nil } +// MarkPATUsed marks a personal access token as used. +func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error { + patCopy := PersonalAccessToken{ + LastUsed: time.Now().UTC(), + } + + fieldsToUpdate := []string{"last_used"} + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Select(fieldsToUpdate).Where(idQueryCondition, patID).Updates(&patCopy) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to mark PAT as used: %s", result.Error) + return status.Errorf(status.Internal, "failed to mark PAT as used") + } + + if result.RowsAffected == 0 { + return status.NewPATNotFoundError() + } + + return nil +} + // SavePAT saves a personal access token to the database. func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error { result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat) diff --git a/management/server/status/error.go b/management/server/status/error.go index 7a4ec3f67..0dd302dfa 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -82,8 +82,8 @@ func NewPeerNotFoundError(peerKey string) error { } // NewAccountNotFoundError creates a new Error with NotFound type for a missing account -func NewAccountNotFoundError(accountKey string) error { - return Errorf(NotFound, "account not found: %s", accountKey) +func NewAccountNotFoundError() error { + return Errorf(NotFound, "account not found") } // NewUserNotFoundError creates a new Error with NotFound type for a missing user @@ -134,11 +134,20 @@ func NewUnauthorizedToViewSetupKeysError() error { return Errorf(PermissionDenied, "only users with admin power can view setup keys") } +func NewGroupNotFoundError() error { + return Errorf(NotFound, "group not found") +} + func NewUnauthorizedToViewGroupsError() error { return Errorf(PermissionDenied, "only users with admin power can view groups") } + +func NewPATNotFoundError() error { + return Errorf(NotFound, "PAT not found") +} + func NewUnauthorizedToViewPATsError() error { - return Errorf(PermissionDenied, "only users with admin power can view personal access tokens") + return Errorf(PermissionDenied, "only users with admin power can view PATs") } func NewUnauthorizedToViewPoliciesError() error { diff --git a/management/server/store.go b/management/server/store.go index fda499e9d..ddaf37e17 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -50,6 +50,7 @@ type Store interface { GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByUserID(userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) + GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) @@ -99,6 +100,7 @@ type Store interface { GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) + GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error @@ -123,6 +125,7 @@ type Store interface { GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error) + MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error