Refactor ephemeral peers and mark PAT as used

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-31 21:50:05 +03:00
parent b7525d9fe8
commit 6b94f6e4e7
7 changed files with 99 additions and 60 deletions

View File

@ -1865,33 +1865,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
// MarkPATUsed marks a personal access token as used // MarkPATUsed marks a personal access token as used
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error { func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
return am.Store.MarkPATUsed(ctx, LockingStrengthUpdate, tokenID)
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
if err != nil {
return err
}
account, err := am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return err
}
pat, ok := account.Users[user.Id].PATs[tokenID]
if !ok {
return fmt.Errorf("token not found")
}
pat.LastUsed = time.Now().UTC()
return am.Store.SaveAccount(ctx, account)
} }
// GetAccount returns an account associated with this account ID. // GetAccount returns an account associated with this account ID.

View File

@ -20,10 +20,10 @@ var (
) )
type ephemeralPeer struct { type ephemeralPeer struct {
id string id string
account *Account accountID string
deadline time.Time deadline time.Time
next *ephemeralPeer next *ephemeralPeer
} }
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it // todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
@ -104,12 +104,6 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID) log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
return
}
e.peersLock.Lock() e.peersLock.Lock()
defer e.peersLock.Unlock() defer e.peersLock.Unlock()
@ -117,7 +111,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
return return
} }
e.addPeer(peer.ID, a, newDeadLine()) e.addPeer(peer.AccountID, peer.ID, newDeadLine())
if e.timer == nil { if e.timer == nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() { e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
e.cleanup(ctx) e.cleanup(ctx)
@ -126,17 +120,21 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
} }
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
accounts := e.store.GetAllAccounts(context.Background()) peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare)
if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return
}
t := newDeadLine() t := newDeadLine()
count := 0 count := 0
for _, a := range accounts { for _, p := range peers {
for id, p := range a.Peers { if p.Ephemeral {
if p.Ephemeral { count++
count++ e.addPeer(p.AccountID, p.ID, t)
e.addPeer(id, a, t)
}
} }
} }
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count) log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
} }
@ -170,18 +168,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
for id, p := range deletePeers { for id, p := range deletePeers {
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator) err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
} }
} }
} }
func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) { func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
ep := &ephemeralPeer{ ep := &ephemeralPeer{
id: id, id: peerID,
account: account, accountID: accountID,
deadline: deadline, deadline: deadline,
} }
if e.headPeer == nil { if e.headPeer == nil {

View File

@ -363,12 +363,12 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, accountID stri
// DeletePeer removes peer from the account by its IP // DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error { func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, LockingStrengthShare, peerID)
if err != nil { if err != nil {
return err return err
} }
if user.AccountID != accountID { if peerAccountID != accountID {
return status.NewUserNotPartOfAccountError() return status.NewUserNotPartOfAccountError()
} }

View File

@ -44,7 +44,7 @@ type Peer struct {
// CreatedAt records the time the peer was created // CreatedAt records the time the peer was created
CreatedAt time.Time `diff:"-"` CreatedAt time.Time `diff:"-"`
// Indicate ephemeral peer attribute // Indicate ephemeral peer attribute
Ephemeral bool `diff:"-"` Ephemeral bool `gorm:"index" diff:"-"`
// Geo location based on connection IP // Geo location based on connection IP
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"` Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"`
} }

View File

@ -39,6 +39,7 @@ const (
accountAndIDQueryCondition = "account_id = ? and id = ?" accountAndIDQueryCondition = "account_id = ? and id = ?"
accountIDCondition = "account_id = ?" accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found" peerNotFoundFMT = "peer %s not found"
batchSize = 500
) )
// SqlStore represents an account storage backed by a Sql DB persisted to disk // SqlStore represents an account storage backed by a Sql DB persisted to disk
@ -592,7 +593,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if result.Error != nil { if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID) return nil, status.NewAccountNotFoundError()
} }
return nil, status.NewGetAccountFromStoreError(result.Error) return nil, status.NewGetAccountFromStoreError(result.Error)
} }
@ -708,7 +709,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID) result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.NewAccountNotFoundError()
} }
return "", status.NewGetAccountFromStoreError(result.Error) return "", status.NewGetAccountFromStoreError(result.Error)
} }
@ -719,6 +720,21 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
var accountID string var accountID string
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.NewAccountNotFoundError()
}
return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
}
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Select("account_id").Where(idQueryCondition, peerID).First(&accountID)
if result.Error != nil { if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) { if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed") return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@ -798,7 +814,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
First(&accountNetwork, idQueryCondition, accountID).Error; err != nil { First(&accountNetwork, idQueryCondition, accountID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID) return nil, status.NewAccountNotFoundError()
} }
log.WithContext(ctx).Errorf("error when getting network from the store: %s", err) log.WithContext(ctx).Errorf("error when getting network from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err) return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
@ -1132,9 +1148,27 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
return peer, nil return peer, nil
} }
// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing.
func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) {
var allEphemeralPeers, batchPeers []*nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Where("ephemeral = ?", true).
FindInBatches(&batchPeers, batchSize, func(tx *gorm.DB, batch int) error {
allEphemeralPeers = append(allEphemeralPeers, batchPeers...)
return nil
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to retrieve ephemeral peers: %s", result.Error)
return nil, fmt.Errorf("failed to retrieve ephemeral peers")
}
return allEphemeralPeers, nil
}
func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error { func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&Policy{}, accountAndIDQueryCondition, accountID, peerID) Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err) log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete peer from store") return status.Errorf(status.Internal, "failed to delete peer from store")
@ -1629,6 +1663,27 @@ func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength
return pats, nil return pats, nil
} }
// MarkPATUsed marks a personal access token as used.
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
patCopy := PersonalAccessToken{
LastUsed: time.Now().UTC(),
}
fieldsToUpdate := []string{"last_used"}
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Select(fieldsToUpdate).Where(idQueryCondition, patID).Updates(&patCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark PAT as used: %s", result.Error)
return status.Errorf(status.Internal, "failed to mark PAT as used")
}
if result.RowsAffected == 0 {
return status.NewPATNotFoundError()
}
return nil
}
// SavePAT saves a personal access token to the database. // SavePAT saves a personal access token to the database.
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error { func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat) result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)

View File

@ -82,8 +82,8 @@ func NewPeerNotFoundError(peerKey string) error {
} }
// NewAccountNotFoundError creates a new Error with NotFound type for a missing account // NewAccountNotFoundError creates a new Error with NotFound type for a missing account
func NewAccountNotFoundError(accountKey string) error { func NewAccountNotFoundError() error {
return Errorf(NotFound, "account not found: %s", accountKey) return Errorf(NotFound, "account not found")
} }
// NewUserNotFoundError creates a new Error with NotFound type for a missing user // NewUserNotFoundError creates a new Error with NotFound type for a missing user
@ -134,11 +134,20 @@ func NewUnauthorizedToViewSetupKeysError() error {
return Errorf(PermissionDenied, "only users with admin power can view setup keys") return Errorf(PermissionDenied, "only users with admin power can view setup keys")
} }
func NewGroupNotFoundError() error {
return Errorf(NotFound, "group not found")
}
func NewUnauthorizedToViewGroupsError() error { func NewUnauthorizedToViewGroupsError() error {
return Errorf(PermissionDenied, "only users with admin power can view groups") return Errorf(PermissionDenied, "only users with admin power can view groups")
} }
func NewPATNotFoundError() error {
return Errorf(NotFound, "PAT not found")
}
func NewUnauthorizedToViewPATsError() error { func NewUnauthorizedToViewPATsError() error {
return Errorf(PermissionDenied, "only users with admin power can view personal access tokens") return Errorf(PermissionDenied, "only users with admin power can view PATs")
} }
func NewUnauthorizedToViewPoliciesError() error { func NewUnauthorizedToViewPoliciesError() error {

View File

@ -50,6 +50,7 @@ type Store interface {
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByUserID(userID string) (string, error) GetAccountIDByUserID(userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
@ -99,6 +100,7 @@ type Store interface {
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
@ -123,6 +125,7 @@ type Store interface {
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error)
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error