mirror of
https://github.com/netbirdio/netbird.git
synced 2025-03-04 01:41:17 +01:00
Merge remote-tracking branch 'origin/peers-get-account-refactoring' into peers-get-account-refactoring
This commit is contained in:
commit
7241a16ff7
@ -45,6 +45,7 @@ import (
|
||||
const (
|
||||
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
|
||||
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
|
||||
peerSchedulerRetryInterval = 3 * time.Second
|
||||
emptyUserID = "empty user ID in claims"
|
||||
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
|
||||
)
|
||||
@ -469,7 +470,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
|
||||
|
||||
expiredPeers, err := am.getExpiredPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
return peerSchedulerRetryInterval, true
|
||||
}
|
||||
|
||||
var peerIDs []string
|
||||
@ -481,7 +482,7 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
|
||||
|
||||
if err := am.expireAndUpdatePeers(ctx, accountID, expiredPeers); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed updating account peers while expiring peers for account %s", accountID)
|
||||
return 0, false
|
||||
return peerSchedulerRetryInterval, true
|
||||
}
|
||||
|
||||
return am.getNextPeerExpiration(ctx, accountID)
|
||||
@ -504,7 +505,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
|
||||
inactivePeers, err := am.getInactivePeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed getting inactive peers for account %s", accountID)
|
||||
return 0, false
|
||||
return peerSchedulerRetryInterval, true
|
||||
}
|
||||
|
||||
var peerIDs []string
|
||||
@ -516,7 +517,7 @@ func (am *DefaultAccountManager) peerInactivityExpirationJob(ctx context.Context
|
||||
|
||||
if err := am.expireAndUpdatePeers(ctx, accountID, inactivePeers); err != nil {
|
||||
log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
|
||||
return 0, false
|
||||
return peerSchedulerRetryInterval, true
|
||||
}
|
||||
|
||||
return am.getNextInactivePeerExpiration(ctx, accountID)
|
||||
|
@ -13,7 +13,8 @@ import (
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error)
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error)
|
||||
GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
||||
AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error
|
||||
AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error)
|
||||
@ -37,7 +38,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, accou
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) {
|
||||
func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Read)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -51,6 +52,15 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string
|
||||
return nil, fmt.Errorf("error getting account groups: %w", err)
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) {
|
||||
groups, err := m.GetAllGroups(ctx, accountID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupsMap := make(map[string]*types.Group)
|
||||
for _, group := range groups {
|
||||
groupsMap[group.ID] = group
|
||||
@ -130,7 +140,7 @@ func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transa
|
||||
return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID)
|
||||
}
|
||||
|
||||
func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum {
|
||||
func ToGroupsInfo(groups []*types.Group, id string) []api.GroupMinimum {
|
||||
groupsInfo := []api.GroupMinimum{}
|
||||
groupsChecked := make(map[string]struct{})
|
||||
for _, group := range groups {
|
||||
@ -167,7 +177,11 @@ func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum
|
||||
return groupsInfo
|
||||
}
|
||||
|
||||
func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) {
|
||||
func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) {
|
||||
return []*types.Group{}, nil
|
||||
}
|
||||
|
||||
func (m *mockManager) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) {
|
||||
return map[string]*types.Group{}, nil
|
||||
}
|
||||
|
||||
|
@ -82,7 +82,7 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
groups, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@ -267,7 +267,7 @@ func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, ne
|
||||
return nil, nil, 0, fmt.Errorf("failed to get routers in network: %w", err)
|
||||
}
|
||||
|
||||
groups, err := h.groupsManager.GetAllGroups(ctx, accountID, userID)
|
||||
groups, err := h.groupsManager.GetAllGroupsMap(ctx, accountID, userID)
|
||||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf("failed to get groups: %w", err)
|
||||
}
|
||||
|
@ -72,13 +72,8 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string,
|
||||
}
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
groupsMap := map[string]*types.Group{}
|
||||
grps, _ := h.accountManager.GetAllGroups(ctx, accountID, userID)
|
||||
for _, group := range grps {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
groupsInfo := groups.ToGroupsInfo(groupsMap, peerID)
|
||||
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
|
||||
groupsInfo := groups.ToGroupsInfo(grps, peerID)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
@ -128,12 +123,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
||||
return
|
||||
}
|
||||
|
||||
groupsMap := map[string]*types.Group{}
|
||||
for _, group := range peerGroups {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID)
|
||||
groupMinimumInfo := groups.ToGroupsInfo(peerGroups, peer.ID)
|
||||
|
||||
validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
@ -204,11 +194,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
dnsDomain := h.accountManager.GetDNSDomain()
|
||||
|
||||
groupsMap := map[string]*types.Group{}
|
||||
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
for _, group := range grps {
|
||||
groupsMap[group.ID] = group
|
||||
}
|
||||
|
||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||
for _, peer := range peers {
|
||||
@ -217,7 +203,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID)
|
||||
groupMinimumInfo := groups.ToGroupsInfo(grps, peer.ID)
|
||||
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||
}
|
||||
|
@ -101,7 +101,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
return nil, err
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.Groups, account.Peers, account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -335,6 +335,15 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if user.AccountID != accountID {
|
||||
return status.NewUserNotPartOfAccountError()
|
||||
}
|
||||
|
||||
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -1057,12 +1066,12 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(accountID, account.Groups, account.Peers, account.Settings.Extra)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1139,6 +1148,11 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
// UpdateAccountPeer updates a single peer that belongs to an account.
|
||||
// Should be called when changes need to be synced to a specific peer only.
|
||||
func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) {
|
||||
if !am.peersUpdateManager.HasChannel(peerId) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peer %s. failed to get account: %v", peerId, err)
|
||||
@ -1151,11 +1165,6 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
return
|
||||
}
|
||||
|
||||
if !am.peersUpdateManager.HasChannel(peerId) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peerId)
|
||||
return
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, account.Groups, account.Peers, account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send update to peer %s, failed to validate peers: %v", peerId, err)
|
||||
@ -1185,7 +1194,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
|
||||
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
|
||||
return 0, false
|
||||
return peerSchedulerRetryInterval, true
|
||||
}
|
||||
|
||||
if len(peersWithExpiry) == 0 {
|
||||
@ -1195,7 +1204,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||
return 0, false
|
||||
return peerSchedulerRetryInterval, true
|
||||
}
|
||||
|
||||
var nextExpiry *time.Duration
|
||||
@ -1229,7 +1238,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
|
||||
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
|
||||
return 0, false
|
||||
return peerSchedulerRetryInterval, true
|
||||
}
|
||||
|
||||
if len(peersWithInactivity) == 0 {
|
||||
@ -1239,7 +1248,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
|
||||
return 0, false
|
||||
return peerSchedulerRetryInterval, true
|
||||
}
|
||||
|
||||
var nextExpiry *time.Duration
|
||||
|
@ -332,7 +332,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, a
|
||||
|
||||
result = tx.Model(&nbpeer.Peer{}).Where(accountAndIDQueryCondition, accountID, peer.ID).Save(peerCopy)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
return status.Errorf(status.Internal, "failed to save peer to store: %v", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -358,7 +358,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
|
||||
Where(idQueryCondition, accountID).
|
||||
Updates(&accountCopy)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
return status.Errorf(status.Internal, "failed to update account domain attributes to store: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
@ -381,7 +381,7 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, lockStrength LockingStren
|
||||
Where(accountAndIDQueryCondition, accountID, peerID).
|
||||
Updates(&peerCopy)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
return status.Errorf(status.Internal, "failed to save peer status to store: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
@ -403,7 +403,7 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, lockStrength LockingStr
|
||||
Updates(peerCopy)
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
|
Loading…
Reference in New Issue
Block a user