Merge remote-tracking branch 'origin/peers-get-account-refactoring' into peers-get-account-refactoring

This commit is contained in:
Pascal Fischer 2025-01-14 23:28:20 +01:00
commit 7241a16ff7
6 changed files with 54 additions and 44 deletions

View File

@ -45,6 +45,7 @@ import (
const (
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
peerSchedulerRetryInterval = 3 * time.Second
emptyUserID = "empty user ID in claims"
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
)
@ -469,7 +470,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
expiredPeers, err := am.getExpiredPeers(ctx, accountID)
if err != nil {
return 0, false
return peerSchedulerRetryInterval, true
}
var peerIDs []string
@ -481,7 +482,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil {
log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID)
return 0, false
return peerSchedulerRetryInterval, true
}
return am.getNextPeerExpiration(ctx, accountID)
@ -504,7 +505,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
inactivePeers, err := am.getInactivePeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID)
return 0, false
return peerSchedulerRetryInterval, true
}
var peerIDs []string
@ -516,7 +517,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil {
log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
return 0, false
return peerSchedulerRetryInterval, true
}
return am.getNextInactivePeerExpiration(ctx, accountID)

View File

@ -13,7 +13,8 @@ import (
)
type Manager interface {
GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error)
GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error)
AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error
AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error)
@ -37,7 +38,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, accou
}
}
func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) {
func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Read)
if err != nil {
return nil, err
@ -51,6 +52,15 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string
return nil, fmt.Errorf("error getting account groups: %w", err)
}
return groups, nil
}
func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) {
groups, err := m.GetAllGroups(ctx, accountID, userID)
if err != nil {
return nil, err
}
groupsMap := make(map[string]*types.Group)
for _, group := range groups {
groupsMap[group.ID] = group
@ -130,7 +140,7 @@ func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transa
return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID)
}
func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum {
func ToGroupsInfo(groups []*types.Group, id string) []api.GroupMinimum {
groupsInfo := []api.GroupMinimum{}
groupsChecked := make(map[string]struct{})
for _, group := range groups {
@ -167,7 +177,11 @@ func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum
return groupsInfo
}
func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) {
func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
return []*types.Group{}, nil
}
func (m *mockManager) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) {
return map[string]*types.Group{}, nil
}

View File

@ -82,7 +82,7 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
return
}
groups, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@ -267,7 +267,7 @@ func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, ne
return nil, nil, 0, fmt.Errorf("failed to get routers in network: %w", err)
}
groups, err := h.groupsManager.GetAllGroups(ctx, accountID, userID)
groups, err := h.groupsManager.GetAllGroupsMap(ctx, accountID, userID)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to get groups: %w", err)
}

View File

@ -72,13 +72,8 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
}
dnsDomain := h.accountManager.GetDNSDomain()
groupsMap := map[string]*types.Group{}
grps, _ := h.accountManager.GetAllGroups(ctx, accountID, userID)
for _, group := range grps {
groupsMap[group.ID] = group
}
groupsInfo := groups.ToGroupsInfo(groupsMap, peerID)
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
groupsInfo := groups.ToGroupsInfo(grps, peerID)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
@ -128,12 +123,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
return
}
groupsMap := map[string]*types.Group{}
for _, group := range peerGroups {
groupsMap[group.ID] = group
}
groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID)
groupMinimumInfo := groups.ToGroupsInfo(peerGroups, peer.ID)
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
if err != nil {
@ -204,11 +194,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
dnsDomain := h.accountManager.GetDNSDomain()
groupsMap := map[string]*types.Group{}
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
for _, group := range grps {
groupsMap[group.ID] = group
}
respBody := make([]*api.PeerBatch, 0, len(peers))
for _, peer := range peers {
@ -217,7 +203,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID)
groupMinimumInfo := groups.ToGroupsInfo(grps, peer.ID)
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
}

View File

@ -101,7 +101,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return nil, err
}
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.Groups, account.Peers, account.Settings.Extra)
if err != nil {
return nil, err
}
@ -335,6 +335,15 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
if err != nil {
return err
}
if user.AccountID != accountID {
return status.NewUserNotPartOfAccountError()
}
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID)
if err != nil {
return err
@ -1057,12 +1066,12 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return nil, err
}
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, err
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.Groups, account.Peers, account.Settings.Extra)
if err != nil {
return nil, err
}
@ -1139,6 +1148,11 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
// UpdateAccountPeer updates a single peer that belongs to an account.
// Should be called when changes need to be synced to a specific peer only.
func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) {
if !am.peersUpdateManager.HasChannel(peerId) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId)
return
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err)
@ -1151,11 +1165,6 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return
}
if !am.peersUpdateManager.HasChannel(peerId) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId)
return
}
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err)
@ -1185,7 +1194,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
return 0, false
return peerSchedulerRetryInterval, true
}
if len(peersWithExpiry) == 0 {
@ -1195,7 +1204,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return 0, false
return peerSchedulerRetryInterval, true
}
var nextExpiry *time.Duration
@ -1229,7 +1238,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
return 0, false
return peerSchedulerRetryInterval, true
}
if len(peersWithInactivity) == 0 {
@ -1239,7 +1248,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return 0, false
return peerSchedulerRetryInterval, true
}
var nextExpiry *time.Duration

View File

@ -332,7 +332,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, a
result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy)
if result.Error != nil {
return result.Error
return status.Errorf(status.Internal, "failed to save peer to store: %v", result.Error)
}
return nil
@ -358,7 +358,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
Where(idQueryCondition, accountID).
Updates(&accountCopy)
if result.Error != nil {
return result.Error
return status.Errorf(status.Internal, "failed to update account domain attributes to store: %v", result.Error)
}
if result.RowsAffected == 0 {
@ -381,7 +381,7 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStren
Where(accountAndIDQueryCondition, accountID, peerID).
Updates(&peerCopy)
if result.Error != nil {
return result.Error
return status.Errorf(status.Internal, "failed to save peer status to store: %v", result.Error)
}
if result.RowsAffected == 0 {
@ -403,7 +403,7 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr
Updates(peerCopy)
if result.Error != nil {
return result.Error
return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error)
}
if result.RowsAffected == 0 {