Refactor ephemeral peers and mark PAT as used

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-31 21:50:05 +03:00
parent b7525d9fe8
commit 6b94f6e4e7
7 changed files with 99 additions and 60 deletions

View File

@ -1865,33 +1865,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
// MarkPATUsed marks a personal access token as used
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
user, err := am.Store.GetUserByTokenID(ctx, tokenID)
if err != nil {
return err
}
account, err := am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return err
}
unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()
account, err = am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return err
}
pat, ok := account.Users[user.Id].PATs[tokenID]
if !ok {
return fmt.Errorf("token not found")
}
pat.LastUsed = time.Now().UTC()
return am.Store.SaveAccount(ctx, account)
return am.Store.MarkPATUsed(ctx, LockingStrengthUpdate, tokenID)
}
// GetAccount returns an account associated with this account ID.

View File

@ -20,10 +20,10 @@ var (
)
type ephemeralPeer struct {
id string
account *Account
deadline time.Time
next *ephemeralPeer
id string
accountID string
deadline time.Time
next *ephemeralPeer
}
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
@ -104,12 +104,6 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
a, err := e.store.GetAccountByPeerID(context.Background(), peer.ID)
if err != nil {
log.WithContext(ctx).Errorf("failed to add peer to ephemeral list: %s", err)
return
}
e.peersLock.Lock()
defer e.peersLock.Unlock()
@ -117,7 +111,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
return
}
e.addPeer(peer.ID, a, newDeadLine())
e.addPeer(peer.AccountID, peer.ID, newDeadLine())
if e.timer == nil {
e.timer = time.AfterFunc(e.headPeer.deadline.Sub(timeNow()), func() {
e.cleanup(ctx)
@ -126,17 +120,21 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
}
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
accounts := e.store.GetAllAccounts(context.Background())
peers, err := e.store.GetAllEphemeralPeers(ctx, LockingStrengthShare)
if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return
}
t := newDeadLine()
count := 0
for _, a := range accounts {
for id, p := range a.Peers {
if p.Ephemeral {
count++
e.addPeer(id, a, t)
}
for _, p := range peers {
if p.Ephemeral {
count++
e.addPeer(p.AccountID, p.ID, t)
}
}
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", count)
}
@ -170,18 +168,18 @@ func (e *EphemeralManager) cleanup(ctx context.Context) {
for id, p := range deletePeers {
log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id)
err := e.accountManager.DeletePeer(ctx, p.account.Id, id, activity.SystemInitiator)
err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
}
}
}
func (e *EphemeralManager) addPeer(id string, account *Account, deadline time.Time) {
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
ep := &ephemeralPeer{
id: id,
account: account,
deadline: deadline,
id: peerID,
accountID: accountID,
deadline: deadline,
}
if e.headPeer == nil {

View File

@ -363,12 +363,12 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, accountID stri
// DeletePeer removes peer from the account by its IP
func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peerID, userID string) error {
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, LockingStrengthShare, peerID)
if err != nil {
return err
}
if user.AccountID != accountID {
if peerAccountID != accountID {
return status.NewUserNotPartOfAccountError()
}

View File

@ -44,7 +44,7 @@ type Peer struct {
// CreatedAt records the time the peer was created
CreatedAt time.Time `diff:"-"`
// Indicate ephemeral peer attribute
Ephemeral bool `diff:"-"`
Ephemeral bool `gorm:"index" diff:"-"`
// Geo location based on connection IP
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"`
}

View File

@ -39,6 +39,7 @@ const (
accountAndIDQueryCondition = "account_id = ? and id = ?"
accountIDCondition = "account_id = ?"
peerNotFoundFMT = "peer %s not found"
batchSize = 500
)
// SqlStore represents an account storage backed by a Sql DB persisted to disk
@ -592,7 +593,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
return nil, status.NewAccountNotFoundError()
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
@ -708,7 +709,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
return "", status.NewAccountNotFoundError()
}
return "", status.NewGetAccountFromStoreError(result.Error)
}
@ -719,6 +720,21 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
var accountID string
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.NewAccountNotFoundError()
}
return "", status.NewGetAccountFromStoreError(result.Error)
}
return accountID, nil
}
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Select("account_id").Where(idQueryCondition, peerID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
@ -798,7 +814,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
First(&accountNetwork, idQueryCondition, accountID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
return nil, status.NewAccountNotFoundError()
}
log.WithContext(ctx).Errorf("error when getting network from the store: %s", err)
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
@ -1132,9 +1148,27 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
return peer, nil
}
// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing.
func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) {
var allEphemeralPeers, batchPeers []*nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Where("ephemeral = ?", true).
FindInBatches(&batchPeers, batchSize, func(tx *gorm.DB, batch int) error {
allEphemeralPeers = append(allEphemeralPeers, batchPeers...)
return nil
})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to retrieve ephemeral peers: %s", result.Error)
return nil, fmt.Errorf("failed to retrieve ephemeral peers")
}
return allEphemeralPeers, nil
}
func (s *SqlStore) DeletePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&Policy{}, accountAndIDQueryCondition, accountID, peerID)
Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete peer from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete peer from store")
@ -1629,6 +1663,27 @@ func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength
return pats, nil
}
// MarkPATUsed marks a personal access token as used.
func (s *SqlStore) MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error {
patCopy := PersonalAccessToken{
LastUsed: time.Now().UTC(),
}
fieldsToUpdate := []string{"last_used"}
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Select(fieldsToUpdate).Where(idQueryCondition, patID).Updates(&patCopy)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to mark PAT as used: %s", result.Error)
return status.Errorf(status.Internal, "failed to mark PAT as used")
}
if result.RowsAffected == 0 {
return status.NewPATNotFoundError()
}
return nil
}
// SavePAT saves a personal access token to the database.
func (s *SqlStore) SavePAT(ctx context.Context, lockStrength LockingStrength, pat *PersonalAccessToken) error {
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(pat)

View File

@ -82,8 +82,8 @@ func NewPeerNotFoundError(peerKey string) error {
}
// NewAccountNotFoundError creates a new Error with NotFound type for a missing account
func NewAccountNotFoundError(accountKey string) error {
return Errorf(NotFound, "account not found: %s", accountKey)
func NewAccountNotFoundError() error {
return Errorf(NotFound, "account not found")
}
// NewUserNotFoundError creates a new Error with NotFound type for a missing user
@ -134,11 +134,20 @@ func NewUnauthorizedToViewSetupKeysError() error {
return Errorf(PermissionDenied, "only users with admin power can view setup keys")
}
func NewGroupNotFoundError() error {
return Errorf(NotFound, "group not found")
}
func NewUnauthorizedToViewGroupsError() error {
return Errorf(PermissionDenied, "only users with admin power can view groups")
}
func NewPATNotFoundError() error {
return Errorf(NotFound, "PAT not found")
}
func NewUnauthorizedToViewPATsError() error {
return Errorf(PermissionDenied, "only users with admin power can view personal access tokens")
return Errorf(PermissionDenied, "only users with admin power can view PATs")
}
func NewUnauthorizedToViewPoliciesError() error {

View File

@ -50,6 +50,7 @@ type Store interface {
GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByUserID(userID string) (string, error)
GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error)
GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error)
GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) // todo use key hash later
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
@ -99,6 +100,7 @@ type Store interface {
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, lockStrength LockingStrength, accountID, peerID string, status nbpeer.PeerStatus) error
SavePeerLocation(ctx context.Context, lockStrength LockingStrength, accountID string, peer *nbpeer.Peer) error
@ -123,6 +125,7 @@ type Store interface {
GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error)
GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error)
MarkPATUsed(ctx context.Context, lockStrength LockingStrength, patID string) error
SavePAT(ctx context.Context, strength LockingStrength, pat *PersonalAccessToken) error
DeletePAT(ctx context.Context, strength LockingStrength, userID, patID string) error